├── .gitignore ├── README.md ├── baseline_risk_assessment ├── __init__.py ├── camera.py ├── convlstm.py ├── dpm_model.py ├── dpm_preprocessor.py └── dpm_trainer.py ├── requirements.txt ├── scripts ├── preprocess_dpm_images.py ├── train_dpm.py └── train_sg2vec.py └── sg_risk_assessment ├── __init__.py ├── image_scenegraph.py ├── metrics.py ├── mrgcn.py ├── relation_extractor.py ├── scene_graph.py └── sg2vec_trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | *wandb* 2 | *.0 3 | *.json 4 | *.wandb 5 | *.txt 6 | *.log 7 | *.pyc 8 | *.pkl 9 | *.tsv 10 | *.csv 11 | *.pt 12 | *.png 13 | *.jpg 14 | *.project 15 | *.pydevproject 16 | *.settings\* 17 | *.prefs 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spatio-Temporal Scene-Graph Embedding for Autonomous Vehicle Collision Prediction 2 | 3 | This repository includes the code and dataset information required for reproducing the results in our paper, *Spatio-Temporal Scene-Graph Embedding for Autonomous Vehicle Collision Prediction*. Furthermore, we also integrated the source code of [the baseline method](https://arxiv.org/abs/1711.10453) we compared against into this repo. The baseline approach infers the likelihood of a future collision using deep ConvLSTMs. Our approach incoporates both spatial modeling and temporal modeling in the task of collision prediction using an MR-GCN and an LSTM. 4 | 5 | For fabricating the lane-changing datasets, we use [CARLA](https://github.com/carla-simulator/carla) 0.9.8 which is an open-source autonomous car driving simulator. We also utilized the [scenario_runner](https://github.com/carla-simulator/scenario_runner) which was designed for the CARLA challenge event. For real-driving datasets, we used the Honda-Driving Dataset (HDD) in our experiments. 6 | 7 | The scene-graph dataset used in our paper is published [here](http://ieee-dataport.org/3618). \ 8 | The corresponding datasets for the baseline model can be found [here](https://drive.google.com/file/d/1YfU_DVdYNVNYhoiuqlYZRUWRbHLZCE7l/view?usp=sharing). 9 | 10 | The architecture of this repository is as below: 11 | - **sg-risk-assessment/**: this folder consists of all the related source files used for our scene-graph based approach (SG2VEC). 12 | - **baseline-risk-assessment/**: this folder consists of all the related source files used for the baseline method (DPM). 13 | - **train_sg2vec.py**: the script that triggers our scene-graph based approach. 14 | - **train_dpm.py**: the script that triggers the baseline model. 15 | 16 | # To Get Started 17 | We recommend our potential users to use [Anaconda](https://www.anaconda.com/) as the primary virtual environment. The requirements to run through our repo are as follows, 18 | - python >= 3.6 19 | - torch == 1.6.0 20 | - torch_geometric == 1.6.1 21 | 22 | First, download and install Anaconda here: 23 | https://www.anaconda.com/products/individual 24 | 25 | If you are using a GPU, install the corresponding CUDA toolkit for your hardware from Nvidia here: 26 | https://developer.nvidia.com/cuda-toolkit 27 | 28 | Next, create a conda virtual environment running Python 3.6: 29 | ```shell 30 | conda create --name av python=3.6 31 | ``` 32 | 33 | After setting up your environment. Activate it with the following command: 34 | 35 | ```shell 36 | conda activate av 37 | ``` 38 | 39 | Install PyTorch to your conda virtual environment by following the instructions here for your CUDA version: 40 | https://pytorch.org/get-started/locally/ 41 | 42 | In our experiments we used Torch 1.5 and 1.6 but later versions should also work fine. 43 | 44 | Next, install the PyTorch Geometric library by running the corresponding commands for your Torch and CUDA version: 45 | https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html 46 | 47 | Once this setup is completed, install the rest of the requirements from requirements.txt: 48 | 49 | ```shell 50 | pip install -r requirements.txt 51 | ``` 52 | --- 53 | 54 | # Usages 55 | For running the sg-collision-prediction in this repo, you may refer to the following commands: 56 | ```shell 57 | $ python train_sg2vec.py --cache_path risk-assessment/scenegraph/synthetic/271_dataset.pkl 58 | 59 | # --cache_path + [wherever path that stores the downloaded pkl] 60 | # For tuning hyperparameters view the config class of sg2vec_trainer.py 61 | ``` 62 | 63 | For running the baseline-risk-assessment in this repo, you may refer to the following commands: 64 | ```shell 65 | $ python train_dpm.py --cache_path risk-assessment/scene/synthetic/271_dataset.pkl 66 | 67 | # --cache_path + [wherever path that stores the downloaded pkl] 68 | # For tuning hyperparameters view the config class of dpm_trainer.py 69 | ``` 70 | 71 | After running these commands, the expected outputs are a dump of metrics logged by wandb: 72 | ```shell 73 | wandb: train_recall ▁████████████████████ 74 | wandb: val_precision █▁▅▄▅▄▆▆▆▅▄▄▇▆▅▆▅▇▆▆▆ 75 | wandb: val_recall ▁████████████████████ 76 | wandb: train_fpr ▁█▅▅▄▅▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂ 77 | wandb: train_tnr █▁▄▅▅▅▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 78 | wandb: train_fnr █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ 79 | wandb: val_fpr ▁█▄▅▄▅▃▃▃▄▄▅▂▃▃▃▄▂▃▃▃ 80 | wandb: val_tnr █▁▆▄▆▄▆▆▆▆▅▄▇▆▆▆▆▇▆▆▆ 81 | wandb: val_fnr █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ 82 | wandb: best_epoch ▁▁▂▂▂▂▃▃▄▄▄▄▅▅▅▅▅▇▇▇█ 83 | wandb: best_val_loss █▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ 84 | wandb: best_val_acc ▁▆█▇█████████████████ 85 | wandb: best_val_auc ▁▅▆▆▇▇▇▇████▇▇▇▇▇████ 86 | wandb: best_val_mcc ▁▇███████████████████ 87 | wandb: best_val_acc_balanced ▁████████████████████ 88 | wandb: train_mcc ▁▇▇▇▇▇███████████████ 89 | wandb: val_mcc ▁▇███████████████████ 90 | ``` 91 | 92 | A graphical visualization of the model outputs including loss and additional metrics can be viewed by creating and linking your runs to [wandb](https://wandb.ai/home). 93 | -------------------------------------------------------------------------------- /baseline_risk_assessment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AICPS/sg-collision-prediction/f7ff6d8a70d4d9414f80f5187171d6d77d65ae51/baseline_risk_assessment/__init__.py -------------------------------------------------------------------------------- /baseline_risk_assessment/camera.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | from .convlstm import ConvLSTM 5 | 6 | class Camera(nn.Module): 7 | ''' 8 | Image pipeline implementation of the model mentioned in this paper: Deep Predictive Models for Collision Risk Assessment in Autonomous Driving 9 | We also reference this implementation as well: https://gitlab.com/avces/avces_tensorflow/-/blob/master/src/dpm_keras_general.py 10 | ''' 11 | def __init__(self, input_shape, config): 12 | super(Camera, self).__init__() 13 | self.batch_size, self.frames, self.channels, self.height, self.width = input_shape # ex: (1, 5, 1, 64, 64) 14 | self.config = config 15 | 16 | # Hyper-parameters 17 | self.num_features = 8 18 | self.kernal_size = (5, 5) 19 | self.stride = (2, 2) 20 | self.num_layers = 1 21 | 22 | # Activation functions 23 | self.activation = nn.ReLU() if self.config.activation == 'relu' else nn.LeakyReLU(0.1) 24 | self.dropout = nn.Dropout(self.config.dropout) 25 | 26 | # Normalization functions 27 | self.bnorm3d_1 = nn.BatchNorm3d(num_features=self.frames) 28 | self.bnorm3d_2 = nn.BatchNorm3d(num_features=self.frames) 29 | 30 | # Layers 31 | self.convlstm1 = ConvLSTM(input_dim=self.channels, hidden_dim=self.num_features, kernel_size=self.kernal_size, num_layers=self.num_layers, batch_first=True) 32 | self.convlstm2 = ConvLSTM(input_dim=self.num_features, hidden_dim=self.num_features, kernel_size=self.kernal_size, stride=self.stride, num_layers=self.num_layers, batch_first=True) 33 | self.convlstm3 = ConvLSTM(input_dim=self.num_features, hidden_dim=self.num_features, kernel_size=self.kernal_size, stride=self.stride, num_layers=self.num_layers, batch_first=True) 34 | self.flatten = nn.Flatten(start_dim=1) 35 | 36 | def forward(self, x): 37 | 38 | l1, (l1_h, l1_c) = self.convlstm1(x) 39 | l2, (l2_h, l2_c) = self.convlstm2(self.bnorm3d_1(l1)) 40 | l3, (l3_h, l3_c) = self.convlstm3(self.bnorm3d_2(l2)) 41 | #l2, (l2_h, l2_c) = self.convlstm2(l1) 42 | #l3, (l3_h, l3_c) = self.convlstm3(l2) 43 | l4 = self.flatten(l3) 44 | return l4 45 | 46 | if __name__ == '__main__': 47 | from types import SimpleNamespace 48 | cfg = {'dropout': 0.1, 'device': 'cpu', 'activation': 'relu'} 49 | config = SimpleNamespace(**cfg) 50 | image = torch.rand((16, 5, 1, 64, 64)) 51 | model = Camera(image.shape, config) 52 | 53 | pred = model(image) 54 | print(pred.shape) 55 | -------------------------------------------------------------------------------- /baseline_risk_assessment/convlstm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class ConvLSTMCell(nn.Module): 6 | ''' 7 | https://github.com/ndrplz/ConvLSTM_pytorch/blob/master/convlstm.py 8 | ''' 9 | 10 | def __init__(self, input_dim, hidden_dim, kernel_size, stride, bias): 11 | """ 12 | Initialize ConvLSTM cell. 13 | Parameters 14 | ---------- 15 | input_dim: int 16 | Number of channels of input tensor. 17 | hidden_dim: int 18 | Number of channels of hidden state. 19 | kernel_size: (int, int) 20 | Size of the convolutional kernel. 21 | stride: (int, int) 22 | Stride between convolution filters. 23 | bias: bool 24 | Whether or not to add the bias. 25 | """ 26 | 27 | super(ConvLSTMCell, self).__init__() 28 | 29 | self.input_dim = input_dim 30 | self.hidden_dim = hidden_dim 31 | 32 | self.kernel_size = kernel_size 33 | self.stride = stride 34 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2 35 | self.bias = bias 36 | 37 | self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim, 38 | out_channels=4 * self.hidden_dim, 39 | kernel_size=self.kernel_size, 40 | padding=self.padding, 41 | bias=self.bias) 42 | 43 | 44 | self.conv_stride = nn.Conv2d(in_channels=self.input_dim, 45 | out_channels=self.input_dim, 46 | kernel_size=(1, 1), 47 | stride=self.stride, 48 | bias=self.bias) 49 | 50 | def forward(self, input_tensor, cur_state): 51 | h_cur, c_cur = cur_state 52 | 53 | conv_input = self.conv_stride(input_tensor) 54 | combined = torch.cat([conv_input, h_cur], dim=1) # concatenate along channel axis 55 | 56 | combined_conv = self.conv(combined) 57 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) 58 | 59 | i = torch.sigmoid(cc_i) 60 | f = torch.sigmoid(cc_f) 61 | o = torch.sigmoid(cc_o) 62 | g = torch.tanh(cc_g) 63 | 64 | c_next = f * c_cur + i * g 65 | h_next = o * torch.tanh(c_next) 66 | 67 | return h_next, c_next 68 | 69 | def init_hidden(self, batch_size, image_size): 70 | height, width = image_size 71 | return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device), 72 | torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device)) 73 | 74 | 75 | class ConvLSTM(nn.Module): 76 | 77 | """ 78 | Parameters: 79 | input_dim: Number of channels in input 80 | hidden_dim: Number of hidden channels 81 | kernel_size: Size of kernel in convolutions 82 | stride: Stride between convolution filters 83 | num_layers: Number of LSTM layers stacked on each other 84 | batch_first: Whether or not dimension 0 is the batch or not 85 | bias: Bias or no bias in Convolution 86 | return_all_layers: Return the list of computations for all layers 87 | Note: Will do same padding. 88 | Input: 89 | A tensor of size B, T, C, H, W or T, B, C, H, W 90 | Output: 91 | A tuple of two lists of length num_layers (or length 1 if return_all_layers is False). 92 | 0 - layer_output_list is the list of lists of length T of each output 93 | 1 - last_state_list is the list of last states 94 | each element of the list is a tuple (h, c) for hidden state and memory 95 | Example: 96 | >> x = torch.rand((32, 10, 64, 128, 128)) 97 | >> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False) 98 | >> _, last_states = convlstm(x) 99 | >> h = last_states[0][0] # 0 for layer index, 0 for h index 100 | """ 101 | 102 | def __init__(self, input_dim, hidden_dim, kernel_size, num_layers, stride=(1, 1), 103 | batch_first=False, bias=True, return_all_layers=False): 104 | super(ConvLSTM, self).__init__() 105 | 106 | self._check_kernel_size_consistency(kernel_size) 107 | 108 | # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers 109 | kernel_size = self._extend_for_multilayer(kernel_size, num_layers) 110 | stride = self._extend_for_multilayer(stride, num_layers) 111 | hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) 112 | if not len(kernel_size) == len(hidden_dim) == num_layers: 113 | raise ValueError('Inconsistent list length.') 114 | 115 | self.input_dim = input_dim 116 | self.hidden_dim = hidden_dim 117 | self.kernel_size = kernel_size 118 | self.stride = stride 119 | self.num_layers = num_layers 120 | self.batch_first = batch_first 121 | self.bias = bias 122 | self.return_all_layers = return_all_layers 123 | 124 | cell_list = [] 125 | for i in range(0, self.num_layers): 126 | cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1] 127 | 128 | cell_list.append(ConvLSTMCell(input_dim=cur_input_dim, 129 | hidden_dim=self.hidden_dim[i], 130 | kernel_size=self.kernel_size[i], 131 | stride=self.stride[i], 132 | bias=self.bias)) 133 | 134 | self.cell_list = nn.ModuleList(cell_list) 135 | 136 | def forward(self, input_tensor, hidden_state=None): 137 | """ 138 | Parameters 139 | ---------- 140 | input_tensor: todo 141 | 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w) 142 | hidden_state: todo 143 | None. todo implement stateful 144 | Returns 145 | ------- 146 | last_state_list, layer_output 147 | """ 148 | if not self.batch_first: 149 | # (t, b, c, h, w) -> (b, t, c, h, w) 150 | input_tensor = input_tensor.permute(1, 0, 2, 3, 4) 151 | 152 | b, _, _, h, w = input_tensor.size() 153 | h //= self.stride[0][0] 154 | w //= self.stride[0][1] 155 | 156 | # Implement stateful ConvLSTM 157 | if hidden_state is not None: 158 | raise NotImplementedError() 159 | else: 160 | # Since the init is done in forward. Can send image size here 161 | hidden_state = self._init_hidden(batch_size=b, 162 | image_size=(h, w)) 163 | 164 | layer_output_list = [] 165 | last_state_list = [] 166 | 167 | seq_len = input_tensor.size(1) 168 | cur_layer_input = input_tensor 169 | 170 | for layer_idx in range(self.num_layers): 171 | 172 | h, c = hidden_state[layer_idx] 173 | output_inner = [] 174 | for t in range(seq_len): 175 | h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :], 176 | cur_state=[h, c]) 177 | output_inner.append(h) 178 | 179 | layer_output = torch.stack(output_inner, dim=1) 180 | cur_layer_input = layer_output 181 | 182 | layer_output_list.append(layer_output) 183 | last_state_list.append([h, c]) 184 | 185 | if not self.return_all_layers: 186 | layer_output_list = layer_output_list[-1:] 187 | last_state_list = last_state_list[-1:] 188 | 189 | # Torch tensors (returns last layer for all frames) 190 | all_layers = lambda x: x[-1] 191 | last_hidden = lambda x: x[-1][0] 192 | last_cell = lambda x: x[-1][1] 193 | 194 | all_l = all_layers(layer_output_list) 195 | last_h = last_hidden(last_state_list) 196 | last_c = last_cell(last_state_list) 197 | 198 | return all_l, (last_h, last_c) 199 | 200 | def _init_hidden(self, batch_size, image_size): 201 | init_states = [] 202 | for i in range(self.num_layers): 203 | init_states.append(self.cell_list[i].init_hidden(batch_size, image_size)) 204 | return init_states 205 | 206 | @staticmethod 207 | def _check_kernel_size_consistency(kernel_size): 208 | if not (isinstance(kernel_size, tuple) or 209 | (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))): 210 | raise ValueError('`kernel_size` must be tuple or list of tuples') 211 | 212 | @staticmethod 213 | def _extend_for_multilayer(param, num_layers): 214 | if not isinstance(param, list): 215 | param = [param] * num_layers 216 | return param 217 | -------------------------------------------------------------------------------- /baseline_risk_assessment/dpm_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from .camera import Camera 4 | import pdb 5 | 6 | class DeepPredictiveModel(nn.Module): 7 | ''' 8 | This baseline model is an implementation of the model mentioned in this paper: Deep Predictive Models for Collision Risk Assessment in Autonomous Driving 9 | We also reference this implementation as well: https://gitlab.com/avces/avces_tensorflow/-/blob/master/src/dpm_keras_general.py 10 | ''' 11 | def __init__(self, input_shape, config): 12 | super(DeepPredictiveModel, self).__init__() 13 | 14 | self.cameras, self.batch_size, self.frames, self.channels, self.height, self.width = input_shape # ex: (1, 1, 5, 1, 64, 64) 15 | self.config = config 16 | 17 | # Activation functions 18 | self.activation = nn.ReLU() if self.config.activation == 'relu' else nn.LeakyReLU(0.1) 19 | self.softmax = nn.Softmax(dim=-1) 20 | self.dropout = nn.Dropout(self.config.dropout) 21 | 22 | # Image processing pipeline 23 | self.camera_models = [Camera(input_shape[1:], self.config).to(self.config.device) for i in range(self.cameras)] 24 | 25 | # TODO: add vehicle state information 26 | 27 | 28 | # TODO: Dynamically calculate in_features 29 | self.linear1 = nn.Linear(in_features=10240*self.cameras, out_features=64) 30 | self.linear2 = nn.Linear(in_features=64, out_features=2) 31 | 32 | def forward(self, x): 33 | 34 | # Image models 35 | all_cameras = torch.cat([self.camera_models[index](value) for index, value in enumerate(x)], dim=-1) #flatten 36 | 37 | # TODO: Vehicle state models 38 | l1 = self.dropout(self.activation(self.linear1(all_cameras))) 39 | l2 = self.softmax(self.linear2(l1)) 40 | #l2 = self.linear2(l1) 41 | return l2 42 | 43 | if __name__ == '__main__': 44 | from types import SimpleNamespace 45 | cfg = {'dropout': 0.1, 'device': 'cuda', 'activation': 'relu'} 46 | config = SimpleNamespace(**cfg) 47 | image = torch.rand((1, 32, 5, 1, 64, 64)) # (cameras, batch, time_steps, channels, height, width) 48 | 49 | # Send to GPU 50 | model = DeepPredictiveModel(image.shape, config).to(config.device) 51 | pred = model(image.to(config.device)) 52 | 53 | print(pred.shape) 54 | 55 | -------------------------------------------------------------------------------- /baseline_risk_assessment/dpm_preprocessor.py: -------------------------------------------------------------------------------- 1 | import os, sys, pdb 2 | import pickle as pkl 3 | import cv2 4 | from tqdm import tqdm 5 | from pathlib import Path, PurePath 6 | import multiprocessing 7 | import torch 8 | from collections import defaultdict 9 | import numpy as np 10 | 11 | 12 | #initializes globals for multiprocessing pool 13 | def initializer(imsettings, outputdir, rescaleshape): 14 | global image_settings 15 | image_settings = imsettings 16 | global output_dir 17 | output_dir = outputdir 18 | global rescale_shape 19 | rescale_shape = rescaleshape 20 | 21 | # TODO: Add honda support 22 | #preprocesses all raw_images in a directory tree 23 | def preprocess_directory(input_path=None, output_dir='dpm_images', image_settings=cv2.IMREAD_GRAYSCALE, rescale_shape=(64,64), num_processes=4): 24 | if(input_path is None): 25 | raise ValueError("please pass a valid input path.") 26 | all_video_clip_dirs = [x for x in input_path.iterdir() if x.is_dir()] 27 | all_video_clip_dirs = sorted(all_video_clip_dirs, key=lambda x: int(x.stem.split('_')[0])) 28 | pool = multiprocessing.Pool(num_processes, initializer, initargs=(image_settings, output_dir, rescale_shape)) 29 | pool.map(preprocess_sequence, all_video_clip_dirs) 30 | print("Image preprocessing completed.") 31 | 32 | 33 | #preprocesses a single sequence. 34 | def preprocess_sequence(path): 35 | print("processing " + str(path)) 36 | os.makedirs(str(path/output_dir), exist_ok=True) 37 | # read all frame numbers from raw_images. and store image_frames (list). 38 | raw_images = sorted(list(path.glob("raw_images/*.jpg")) + 39 | list(path.glob("raw_images/*.png")), key=lambda x: int(x.stem)) 40 | for raw_image_path in raw_images: 41 | frame = raw_image_path.stem 42 | image = cv2.imread(str(raw_image_path), image_settings) 43 | resized_image = cv2.resize(image, rescale_shape) 44 | cv2.imwrite(str(path/output_dir/frame)+".png", resized_image) 45 | return 1 46 | 47 | 48 | #this class preprocesses labeled image data into input sequences for the DPM model. 49 | class DPMPreprocessor(): 50 | 51 | def __init__(self, input_path, cache_path="dpm_data.pkl", subseq_len=5, rescale_shape=(64,64), convert2gray=True, image_output_dir="dpm_images", num_processes=4): 52 | self.input_path = input_path 53 | self.cache_path = cache_path 54 | self.subseq_len = subseq_len 55 | self.rescale_shape = rescale_shape 56 | self.image_output_dir = image_output_dir 57 | self.num_processes = num_processes 58 | if convert2gray: 59 | self.image_settings = cv2.IMREAD_GRAYSCALE 60 | else: 61 | self.image_settings = cv2.IMREAD_UNCHANGED 62 | 63 | 64 | #preprocesses raw images into rescaled and recolored format for DPM. 65 | def preprocess_images(self): 66 | preprocess_directory(self.input_path, self.image_output_dir, self.image_settings, self.rescale_shape, self.num_processes) 67 | 68 | # TODO: Add honda support 69 | #load raw image data from directory 70 | def process_dataset(self): 71 | all_video_clip_dirs = [x for x in self.input_path.iterdir() if x.is_dir()] 72 | all_video_clip_dirs = sorted(all_video_clip_dirs, key=lambda x: int(x.stem.split('_')[0])) 73 | new_sequences = [] 74 | for path in tqdm(all_video_clip_dirs): 75 | 76 | ignore_path = (path/"ignore.txt").resolve() 77 | if ignore_path.exists(): 78 | with open(str(path/"ignore.txt"), 'r') as label_f: 79 | ignore_label = int(label_f.read()) 80 | if ignore_label: continue; 81 | 82 | label_path = (path/"label.txt").resolve() 83 | 84 | if label_path.exists(): 85 | with open(str(path/"label.txt"), 'r') as label_f: 86 | risk_label = float(label_f.read().strip().split(",")[0]) 87 | 88 | risk_label = 1 if risk_label >= 0 else 0 #binarize float value. 89 | else: 90 | print("Label not found for path: " + str(path)) 91 | continue #skip paths that dont have labels 92 | 93 | subseqs, labels = self.process_sequence(path, risk_label) 94 | new_sequences.append((subseqs, labels, PurePath(path).name)) 95 | 96 | with open(self.cache_path, 'wb') as f: 97 | pkl.dump(new_sequences, f, fix_imports=False) 98 | 99 | 100 | #generates a list of subsequences of length subseq_len from a top level sequence. 101 | def process_sequence(self, seq_path, label): 102 | # read all frame numbers from raw_images. and store image_frames (list). 103 | images = sorted(list(seq_path.glob(self.image_output_dir+"/*.jpg")) + 104 | list(seq_path.glob(self.image_output_dir+"/*.png")), key=lambda x: int(x.stem)) 105 | ims = [] 106 | for image_path in images: 107 | ims.append(cv2.imread(str(image_path), cv2.IMREAD_UNCHANGED)) #read images from file 108 | 109 | dim1 = len(ims) - self.subseq_len + 1 110 | dim2 = self.subseq_len 111 | subseqs = np.zeros((dim1, dim2, self.rescale_shape[0], self.rescale_shape[1])) 112 | labels = np.full((dim1), label) 113 | ims = np.array(ims) 114 | 115 | #TODO optimize 116 | for i in range(dim1): 117 | subseqs[i,:] = ims[i:i+self.subseq_len] #fill array with subsequences of images 118 | 119 | return subseqs, labels 120 | -------------------------------------------------------------------------------- /baseline_risk_assessment/dpm_trainer.py: -------------------------------------------------------------------------------- 1 | import os, sys, pdb 2 | sys.path.append(os.path.dirname(sys.path[0])) 3 | from argparse import ArgumentParser 4 | from .dpm_model import DeepPredictiveModel 5 | from pathlib import Path 6 | import torch 7 | import torch.optim as optim 8 | from torch_geometric.data import DataListLoader 9 | from torch.utils.data import DataLoader, TensorDataset 10 | import wandb 11 | import pandas as pd 12 | import numpy as np 13 | import pickle as pkl 14 | from tqdm import tqdm 15 | from sklearn.utils import resample 16 | from sklearn.model_selection import train_test_split, StratifiedKFold 17 | from sklearn.utils.class_weight import compute_class_weight 18 | from sg_risk_assessment.metrics import * 19 | 20 | INPUT_SHAPE = (1, 1, 5, 1, 64, 64) # (num_camera, batch_size, frames, channels, height, width) 21 | 22 | #model configuration settings. specified on the command line 23 | class Config: 24 | def __init__(self, args): 25 | self.parser = ArgumentParser() 26 | self.parser.add_argument('--cache_path', type=str, default="../scripts/dpm_271_seqlen_5.pkl", help="Path to the cache file.") 27 | self.parser.add_argument('--transfer_path', type=str, default="", help="Path to the transfer file.") 28 | self.parser.add_argument('--model_load_path', type=str, default="./model/model_best_val_loss_.vec.pt", help="Path to load cached model file.") 29 | self.parser.add_argument('--model_save_path', type=str, default="./model/model_best_val_loss_.vec.pt", help="Path to save model file.") 30 | self.parser.add_argument('--n_folds', type=int, default=1, help='Number of folds for cross validation') 31 | self.parser.add_argument('--split_ratio', type=float, default=0.3, help="Ratio of dataset withheld for testing.") 32 | self.parser.add_argument('--downsample', type=lambda x: (str(x).lower() == 'true'), default=False, help='Set to true to downsample dataset.') 33 | self.parser.add_argument('--learning_rate', default=0.00005, type=float, help='The initial learning rate.') 34 | self.parser.add_argument('--seed', type=int, default=0, help='Random seed.') 35 | self.parser.add_argument('--epochs', type=int, default=200, help='Number of epochs to train.') 36 | self.parser.add_argument('--activation', type=str, default='relu', help='Activation function to use, options: [relu, leaky_relu].') 37 | self.parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay (L2 loss on parameters).') 38 | self.parser.add_argument('--dropout', type=float, default=0.1, help='Dropout rate (1 - keep probability).') 39 | self.parser.add_argument('--batch_size', type=int, default=8, help='Number of sequences in a batch.') 40 | self.parser.add_argument('--device', type=str, default="cuda", help='The device on which models are run, options: [cuda, cpu].') 41 | self.parser.add_argument('--test_step', type=int, default=5, help='Number of training epochs before testing the model.') 42 | 43 | parsed_args = self.parser.parse_args(args) 44 | wandb.init(project="av-dpm") 45 | wandb_config = wandb.config 46 | 47 | for arg_name in vars(parsed_args): 48 | self.__dict__[arg_name] = getattr(parsed_args, arg_name) 49 | wandb_config[arg_name] = getattr(parsed_args, arg_name) 50 | 51 | self.cache_path = Path(self.cache_path).resolve() 52 | 53 | if os.path.exists(self.transfer_path) and os.path.splitext(self.transfer_path)[-1] == '.pkl': 54 | self.transfer_path = Path(self.transfer_path).resolve() 55 | else: 56 | self.transfer_path = None 57 | print('Not using transfer learning') 58 | 59 | 60 | #This class trains and evaluates the DPM model 61 | class DPMTrainer: 62 | def __init__(self, args): 63 | self.config = Config(args) 64 | self.args = args 65 | np.random.seed(self.config.seed) 66 | torch.manual_seed(self.config.seed) 67 | self.best_val_loss = 99999 68 | self.best_epoch = 0 69 | self.best_val_acc = 0 70 | self.best_val_auc = 0 71 | self.best_val_confusion = [] 72 | self.best_val_f1 = 0 73 | self.best_val_mcc = -1.0 74 | self.best_val_acc_balanced = 0 75 | self.log = False 76 | 77 | if not self.config.cache_path.exists(): 78 | raise Exception("The cache file does not exist.") 79 | 80 | with open(self.config.cache_path, 'rb') as f: 81 | self.dataset = pkl.load(f) 82 | 83 | if self.config.transfer_path != None and self.config.transfer_path.exists(): 84 | with open(self.config.transfer_path, 'rb') as f: 85 | self.transfer = pkl.load(f) 86 | 87 | # Class balancer 88 | if self.config.downsample == True: 89 | self.dataset = self.balance_dataset(self.dataset) 90 | # Transfer balancer 91 | # if self.config.transfer_path != None: 92 | # self.transfer = self.balance_dataset(self.transfer) 93 | 94 | self.toGPU = lambda x, dtype: torch.as_tensor(x, dtype=dtype, device=self.config.device) 95 | self.split_dataset() 96 | self.build_model() 97 | 98 | 99 | # TODO: Ensure dataset has a diverse representation of risk and non risk lane changes 100 | # assumes label is the same for all frames in a scene 101 | def balance_dataset(self, dataset): 102 | # binary classes 103 | seq_label = lambda x: x[1][0] 104 | risk = [seq_label(sequence) for sequence in dataset if seq_label(sequence) == 1].count(1) 105 | non_risk = len(dataset) - risk 106 | min_number = min(risk, non_risk) 107 | risk = min_number 108 | non_risk = min_number 109 | 110 | balanced = [] 111 | for sequence in dataset: 112 | label = seq_label(sequence) 113 | if label == 1 and risk > 0: 114 | risk -= 1 115 | balanced.append(sequence) 116 | if label == 0 and non_risk > 0: 117 | non_risk -= 1 118 | balanced.append(sequence) 119 | if risk == 0 and non_risk == 0: 120 | break 121 | 122 | return balanced 123 | 124 | def get_fileid(self, fname): 125 | try: 126 | return int(fname) 127 | except: 128 | return int(fname.split('_')[0]) 129 | 130 | def split_dataset(self): 131 | training_data, testing_data = train_test_split(self.dataset, test_size=self.config.split_ratio, shuffle=True, random_state=self.config.seed, stratify=None) 132 | self.dataset = None #clearing to save memory 133 | # transfer learning 134 | if self.config.transfer_path != None: 135 | training_data = np.append(training_data, testing_data, axis=0) 136 | testing_data = self.transfer 137 | 138 | self.training_x, self.training_y, self.training_filenames = list(zip(*training_data)) 139 | del training_data 140 | 141 | self.testing_x, self.testing_y, self.testing_filenames = list(zip(*testing_data)) 142 | del testing_data 143 | 144 | if self.config.n_folds <= 1: 145 | print("Number of Training Sequences Included: ", len(self.training_x)) 146 | print("Number of Testing Sequences Included: ", len(self.testing_x)) 147 | 148 | self.training_filenames = np.concatenate([np.full(y.shape[0], self.get_fileid(fname)) for y,fname in zip(self.training_y, self.training_filenames)]) 149 | self.testing_filenames = np.concatenate([np.full(y.shape[0], self.get_fileid(fname)) for y,fname in zip(self.testing_y, self.testing_filenames)]) 150 | self.training_x = np.concatenate(self.training_x) 151 | self.testing_x = np.concatenate(self.testing_x) 152 | self.training_x = np.expand_dims(self.training_x, axis=-3) # color channels = 1 153 | self.testing_x = np.expand_dims(self.testing_x, axis=-3) # color channels = 1 154 | self.training_y = np.concatenate(self.training_y) 155 | self.testing_y = np.concatenate(self.testing_y) 156 | self.class_weights = torch.from_numpy(compute_class_weight('balanced', np.unique(self.training_y), self.training_y)) 157 | 158 | if self.config.n_folds <= 1: 159 | print("Num of Training Labels in Each Class: " + str(np.unique(self.training_y, return_counts=True)[1]) + ", Class Weights: " + str(self.class_weights)) 160 | print("Num of Testing Labels in Each Class: " + str(np.unique(self.testing_y, return_counts=True)[1]) + ", Class Weights: " + str(self.class_weights)) 161 | 162 | 163 | def build_model(self): 164 | self.model = DeepPredictiveModel(INPUT_SHAPE, self.config).to(self.config.device) 165 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.learning_rate, weight_decay=self.config.weight_decay) 166 | if self.class_weights.shape[0] < 2: 167 | self.loss_func = torch.nn.CrossEntropyLoss() 168 | else: 169 | self.loss_func = torch.nn.CrossEntropyLoss(weight=self.class_weights.float().to(self.config.device)) 170 | wandb.watch(self.model, log="all") 171 | 172 | 173 | # Pick between Standard Training and KFold Cross Validation Training 174 | def learn(self): 175 | if self.config.n_folds <= 1 or self.config.transfer_path != None: 176 | print('Running Standard Training Loop\n') 177 | self.train() 178 | else: 179 | print(torch.cuda.get_device_name(0)) 180 | print('Running {}-Fold Cross Validation Training Loop\n'.format(self.config.n_folds)) 181 | self.cross_valid() 182 | 183 | 184 | def cross_valid(self): 185 | 186 | # KFold cross validation with similar class distribution in each fold 187 | skf = StratifiedKFold(n_splits=self.config.n_folds) 188 | X = np.append(self.training_x, self.testing_x, axis=0) 189 | y = np.append(self.training_y, self.testing_y, axis=0) 190 | filenames = np.append(self.training_filenames, self.testing_filenames, axis=0) 191 | 192 | # self.results stores average metrics for the the n_folds 193 | self.results = {} 194 | self.fold = 1 195 | 196 | # Split training and testing data based on n_splits (Folds) 197 | for train_index, test_index in skf.split(X, y): 198 | self.training_x, self.testing_x, self.training_y, self.testing_y = None, None, None, None #clear vars to save memory 199 | X_train, X_test = X[train_index], X[test_index] 200 | y_train, y_test = y[train_index], y[test_index] 201 | training_filenames, testing_filenames = filenames[train_index], filenames[test_index] 202 | self.class_weights = torch.from_numpy(compute_class_weight('balanced', np.unique(y_train), y_train)) 203 | 204 | # Update dataset 205 | self.training_x = X_train 206 | self.testing_x = X_test 207 | self.training_y = y_train 208 | self.testing_y = y_test 209 | self.training_filenames = training_filenames 210 | self.testing_filenames = testing_filenames 211 | 212 | print('\nFold {}'.format(self.fold)) 213 | print("Number of Training Sequences Included: ", len(np.unique(training_filenames))) 214 | print("Number of Testing Sequences Included: ", len(np.unique(testing_filenames))) 215 | print("Num of Training Labels in Each Class: " + str(np.unique(self.training_y, return_counts=True)[1]) + ", Class Weights: " + str(self.class_weights)) 216 | print("Num of Testing Labels in Each Class: " + str(np.unique(self.testing_y, return_counts=True)[1]) + ", Class Weights: " + str(self.class_weights)) 217 | 218 | self.best_val_loss = 99999 219 | self.train() 220 | self.log = True 221 | outputs_train, labels_train, outputs_test, labels_test, metrics = self.evaluate(self.fold) 222 | self.update_cross_valid_metrics(outputs_train, labels_train, outputs_test, labels_test, metrics) 223 | self.log = False 224 | 225 | if self.fold != self.config.n_folds: 226 | del self.model 227 | del self.optimizer 228 | self.build_model() 229 | 230 | self.fold += 1 231 | del self.results 232 | 233 | 234 | def train(self): 235 | tqdm_bar = tqdm(range(self.config.epochs)) 236 | for epoch_idx in tqdm_bar: # iterate through epoch 237 | acc_loss_train = 0 238 | permutation = np.random.permutation(len(self.training_x)) # shuffle dataset before each epoch 239 | self.model.train() 240 | 241 | for i in range(0, len(self.training_x), self.config.batch_size): # iterate through batches of the dataset 242 | batch_index = i + self.config.batch_size if i + self.config.batch_size <= len(self.training_x) else len(self.training_x) 243 | indices = permutation[i:batch_index] 244 | batch_x, batch_y = self.training_x[indices], self.training_y[indices] 245 | batch_x, batch_y = self.toGPU(batch_x, torch.float32), self.toGPU(batch_y, torch.long) 246 | batch_x = torch.unsqueeze(batch_x, 0) #cameras = 1 247 | output = self.model.forward(batch_x).view(-1, 2) 248 | loss_train = self.loss_func(output, batch_y) 249 | loss_train.backward() 250 | acc_loss_train += loss_train.detach().cpu().item() * len(indices) 251 | self.optimizer.step() 252 | del loss_train 253 | 254 | acc_loss_train /= len(self.training_x) 255 | tqdm_bar.set_description('Epoch: {:04d}, loss_train: {:.4f}'.format(epoch_idx, acc_loss_train)) 256 | 257 | # no cross validation 258 | if epoch_idx % self.config.test_step == 0: 259 | self.evaluate(epoch_idx) 260 | 261 | 262 | def inference(self, testing_x, testing_y, testing_filenames): 263 | labels = torch.LongTensor().to(self.config.device) 264 | outputs = torch.FloatTensor().to(self.config.device) 265 | testing_filenames = torch.as_tensor(testing_filenames) 266 | acc_loss_test = 0 267 | sum_prediction_frame = 0 268 | sum_seq_len = 0 269 | num_risky_sequences = 0 270 | num_safe_sequences = 0 271 | sum_predicted_risky_indices = 0 #sum is calculated as (value * (index+1))/sum(range(seq_len)) for each value and index in the sequence. 272 | sum_predicted_safe_indices = 0 #sum is calculated as ((1-value) * (index+1))/sum(range(seq_len)) for each value and index in the sequence. 273 | inference_time = 0 274 | prof_result = "" 275 | batch_size = self.config.batch_size #NOTE: set to 1 when profiling or calculating inference time. 276 | correct_risky_seq = 0 277 | correct_safe_seq = 0 278 | incorrect_risky_seq = 0 279 | incorrect_safe_seq = 0 280 | 281 | with torch.autograd.profiler.profile(enabled=False, use_cuda=True) as prof: 282 | with torch.no_grad(): 283 | self.model.eval() 284 | 285 | for i in range(0, len(testing_x), batch_size): # iterate through subsequences 286 | batch_index = i + batch_size if i + batch_size <= len(testing_x) else len(testing_x) 287 | batch_x, batch_y = testing_x[i:batch_index], testing_y[i:batch_index] 288 | batch_x, batch_y = self.toGPU(batch_x, torch.float32), self.toGPU(batch_y, torch.long) 289 | batch_x = torch.unsqueeze(batch_x, 0) #cameras = 1 290 | #start = torch.cuda.Event(enable_timing=True) 291 | #end = torch.cuda.Event(enable_timing=True) 292 | #start.record() 293 | output = self.model.forward(batch_x).view(-1, 2) 294 | #end.record() 295 | #torch.cuda.synchronize() 296 | inference_time += 0#start.elapsed_time(end) 297 | loss_test = self.loss_func(output, batch_y) 298 | acc_loss_test += loss_test.detach().cpu().item() * len(batch_y) 299 | outputs = torch.cat([outputs, output], dim=0) 300 | labels = torch.cat([labels,batch_y], dim=0) 301 | 302 | #extract list of sequences and their associated predictions. calculate metrics over sequences. 303 | sequences = torch.unique(testing_filenames) 304 | for seq in sequences: 305 | indices = torch.where(testing_filenames == seq)[0] 306 | seq_outputs = outputs[indices] 307 | seq_labels = labels[indices] 308 | 309 | #log metrics for risky and non-risky clips separately. 310 | if(1 in seq_labels): 311 | preds = seq_outputs.max(1)[1].type_as(seq_labels) 312 | num_risky_sequences += 1 313 | sum_seq_len += seq_outputs.shape[0] 314 | if (1 in preds): 315 | correct_risky_seq += 1 #sequence level metrics 316 | sum_prediction_frame += torch.where(preds == 1)[0][0].item() #returns the first index of a "risky" prediction in this sequence. 317 | sum_predicted_risky_indices += torch.sum(torch.where(preds==1)[0]+1).item()/np.sum(range(seq_outputs.shape[0]+1)) 318 | else: 319 | incorrect_risky_seq += 1 320 | sum_prediction_frame += seq_outputs.shape[0] #if no risky predictions are made, then add the full sequence length to running avg. 321 | elif(0 in seq_labels): 322 | preds = seq_outputs.max(1)[1].type_as(seq_labels) 323 | num_safe_sequences += 1 324 | if(1 in preds): 325 | incorrect_safe_seq += 1 326 | else: 327 | correct_safe_seq += 1 328 | 329 | if (0 in preds): 330 | sum_predicted_safe_indices += torch.sum(torch.where(preds==0)[0]+1).item()/np.sum(range(seq_outputs.shape[0]+1)) 331 | 332 | avg_risky_prediction_frame = sum_prediction_frame / num_risky_sequences #avg of first indices in a sequence that a risky frame is first correctly predicted. 333 | avg_risky_seq_len = sum_seq_len / num_risky_sequences #sequence length for comparison with the prediction frame metric. 334 | avg_predicted_risky_indices = sum_predicted_risky_indices / num_risky_sequences 335 | avg_predicted_safe_indices = sum_predicted_safe_indices / num_safe_sequences 336 | seq_tpr = correct_risky_seq / num_risky_sequences 337 | seq_fpr = incorrect_safe_seq / num_safe_sequences 338 | seq_tnr = correct_safe_seq / num_safe_sequences 339 | seq_fnr = incorrect_risky_seq / num_risky_sequences 340 | if prof != None: 341 | prof_result = prof.key_averages().table(sort_by="cuda_time_total") 342 | 343 | return outputs, \ 344 | labels, \ 345 | acc_loss_test/len(testing_x), \ 346 | avg_risky_prediction_frame, \ 347 | avg_risky_seq_len, \ 348 | avg_predicted_risky_indices, \ 349 | avg_predicted_safe_indices, \ 350 | inference_time, \ 351 | prof_result, \ 352 | seq_tpr, \ 353 | seq_fpr, \ 354 | seq_tnr, \ 355 | seq_fnr 356 | 357 | def evaluate(self, current_epoch=None): 358 | metrics = {} 359 | outputs_train, \ 360 | labels_train, \ 361 | acc_loss_train, \ 362 | train_avg_prediction_frame, \ 363 | train_avg_seq_len, \ 364 | avg_predicted_risky_indices, \ 365 | avg_predicted_safe_indices, \ 366 | train_inference_time, \ 367 | train_profiler_result, \ 368 | seq_tpr, seq_fpr, seq_tnr, seq_fnr = self.inference(self.training_x, 369 | self.training_y, 370 | self.training_filenames) 371 | metrics['train'] = get_metrics(outputs_train, labels_train) 372 | metrics['train']['loss'] = acc_loss_train 373 | metrics['train']['avg_prediction_frame'] = train_avg_prediction_frame 374 | metrics['train']['avg_seq_len'] = train_avg_seq_len 375 | metrics['train']['avg_predicted_risky_indices'] = avg_predicted_risky_indices 376 | metrics['train']['avg_predicted_safe_indices'] = avg_predicted_safe_indices 377 | metrics['train']['seq_tpr'] = seq_tpr 378 | metrics['train']['seq_tnr'] = seq_tnr 379 | metrics['train']['seq_fpr'] = seq_fpr 380 | metrics['train']['seq_fnr'] = seq_fnr 381 | with open("dpm_profile_metrics.txt", mode='w') as f: 382 | f.write(train_profiler_result) 383 | 384 | outputs_test, \ 385 | labels_test, \ 386 | acc_loss_test, \ 387 | val_avg_prediction_frame, \ 388 | val_avg_seq_len, \ 389 | avg_predicted_risky_indices, \ 390 | avg_predicted_safe_indices, \ 391 | test_inference_time, \ 392 | test_profiler_result, \ 393 | seq_tpr, seq_fpr, seq_tnr, seq_fnr = self.inference(self.testing_x, 394 | self.testing_y, 395 | self.testing_filenames) 396 | metrics['test'] = get_metrics(outputs_test, labels_test) 397 | metrics['test']['loss'] = acc_loss_test 398 | metrics['test']['avg_prediction_frame'] = val_avg_prediction_frame 399 | metrics['test']['avg_seq_len'] = val_avg_seq_len 400 | metrics['test']['avg_predicted_risky_indices'] = avg_predicted_risky_indices 401 | metrics['test']['avg_predicted_safe_indices'] = avg_predicted_safe_indices 402 | metrics['test']['seq_tpr'] = seq_tpr 403 | metrics['test']['seq_tnr'] = seq_tnr 404 | metrics['test']['seq_fpr'] = seq_fpr 405 | metrics['test']['seq_fnr'] = seq_fnr 406 | metrics['avg_inf_time'] = (train_inference_time + test_inference_time) / ((len(self.training_y) + len(self.testing_y))*5) 407 | 408 | print("\ntrain loss: " + str(acc_loss_train) + ", acc:", metrics['train']['acc'], metrics['train']['confusion'], "mcc:", metrics['train']['mcc'], \ 409 | "\ntest loss: " + str(acc_loss_test) + ", acc:", metrics['test']['acc'], metrics['test']['confusion'], "mcc:", metrics['test']['mcc']) 410 | 411 | self.update_best_metrics(metrics, current_epoch) 412 | metrics['best_epoch'] = self.best_epoch 413 | metrics['best_val_loss'] = self.best_val_loss 414 | metrics['best_val_acc'] = self.best_val_acc 415 | metrics['best_val_auc'] = self.best_val_auc 416 | metrics['best_val_conf'] = self.best_val_confusion 417 | metrics['best_val_f1'] = self.best_val_f1 418 | metrics['best_val_mcc'] = self.best_val_mcc 419 | metrics['best_val_acc_balanced'] = self.best_val_acc_balanced 420 | metrics['best_avg_pred_frame'] = self.best_avg_pred_frame 421 | 422 | if self.config.n_folds <= 1 or self.log: 423 | log_wandb(metrics) 424 | 425 | return outputs_train, labels_train, outputs_test, labels_test, metrics 426 | 427 | 428 | #automatically save the model and metrics with the lowest validation loss 429 | def update_best_metrics(self, metrics, current_epoch): 430 | if metrics['test']['loss'] < self.best_val_loss: 431 | self.best_val_loss = metrics['test']['loss'] 432 | self.best_epoch = current_epoch if current_epoch != None else self.config.epochs 433 | self.best_val_acc = metrics['test']['acc'] 434 | self.best_val_auc = metrics['test']['auc'] 435 | self.best_val_confusion = metrics['test']['confusion'] 436 | self.best_val_f1 = metrics['test']['f1'] 437 | self.best_val_mcc = metrics['test']['mcc'] 438 | self.best_val_acc_balanced = metrics['test']['balanced_acc'] 439 | self.best_avg_pred_frame = metrics['test']['avg_prediction_frame'] 440 | #self.save_model() 441 | 442 | # Averages metrics after the end of each cross validation fold 443 | def update_cross_valid_metrics(self, outputs_train, labels_train, outputs_test, labels_test, metrics): 444 | if self.fold == 1: 445 | self.results['outputs_train'] = outputs_train 446 | self.results['labels_train'] = labels_train 447 | self.results['train'] = metrics['train'] 448 | self.results['train']['loss'] = metrics['train']['loss'] 449 | self.results['train']['avg_prediction_frame'] = metrics['train']['avg_prediction_frame'] 450 | self.results['train']['avg_seq_len'] = metrics['train']['avg_seq_len'] 451 | self.results['train']['avg_predicted_risky_indices'] = metrics['train']['avg_predicted_risky_indices'] 452 | self.results['train']['avg_predicted_safe_indices'] = metrics['train']['avg_predicted_safe_indices'] 453 | 454 | self.results['outputs_test'] = outputs_test 455 | self.results['labels_test'] = labels_test 456 | self.results['test'] = metrics['test'] 457 | self.results['test']['loss'] = metrics['test']['loss'] 458 | self.results['test']['avg_prediction_frame'] = metrics['test']['avg_prediction_frame'] 459 | self.results['test']['avg_seq_len'] = metrics['test']['avg_seq_len'] 460 | self.results['test']['avg_predicted_risky_indices'] = metrics['test']['avg_predicted_risky_indices'] 461 | self.results['test']['avg_predicted_safe_indices'] = metrics['test']['avg_predicted_safe_indices'] 462 | self.results['avg_inf_time'] = metrics['avg_inf_time'] 463 | 464 | self.results['best_epoch'] = metrics['best_epoch'] 465 | self.results['best_val_loss'] = metrics['best_val_loss'] 466 | self.results['best_val_acc'] = metrics['best_val_acc'] 467 | self.results['best_val_auc'] = metrics['best_val_auc'] 468 | self.results['best_val_conf'] = metrics['best_val_conf'] 469 | self.results['best_val_f1'] = metrics['best_val_f1'] 470 | self.results['best_val_mcc'] = metrics['best_val_mcc'] 471 | self.results['best_val_acc_balanced'] = metrics['best_val_acc_balanced'] 472 | self.results['best_avg_pred_frame'] = metrics['best_avg_pred_frame'] 473 | 474 | else: 475 | self.results['outputs_train'] = torch.cat((self.results['outputs_train'], outputs_train), dim=0) 476 | self.results['labels_train'] = torch.cat((self.results['labels_train'], labels_train), dim=0) 477 | self.results['train']['loss'] = np.append(self.results['train']['loss'], metrics['train']['loss']) 478 | self.results['train']['avg_prediction_frame'] = np.append(self.results['train']['avg_prediction_frame'], metrics['train']['avg_prediction_frame']) 479 | self.results['train']['avg_seq_len'] = np.append(self.results['train']['avg_seq_len'], metrics['train']['avg_seq_len']) 480 | self.results['train']['avg_predicted_risky_indices'] = np.append(self.results['train']['avg_predicted_risky_indices'], metrics['train']['avg_predicted_risky_indices']) 481 | self.results['train']['avg_predicted_safe_indices'] = np.append(self.results['train']['avg_predicted_safe_indices'], metrics['train']['avg_predicted_safe_indices']) 482 | 483 | self.results['outputs_test'] = torch.cat((self.results['outputs_test'], outputs_test), dim=0) 484 | self.results['labels_test'] = torch.cat((self.results['labels_test'], labels_test), dim=0) 485 | self.results['test']['loss'] = np.append(self.results['test']['loss'], metrics['test']['loss']) 486 | self.results['test']['avg_prediction_frame'] = np.append(self.results['test']['avg_prediction_frame'], metrics['test']['avg_prediction_frame']) 487 | self.results['test']['avg_seq_len'] = np.append(self.results['test']['avg_seq_len'], metrics['test']['avg_seq_len']) 488 | self.results['test']['avg_predicted_risky_indices'] = np.append(self.results['test']['avg_predicted_risky_indices'], metrics['test']['avg_predicted_risky_indices']) 489 | self.results['test']['avg_predicted_safe_indices'] = np.append(self.results['test']['avg_predicted_safe_indices'], metrics['test']['avg_predicted_safe_indices']) 490 | self.results['avg_inf_time'] = np.append(self.results['avg_inf_time'], metrics['avg_inf_time']) 491 | 492 | self.results['best_epoch'] = np.append(self.results['best_epoch'], metrics['best_epoch']) 493 | self.results['best_val_loss'] = np.append(self.results['best_val_loss'], metrics['best_val_loss']) 494 | self.results['best_val_acc'] = np.append(self.results['best_val_acc'], metrics['best_val_acc']) 495 | self.results['best_val_auc'] = np.append(self.results['best_val_auc'], metrics['best_val_auc']) 496 | self.results['best_val_conf'] = np.append(self.results['best_val_conf'], metrics['best_val_conf']) 497 | self.results['best_val_f1'] = np.append(self.results['best_val_f1'], metrics['best_val_f1']) 498 | self.results['best_val_mcc'] = np.append(self.results['best_val_mcc'], metrics['best_val_mcc']) 499 | self.results['best_val_acc_balanced'] = np.append(self.results['best_val_acc_balanced'], metrics['best_val_acc_balanced']) 500 | self.results['best_avg_pred_frame'] = np.append(self.results['best_avg_pred_frame'], metrics['best_avg_pred_frame']) 501 | 502 | # Log final averaged results 503 | if self.fold == self.config.n_folds: 504 | final_results = {} 505 | final_results['train'] = get_metrics(self.results['outputs_train'], self.results['labels_train']) 506 | final_results['train']['loss'] = np.average(self.results['train']['loss']) 507 | final_results['train']['avg_prediction_frame'] = np.average(self.results['train']['avg_prediction_frame']) 508 | final_results['train']['avg_seq_len'] = np.average(self.results['train']['avg_seq_len']) 509 | final_results['train']['avg_predicted_risky_indices'] = np.average(self.results['train']['avg_predicted_risky_indices']) 510 | final_results['train']['avg_predicted_safe_indices'] = np.average(self.results['train']['avg_predicted_safe_indices']) 511 | 512 | final_results['test'] = get_metrics(self.results['outputs_test'], self.results['labels_test']) 513 | final_results['test']['loss'] = np.average(self.results['test']['loss']) 514 | final_results['test']['avg_prediction_frame'] = np.average(self.results['test']['avg_prediction_frame']) 515 | final_results['test']['avg_seq_len'] = np.average(self.results['test']['avg_seq_len']) 516 | final_results['test']['avg_predicted_risky_indices'] = np.average(self.results['test']['avg_predicted_risky_indices']) 517 | final_results['test']['avg_predicted_safe_indices'] = np.average(self.results['test']['avg_predicted_safe_indices']) 518 | final_results['avg_inf_time'] = np.average(self.results['avg_inf_time']) 519 | 520 | # Best results 521 | final_results['best_epoch'] = np.average(self.results['best_epoch']) 522 | final_results['best_val_loss'] = np.average(self.results['best_val_loss']) 523 | final_results['best_val_acc'] = np.average(self.results['best_val_acc']) 524 | final_results['best_val_auc'] = np.average(self.results['best_val_auc']) 525 | final_results['best_val_conf'] = self.results['best_val_conf'] 526 | final_results['best_val_f1'] = np.average(self.results['best_val_f1']) 527 | final_results['best_val_mcc'] = np.average(self.results['best_val_mcc']) 528 | final_results['best_val_acc_balanced'] = np.average(self.results['best_val_acc_balanced']) 529 | final_results['best_avg_pred_frame'] = np.average(self.results['best_avg_pred_frame']) 530 | 531 | print('\nFinal Averaged Results') 532 | print("\naverage train loss: " + str(final_results['train']['loss']) + ", average acc:", final_results['train']['acc'], final_results['train']['confusion'], final_results['train']['auc'], \ 533 | "\naverage test loss: " + str(final_results['test']['loss']) + ", average acc:", final_results['test']['acc'], final_results['test']['confusion'], final_results['test']['auc']) 534 | 535 | log_wandb(final_results) 536 | 537 | return self.results['outputs_train'], self.results['labels_train'], self.results['outputs_test'], self.results['labels_test'], final_results 538 | 539 | 540 | #UNTESTED 541 | def save_model(self): 542 | """Function to save the model.""" 543 | saved_path = Path(self.config.model_save_path).resolve() 544 | os.makedirs(os.path.dirname(saved_path), exist_ok=True) 545 | torch.save(self.model.state_dict(), str(saved_path)) 546 | with open(os.path.dirname(saved_path) + "/model_parameters.txt", "w+") as f: 547 | f.write(str(self.config)) 548 | f.write('\n') 549 | f.write(str(' '.join(sys.argv))) 550 | 551 | #UNTESTED 552 | def load_model(self): 553 | """Function to load the model.""" 554 | saved_path = Path(self.config.model_load_path).resolve() 555 | if saved_path.exists(): 556 | self.build_model() 557 | self.model.load_state_dict(torch.load(str(saved_path))) 558 | self.model.eval() 559 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.7.1 2 | astor==0.7.1 3 | backcall==0.1.0 4 | cycler==0.10.0 5 | decorator==4.4.0 6 | gast==0.2.2 7 | google-pasta==0.1.7 8 | grpcio==1.20.1 9 | h5py==2.9.0 10 | ipython==7.8.0 11 | ipython-genutils==0.2.0 12 | jedi==0.15.1 13 | kiwisolver==1.1.0 14 | Markdown==3.1.1 15 | matplotlib==3.1.1 16 | networkx==2.4 17 | numpy==1.16.5 18 | opencv-python==4.1.1.26 19 | pandas==0.23.4 20 | parso==0.5.1 21 | pexpect==4.7.0 22 | pickleshare==0.7.5 23 | Pillow==8.1.1 24 | prompt-toolkit==2.0.10 25 | protobuf==3.7.1 26 | ptyprocess==0.6.0 27 | Pygments==2.4.2 28 | pyparsing==2.4.2 29 | python-dateutil==2.8.0 30 | pytz==2019.3 31 | PyYAML==5.4 32 | scikit-image==0.15.0 33 | scikit-learn==0.21.3 34 | scipy==1.1.0 35 | six==1.12.0 36 | termcolor==1.1.0 37 | tqdm==4.36.1 38 | pytorch-nlp==0.5.0 39 | torch-geometric==1.5.0 40 | traitlets==4.3.3 41 | wcwidth==0.1.7 42 | Werkzeug==0.16.0 43 | wrapt==1.11.2 44 | -------------------------------------------------------------------------------- /scripts/preprocess_dpm_images.py: -------------------------------------------------------------------------------- 1 | import os, sys, pdb 2 | sys.path.append(os.path.dirname(sys.path[0])) 3 | from baseline_risk_assessment.dpm_preprocessor import DPMPreprocessor 4 | from argparse import ArgumentParser 5 | from pathlib import Path 6 | 7 | #This script runs pre-processing of image data for use in the DPM pipeline 8 | def preprocess_dpm_data(args): 9 | parser = ArgumentParser() 10 | parser.add_argument("--input_path",type=str,default='M:/louisccc/av/synthesis_data/legacy_dataset/lane-change-100-balanced',help='directory containing the raw data sequences and labels.') 11 | parser.add_argument("--cache_path", type=str, default='dpm_data.pkl', help="path to save processed sequence data.") 12 | parser.add_argument("--subseq_len", type=int, default=5, help="length of output subsequences") 13 | parser.add_argument("--preprocess", help="use this option to preprocess images before subsequencing.") 14 | parser.add_argument("--rescale_shape", type=str, default="64,64", help="reshaped images will be this size. Format: x,y ") 15 | parser.add_argument("--image_output_dir", type=str, default="dpm_images_64x64", help="directory where processed images will be saved.") 16 | parser.add_argument("--grayscale", help="use this option to convert images to grayscale during processing.") 17 | parser.add_argument("--num_processes", type=int, default=4, help="number of processes to run in parallel") 18 | config = parser.parse_args(args) 19 | config.input_path = Path(config.input_path).resolve() 20 | config.cache_path = Path(config.cache_path).resolve() 21 | config.rescale_shape = (int(config.rescale_shape.split(',')[0]), int(config.rescale_shape.split(',')[1])) 22 | preprocessor = DPMPreprocessor(config.input_path, 23 | config.cache_path, 24 | config.subseq_len, 25 | config.rescale_shape, 26 | config.grayscale, 27 | config.image_output_dir, 28 | config.num_processes) 29 | if(config.preprocess): 30 | preprocessor.preprocess_images() 31 | preprocessor.process_dataset() 32 | 33 | if __name__ == "__main__": 34 | preprocess_dpm_data(sys.argv[1:]) -------------------------------------------------------------------------------- /scripts/train_dpm.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.append(os.path.dirname(sys.path[0])) 3 | from baseline_risk_assessment.dpm_trainer import DPMTrainer 4 | 5 | def train_dpm_model(args): 6 | trainer = DPMTrainer(args) 7 | trainer.learn() 8 | 9 | 10 | if __name__ == "__main__": 11 | train_dpm_model(sys.argv[1:]) 12 | -------------------------------------------------------------------------------- /scripts/train_sg2vec.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | sys.path.append(os.path.dirname(sys.path[0])) 3 | from sg_risk_assessment.sg2vec_trainer import SG2VECTrainer 4 | import pandas as pd 5 | 6 | 7 | def train_sg2vec_model(args, iterations=1): 8 | ''' Training the dynamic kg algorithm with different attention layer choice.''' 9 | 10 | outputs = [] 11 | labels = [] 12 | metrics = [] 13 | 14 | for i in range(iterations): 15 | trainer = SG2VECTrainer(args) 16 | trainer.split_dataset() 17 | trainer.build_model() 18 | trainer.learn() 19 | if trainer.config.n_folds <= 1: 20 | outputs_train, labels_train, outputs_test, labels_test, metric = trainer.evaluate() 21 | 22 | outputs += outputs_test 23 | labels += labels_test 24 | metrics.append(metric) 25 | 26 | if len(outputs) and len(labels) and len(metrics): 27 | # Store the prediction results. 28 | store_path = trainer.config.cache_path.parent 29 | outputs_pd = pd.DataFrame(outputs) 30 | labels_pd = pd.DataFrame(labels) 31 | 32 | labels_pd.to_csv(store_path / "dynkg_training_labels.tsv", sep='\t', header=False, index=False) 33 | outputs_pd.to_csv(store_path / "dynkg_training_outputs.tsv", sep="\t", header=False, index=False) 34 | 35 | # Store the metric results. 36 | metrics_pd = pd.DataFrame(metrics[-1]['test'], index=[0]) 37 | metrics_pd.to_csv(store_path / "dynkg_classification_metrics.csv", header=True) 38 | 39 | 40 | if __name__ == "__main__": 41 | # the entry of dynkg pipeline training 42 | train_sg2vec_model(sys.argv[1:]) 43 | -------------------------------------------------------------------------------- /sg_risk_assessment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AICPS/sg-collision-prediction/f7ff6d8a70d4d9414f80f5187171d6d77d65ae51/sg_risk_assessment/__init__.py -------------------------------------------------------------------------------- /sg_risk_assessment/image_scenegraph.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from sg_risk_assessment.relation_extractor import ActorType, Relations, RELATION_COLORS 3 | from networkx.drawing.nx_agraph import to_agraph 4 | import matplotlib.pyplot as plt 5 | import networkx as nx 6 | import numpy as np 7 | import sys 8 | import os 9 | import cv2 10 | import itertools 11 | import math 12 | import matplotlib 13 | matplotlib.use("Agg") 14 | sys.path.append(os.path.dirname(sys.path[0])) 15 | 16 | # SELECT ONE OF THE FOLLOWING: 17 | 18 | # #SETTINGS FOR 1280x720 CARLA IMAGES: 19 | # IMAGE_H = 720 20 | # IMAGE_W = 1280 21 | # CROPPED_H = 350 #height of ROI. crops to lane area of carla images 22 | # BIRDS_EYE_IMAGE_H = 850 23 | # BIRDS_EYE_IMAGE_W = 1280 24 | # Y_SCALE = 0.55 #18 pixels = length of lane line (10 feet) 25 | # X_SCALE = 0.54 #22 pixels = width of lane (12 feet) 26 | 27 | 28 | # SETTINGS FOR 1280x720 HONDA IMAGES: 29 | IMAGE_H = 720 30 | IMAGE_W = 1280 31 | CROPPED_H = 390 32 | BIRDS_EYE_IMAGE_H = 620 33 | BIRDS_EYE_IMAGE_W = 1280 34 | Y_SCALE = 0.45 # 22 pixels = length of lane line (10 feet) 35 | X_SCALE = 0.46 # 26 pixels = width of lane (12 feet) 36 | 37 | 38 | H_OFFSET = IMAGE_H - CROPPED_H # offset from top of image to start of ROI 39 | 40 | CAR_PROXIMITY_THRESH_NEAR_COLL = 4 41 | # max number of feet between a car and another entity to build proximity relation 42 | CAR_PROXIMITY_THRESH_SUPER_NEAR = 7 43 | CAR_PROXIMITY_THRESH_VERY_NEAR = 10 44 | CAR_PROXIMITY_THRESH_NEAR = 16 45 | CAR_PROXIMITY_THRESH_VISIBLE = 25 46 | 47 | LANE_THRESHOLD = 6 # feet. if object's center is more than this distance away from ego's center, build left or right lane relation 48 | # feet. if object's center is within this distance of ego's center, build middle lane relation 49 | CENTER_LANE_THRESHOLD = 9 50 | 51 | 52 | class ObjectNode: 53 | def __init__(self, name, attr, label): 54 | self.name = name # Car-1, Car-2. 55 | self.attr = attr # bounding box info 56 | self.label = label # ActorType 57 | 58 | def __repr__(self): 59 | return "%s" % (self.name) 60 | 61 | 62 | class RealSceneGraph: 63 | ''' 64 | scene graph the real images 65 | arguments: 66 | image_path : path to the image for which the scene graph is generated 67 | 68 | ''' 69 | 70 | def __init__(self, image_path, bounding_boxes, coco_class_names=None, platform='image'): 71 | self.g = nx.MultiDiGraph() # initialize scenegraph as networkx graph 72 | 73 | # road and lane settings. 74 | # we need to define the type of node. 75 | self.road_node = ObjectNode("Root Road", {}, ActorType.ROAD) 76 | self.add_node(self.road_node) # adding the road as the root node 77 | 78 | # specify which type of data to load into model (options: image or honda) 79 | self.platfrom = platform 80 | 81 | # set ego location to middle-bottom of image. 82 | self.ego_location = ((BIRDS_EYE_IMAGE_W/2) * 83 | X_SCALE, BIRDS_EYE_IMAGE_H * Y_SCALE) 84 | self.ego_node = ObjectNode("Ego Car", { 85 | "location_x": self.ego_location[0], "location_y": self.ego_location[1]}, ActorType.CAR) 86 | self.add_node(self.ego_node) 87 | self.extract_relative_lanes() # three lane formulation. 88 | 89 | # convert bounding boxes to nodes and build relations. 90 | boxes, labels, image_size = bounding_boxes 91 | self.get_nodes_from_bboxes(boxes, labels, coco_class_names) 92 | 93 | # import pdb; pdb.set_trace() 94 | self.extract_relations() 95 | 96 | def get_nodes_from_bboxes(self, boxes, labels, coco_class_names): 97 | # birds eye view projection 98 | M = get_birds_eye_matrix() 99 | # warped_img = get_birds_eye_warp(image_path, M) 100 | # cv2.imwrite( "./warped.jpg", cv2.cvtColor(warped_img, cv2.COLOR_BGR2RGB)) #plot warped image 101 | 102 | for idx, (box, label) in enumerate(zip(boxes, labels)): 103 | box = box.cpu().numpy().tolist() 104 | class_name = coco_class_names[label] 105 | 106 | if box[1] >= 620: 107 | continue 108 | 109 | if class_name in ['car', 'truck', 'bus']: 110 | actor_type = ActorType.CAR 111 | # elif class_name in ['person']: 112 | # actor_type = ActorType.PED 113 | # elif class_name in ['bicycle']: 114 | # actor_type = ActorType.BICYCLE 115 | # elif class_name in ['motorcycle']: 116 | # actor_type = ActorType.MOTO 117 | # elif class_name in ['traffic light']: 118 | # actor_type = ActorType.LIGHT 119 | # elif class_name in ['stop sign']: 120 | # actor_type = ActorType.SIGN 121 | else: 122 | continue 123 | 124 | attr = {'x1': box[0], 'y1': box[1], 'x2': box[2], 'y2': box[3]} 125 | 126 | # map center-bottom of bounding box to warped image 127 | x_mid = (box[2] + box[0]) / 2 128 | y_bottom = box[3] - H_OFFSET # offset to account for image crop 129 | pt = np.array([[[x_mid, y_bottom]]], dtype='float32') 130 | warp_pt = cv2.perspectiveTransform(pt, M)[0][0] 131 | 132 | #locations/distances in feet 133 | attr['location_x'] = warp_pt[0] * X_SCALE 134 | attr['location_y'] = warp_pt[1] * Y_SCALE 135 | attr['rel_location_x'] = attr['location_x'] - \ 136 | self.ego_node.attr["location_x"] # x position relative to ego 137 | attr['rel_location_y'] = attr['location_y'] - \ 138 | self.ego_node.attr["location_y"] # y position relative to ego 139 | attr['distance_abs'] = math.sqrt( 140 | attr['rel_location_x']**2 + attr['rel_location_y']**2) # absolute distance from ego 141 | node = ObjectNode("%s_%d" % (class_name, idx), attr, actor_type) 142 | self.add_node(node) 143 | self.add_mapping_to_relative_lanes(node) 144 | 145 | # extract relations between all nodes in the graph 146 | # does not build relations with the road node. 147 | # only builds relations between the ego node and other nodes. 148 | # only builds relations if other node is within the distance CAR_PROXIMITY_THRESH_VISIBLE from ego. 149 | 150 | def extract_relations(self): 151 | for node_a, node_b in itertools.combinations(self.g.nodes, 2): 152 | relation_list = [] 153 | if node_a.label == ActorType.ROAD or node_b.label == ActorType.ROAD: 154 | # dont build relations w/ road 155 | continue 156 | if node_a.label == ActorType.CAR and node_b.label == ActorType.CAR: 157 | if node_a.name.startswith("Ego") or node_b.name.startswith("Ego"): 158 | # print(node_a, node_b, self.get_euclidean_distance(node_a, node_b)) 159 | # import pdb; pdb.set_trace() 160 | if self.get_euclidean_distance(node_a, node_b) <= CAR_PROXIMITY_THRESH_VISIBLE: 161 | relation_list += self.extract_proximity_relations( 162 | node_a, node_b) 163 | relation_list += self.extract_directional_relations( 164 | node_a, node_b) 165 | relation_list += self.extract_proximity_relations( 166 | node_b, node_a) 167 | relation_list += self.extract_directional_relations( 168 | node_b, node_a) 169 | self.add_relations(relation_list) 170 | 171 | # returns proximity relations based on the absolute distance between two actors. 172 | 173 | def extract_proximity_relations(self, actor1, actor2): 174 | if self.get_euclidean_distance(actor1, actor2) <= CAR_PROXIMITY_THRESH_NEAR_COLL: 175 | return [[actor1, Relations.near_coll, actor2]] 176 | elif self.get_euclidean_distance(actor1, actor2) <= CAR_PROXIMITY_THRESH_SUPER_NEAR: 177 | return [[actor1, Relations.super_near, actor2]] 178 | elif self.get_euclidean_distance(actor1, actor2) <= CAR_PROXIMITY_THRESH_VERY_NEAR: 179 | return [[actor1, Relations.very_near, actor2]] 180 | elif self.get_euclidean_distance(actor1, actor2) <= CAR_PROXIMITY_THRESH_NEAR: 181 | return [[actor1, Relations.near, actor2]] 182 | elif self.get_euclidean_distance(actor1, actor2) <= CAR_PROXIMITY_THRESH_VISIBLE: 183 | return [[actor1, Relations.visible, actor2]] 184 | return [] 185 | 186 | # calculates absolute distance between two actors 187 | 188 | def get_euclidean_distance(self, actor1, actor2): 189 | l1 = (actor1.attr['location_x'], actor1.attr['location_y']) 190 | l2 = (actor2.attr['location_x'], actor2.attr['location_y']) 191 | return math.sqrt((l1[0] - l2[0])**2 + (l1[1] - l2[1])**2) 192 | 193 | # returns directional relations between entities based on their relative positions to one another in the scene. 194 | 195 | def extract_directional_relations(self, actor1, actor2): 196 | relation_list = [] 197 | x1, y1 = math.cos(math.radians(0)), math.sin(math.radians(0)) 198 | x2, y2 = actor2.attr['location_x'] - \ 199 | actor1.attr['location_x'], actor2.attr['location_y'] - \ 200 | actor1.attr['location_y'] 201 | x2, y2 = x2 / math.sqrt(x2**2+y2**2), y2 / math.sqrt(x2**2+y2**2) 202 | 203 | degree = math.degrees(math.atan2(y1, x1)) - \ 204 | math.degrees(math.atan2(y2, x2)) 205 | if degree < 0: 206 | degree += 360 207 | 208 | if degree <= 45: # actor2 is in front of actor1 209 | relation_list.append([actor1, Relations.atDRearOf, actor2]) 210 | elif degree >= 45 and degree <= 90: 211 | relation_list.append([actor1, Relations.atSRearOf, actor2]) 212 | elif degree >= 90 and degree <= 135: 213 | relation_list.append([actor1, Relations.inSFrontOf, actor2]) 214 | elif degree >= 135 and degree <= 180: # actor2 is behind actor1 215 | relation_list.append([actor1, Relations.inDFrontOf, actor2]) 216 | elif degree >= 180 and degree <= 225: # actor2 is behind actor1 217 | relation_list.append([actor1, Relations.inDFrontOf, actor2]) 218 | elif degree >= 225 and degree <= 270: 219 | relation_list.append([actor1, Relations.inSFrontOf, actor2]) 220 | elif degree >= 270 and degree <= 315: 221 | relation_list.append([actor1, Relations.atSRearOf, actor2]) 222 | elif degree >= 315 and degree <= 360: 223 | relation_list.append([actor1, Relations.atDRearOf, actor2]) 224 | 225 | if abs(actor2.attr['location_x'] - actor1.attr['location_x']) <= CENTER_LANE_THRESHOLD: 226 | pass 227 | # actor2 to the left of actor1 228 | elif actor2.attr['location_x'] < actor1.attr['location_x']: 229 | relation_list.append([actor2, Relations.toLeftOf, actor1]) 230 | # actor2 to the right of actor1 231 | elif actor2.attr['location_x'] > actor1.attr['location_x']: 232 | relation_list.append([actor2, Relations.toRightOf, actor1]) 233 | # disable rear relations help the inference. 234 | return relation_list 235 | 236 | # relative lane mapping method. Each vehicle is assigned to left, middle, or right lane depending on relative position to ego 237 | 238 | def extract_relative_lanes(self): 239 | self.left_lane = ObjectNode("Left Lane", {}, ActorType.LANE) 240 | self.right_lane = ObjectNode("Right Lane", {}, ActorType.LANE) 241 | self.middle_lane = ObjectNode("Middle Lane", {}, ActorType.LANE) 242 | self.add_node(self.left_lane) 243 | self.add_node(self.right_lane) 244 | self.add_node(self.middle_lane) 245 | self.add_relation([self.left_lane, Relations.isIn, self.road_node]) 246 | self.add_relation([self.right_lane, Relations.isIn, self.road_node]) 247 | self.add_relation([self.middle_lane, Relations.isIn, self.road_node]) 248 | self.add_relation([self.ego_node, Relations.isIn, self.middle_lane]) 249 | 250 | # builds isIn relation between object and lane depending on x-displacement relative to ego 251 | # left/middle and right/middle relations have an overlap area determined by the size of CENTER_LANE_THRESHOLD and LANE_THRESHOLD. 252 | # TODO: move to relation_extractor in replacement of current lane-vehicle relation code 253 | 254 | def add_mapping_to_relative_lanes(self, object_node): 255 | # don't build lane relations with static objects 256 | if object_node.label in [ActorType.LANE, ActorType.LIGHT, ActorType.SIGN, ActorType.ROAD]: 257 | return 258 | if object_node.attr['rel_location_x'] < -LANE_THRESHOLD: 259 | self.add_relation([object_node, Relations.isIn, self.left_lane]) 260 | elif object_node.attr['rel_location_x'] > LANE_THRESHOLD: 261 | self.add_relation([object_node, Relations.isIn, self.right_lane]) 262 | if abs(object_node.attr['rel_location_x']) <= CENTER_LANE_THRESHOLD: 263 | self.add_relation([object_node, Relations.isIn, self.middle_lane]) 264 | 265 | # add single node to graph. node can be any hashable datatype including objects. 266 | 267 | def add_node(self, node): 268 | color = "white" 269 | if "ego" in node.name.lower(): 270 | color = "red" 271 | elif "car" in node.name.lower(): 272 | color = "green" 273 | elif "lane" in node.name.lower(): 274 | color = "yellow" 275 | self.g.add_node(node, attr=node.attr, label=node.name, 276 | style='filled', fillcolor=color) 277 | 278 | # add relation (edge) between nodes on graph. relation is a list containing [subject, relation, object] 279 | 280 | def add_relation(self, relation): 281 | if relation != []: 282 | if relation[0] in self.g.nodes and relation[2] in self.g.nodes: 283 | self.g.add_edge(relation[0], relation[2], object=relation[1], 284 | label=relation[1].name, color=RELATION_COLORS[int(relation[1].value)]) 285 | else: 286 | raise NameError( 287 | "One or both nodes in relation do not exist in graph. Relation: " + str(relation)) 288 | 289 | def add_relations(self, relations_list): 290 | for relation in relations_list: 291 | self.add_relation(relation) 292 | 293 | def visualize(self, to_filename): 294 | A = to_agraph(self.g) 295 | A.layout('dot') 296 | A.draw(to_filename) 297 | 298 | 299 | # ROI: Region of Interest 300 | # returns transformation matrix for warping image to birds eye projection 301 | # birds eye matrix fixed for all images using the assumption that camera perspective does not change over time. 302 | def get_birds_eye_matrix(): #edit this 303 | # original dimensions (cropped to ROI) 304 | src = np.float32( 305 | [[0, CROPPED_H], [IMAGE_W, CROPPED_H], [0, 0], [IMAGE_W, 0]]) 306 | dst = np.float32([[int(BIRDS_EYE_IMAGE_W*16/33), BIRDS_EYE_IMAGE_H], [int(BIRDS_EYE_IMAGE_W * 307 | 17/33), BIRDS_EYE_IMAGE_H], [0, 0], [BIRDS_EYE_IMAGE_W, 0]]) # warped dimensions 308 | M = cv2.getPerspectiveTransform(src, dst) # The transformation matrix 309 | # Minv = cv2.getPerspectiveTransform(dst, src) # Inverse transformation (if needed) 310 | return M 311 | 312 | 313 | # returns image warped to birds eye projection using M 314 | # returned image is vertically cropped to the ROI (lane area) 315 | def get_birds_eye_warp(image_path, M): 316 | img = cv2.imread(image_path) 317 | img = img[H_OFFSET:IMAGE_H, 0:IMAGE_W] # Apply np slicing for ROI crop 318 | warped_img = cv2.warpPerspective( 319 | img, M, (BIRDS_EYE_IMAGE_W, BIRDS_EYE_IMAGE_H)) # Image warping 320 | warped_img = cv2.cvtColor(warped_img, cv2.COLOR_BGR2RGB) # set to RGB 321 | return warped_img 322 | -------------------------------------------------------------------------------- /sg_risk_assessment/metrics.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_score, recall_score, roc_auc_score, roc_curve, balanced_accuracy_score, matthews_corrcoef 2 | import torch 3 | from sklearn import preprocessing 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import pandas as pd 7 | import wandb 8 | 9 | #this file contains functions for scoring the prediction models. 10 | 11 | ''' 12 | #Expected Inputs: 13 | outputs: (n, 2) FloatTensor 14 | labels: (n,) LongTensor 15 | ''' 16 | def get_metrics(outputs, labels): 17 | labels_tensor = labels.cpu() 18 | outputs_tensor = outputs.cpu() 19 | preds = outputs_tensor.max(1)[1].type_as(labels_tensor).cpu() #binarized version of outputs_tensor. 20 | 21 | metrics = {} 22 | metrics['acc'] = accuracy_score(labels_tensor, preds) 23 | metrics['f1'] = f1_score(labels_tensor, preds, average="binary") 24 | conf = confusion_matrix(labels_tensor, preds) 25 | metrics['fpr'] = conf[0][1] / (conf[0][1] + conf[0][0]) #FPR = FP/(FP+TN) 26 | metrics['tnr'] = conf[0][0] / (conf[0][1] + conf[0][0]) #TNR = TN/(FP+TN) 27 | metrics['fnr'] = conf[1][0] / (conf[1][0] + conf[1][1]) #FNR = FN/(FN+TP) 28 | metrics['confusion'] = str(conf).replace('\n', ',') 29 | metrics['precision'] = precision_score(labels_tensor, preds, average="binary") 30 | metrics['recall'] = recall_score(labels_tensor, preds, average="binary") #recall and TPR are the same. TPR = TP/(TP+FN) 31 | metrics['auc'] = get_auc(outputs_tensor, labels_tensor) 32 | metrics['label_distribution'] = str(np.unique(labels_tensor, return_counts=True)[1]) 33 | metrics['balanced_acc'] = balanced_accuracy_score(labels_tensor, preds) 34 | metrics['mcc'] = matthews_corrcoef(labels_tensor, preds) 35 | 36 | return metrics 37 | 38 | #returns onehot version of labels. can specify n_classes to force onehot size. 39 | def encode_onehot(labels, n_classes=None): 40 | if(n_classes): 41 | classes = set(range(n_classes)) 42 | else: 43 | classes = set(labels) 44 | classes_dict = {c: np.identity(len(classes))[i, :] for i, c in 45 | enumerate(classes)} 46 | labels_onehot = np.array(list(map(classes_dict.get, labels)), 47 | dtype=np.int32) 48 | return labels_onehot 49 | 50 | #log data to to Weights & Biases 51 | def log_wandb(metrics): 52 | wandb.log({ 53 | "train_acc": metrics['train']['acc'], 54 | "val_acc": metrics['test']['acc'], 55 | "train_acc_balanced": metrics['train']['balanced_acc'], 56 | "val_acc_balanced": metrics['test']['balanced_acc'], 57 | "train_loss": metrics['train']['loss'], 58 | "val_loss": metrics['test']['loss'], 59 | 'train_auc': metrics['train']['auc'], 60 | 'train_f1': metrics['train']['f1'], 61 | 'val_auc': metrics['test']['auc'], 62 | 'val_f1': metrics['test']['f1'], 63 | 'train_precision': metrics['train']['precision'], 64 | 'train_recall': metrics['train']['recall'], 65 | 'val_precision': metrics['test']['precision'], 66 | 'val_recall': metrics['test']['recall'], 67 | 'train_conf': metrics['train']['confusion'], 68 | 'val_conf': metrics['test']['confusion'], 69 | 'train_fpr': metrics['train']['fpr'], 70 | 'train_tnr': metrics['train']['tnr'], 71 | 'train_fnr': metrics['train']['fnr'], 72 | 'val_fpr': metrics['test']['fpr'], 73 | 'val_tnr': metrics['test']['tnr'], 74 | 'val_fnr': metrics['test']['fnr'], 75 | 'train_avg_seq_len': metrics['train']['avg_seq_len'], 76 | 'train_avg_pred_frame': metrics['train']['avg_prediction_frame'], 77 | 'val_avg_seq_len': metrics['test']['avg_seq_len'], 78 | 'val_avg_pred_frame': metrics['test']['avg_prediction_frame'], 79 | 'train_avg_pred_risky_indices': metrics['train']['avg_predicted_risky_indices'], 80 | 'train_avg_pred_safe_indices': metrics['train']['avg_predicted_safe_indices'], 81 | 'val_avg_pred_risky_indices': metrics['test']['avg_predicted_risky_indices'], 82 | 'val_avg_pred_safe_indices': metrics['test']['avg_predicted_safe_indices'], 83 | 'best_epoch': metrics['best_epoch'], 84 | 'best_val_loss': metrics['best_val_loss'], 85 | 'best_val_acc': metrics['best_val_acc'], 86 | 'best_val_auc': metrics['best_val_auc'], 87 | 'best_val_conf': metrics['best_val_conf'], 88 | 'best_val_mcc': metrics['best_val_mcc'], 89 | 'best_val_acc_balanced': metrics['best_val_acc_balanced'], 90 | 'train_mcc': metrics['train']['mcc'], 91 | 'val_mcc': metrics['test']['mcc'], 92 | 'avg_inf_time': metrics['avg_inf_time'], 93 | 'best_avg_pred_frame': metrics['best_avg_pred_frame'], 94 | # 'test_seq_tpr': metrics['test']['seq_tpr'], 95 | # 'test_seq_tnr': metrics['test']['seq_tnr'], 96 | # 'test_seq_fpr': metrics['test']['seq_fpr'], 97 | # 'test_seq_fnr': metrics['test']['seq_fnr'], 98 | # 'train_seq_tpr': metrics['train']['seq_tpr'], 99 | # 'train_seq_tnr': metrics['train']['seq_tnr'], 100 | # 'train_seq_fpr': metrics['train']['seq_fpr'], 101 | # 'train_seq_fnr': metrics['train']['seq_fnr'] 102 | }) 103 | 104 | #~~~~~~~~~~Scoring Metrics~~~~~~~~~~ 105 | #note: these scoring metrics only work properly for binary classification use cases (graph classification, dyngraph classification) 106 | def get_auc(outputs, labels): 107 | try: 108 | labels = encode_onehot(labels.numpy().tolist(), 2) #binary labels 109 | auc = roc_auc_score(labels, outputs.numpy(), average="micro") 110 | except ValueError as err: 111 | print("error calculating AUC: ", err) 112 | auc = 0.0 113 | return auc 114 | 115 | #NOTE: ROC curve is only generated for positive class (risky label) confidence values 116 | #render parameter determines if the figure is actually generated. If false, it saves the values to a csv file. 117 | def get_roc_curve(outputs, labels, render=False): 118 | risk_scores = [] 119 | outputs = preprocessing.normalize(outputs.numpy(), axis=0) 120 | for i in outputs: 121 | risk_scores.append(i[1]) 122 | fpr, tpr, thresholds = roc_curve(labels.numpy(), risk_scores) 123 | roc = pd.DataFrame() 124 | roc['fpr'] = fpr 125 | roc['tpr'] = tpr 126 | roc['thresholds'] = thresholds 127 | roc.to_csv("ROC_data.csv") 128 | 129 | if(render): 130 | plt.figure(figsize=(8,8)) 131 | plt.xlim((0,1)) 132 | plt.ylim((0,1)) 133 | plt.ylabel("TPR") 134 | plt.xlabel("FPR") 135 | plt.title("Receiver Operating Characteristic") 136 | plt.plot([0,1],[0,1], linestyle='dashed') 137 | plt.plot(fpr,tpr, linewidth=2) 138 | plt.savefig("ROC_curve.svg") 139 | -------------------------------------------------------------------------------- /sg_risk_assessment/mrgcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchnlp.nn import Attention 5 | from torch.nn import Linear, LSTM 6 | from torch_geometric.nn import RGCNConv, TopKPooling, FastRGCNConv 7 | from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool 8 | 9 | from torch_geometric.nn.pool.topk_pool import topk, filter_adj 10 | from torch_geometric.utils import softmax 11 | import pdb 12 | 13 | 14 | class RGCNSAGPooling(torch.nn.Module): 15 | def __init__(self, in_channels, num_relations, ratio=0.5, min_score=None, 16 | multiplier=1, nonlinearity=torch.tanh, rgcn_func="FastRGCNConv", **kwargs): 17 | super(RGCNSAGPooling, self).__init__() 18 | 19 | self.in_channels = in_channels 20 | self.ratio = ratio 21 | self.gnn = FastRGCNConv(in_channels, 1, num_relations, **kwargs) if rgcn_func=="FastRGCNConv" else RGCNConv(in_channels, 1, num_relations, **kwargs) 22 | self.min_score = min_score 23 | self.multiplier = multiplier 24 | self.nonlinearity = nonlinearity 25 | 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | self.gnn.reset_parameters() 30 | 31 | 32 | def forward(self, x, edge_index, edge_attr=None, batch=None, attn=None): 33 | """""" 34 | if batch is None: 35 | batch = edge_index.new_zeros(x.size(0)) 36 | 37 | attn = x if attn is None else attn 38 | attn = attn.unsqueeze(-1) if attn.dim() == 1 else attn 39 | score = self.gnn(attn, edge_index, edge_attr).view(-1) 40 | 41 | if self.min_score is None: 42 | score = self.nonlinearity(score) 43 | else: 44 | score = softmax(score, batch) 45 | 46 | perm = topk(score, self.ratio, batch, self.min_score) 47 | x = x[perm] * score[perm].view(-1, 1) 48 | x = self.multiplier * x if self.multiplier != 1 else x 49 | 50 | batch = batch[perm] 51 | edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm, 52 | num_nodes=score.size(0)) 53 | 54 | return x, edge_index, edge_attr, batch, perm, score[perm] 55 | 56 | 57 | def __repr__(self): 58 | return '{}({}, {}, {}={}, multiplier={})'.format( 59 | self.__class__.__name__, self.gnn.__class__.__name__, 60 | self.in_channels, 61 | 'ratio' if self.min_score is None else 'min_score', 62 | self.ratio if self.min_score is None else self.min_score, 63 | self.multiplier) 64 | 65 | class MRGCN(nn.Module): 66 | 67 | def __init__(self, config): 68 | super(MRGCN, self).__init__() 69 | 70 | self.num_features = config.num_features 71 | self.num_relations = config.num_relations 72 | self.num_classes = config.nclass 73 | self.num_layers = config.num_layers #defines number of RGCN conv layers. 74 | self.hidden_dim = config.hidden_dim 75 | self.layer_spec = None if config.layer_spec == None else list(map(int, config.layer_spec.split(','))) 76 | if self.layer_spec != None and self.num_layers != len(self.layer_spec): 77 | raise ValueError("num_layers does not match the length of layer_spec") #we want this to break here because our data logging will not be accurate if config.num_layers != len(config.layer_spec) 78 | self.lstm_dim1 = config.lstm_input_dim 79 | self.lstm_dim2 = config.lstm_output_dim 80 | self.rgcn_func = FastRGCNConv if config.conv_type == "FastRGCNConv" else RGCNConv 81 | self.activation = F.relu if config.activation == 'relu' else F.leaky_relu 82 | self.pooling_type = config.pooling_type 83 | self.readout_type = config.readout_type 84 | self.temporal_type = config.temporal_type 85 | self.lstm_layers = config.lstm_layers #defines number of lstm layers 86 | self.dropout = config.dropout 87 | self.conv = [] 88 | total_dim = 0 89 | 90 | if self.layer_spec == None: 91 | if self.num_layers > 0: 92 | self.conv.append(self.rgcn_func(self.num_features, self.hidden_dim, self.num_relations).to(config.device)) 93 | total_dim += self.hidden_dim 94 | for i in range(1, self.num_layers): 95 | self.conv.append(self.rgcn_func(self.hidden_dim, self.hidden_dim, self.num_relations).to(config.device)) 96 | total_dim += self.hidden_dim 97 | else: 98 | self.fc0_5 = Linear(self.num_features, self.hidden_dim) 99 | total_dim += self.hidden_dim 100 | else: 101 | if self.num_layers > 0: 102 | print("using layer specification and ignoring hidden_dim parameter.") 103 | print("layer_spec: " + str(self.layer_spec)) 104 | self.conv.append(self.rgcn_func(self.num_features, self.layer_spec[0], self.num_relations).to(config.device)) 105 | total_dim += self.layer_spec[0] 106 | for i in range(1, self.num_layers): 107 | self.conv.append(self.rgcn_func(self.layer_spec[i-1], self.layer_spec[i], self.num_relations).to(config.device)) 108 | total_dim += self.layer_spec[i] 109 | 110 | else: 111 | self.fc0_5 = Linear(self.num_features, self.hidden_dim) 112 | total_dim += self.hidden_dim 113 | 114 | if self.pooling_type == "sagpool": 115 | self.pool1 = RGCNSAGPooling(total_dim, self.num_relations, ratio=config.pooling_ratio, rgcn_func=config.conv_type) 116 | elif self.pooling_type == "topk": 117 | self.pool1 = TopKPooling(total_dim, ratio=config.pooling_ratio) 118 | 119 | self.fc1 = Linear(total_dim, self.lstm_dim1) 120 | 121 | if "lstm" in self.temporal_type: 122 | self.lstm = LSTM(self.lstm_dim1, self.lstm_dim2, batch_first=True, num_layers=config.lstm_layers) 123 | self.attn = Attention(self.lstm_dim2) 124 | self.lstm_decoder = LSTM(self.lstm_dim2, self.lstm_dim2, batch_first=True) 125 | else: 126 | self.fc1_5 = Linear(self.lstm_dim1, self.lstm_dim2) 127 | 128 | self.fc2 = Linear(self.lstm_dim2, self.num_classes) 129 | 130 | 131 | def forward(self, x, edge_index, edge_attr, batch=None): 132 | attn_weights = dict() 133 | outputs = [] 134 | if self.num_layers > 0: 135 | for i in range(self.num_layers): 136 | x = self.activation(self.conv[i](x, edge_index, edge_attr)) 137 | x = F.dropout(x, self.dropout, training=self.training) 138 | outputs.append(x) 139 | x = torch.cat(outputs, dim=-1) 140 | else: 141 | x = self.activation(self.fc0_5(x)) 142 | 143 | if self.pooling_type == "sagpool": 144 | x, edge_index, _, attn_weights['batch'], _, _ = self.pool1(x, edge_index, edge_attr=edge_attr, batch=batch) 145 | elif self.pooling_type == "topk": 146 | x, edge_index, _, attn_weights['batch'], attn_weights['pool_perm'], attn_weights['pool_score'] = self.pool1(x, edge_index, edge_attr=edge_attr, batch=batch) 147 | else: 148 | attn_weights['batch'] = batch 149 | 150 | if self.readout_type == "add": 151 | x = global_add_pool(x, attn_weights['batch']) 152 | elif self.readout_type == "mean": 153 | x = global_mean_pool(x, attn_weights['batch']) 154 | elif self.readout_type == "max": 155 | x = global_max_pool(x, attn_weights['batch']) 156 | else: 157 | pass 158 | 159 | x = self.activation(self.fc1(x)) 160 | 161 | if self.temporal_type == "mean": 162 | x = self.activation(self.fc1_5(x.mean(axis=0))) 163 | elif self.temporal_type == "lstm_last": 164 | x_predicted, (h, c) = self.lstm(x.unsqueeze(0)) 165 | x = h.flatten() 166 | elif self.temporal_type == "lstm_sum": 167 | x_predicted, (h, c) = self.lstm(x.unsqueeze(0)) 168 | x = x_predicted.sum(dim=1).flatten() 169 | elif self.temporal_type == "lstm_attn": 170 | x_predicted, (h, c) = self.lstm(x.unsqueeze(0)) 171 | x, attn_weights['lstm_attn_weights'] = self.attn(h.view(1,1,-1), x_predicted) 172 | x, (h_decoder, c_decoder) = self.lstm_decoder(x, (h, c)) 173 | x = x.flatten() 174 | elif self.temporal_type == "lstm_seq": #used for step-by-step sequence prediction. 175 | x_predicted, (h, c) = self.lstm(x.unsqueeze(0)) #x_predicted is sequence of predictions for each frame, h is hidden state of last item, c is last cell state 176 | x = x_predicted.squeeze(0) #we return x_predicted as we want to know the output of the LSTM for each value in the sequence 177 | elif self.temporal_type == 'none': #this option uses no temporal modeling at all. 178 | x = self.activation(self.fc1_5(x)) 179 | else: 180 | pass 181 | 182 | return F.log_softmax(self.fc2(x), dim=-1), attn_weights 183 | 184 | 185 | 186 | 187 | 188 | #implementation of MRGCN using a GIN style readout. 189 | class MRGIN(nn.Module): 190 | def __init__(self, config): 191 | super(MRGIN, self).__init__() 192 | self.num_features = config.num_features 193 | self.num_relations = config.num_relations 194 | self.num_classes = config.nclass 195 | self.num_layers = config.num_layers #defines number of RGCN conv layers. 196 | self.hidden_dim = config.hidden_dim 197 | self.layer_spec = None if config.layer_spec == None else list(map(int, config.layer_spec.split(','))) 198 | self.lstm_dim1 = config.lstm_input_dim 199 | self.lstm_dim2 = config.lstm_output_dim 200 | self.rgcn_func = FastRGCNConv if config.conv_type == "FastRGCNConv" else RGCNConv 201 | self.activation = F.relu if config.activation == 'relu' else F.leaky_relu 202 | self.pooling_type = config.pooling_type 203 | self.readout_type = config.readout_type 204 | self.temporal_type = config.temporal_type 205 | self.dropout = config.dropout 206 | self.conv = [] 207 | self.pool = [] 208 | total_dim = 0 209 | 210 | if self.layer_spec == None: 211 | for i in range(self.num_layers): 212 | if i == 0: 213 | self.conv.append(self.rgcn_func(self.num_features, self.hidden_dim, self.num_relations).to(config.device)) 214 | else: 215 | self.conv.append(self.rgcn_func(self.hidden_dim, self.hidden_dim, self.num_relations).to(config.device)) 216 | if self.pooling_type == "sagpool": 217 | self.pool.append(RGCNSAGPooling(self.hidden_dim, self.num_relations, ratio=config.pooling_ratio, rgcn_func=config.conv_type).to(config.device)) 218 | elif self.pooling_type == "topk": 219 | self.pool.append(TopKPooling(self.hidden_dim, ratio=config.pooling_ratio).to(config.device)) 220 | total_dim += self.hidden_dim 221 | 222 | else: 223 | print("using layer specification and ignoring hidden_dim parameter.") 224 | print("layer_spec: " + str(self.layer_spec)) 225 | for i in range(self.num_layers): 226 | if i == 0: 227 | self.conv.append(self.rgcn_func(self.num_features, self.layer_spec[0], self.num_relations).to(config.device)) 228 | else: 229 | self.conv.append(self.rgcn_func(self.layer_spec[i-1], self.layer_spec[i], self.num_relations).to(config.device)) 230 | if self.pooling_type == "sagpool": 231 | self.pool.append(RGCNSAGPooling(self.layer_spec[i], self.num_relations, ratio=config.pooling_ratio, rgcn_func=config.conv_type).to(config.device)) 232 | elif self.pooling_type == "topk": 233 | self.pool.append(TopKPooling(self.layer_spec[i], ratio=config.pooling_ratio).to(config.device)) 234 | total_dim += self.layer_spec[i] 235 | 236 | self.fc1 = Linear(total_dim, self.lstm_dim1) 237 | 238 | if "lstm" in self.temporal_type: 239 | self.lstm = LSTM(self.lstm_dim1, self.lstm_dim2, batch_first=True) 240 | self.attn = Attention(self.lstm_dim2) 241 | 242 | self.fc2 = Linear(self.lstm_dim2, self.num_classes) 243 | 244 | 245 | 246 | def forward(self, x, edge_index, edge_attr, batch=None): 247 | attn_weights = dict() 248 | outputs = [] 249 | 250 | #readout performed after each layer and concatenated 251 | for i in range(self.num_layers): 252 | x = self.activation(self.conv[i](x, edge_index, edge_attr)) 253 | x = F.dropout(x, self.dropout, training=self.training) 254 | if self.pooling_type == "sagpool": 255 | p, _, _, batch2, attn_weights['pool_perm'], attn_weights['pool_score'] = self.pool[i](x, edge_index, edge_attr=edge_attr, batch=batch) 256 | elif self.pooling_type == "topk": 257 | p, _, _, batch2, attn_weights['pool_perm'], attn_weights['pool_score'] = self.pool[i](x, edge_index, edge_attr=edge_attr, batch=batch) 258 | else: 259 | p = x 260 | batch2 = batch 261 | if self.readout_type == "add": 262 | r = global_add_pool(p, batch2) 263 | elif self.readout_type == "mean": 264 | r = global_mean_pool(p, batch2) 265 | elif self.readout_type == "max": 266 | r = global_max_pool(p, batch2) 267 | else: 268 | r = p 269 | outputs.append(r) 270 | 271 | x = torch.cat(outputs, dim=-1) 272 | x = self.activation(self.fc1(x)) 273 | 274 | if self.temporal_type == "mean": 275 | x = self.activation(x.mean(axis=0)) 276 | elif self.temporal_type == "lstm_last": 277 | x_predicted, (h, c) = self.lstm(x.unsqueeze(0)) 278 | x = h.flatten() 279 | elif self.temporal_type == "lstm_sum": 280 | x_predicted, (h, c) = self.lstm(x.unsqueeze(0)) 281 | x = x_predicted.sum(dim=1).flatten() 282 | elif self.temporal_type == "lstm_attn": 283 | x_predicted, (h, c) = self.lstm(x.unsqueeze(0)) 284 | x, attn_weights['lstm_attn_weights'] = self.attn(h.view(1,1,-1), x_predicted) 285 | x = x.flatten() 286 | elif self.temporal_type == "lstm_seq": #used for step-by-step sequence prediction. 287 | x_predicted, (h, c) = self.lstm(x.unsqueeze(0)) #x_predicted is sequence of predictions for each frame, h is hidden state of last item, c is last cell state 288 | x = x_predicted.squeeze(0) #we return x_predicted as we want to know the output of the LSTM for each value in the sequence 289 | else: 290 | pass 291 | 292 | return F.log_softmax(self.fc2(x), dim=-1), attn_weights -------------------------------------------------------------------------------- /sg_risk_assessment/relation_extractor.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import math 3 | 4 | 5 | MOTO_NAMES = ["Harley-Davidson", "Kawasaki", "Yamaha"] 6 | BICYCLE_NAMES = ["Gazelle", "Diamondback", "Bh"] 7 | CAR_NAMES = ["Ford", "Bmw", "Toyota", "Nissan", "Mini", "Tesla", "Seat", "Lincoln", "Audi", "Carlamotors", "Citroen", "Mercedes-Benz", "Chevrolet", "Volkswagen", "Jeep", "Nissan", "Dodge", "Mustang"] 8 | 9 | CAR_PROXIMITY_THRESH_NEAR_COLL = 4 10 | CAR_PROXIMITY_THRESH_SUPER_NEAR = 7 # max number of feet between a car and another entity to build proximity relation 11 | CAR_PROXIMITY_THRESH_VERY_NEAR = 10 12 | CAR_PROXIMITY_THRESH_NEAR = 16 13 | CAR_PROXIMITY_THRESH_VISIBLE = 25 14 | MOTO_PROXIMITY_THRESH = 50 15 | BICYCLE_PROXIMITY_THRESH = 50 16 | PED_PROXIMITY_THRESH = 50 17 | 18 | #defines all types of actors which can exist 19 | #order of enum values is important as this determines which function is called. DO NOT CHANGE ENUM ORDER 20 | class ActorType(Enum): 21 | CAR = 0 #26, 142, 137:truck 22 | MOTO = 1 #80 23 | BICYCLE = 2 #11 24 | PED = 3 #90, 91, 98: "player", 78:man, 79:men, 149:woman, 56: guy, 53: girl 25 | LANE = 4 #124:street, 114:sidewalk 26 | LIGHT = 5 # 99: "pole", 76: light 27 | SIGN = 6 28 | ROAD = 7 29 | 30 | ACTOR_NAMES=['car','moto','bicycle','ped','lane','light','sign', 'road'] 31 | 32 | class Relations(Enum): 33 | isIn = 0 34 | near_coll = 1 35 | super_near = 2 36 | very_near = 3 37 | near = 4 38 | visible = 5 39 | inDFrontOf = 6 40 | inSFrontOf = 7 41 | atDRearOf = 8 42 | atSRearOf = 9 43 | toLeftOf = 10 44 | toRightOf = 11 45 | 46 | RELATION_COLORS = ["black", "red", "orange", "yellow", "green", "purple", "blue", 47 | "sienna", "pink", "pink", "pink", "turquoise", "turquoise", "turquoise", "violet", "violet"] 48 | 49 | #This class extracts relations for every pair of entities in a scene 50 | class RelationExtractor: 51 | def __init__(self, ego_node): 52 | self.ego_node = ego_node 53 | 54 | def get_actor_type(self, actor): 55 | if "curr" in actor.attr.keys(): 56 | return ActorType.LANE 57 | if actor.attr["name"] == "Traffic Light": 58 | return ActorType.LIGHT 59 | if actor.attr["name"].split(" ")[0] == "Pedestrian": 60 | return ActorType.PED 61 | if actor.attr["name"].split(" ")[0] in CAR_NAMES: 62 | return ActorType.CAR 63 | if actor.attr["name"].split(" ")[0] in MOTO_NAMES: 64 | return ActorType.MOTO 65 | if actor.attr["name"].split(" ")[0] in BICYCLE_NAMES: 66 | return ActorType.BICYCLE 67 | if "Sign" in actor.attr["name"]: 68 | return ActorType.SIGN 69 | 70 | # import pdb; pdb.set_trace() 71 | raise NameError("Actor name not found for actor with name: " + actor.attr["name"]) 72 | 73 | #takes in two entities and extracts all relations between those two entities. extracted relations are bidirectional 74 | def extract_relations(self, actor1, actor2): 75 | #import pdb; pdb.set_trace() 76 | type1 = self.get_actor_type(actor1) 77 | type2 = self.get_actor_type(actor2) 78 | 79 | low_type = min(type1.value, type2.value) #the lower of the two enums. 80 | high_type = max(type1.value, type2.value) 81 | 82 | function_call = "self.extract_relations_"+ACTOR_NAMES[low_type]+"_"+ACTOR_NAMES[high_type]+"(actor1, actor2) if type1.value <= type2.value "\ 83 | "else self.extract_relations_"+ACTOR_NAMES[low_type]+"_"+ACTOR_NAMES[high_type]+"(actor2, actor1)" 84 | return eval(function_call) 85 | 86 | 87 | #~~~~~~~~~specific relations for each pair of actors possible~~~~~~~~~~~~ 88 | #actor 1 corresponds to the first actor in the function name and actor2 the second 89 | 90 | def extract_relations_car_car(self, actor1, actor2): 91 | relation_list = [] 92 | # consider the proximity relations with neighboring lanes. 93 | if actor1.name.startswith("ego:") or actor2.name.startswith("ego:"): 94 | if self.euclidean_distance(actor1, actor2) <= CAR_PROXIMITY_THRESH_NEAR: 95 | relation_list += self.create_proximity_relations(actor1, actor2) 96 | relation_list += self.create_proximity_relations(actor2, actor1) 97 | relation_list += self.extract_directional_relation(actor1, actor2) 98 | relation_list += self.extract_directional_relation(actor2, actor1) 99 | return relation_list 100 | 101 | def extract_relations_car_lane(self, actor1, actor2): 102 | relation_list = [] 103 | # if(self.in_lane(actor1,actor2)): 104 | # relation_list.append([actor1, Relations.isIn, actor2]) 105 | 106 | return relation_list 107 | 108 | def extract_relations_car_light(self, actor1, actor2): 109 | relation_list = [] 110 | return relation_list 111 | 112 | def extract_relations_car_sign(self, actor1, actor2): 113 | relation_list = [] 114 | return relation_list 115 | 116 | def extract_relations_car_ped(self, actor1, actor2): 117 | relation_list = [] 118 | return relation_list 119 | 120 | def extract_relations_car_bicycle(self, actor1, actor2): 121 | relation_list = [] 122 | return relation_list 123 | 124 | def extract_relations_car_moto(self, actor1, actor2): 125 | relation_list = [] 126 | return relation_list 127 | 128 | 129 | def extract_relations_moto_moto(self, actor1, actor2): 130 | relation_list = [] 131 | return relation_list 132 | 133 | def extract_relations_moto_bicycle(self, actor1, actor2): 134 | relation_list = [] 135 | return relation_list 136 | 137 | def extract_relations_moto_ped(self, actor1, actor2): 138 | relation_list = [] 139 | return relation_list 140 | 141 | def extract_relations_moto_lane(self, actor1, actor2): 142 | relation_list = [] 143 | # if(self.in_lane(actor1,actor2)): 144 | # relation_list.append([actor1, Relations.isIn, actor2]) 145 | # # relation_list.append([actor2, Relations.isIn, actor1]) 146 | return relation_list 147 | 148 | def extract_relations_moto_light(self, actor1, actor2): 149 | relation_list = [] 150 | return relation_list 151 | 152 | def extract_relations_moto_sign(self, actor1, actor2): 153 | relation_list = [] 154 | return relation_list 155 | 156 | 157 | def extract_relations_bicycle_bicycle(self, actor1, actor2): 158 | relation_list = [] 159 | # if(self.euclidean_distance(actor1, actor2) < BICYCLE_PROXIMITY_THRESH): 160 | # relation_list.append([actor1, Relations.near, actor2]) 161 | # relation_list.append([actor2, Relations.near, actor1]) 162 | # #relation_list.append(self.extract_directional_relation(actor1, actor2)) 163 | # #relation_list.append(self.extract_directional_relation(actor2, actor1)) 164 | return relation_list 165 | 166 | def extract_relations_bicycle_ped(self, actor1, actor2): 167 | relation_list = [] 168 | # if(self.euclidean_distance(actor1, actor2) < BICYCLE_PROXIMITY_THRESH): 169 | # relation_list.append([actor1, Relations.near, actor2]) 170 | # relation_list.append([actor2, Relations.near, actor1]) 171 | # #relation_list.append(self.extract_directional_relation(actor1, actor2)) 172 | # #relation_list.append(self.extract_directional_relation(actor2, actor1)) 173 | return relation_list 174 | 175 | def extract_relations_bicycle_lane(self, actor1, actor2): 176 | relation_list = [] 177 | # if(self.in_lane(actor1,actor2)): 178 | # relation_list.append([actor1, Relations.isIn, actor2]) 179 | return relation_list 180 | 181 | def extract_relations_bicycle_light(self, actor1, actor2): 182 | relation_list = [] 183 | #relation_list.append(self.extract_directional_relation(actor1, actor2)) 184 | #relation_list.append(self.extract_directional_relation(actor2, actor1)) 185 | return relation_list 186 | 187 | def extract_relations_bicycle_sign(self, actor1, actor2): 188 | relation_list = [] 189 | #relation_list.append(self.extract_directional_relation(actor1, actor2)) 190 | #relation_list.append(self.extract_directional_relation(actor2, actor1)) 191 | return relation_list 192 | 193 | def extract_relations_ped_ped(self, actor1, actor2): 194 | relation_list = [] 195 | if(self.euclidean_distance(actor1, actor2) < PED_PROXIMITY_THRESH): 196 | relation_list.append([actor1, Relations.near, actor2]) 197 | relation_list.append([actor2, Relations.near, actor1]) 198 | #relation_list.append(self.extract_directional_relation(actor1, actor2)) 199 | #relation_list.append(self.extract_directional_relation(actor2, actor1)) 200 | return relation_list 201 | 202 | def extract_relations_ped_lane(self, actor1, actor2): 203 | relation_list = [] 204 | # if(self.in_lane(actor1,actor2)): 205 | # relation_list.append([actor1, Relations.isIn, actor2]) 206 | return relation_list 207 | 208 | def extract_relations_ped_light(self, actor1, actor2): 209 | relation_list = [] 210 | #proximity relation could indicate ped waiting for crosswalk at a light 211 | # if(self.euclidean_distance(actor1, actor2) < PED_PROXIMITY_THRESH): 212 | # relation_list.append([actor1, Relations.near, actor2]) 213 | # relation_list.append([actor2, Relations.near, actor1]) 214 | #relation_list.append(self.extract_directional_relation(actor1, actor2)) 215 | #relation_list.append(self.extract_directional_relation(actor2, actor1)) 216 | return relation_list 217 | 218 | def extract_relations_ped_sign(self, actor1, actor2): 219 | relation_list = [] 220 | # relation_list.append(self.extract_directional_relation(actor1, actor2)) 221 | # relation_list.append(self.extract_directional_relation(actor2, actor1)) 222 | return relation_list 223 | 224 | def extract_relations_lane_lane(self, actor1, actor2): 225 | relation_list = [] 226 | return relation_list 227 | 228 | def extract_relations_lane_light(self, actor1, actor2): 229 | relation_list = [] 230 | return relation_list 231 | 232 | def extract_relations_lane_sign(self, actor1, actor2): 233 | relation_list = [] 234 | return relation_list 235 | 236 | def extract_relations_light_light(self, actor1, actor2): 237 | relation_list = [] 238 | return relation_list 239 | 240 | def extract_relations_light_sign(self, actor1, actor2): 241 | relation_list = [] 242 | return relation_list 243 | 244 | def extract_relations_sign_sign(self, actor1, actor2): 245 | relation_list = [] 246 | return relation_list 247 | 248 | 249 | #~~~~~~~~~~~~~~~~~~UTILITY FUNCTIONS~~~~~~~~~~~~~~~~~~~~~~ 250 | #return euclidean distance between actors 251 | def euclidean_distance(self, actor1, actor2): 252 | #import pdb; pdb.set_trace() 253 | l1 = actor1.attr['location'] 254 | l2 = actor2.attr['location'] 255 | return math.sqrt((l1[0] - l2[0])**2 + (l1[1]- l2[1])**2 + (l1[2] - l2[2])**2) 256 | 257 | #check if an actor is in a certain lane 258 | def in_lane(self, actor1, actor2): 259 | if 'lane_idx' in actor1.attr.keys(): 260 | # calculate the distance bewteen actor1 and actor2 261 | # if it is below 3.5 then they have is in relation. 262 | # if actor1 is ego: if actor2 is not equal to the ego_lane's index then it's invading relation. 263 | if actor1.attr['lane_idx'] == actor2.attr['lane_idx']: 264 | return True 265 | if "invading_lane" in actor1.attr: 266 | if actor1.attr['invading_lane'] == actor2.attr['lane_idx']: 267 | return True 268 | if "orig_lane_idx" in actor1.attr: 269 | if actor1.attr['orig_lane_idx'] == actor2.attr['lane_idx']: 270 | return True 271 | else: 272 | return False 273 | 274 | def create_proximity_relations(self, actor1, actor2): 275 | if self.euclidean_distance(actor1, actor2) <= CAR_PROXIMITY_THRESH_NEAR_COLL: 276 | return [[actor1, Relations.near_coll, actor2]] 277 | elif self.euclidean_distance(actor1, actor2) <= CAR_PROXIMITY_THRESH_SUPER_NEAR: 278 | return [[actor1, Relations.super_near, actor2]] 279 | elif self.euclidean_distance(actor1, actor2) <= CAR_PROXIMITY_THRESH_VERY_NEAR: 280 | return [[actor1, Relations.very_near, actor2]] 281 | elif self.euclidean_distance(actor1, actor2) <= CAR_PROXIMITY_THRESH_NEAR: 282 | return [[actor1, Relations.near, actor2]] 283 | elif self.euclidean_distance(actor1, actor2) <= CAR_PROXIMITY_THRESH_VISIBLE: 284 | return [[actor1, Relations.visible, actor2]] 285 | return [] 286 | 287 | def extract_directional_relation(self, actor1, actor2): 288 | relation_list = [] 289 | # gives directional relations between actors based on their 2D absolute positions. 290 | x1, y1 = math.cos(math.radians(actor1.attr['rotation'][0])), math.sin(math.radians(actor1.attr['rotation'][0])) 291 | x2, y2 = actor2.attr['location'][0] - actor1.attr['location'][0], actor2.attr['location'][1] - actor1.attr['location'][1] 292 | x2, y2 = x2 / math.sqrt(x2**2+y2**2), y2 / math.sqrt(x2**2+y2**2) 293 | 294 | degree = math.degrees(math.atan2(y1, x1)) - math.degrees(math.atan2(y2, x2)) 295 | if degree < 0: 296 | degree += 360 297 | 298 | if degree <= 45: # actor2 is in front of actor1 299 | relation_list.append([actor1, Relations.atDRearOf, actor2]) 300 | elif degree >= 45 and degree <= 90: 301 | relation_list.append([actor1, Relations.atSRearOf, actor2]) 302 | elif degree >= 90 and degree <= 135: 303 | relation_list.append([actor1, Relations.inSFrontOf, actor2]) 304 | elif degree >= 135 and degree <= 180: # actor2 is behind actor1 305 | relation_list.append([actor1, Relations.inDFrontOf, actor2]) 306 | elif degree >= 180 and degree <= 225: # actor2 is behind actor1 307 | relation_list.append([actor1, Relations.inDFrontOf, actor2]) 308 | elif degree >= 225 and degree <= 270: 309 | relation_list.append([actor1, Relations.inSFrontOf, actor2]) 310 | elif degree >= 270 and degree <= 315: 311 | relation_list.append([actor1, Relations.atSRearOf, actor2]) 312 | elif degree >= 315 and degree <= 360: 313 | relation_list.append([actor1, Relations.atDRearOf, actor2]) 314 | 315 | if actor2.attr['lane_idx'] < actor1.attr['lane_idx']: # actor2 to the left of actor1 316 | relation_list.append([actor1, Relations.toRightOf, actor2]) 317 | elif actor2.attr['lane_idx'] > actor1.attr['lane_idx']: # actor2 to the right of actor1 318 | relation_list.append([actor1, Relations.toLeftOf, actor2]) 319 | 320 | return relation_list -------------------------------------------------------------------------------- /sg_risk_assessment/scene_graph.py: -------------------------------------------------------------------------------- 1 | import matplotlib, math, itertools 2 | matplotlib.use("Agg") 3 | import networkx as nx 4 | from networkx.drawing.nx_agraph import to_agraph 5 | from sg_risk_assessment.relation_extractor import Relations, ActorType, RelationExtractor, RELATION_COLORS 6 | 7 | 8 | LANE_THRESHOLD = 6 #feet. if object's center is more than this distance away from ego's center, build left or right lane relation 9 | CENTER_LANE_THRESHOLD = 9 #feet. if object's center is within this distance of ego's center, build middle lane relation 10 | 11 | 12 | #class representing a node in the scene graph. this is mainly used for holding the data for each node. 13 | class Node: 14 | def __init__(self, name, attr, type=None): 15 | self.name = name 16 | self.attr = attr 17 | self.label = name 18 | self.type = type.value if type != None else None 19 | 20 | def __repr__(self): 21 | return "%s" % self.name 22 | 23 | 24 | #class defining scene graph and its attributes. contains functions for construction and operations 25 | class SceneGraph: 26 | 27 | #graph can be initialized with a framedict to load all objects at once 28 | def __init__(self, framedict, framenum=None): 29 | self.g = nx.MultiDiGraph() #initialize scenegraph as networkx graph 30 | self.road_node = Node("Root Road", {}, ActorType.ROAD) 31 | self.add_node(self.road_node) #adding the road as the root node 32 | self.parse_json(framedict) # processing json framedict 33 | 34 | #add single node to graph. node can be any hashable datatype including objects. 35 | def add_node(self, node): 36 | color = "white" 37 | if node.name.startswith("ego"): 38 | color = "red" 39 | elif node.name.startswith("car"): 40 | color = "blue" 41 | elif node.name.startswith("lane"): 42 | color = "yellow" 43 | self.g.add_node(node, attr=node.attr, label=node.name, style='filled', fillcolor=color) 44 | 45 | #add relation (edge) between nodes on graph. relation is a list containing [subject, relation, object] 46 | def add_relation(self, relation): 47 | if relation != []: 48 | if relation[0] in self.g.nodes and relation[2] in self.g.nodes: 49 | self.g.add_edge(relation[0], relation[2], object=relation[1], label=relation[1].name, color=RELATION_COLORS[int(relation[1].value)]) 50 | else: 51 | raise NameError("One or both nodes in relation do not exist in graph. Relation: " + str(relation)) 52 | 53 | def add_relations(self, relations_list): 54 | for relation in relations_list: 55 | self.add_relation(relation) 56 | 57 | #parses actor dict and adds nodes to graph. this can be used for all actor types. 58 | def add_actor_dict(self, actordict): 59 | for actor_id, attr in actordict.items(): 60 | # filter actors behind ego #TODO remove this or make it configurable. 61 | x1, y1 = math.cos(math.radians(self.egoNode.attr['rotation'][0])), math.sin(math.radians(self.egoNode.attr['rotation'][0])) 62 | x2, y2 = attr['location'][0] - self.egoNode.attr['location'][0], attr['location'][1] - self.egoNode.attr['location'][1] 63 | inner_product = x1*x2 + y1*y2 64 | length_product = math.sqrt(x1**2+y1**2) + math.sqrt(x2**2+y2**2) 65 | degree = math.degrees(math.acos(inner_product / length_product)) 66 | 67 | if degree <= 80 or (degree >=280 and degree <= 360): 68 | # if abs(self.egoNode.attr['lane_idx'] - attr['lane_idx']) <= 1 \ 69 | # or ("invading_lane" in self.egoNode.attr and (2*self.egoNode.attr['invading_lane'] - self.egoNode.attr['orig_lane_idx']) == attr['lane_idx']): 70 | n = Node(actor_id, attr, None) #using the actor key as the node name and the dict as its attributes. 71 | n.name = self.relation_extractor.get_actor_type(n).name.lower() + ":" + actor_id 72 | n.type = self.relation_extractor.get_actor_type(n).value 73 | self.add_node(n) 74 | self.add_mapping_to_relative_lanes(n) 75 | 76 | #adds lanes and their dicts. constructs relation between each lane and the root road node. 77 | def add_lane_dict(self, lanedict): 78 | #TODO: can we filter out the lane that has no car on it? 79 | for idx, lane in enumerate(lanedict['lanes']): 80 | lane['lane_idx'] = idx 81 | n = Node("lane:"+str(idx), lane, ActorType.LANE) 82 | self.add_node(n) 83 | self.add_relation([n, Relations.isIn, self.road_node]) 84 | 85 | #add signs as entities of the road. 86 | def add_sign_dict(self, signdict): 87 | for sign_id, signattr in signdict.items(): 88 | n = Node(sign_id, signattr, ActorType.SIGN) 89 | self.add_node(n) 90 | self.add_relation([n, Relations.isIn, self.road_node]) 91 | 92 | #add the contents of a whole framedict to the graph 93 | def parse_json(self, framedict): 94 | self.egoNode = Node("ego:"+framedict['ego']['name'], framedict['ego'], ActorType.CAR) 95 | self.add_node(self.egoNode) 96 | 97 | #rotating axes to align with ego. yaw axis is the primary rotation axis in vehicles 98 | self.ego_yaw = math.radians(self.egoNode.attr['rotation'][0]) 99 | self.ego_cos_term = math.cos(self.ego_yaw) 100 | self.ego_sin_term = math.sin(self.ego_yaw) 101 | self.extract_relative_lanes() 102 | 103 | self.relation_extractor = RelationExtractor(self.egoNode) 104 | for key, attrs in framedict.items(): 105 | # if key == "lane": 106 | # self.add_lane_dict(attrs) 107 | if key == "sign": 108 | self.add_sign_dict(attrs) 109 | elif key == "actors": 110 | self.add_actor_dict(attrs) 111 | self.extract_semantic_relations() 112 | 113 | #calls RelationExtractor to build semantic relations between every pair of entity nodes in graph. call this function after all nodes have been added to graph. 114 | def extract_semantic_relations(self): 115 | for node1, node2 in itertools.combinations(self.g.nodes, 2): 116 | if node1.name != node2.name: #dont build self-relations 117 | if node1.type != ActorType.ROAD.value and node2.type != ActorType.ROAD.value: # dont build relations w/ road 118 | self.add_relations(self.relation_extractor.extract_relations(node1, node2)) 119 | 120 | def visualize(self, filename=None): 121 | A = to_agraph(self.g) 122 | A.layout('dot') 123 | A.draw(filename) 124 | 125 | 126 | # TODO refactor after testing 127 | # relative lane mapping method. Each vehicle is assigned to left, middle, or right lane depending on relative position to ego 128 | def extract_relative_lanes(self): 129 | self.left_lane = Node("lane_left", {"curr":"lane_left"}, ActorType.LANE) 130 | self.right_lane = Node("lane_right", {"curr":"lane_right"}, ActorType.LANE) 131 | self.middle_lane = Node("lane_middle", {"curr":"lane_middle"}, ActorType.LANE) 132 | self.add_node(self.left_lane) 133 | self.add_node(self.right_lane) 134 | self.add_node(self.middle_lane) 135 | self.add_relation([self.left_lane, Relations.isIn, self.road_node]) 136 | self.add_relation([self.right_lane, Relations.isIn, self.road_node]) 137 | self.add_relation([self.middle_lane, Relations.isIn, self.road_node]) 138 | self.add_relation([self.egoNode, Relations.isIn, self.middle_lane]) 139 | 140 | #builds isIn relation between object and lane depending on x-displacement relative to ego 141 | #left/middle and right/middle relations have an overlap area determined by the size of CENTER_LANE_THRESHOLD and LANE_THRESHOLD. 142 | #TODO: move to relation_extractor in replacement of current lane-vehicle relation code 143 | def add_mapping_to_relative_lanes(self, object_node): 144 | if object_node.label in [ActorType.LANE, ActorType.LIGHT, ActorType.SIGN, ActorType.ROAD]: #don't build lane relations with static objects 145 | return 146 | _, ego_y = self.rotate_coords(self.egoNode.attr['location'][0], self.egoNode.attr['location'][1]) #NOTE: X corresponds to forward/back displacement and Y corresponds to left/right displacement 147 | _, new_y = self.rotate_coords(object_node.attr['location'][0], object_node.attr['location'][1]) 148 | y_diff = new_y - ego_y 149 | if y_diff < -LANE_THRESHOLD: 150 | self.add_relation([object_node, Relations.isIn, self.left_lane]) 151 | elif y_diff > LANE_THRESHOLD: 152 | self.add_relation([object_node, Relations.isIn, self.right_lane]) 153 | if abs(y_diff) <= CENTER_LANE_THRESHOLD: 154 | self.add_relation([object_node, Relations.isIn, self.middle_lane]) 155 | 156 | 157 | #copied from get_node_embeddings(). rotates coordinates to be relative to ego vector. 158 | def rotate_coords(self, x, y): 159 | new_x = (x*self.ego_cos_term) + (y*self.ego_sin_term) 160 | new_y = ((-x)*self.ego_sin_term) + (y*self.ego_cos_term) 161 | return new_x, new_y -------------------------------------------------------------------------------- /sg_risk_assessment/sg2vec_trainer.py: -------------------------------------------------------------------------------- 1 | import os, sys, pdb 2 | sys.path.append(os.path.dirname(sys.path[0])) 3 | import torch 4 | import torch.optim as optim 5 | import numpy as np 6 | import pandas as pd 7 | import random 8 | from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_score, recall_score, roc_auc_score, roc_curve 9 | from sklearn import preprocessing 10 | from matplotlib import pyplot as plt 11 | 12 | from sg_risk_assessment.relation_extractor import Relations 13 | from argparse import ArgumentParser 14 | from pathlib import Path 15 | from tqdm import tqdm 16 | from sg_risk_assessment.mrgcn import * 17 | from torch_geometric.data import Data, DataLoader, DataListLoader 18 | from sklearn.utils.class_weight import compute_class_weight 19 | import warnings 20 | warnings.simplefilter(action='ignore', category=FutureWarning) 21 | from sklearn.utils import resample 22 | import pickle as pkl 23 | from sklearn.model_selection import train_test_split, StratifiedKFold 24 | from sg_risk_assessment.metrics import * 25 | 26 | from collections import Counter 27 | import wandb 28 | 29 | class Config: 30 | '''Argument Parser for script to train scenegraphs.''' 31 | def __init__(self, args): 32 | self.parser = ArgumentParser(description='The parameters for training the scene graph using GCN.') 33 | self.parser.add_argument('--cache_path', type=str, default="../script/image_dataset.pkl", help="Path to the cache file.") 34 | self.parser.add_argument('--transfer_path', type=str, default="", help="Path to the transfer file.") 35 | self.parser.add_argument('--model_load_path', type=str, default="./model/model_best_val_loss_.vec.pt", help="Path to load cached model file.") 36 | self.parser.add_argument('--model_save_path', type=str, default="./model/model_best_val_loss_.vec.pt", help="Path to save model file.") 37 | self.parser.add_argument('--split_ratio', type=float, default=0.3, help="Ratio of dataset withheld for testing.") 38 | self.parser.add_argument('--downsample', type=lambda x: (str(x).lower() == 'true'), default=False, help='Set to true to downsample dataset.') 39 | self.parser.add_argument('--learning_rate', default=0.00005, type=float, help='The initial learning rate for GCN.') 40 | self.parser.add_argument('--n_folds', type=int, default=1, help='Number of folds for cross validation') 41 | self.parser.add_argument('--seed', type=int, default=0, help='Random seed.') 42 | self.parser.add_argument('--epochs', type=int, default=200, help='Number of epochs to train.') 43 | self.parser.add_argument('--activation', type=str, default='relu', help='Activation function to use, options: [relu, leaky_relu].') 44 | self.parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay (L2 loss on parameters).') 45 | self.parser.add_argument('--dropout', type=float, default=0.1, help='Dropout rate (1 - keep probability).') 46 | self.parser.add_argument('--nclass', type=int, default=2, help="The number of classes for dynamic graph classification (currently only supports 2).") 47 | self.parser.add_argument('--batch_size', type=int, default=16, help='Number of graphs in a batch.') 48 | self.parser.add_argument('--device', type=str, default="cuda", help='The device on which models are run, options: [cuda, cpu].') 49 | self.parser.add_argument('--test_step', type=int, default=5, help='Number of training epochs before testing the model.') 50 | self.parser.add_argument('--inference_mode', type=str, default="5_frames", help='Window size of frames before making one prediction (all_frames for per-frame prediction).') 51 | self.parser.add_argument('--model', type=str, default="mrgcn", help="Model to be used intrinsically. options: [mrgcn, mrgin]") 52 | self.parser.add_argument('--conv_type', type=str, default="FastRGCNConv", help="type of RGCNConv to use [RGCNConv, FastRGCNConv].") 53 | self.parser.add_argument('--num_layers', type=int, default=2, help="Number of layers in the network.") 54 | self.parser.add_argument('--hidden_dim', type=int, default=64, help="Hidden dimension in RGCN.") 55 | self.parser.add_argument('--layer_spec', type=str, default=None, help="manually specify the size of each layer in format l1,l2,l3 (no spaces).") 56 | self.parser.add_argument('--pooling_type', type=str, default="sagpool", help="Graph pooling type, options: [sagpool, topk, None].") 57 | self.parser.add_argument('--pooling_ratio', type=float, default=0.5, help="Graph pooling ratio.") 58 | self.parser.add_argument('--readout_type', type=str, default="add", help="Readout type, options: [max, mean, add].") 59 | self.parser.add_argument('--temporal_type', type=str, default="lstm_seq", help="Temporal type, options: [mean, lstm_last, lstm_sum, lstm_attn, lstm_seq].") 60 | self.parser.add_argument('--lstm_input_dim', type=int, default=50, help="LSTM input dimensions.") 61 | self.parser.add_argument('--lstm_output_dim', type=int, default=20, help="LSTM output dimensions.") 62 | self.parser.add_argument('--lstm_layers', type=int, default=1, help="LSTM layers.") 63 | 64 | args_parsed = self.parser.parse_args(args) 65 | wandb.init(project="av-scenegraph") 66 | wandb_config = wandb.config 67 | 68 | for arg_name in vars(args_parsed): 69 | self.__dict__[arg_name] = getattr(args_parsed, arg_name) 70 | wandb_config[arg_name] = getattr(args_parsed, arg_name) 71 | 72 | self.cache_path = Path(self.cache_path).resolve() 73 | self.transfer_path = Path(self.transfer_path).resolve() if self.transfer_path != "" else None 74 | 75 | def build_scenegraph_dataset(cache_path, train_to_test_ratio=0.3, downsample=False, seed=0, transfer_path=None): 76 | ''' 77 | Dataset format 78 | scenegraphs_sequence: dict_keys(['sequence', 'label', 'folder_name']) 79 | 'sequence': scenegraph metadata 80 | 'label': classification output [0 -> non_risky (negative), 1 -> risky (positive)] 81 | 'folder_name': foldername storing sequence data 82 | 83 | Dataset modes 84 | no downsample 85 | all sequences used for the train and test set regardless of class distribution 86 | downsample 87 | equal amount of positive and negative sequences used for the train and test set 88 | transfer 89 | replaces original test set with another dataset 90 | ''' 91 | dataset_file = open(cache_path, "rb") 92 | scenegraphs_sequence, feature_list = pkl.load(dataset_file) 93 | 94 | class_0 = [] 95 | class_1 = [] 96 | 97 | for g in scenegraphs_sequence: 98 | if g['label'] == 0: 99 | class_0.append(g) 100 | elif g['label'] == 1: 101 | class_1.append(g) 102 | 103 | y_0 = [0]*len(class_0) 104 | y_1 = [1]*len(class_1) 105 | min_number = min(len(class_0), len(class_1)) 106 | 107 | # dataset class distribution 108 | if downsample: 109 | modified_class_0, modified_y_0 = resample(class_0, y_0, n_samples=min_number) 110 | else: 111 | modified_class_0, modified_y_0 = class_0, y_0 112 | train, test, _, _ = train_test_split(modified_class_0+class_1, modified_y_0+y_1, test_size=train_to_test_ratio, shuffle=True, stratify=modified_y_0+y_1, random_state=seed) 113 | 114 | # transfer learning 115 | if transfer_path != None: 116 | train = np.append(train, test, axis=0) 117 | test, _ = pkl.load(open(transfer_path, "rb")) 118 | 119 | return train, test, feature_list 120 | 121 | 122 | class SG2VECTrainer: 123 | 124 | def __init__(self, args): 125 | self.config = Config(args) 126 | self.args = args 127 | np.random.seed(self.config.seed) 128 | torch.manual_seed(self.config.seed) 129 | 130 | if not self.config.cache_path.exists(): 131 | raise Exception("The cache file does not exist.") 132 | 133 | if not self.config.temporal_type in ["lstm_seq", 'none']: 134 | raise NotImplementedError("This version of dynkg_trainer does not support temporal types other than step-by-step sequence prediction (lstm_seq) or 'none'.") 135 | 136 | self.best_val_loss = 99999 137 | self.best_epoch = 0 138 | self.best_val_acc = 0 139 | self.best_val_auc = 0 140 | self.best_val_confusion = [] 141 | self.best_val_f1 = 0 142 | self.best_val_mcc = -1.0 143 | self.best_val_acc_balanced = 0 144 | self.best_avg_pred_frame = 0 145 | self.log = False 146 | 147 | 148 | def split_dataset(self): 149 | self.training_data, self.testing_data, self.feature_list = build_scenegraph_dataset(self.config.cache_path, self.config.split_ratio, downsample=self.config.downsample, seed=self.config.seed, transfer_path=self.config.transfer_path) 150 | total_train_labels = np.concatenate([np.full(len(data['sequence']), data['label']) for data in self.training_data]) # used to compute frame-level class weighting 151 | total_test_labels = np.concatenate([np.full(len(data['sequence']), data['label']) for data in self.testing_data]) 152 | self.training_labels = [data['label'] for data in self.training_data] 153 | self.testing_labels = [data['label'] for data in self.testing_data] 154 | self.class_weights = torch.from_numpy(compute_class_weight('balanced', np.unique(total_train_labels), total_train_labels)) 155 | if self.config.n_folds <= 1: 156 | print("Number of Training Sequences Included: ", len(self.training_data)) 157 | print("Number of Testing Sequences Included: ", len(self.testing_data)) 158 | print("Number of Training Labels in Each Class: " + str(np.unique(total_train_labels, return_counts=True)[1]) + ", Class Weights: " + str(self.class_weights)) 159 | print("Number of Testing Labels in Each Class: " + str(np.unique(total_test_labels, return_counts=True)[1]) + ", Class Weights: " + str(self.class_weights)) 160 | 161 | 162 | def build_model(self): 163 | self.config.num_features = len(self.feature_list) 164 | self.config.num_relations = max([r.value for r in Relations])+1 165 | if self.config.model == "mrgcn": 166 | self.model = MRGCN(self.config).to(self.config.device) 167 | elif self.config.model == "mrgin": 168 | self.model = MRGIN(self.config).to(self.config.device) 169 | else: 170 | raise Exception("model selection is invalid: " + self.config.model) 171 | 172 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.learning_rate, weight_decay=self.config.weight_decay) 173 | if self.class_weights.shape[0] < 2: 174 | self.loss_func = nn.CrossEntropyLoss() 175 | else: 176 | self.loss_func = nn.CrossEntropyLoss(weight=self.class_weights.float().to(self.config.device)) 177 | 178 | wandb.watch(self.model, log="all") 179 | 180 | 181 | # Pick between Standard Training and KFold Cross Validation Training 182 | def learn(self): 183 | if self.config.n_folds <= 1 or self.config.transfer_path != None: 184 | print('\nRunning Standard Training Loop\n') 185 | self.train() 186 | else: 187 | print('\nRunning {}-Fold Cross Validation Training Loop\n'.format(self.config.n_folds)) 188 | self.cross_valid() 189 | 190 | 191 | def cross_valid(self): 192 | 193 | # KFold cross validation with similar class distribution in each fold 194 | skf = StratifiedKFold(n_splits=self.config.n_folds) 195 | X = np.array(self.training_data + self.testing_data) 196 | y = np.array(self.training_labels + self.testing_labels) 197 | 198 | # self.results stores average metrics for the the n_folds 199 | self.results = {} 200 | self.fold = 1 201 | 202 | # Split training and testing data based on n_splits (Folds) 203 | for train_index, test_index in skf.split(X, y): 204 | X_train, X_test = X[train_index], X[test_index] 205 | y_train, y_test = y[train_index], y[test_index] 206 | 207 | self.training_data = X_train 208 | self.testing_data = X_test 209 | self.training_labels = y_train 210 | self.testing_labels = y_test 211 | 212 | # To compute frame-level class weighting 213 | total_train_labels = np.concatenate([np.full(len(data['sequence']), data['label']) for data in self.training_data]) 214 | total_test_labels = np.concatenate([np.full(len(data['sequence']), data['label']) for data in self.testing_data]) 215 | self.class_weights = torch.from_numpy(compute_class_weight('balanced', np.unique(total_train_labels), total_train_labels)) 216 | 217 | print('\nFold {}'.format(self.fold)) 218 | print("Number of Training Sequences Included: ", len(self.training_data)) 219 | print("Number of Testing Sequences Included: ", len(self.testing_data)) 220 | print("Number of Training Labels in Each Class: " + str(np.unique(total_train_labels, return_counts=True)[1]) + ", Class Weights: " + str(self.class_weights)) 221 | print("Number of Testing Labels in Each Class: " + str(np.unique(total_test_labels, return_counts=True)[1]) + ", Class Weights: " + str(self.class_weights)) 222 | 223 | self.best_val_loss = 99999 224 | self.train() 225 | self.log = True 226 | outputs_train, labels_train, outputs_test, labels_test, metrics = self.evaluate(self.fold) 227 | self.update_cross_valid_metrics(outputs_train, labels_train, outputs_test, labels_test, metrics) 228 | self.log = False 229 | 230 | if self.fold != self.config.n_folds: 231 | del self.model 232 | del self.optimizer 233 | self.build_model() 234 | 235 | self.fold += 1 236 | del self.results 237 | 238 | 239 | def train(self): 240 | tqdm_bar = tqdm(range(self.config.epochs)) 241 | 242 | for epoch_idx in tqdm_bar: # iterate through epoch 243 | acc_loss_train = 0 244 | self.sequence_loader = DataListLoader(self.training_data, batch_size=self.config.batch_size) 245 | 246 | for data_list in self.sequence_loader: # iterate through batches of the dataset 247 | self.model.train() 248 | self.optimizer.zero_grad() 249 | labels = torch.empty(0).long().to(self.config.device) 250 | outputs = torch.empty(0,2).to(self.config.device) 251 | 252 | for sequence in data_list: # iterate through scene-graph sequences in the batch 253 | data, label = sequence['sequence'], sequence['label'] 254 | graph_list = [Data(x=g['node_features'], edge_index=g['edge_index'], edge_attr=g['edge_attr']) for g in data] 255 | self.train_loader = DataLoader(graph_list, batch_size=len(graph_list)) 256 | sequence = next(iter(self.train_loader)).to(self.config.device) 257 | output, _ = self.model.forward(sequence.x, sequence.edge_index, sequence.edge_attr, sequence.batch) 258 | label = torch.LongTensor(np.full(output.shape[0], label)).to(self.config.device) #fill label to length of the sequence. shape (len_input_sequence, 1) 259 | labels = torch.cat([labels, label], dim=0) 260 | outputs = torch.cat([outputs, output.view(-1, 2)], dim=0) #in this case the output is of shape (len_input_sequence, 2) 261 | 262 | loss_train = self.loss_func(outputs, labels) 263 | loss_train.backward() 264 | acc_loss_train += loss_train.detach().cpu().item() * len(data_list) 265 | self.optimizer.step() 266 | del loss_train 267 | 268 | acc_loss_train /= len(self.training_data) 269 | tqdm_bar.set_description('Epoch: {:04d}, loss_train: {:.4f}'.format(epoch_idx, acc_loss_train)) 270 | 271 | if epoch_idx % self.config.test_step == 0: 272 | self.evaluate(epoch_idx) 273 | 274 | def inference(self, testing_data, testing_labels, mode='5_frames'): # change mode='all_frames' to run per-frame prediction 275 | labels = torch.LongTensor().to(self.config.device) 276 | outputs = torch.FloatTensor().to(self.config.device) 277 | acc_loss_test = 0 278 | attns_weights = [] 279 | node_attns = [] 280 | sum_prediction_frame = 0 281 | sum_seq_len = 0 282 | num_risky_sequences = 0 283 | num_safe_sequences = 0 284 | sum_predicted_risky_indices = 0 #sum is calculated as (value * (index+1))/sum(range(seq_len)) for each value and index in the sequence. 285 | sum_predicted_safe_indices = 0 #sum is calculated as ((1-value) * (index+1))/sum(range(seq_len)) for each value and index in the sequence. 286 | inference_time = 0 287 | prof_result = "" 288 | correct_risky_seq = 0 289 | correct_safe_seq = 0 290 | incorrect_risky_seq = 0 291 | incorrect_safe_seq = 0 292 | num_sequences = 0 293 | 294 | with torch.autograd.profiler.profile(enabled=False, use_cuda=True) as prof: 295 | with torch.no_grad(): 296 | for i in range(len(testing_data)): # iterate through sequences of scenegraphs 297 | 298 | # determine number of frames per clip and amount of frames to evaluate 299 | frames_per_clip = len(testing_data[i]['sequence']) 300 | frames_to_evaluate = mode.split('_')[0] 301 | if frames_to_evaluate.isdigit(): 302 | frames_to_evaluate = int(frames_to_evaluate) 303 | else: 304 | frames_to_evaluate = frames_per_clip 305 | 306 | pred_all = frames_to_evaluate == frames_per_clip # determine to use all outputs (True) or last output of lstm (False) 307 | 308 | # run model inference 309 | for j in range(frames_per_clip - frames_to_evaluate + 1): 310 | data, label = testing_data[i]['sequence'][j:j+frames_to_evaluate], testing_labels[i] 311 | data_list = [Data(x=g['node_features'], edge_index=g['edge_index'], edge_attr=g['edge_attr']) for g in data] 312 | self.test_loader = DataLoader(data_list, batch_size=len(data_list)) 313 | sequence = next(iter(self.test_loader)).to(self.config.device) 314 | self.model.eval() 315 | #start = torch.cuda.Event(enable_timing=True) 316 | #end = torch.cuda.Event(enable_timing=True) 317 | #start.record() 318 | output, attns = self.model.forward(sequence.x, sequence.edge_index, sequence.edge_attr, sequence.batch) 319 | #end.record() 320 | #torch.cuda.synchronize() 321 | inference_time += 0#start.elapsed_time(end) 322 | output = output.view(-1,2) 323 | seq_len = output.shape[0] 324 | label = torch.LongTensor(np.full(seq_len, label)).to(self.config.device) #fill label to length of the sequence. 325 | 326 | 327 | if not pred_all: 328 | # currently not supporting the attention weights when mode != 'all_frames' or pred_all == False 329 | output = output[-1].unsqueeze(dim=0) 330 | label = label[-1].unsqueeze(dim=0) 331 | 332 | outputs = torch.cat([outputs, output], dim=0) 333 | labels = torch.cat([labels, label], dim=0) 334 | loss_test = self.loss_func(output, label) 335 | acc_loss_test += loss_test.detach().cpu().item() 336 | num_sequences += 1 337 | 338 | # if 'lstm_attn_weights' in attns: 339 | # attns_weights.append(attns['lstm_attn_weights'].squeeze().detach().cpu().numpy().tolist()) 340 | # if 'pool_score' in attns: 341 | # node_attn = {} 342 | # node_attn["original_batch"] = sequence.batch.detach().cpu().numpy().tolist() 343 | # node_attn["pool_perm"] = attns['pool_perm'].detach().cpu().numpy().tolist() 344 | # node_attn["pool_batch"] = attns['batch'].detach().cpu().numpy().tolist() 345 | # node_attn["pool_score"] = attns['pool_score'].detach().cpu().numpy().tolist() 346 | # node_attns.append(node_attn) 347 | 348 | # log metrics for risky and non-risky clips separately. 349 | if not pred_all: 350 | preds = torch.argmax(output) 351 | else: 352 | preds = output.max(1)[1].type_as(label) 353 | 354 | # ---------------------------------------- omg... ---------------------------------------- 355 | if(1 in label): 356 | num_risky_sequences += 1 357 | sum_seq_len += seq_len 358 | if (1 in preds): 359 | correct_risky_seq += 1 #sequence level metrics 360 | if not pred_all: 361 | sum_prediction_frame = 0 362 | sum_predicted_risky_indices = 0 363 | else: 364 | sum_prediction_frame += torch.where(preds == 1)[0][0].item() #returns the first index of a "risky" prediction in this sequence. 365 | sum_predicted_risky_indices += torch.sum(torch.where(preds==1)[0] + 1).item() / np.sum(range(seq_len + 1)) #(1*index)/seq_len added to sum. 366 | else: 367 | incorrect_risky_seq += 1 368 | if not pred_all: 369 | sum_prediction_frame = 0 370 | else: 371 | sum_prediction_frame += seq_len #if no risky predictions are made, then add the full sequence length to running avg. 372 | elif(0 in label): 373 | num_safe_sequences += 1 374 | if (0 in preds): 375 | correct_safe_seq += 1 #sequence level metrics 376 | if not pred_all: 377 | sum_predicted_safe_indices = 0 378 | else: 379 | sum_predicted_safe_indices += torch.sum(torch.where(preds==0)[0] + 1).item() / np.sum(range(seq_len + 1)) #(1*index)/seq_len added to sum. 380 | else: 381 | incorrect_safe_seq += 1 382 | # ---------------------------------------- omg... ---------------------------------------- 383 | 384 | avg_risky_prediction_frame = sum_prediction_frame / num_risky_sequences #avg of first indices in a sequence that a risky frame is first correctly predicted. 385 | avg_risky_seq_len = sum_seq_len / num_risky_sequences #sequence length for comparison with the prediction frame metric. 386 | avg_predicted_risky_indices = sum_predicted_risky_indices / num_risky_sequences 387 | avg_predicted_safe_indices = sum_predicted_safe_indices / num_safe_sequences 388 | seq_tpr = correct_risky_seq / num_risky_sequences 389 | seq_fpr = incorrect_safe_seq / num_safe_sequences 390 | seq_tnr = correct_safe_seq / num_safe_sequences 391 | seq_fnr = incorrect_risky_seq / num_risky_sequences 392 | if prof != None: 393 | prof_result = prof.key_averages().table(sort_by="cuda_time_total") 394 | 395 | return outputs, \ 396 | labels, \ 397 | acc_loss_test / num_sequences, \ 398 | attns_weights, \ 399 | node_attns, \ 400 | avg_risky_prediction_frame, \ 401 | avg_risky_seq_len, \ 402 | avg_predicted_risky_indices, \ 403 | avg_predicted_safe_indices, \ 404 | inference_time, \ 405 | prof_result, \ 406 | seq_tpr, \ 407 | seq_fpr, \ 408 | seq_tnr, \ 409 | seq_fnr 410 | 411 | def evaluate(self, current_epoch=None): 412 | metrics = {} 413 | outputs_train, \ 414 | labels_train, \ 415 | acc_loss_train, \ 416 | attns_train, \ 417 | node_attns_train, \ 418 | train_avg_prediction_frame, \ 419 | train_avg_seq_len, \ 420 | avg_predicted_risky_indices, \ 421 | avg_predicted_safe_indices, \ 422 | train_inference_time, \ 423 | train_profiler_result, \ 424 | seq_tpr, \ 425 | seq_fpr, \ 426 | seq_tnr, \ 427 | seq_fnr = self.inference(self.training_data, self.training_labels, mode=self.config.inference_mode) 428 | 429 | metrics['train'] = get_metrics(outputs_train, labels_train) 430 | metrics['train']['loss'] = acc_loss_train 431 | metrics['train']['avg_prediction_frame'] = train_avg_prediction_frame 432 | metrics['train']['avg_seq_len'] = train_avg_seq_len 433 | metrics['train']['avg_predicted_risky_indices'] = avg_predicted_risky_indices 434 | metrics['train']['avg_predicted_safe_indices'] = avg_predicted_safe_indices 435 | metrics['train']['seq_tpr'] = seq_tpr 436 | metrics['train']['seq_tnr'] = seq_tnr 437 | metrics['train']['seq_fpr'] = seq_fpr 438 | metrics['train']['seq_fnr'] = seq_fnr 439 | with open("graph_profile_metrics.txt", mode='w') as f: 440 | f.write(train_profiler_result) 441 | 442 | outputs_test, \ 443 | labels_test, \ 444 | acc_loss_test, \ 445 | attns_test, \ 446 | node_attns_test, \ 447 | val_avg_prediction_frame, \ 448 | val_avg_seq_len, \ 449 | avg_predicted_risky_indices, \ 450 | avg_predicted_safe_indices, \ 451 | test_inference_time, \ 452 | test_profiler_result, \ 453 | seq_tpr, \ 454 | seq_fpr, \ 455 | seq_tnr, \ 456 | seq_fnr = self.inference(self.testing_data, self.testing_labels, mode=self.config.inference_mode) 457 | 458 | metrics['test'] = get_metrics(outputs_test, labels_test) 459 | metrics['test']['loss'] = acc_loss_test 460 | metrics['test']['avg_prediction_frame'] = val_avg_prediction_frame 461 | metrics['test']['avg_seq_len'] = val_avg_seq_len 462 | metrics['test']['avg_predicted_risky_indices'] = avg_predicted_risky_indices 463 | metrics['test']['avg_predicted_safe_indices'] = avg_predicted_safe_indices 464 | metrics['test']['seq_tpr'] = seq_tpr 465 | metrics['test']['seq_tnr'] = seq_tnr 466 | metrics['test']['seq_fpr'] = seq_fpr 467 | metrics['test']['seq_fnr'] = seq_fnr 468 | metrics['avg_inf_time'] = (train_inference_time + test_inference_time) / (len(labels_train) + len(labels_test)) 469 | 470 | print("\ntrain loss: " + str(acc_loss_train) + ", acc:", metrics['train']['acc'], metrics['train']['confusion'], "mcc:", metrics['train']['mcc'], \ 471 | "\ntest loss: " + str(acc_loss_test) + ", acc:", metrics['test']['acc'], metrics['test']['confusion'], "mcc:", metrics['test']['mcc']) 472 | 473 | self.update_best_metrics(metrics, current_epoch) 474 | metrics['best_epoch'] = self.best_epoch 475 | metrics['best_val_loss'] = self.best_val_loss 476 | metrics['best_val_acc'] = self.best_val_acc 477 | metrics['best_val_auc'] = self.best_val_auc 478 | metrics['best_val_conf'] = self.best_val_confusion 479 | metrics['best_val_f1'] = self.best_val_f1 480 | metrics['best_val_mcc'] = self.best_val_mcc 481 | metrics['best_val_acc_balanced'] = self.best_val_acc_balanced 482 | metrics['best_avg_pred_frame'] = self.best_avg_pred_frame 483 | 484 | if self.config.n_folds <= 1 or self.log: 485 | log_wandb(metrics) 486 | 487 | return outputs_train, labels_train, outputs_test, labels_test, metrics 488 | 489 | 490 | #automatically save the model and metrics with the lowest validation loss 491 | def update_best_metrics(self, metrics, current_epoch): 492 | if metrics['test']['loss'] < self.best_val_loss: 493 | self.best_val_loss = metrics['test']['loss'] 494 | self.best_epoch = current_epoch if current_epoch != None else self.config.epochs 495 | self.best_val_acc = metrics['test']['acc'] 496 | self.best_val_auc = metrics['test']['auc'] 497 | self.best_val_confusion = metrics['test']['confusion'] 498 | self.best_val_f1 = metrics['test']['f1'] 499 | self.best_val_mcc = metrics['test']['mcc'] 500 | self.best_val_acc_balanced = metrics['test']['balanced_acc'] 501 | self.best_avg_pred_frame = metrics['test']['avg_prediction_frame'] 502 | #self.save_model() 503 | 504 | 505 | # Averages metrics after the end of each cross validation fold 506 | def update_cross_valid_metrics(self, outputs_train, labels_train, outputs_test, labels_test, metrics): 507 | if self.fold == 1: 508 | self.results['outputs_train'] = outputs_train 509 | self.results['labels_train'] = labels_train 510 | self.results['train'] = metrics['train'] 511 | self.results['train']['loss'] = metrics['train']['loss'] 512 | self.results['train']['avg_prediction_frame'] = metrics['train']['avg_prediction_frame'] 513 | self.results['train']['avg_seq_len'] = metrics['train']['avg_seq_len'] 514 | self.results['train']['avg_predicted_risky_indices'] = metrics['train']['avg_predicted_risky_indices'] 515 | self.results['train']['avg_predicted_safe_indices'] = metrics['train']['avg_predicted_safe_indices'] 516 | 517 | self.results['outputs_test'] = outputs_test 518 | self.results['labels_test'] = labels_test 519 | self.results['test'] = metrics['test'] 520 | self.results['test']['loss'] = metrics['test']['loss'] 521 | self.results['test']['avg_prediction_frame'] = metrics['test']['avg_prediction_frame'] 522 | self.results['test']['avg_seq_len'] = metrics['test']['avg_seq_len'] 523 | self.results['test']['avg_predicted_risky_indices'] = metrics['test']['avg_predicted_risky_indices'] 524 | self.results['test']['avg_predicted_safe_indices'] = metrics['test']['avg_predicted_safe_indices'] 525 | self.results['avg_inf_time'] = metrics['avg_inf_time'] 526 | 527 | self.results['best_epoch'] = metrics['best_epoch'] 528 | self.results['best_val_loss'] = metrics['best_val_loss'] 529 | self.results['best_val_acc'] = metrics['best_val_acc'] 530 | self.results['best_val_auc'] = metrics['best_val_auc'] 531 | self.results['best_val_conf'] = metrics['best_val_conf'] 532 | self.results['best_val_f1'] = metrics['best_val_f1'] 533 | self.results['best_val_mcc'] = metrics['best_val_mcc'] 534 | self.results['best_val_acc_balanced'] = metrics['best_val_acc_balanced'] 535 | self.results['best_avg_pred_frame'] = metrics['best_avg_pred_frame'] 536 | else: 537 | self.results['outputs_train'] = torch.cat((self.results['outputs_train'], outputs_train), dim=0) 538 | self.results['labels_train'] = torch.cat((self.results['labels_train'], labels_train), dim=0) 539 | self.results['train']['loss'] = np.append(self.results['train']['loss'], metrics['train']['loss']) 540 | self.results['train']['avg_prediction_frame'] = np.append(self.results['train']['avg_prediction_frame'], 541 | metrics['train']['avg_prediction_frame']) 542 | self.results['train']['avg_seq_len'] = np.append(self.results['train']['avg_seq_len'], metrics['train']['avg_seq_len']) 543 | self.results['train']['avg_predicted_risky_indices'] = np.append(self.results['train']['avg_predicted_risky_indices'], 544 | metrics['train']['avg_predicted_risky_indices']) 545 | self.results['train']['avg_predicted_safe_indices'] = np.append(self.results['train']['avg_predicted_safe_indices'], 546 | metrics['train']['avg_predicted_safe_indices']) 547 | 548 | self.results['outputs_test'] = torch.cat((self.results['outputs_test'], outputs_test), dim=0) 549 | self.results['labels_test'] = torch.cat((self.results['labels_test'], labels_test), dim=0) 550 | self.results['test']['loss'] = np.append(self.results['test']['loss'], metrics['test']['loss']) 551 | self.results['test']['avg_prediction_frame'] = np.append(self.results['test']['avg_prediction_frame'], 552 | metrics['test']['avg_prediction_frame']) 553 | self.results['test']['avg_seq_len'] = np.append(self.results['test']['avg_seq_len'], metrics['test']['avg_seq_len']) 554 | self.results['test']['avg_predicted_risky_indices'] = np.append(self.results['test']['avg_predicted_risky_indices'], 555 | metrics['test']['avg_predicted_risky_indices']) 556 | self.results['test']['avg_predicted_safe_indices'] = np.append(self.results['test']['avg_predicted_safe_indices'], 557 | metrics['test']['avg_predicted_safe_indices']) 558 | self.results['avg_inf_time'] = np.append(self.results['avg_inf_time'], metrics['avg_inf_time']) 559 | 560 | self.results['best_epoch'] = np.append(self.results['best_epoch'], metrics['best_epoch']) 561 | self.results['best_val_loss'] = np.append(self.results['best_val_loss'], metrics['best_val_loss']) 562 | self.results['best_val_acc'] = np.append(self.results['best_val_acc'], metrics['best_val_acc']) 563 | self.results['best_val_auc'] = np.append(self.results['best_val_auc'], metrics['best_val_auc']) 564 | self.results['best_val_conf'] = np.append(self.results['best_val_conf'], metrics['best_val_conf']) 565 | self.results['best_val_f1'] = np.append(self.results['best_val_f1'], metrics['best_val_f1']) 566 | self.results['best_val_mcc'] = np.append(self.results['best_val_mcc'], metrics['best_val_mcc']) 567 | self.results['best_val_acc_balanced'] = np.append(self.results['best_val_acc_balanced'], metrics['best_val_acc_balanced']) 568 | self.results['best_avg_pred_frame'] = np.append(self.results['best_avg_pred_frame'], metrics['best_avg_pred_frame']) 569 | 570 | # Log final averaged results 571 | if self.fold == self.config.n_folds: 572 | final_results = {} 573 | final_results['train'] = get_metrics(self.results['outputs_train'], self.results['labels_train']) 574 | final_results['train']['loss'] = np.average(self.results['train']['loss']) 575 | final_results['train']['avg_prediction_frame'] = np.average(self.results['train']['avg_prediction_frame']) 576 | final_results['train']['avg_seq_len'] = np.average(self.results['train']['avg_seq_len']) 577 | final_results['train']['avg_predicted_risky_indices'] = np.average(self.results['train']['avg_predicted_risky_indices']) 578 | final_results['train']['avg_predicted_safe_indices'] = np.average(self.results['train']['avg_predicted_safe_indices']) 579 | 580 | final_results['test'] = get_metrics(self.results['outputs_test'], self.results['labels_test']) 581 | final_results['test']['loss'] = np.average(self.results['test']['loss']) 582 | final_results['test']['avg_prediction_frame'] = np.average(self.results['test']['avg_prediction_frame']) 583 | final_results['test']['avg_seq_len'] = np.average(self.results['test']['avg_seq_len']) 584 | final_results['test']['avg_predicted_risky_indices'] = np.average(self.results['test']['avg_predicted_risky_indices']) 585 | final_results['test']['avg_predicted_safe_indices'] = np.average(self.results['test']['avg_predicted_safe_indices']) 586 | final_results['avg_inf_time'] = np.average(self.results['avg_inf_time']) 587 | 588 | # Best results 589 | final_results['best_epoch'] = np.average(self.results['best_epoch']) 590 | final_results['best_val_loss'] = np.average(self.results['best_val_loss']) 591 | final_results['best_val_acc'] = np.average(self.results['best_val_acc']) 592 | final_results['best_val_auc'] = np.average(self.results['best_val_auc']) 593 | final_results['best_val_conf'] = self.results['best_val_conf'] 594 | final_results['best_val_f1'] = np.average(self.results['best_val_f1']) 595 | final_results['best_val_mcc'] = np.average(self.results['best_val_mcc']) 596 | final_results['best_val_acc_balanced'] = np.average(self.results['best_val_acc_balanced']) 597 | final_results['best_avg_pred_frame'] = np.average(self.results['best_avg_pred_frame']) 598 | 599 | print('\nFinal Averaged Results') 600 | print("\naverage train loss: " + str(final_results['train']['loss']) + ", average acc:", final_results['train']['acc'], final_results['train']['confusion'], final_results['train']['auc'], \ 601 | "\naverage test loss: " + str(final_results['test']['loss']) + ", average acc:", final_results['test']['acc'], final_results['test']['confusion'], final_results['test']['auc']) 602 | 603 | log_wandb(final_results) 604 | 605 | return self.results['outputs_train'], self.results['labels_train'], self.results['outputs_test'], self.results['labels_test'], final_results 606 | 607 | def save_model(self): 608 | """Function to save the model.""" 609 | saved_path = Path(self.config.model_save_path).resolve() 610 | os.makedirs(os.path.dirname(saved_path), exist_ok=True) 611 | torch.save(self.model.state_dict(), str(saved_path)) 612 | with open(os.path.dirname(saved_path) + "/model_parameters.txt", "w+") as f: 613 | f.write(str(self.config)) 614 | f.write('\n') 615 | f.write(str(' '.join(sys.argv))) 616 | 617 | def load_model(self): 618 | """Function to load the model.""" 619 | saved_path = Path(self.config.model_load_path).resolve() 620 | if saved_path.exists(): 621 | self.build_model() 622 | self.model.load_state_dict(torch.load(str(saved_path))) 623 | self.model.eval() 624 | --------------------------------------------------------------------------------