├── sg_risk_assessment ├── __init__.py ├── metrics.py ├── mrgcn.py ├── relation_extractor.py ├── image_scenegraph.py └── dynkg_trainer.py ├── baseline_risk_assessment ├── __init__.py ├── metrics.py ├── dataset.py ├── models.py └── train.py ├── assets └── archi.png ├── requirements.txt ├── .gitignore ├── README.md ├── baseline_risk_assessment.py └── sg_risk_assessment.py /sg_risk_assessment/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baseline_risk_assessment/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/archi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louisccc/sg-risk-assessment/HEAD/assets/archi.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.7.1 2 | astor==0.7.1 3 | backcall==0.1.0 4 | cycler==0.10.0 5 | decorator==4.4.0 6 | gast==0.2.2 7 | google-pasta==0.1.7 8 | grpcio==1.20.1 9 | h5py==2.9.0 10 | ipython==7.8.0 11 | ipython-genutils==0.2.0 12 | jedi==0.15.1 13 | Keras==2.2.4 14 | Keras-Applications==1.0.7 15 | Keras-Preprocessing==1.0.9 16 | kiwisolver==1.1.0 17 | Markdown==3.1.1 18 | matplotlib==3.1.1 19 | networkx==2.4 20 | numpy==1.16.5 21 | opencv-python==4.1.1.26 22 | pandas==0.23.4 23 | parso==0.5.1 24 | pexpect==4.7.0 25 | pickleshare==0.7.5 26 | Pillow==8.1.1 27 | prompt-toolkit==2.0.10 28 | protobuf==3.7.1 29 | ptyprocess==0.6.0 30 | Pygments==2.4.2 31 | pyparsing==2.4.2 32 | python-dateutil==2.8.0 33 | pytz==2019.3 34 | PyYAML==5.4 35 | scikit-image==0.15.0 36 | scikit-learn==0.21.3 37 | scipy==1.1.0 38 | six==1.12.0 39 | tensorboard==1.13.1 40 | tensorflow==2.3.1 41 | tensorflow-estimator==1.13.0 42 | termcolor==1.1.0 43 | tqdm==4.36.1 44 | pytorch-nlp==0.5.0 45 | torch-geometric==1.5.0 46 | traitlets==4.3.3 47 | wcwidth==0.1.7 48 | Werkzeug==0.16.0 49 | wrapt==1.11.2 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | # dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 94 | __pypackages__/ 95 | 96 | # Celery stuff 97 | celerybeat-schedule 98 | celerybeat.pid 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | # wandb dir 131 | wandb/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Scene-graph Augmented Data-driven Risk Assessment of Autonomous Vehicle Decisions 2 | This repository includes the code and dataset information required for reproducing the results in [our paper](https://arxiv.org/abs/2009.06435). Besides, we also integrated the source code of [our baseline method](https://arxiv.org/abs/1906.02859), [DeepTL-Lane-Change-Classification](https://github.com/Ekim-Yurtsever/DeepTL-Lane-Change-Classification), into this repo. The baseline approach infers the risk level of lane change video clips with deep CNN+LSTM. Our approach incoporates both spatial modeling and temporal modeling in the task of subjective risk assessment. 3 | 4 | **NOTE:** For a more comprehensive implementation of the code from this project and our other related work, please refer to our new open-source tool for AV scene-graph generation and embedding [roadscene2vec](https://github.com/AICPS/roadscene2vec). 5 | 6 | 7 | The architecture of our approach is illustrated as below, 8 | 9 | ![](https://github.com/louisccc/sg-risk-assessment/blob/master/assets/archi.png?raw=true) 10 | 11 | As for fabricating the lane-changing datasets, we use Carla [CARLA](https://github.com/carla-simulator/carla) 0.9.8 which is an open-source autonomous car driving simulator. Besides, we also utilized the [scenario_runner](https://github.com/carla-simulator/scenario_runner) which was designed for CARLA challenge event. For real-driving datasets, we used Honda-Driving Dataset (HDD) in our experiments. We published the converted scene-graph datasets used in our paper [here](http://ieee-dataport.org/3618). 12 | 13 | The architecture of this repository is as below: 14 | - **sg-risk-assessment/**: this folder consists of all the related source files used for our scene-graph based approach. 15 | - **baseline-risk-assessment/**: this folder consists of all the related source files used for the baseline method. 16 | - **sg_risk_assessment.py**: the script that triggers our scene-graph based approach. 17 | - **baseline_risk_assessment.py**: the script that triggers the baseline model. 18 | 19 | # To Get Started 20 | We recommend our potential users to use [Anaconda](https://www.anaconda.com/) as the primary virtual environment. The requirements to run through our repo are as follows, 21 | - python >= 3.6 22 | - torch == 1.6.0 23 | - torch_geometric == 1.6.1 24 | 25 | Our recommended command sequence is as follows: 26 | ```shell 27 | $ conda create --name sg_risk_assessment python=3.6 28 | $ conda install pytorch==1.6.0 torchvision==0.7.0 cudatoolkit=10.1 -c pytorch 29 | $ python -m pip install --no-index torch-scatter -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html 30 | $ python -m pip install --no-index torch-sparse -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html 31 | $ python -m pip install --no-index torch-cluster -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html 32 | $ python -m pip install --no-index torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html 33 | $ python -m pip install torch-geometric==1.6.1 34 | $ python -m pip install -r requirements.txt 35 | ``` 36 | This set of commands assumes you to have cuda10.1 in your local. Please refer to the installation guides of [torch](https://pytorch.org/) and [pytorch_geometric](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html) if you have different environment settings. 37 | 38 | # Usages 39 | For running the sg-risk-assessment in this repo, you may refer to the following commands: 40 | ```shell 41 | $ python sg_risk_assessment.py --pkl_path risk-assessment/scenegraph/synthetic/271_dataset.pkl 42 | 43 | # --pkl_path + [wherever path that stores the downloaded pkl] 44 | # For tuning hyperparameters view the config class of sg_risk_assessment.py 45 | ``` 46 | 47 | For running the baseline-risk-assessment in this repo, you may refer to the following commands: 48 | ```shell 49 | $ python baseline_risk_assessment.py --load_pkl True --pkl_path risk-assessment/scene/synthetic/271_dataset.pkl 50 | 51 | # --pkl_path + [wherever path that stores the downloaded pkl] 52 | # For tuning hyperparameters view the config class of baseline_risk_assessment.py 53 | ``` 54 | 55 | After running these commands, the expected outputs are a dump of metrics logged by wandb: 56 | ```shell 57 | wandb: train_recall ▁████████████████████ 58 | wandb: val_precision █▁▅▄▅▄▆▆▆▅▄▄▇▆▅▆▅▇▆▆▆ 59 | wandb: val_recall ▁████████████████████ 60 | wandb: train_fpr ▁█▅▅▄▅▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂ 61 | wandb: train_tnr █▁▄▅▅▅▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 62 | wandb: train_fnr █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ 63 | wandb: val_fpr ▁█▄▅▄▅▃▃▃▄▄▅▂▃▃▃▄▂▃▃▃ 64 | wandb: val_tnr █▁▆▄▆▄▆▆▆▆▅▄▇▆▆▆▆▇▆▆▆ 65 | wandb: val_fnr █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ 66 | wandb: best_epoch ▁▁▂▂▂▂▃▃▄▄▄▄▅▅▅▅▅▇▇▇█ 67 | wandb: best_val_loss █▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ 68 | wandb: best_val_acc ▁▆█▇█████████████████ 69 | wandb: best_val_auc ▁▅▆▆▇▇▇▇████▇▇▇▇▇████ 70 | wandb: best_val_mcc ▁▇███████████████████ 71 | wandb: best_val_acc_balanced ▁████████████████████ 72 | wandb: train_mcc ▁▇▇▇▇▇███████████████ 73 | wandb: val_mcc ▁▇███████████████████ 74 | ``` 75 | 76 | A graphical visualization of the model outputs including loss and additional metrics can be viewed by creating and linking your runs to [wandb](https://wandb.ai/home). 77 | 78 | # Citation 79 | Please kindly consider citing our paper if you find our work useful for your research 80 | ``` 81 | @article{yu2020scene, 82 | title={Scene-graph augmented data-driven risk assessment of autonomous vehicle decisions}, 83 | author={Yu, Shih-Yuan and Malawade, Arnav V and Muthirayan, Deepan and Khargonekar, Pramod P and Faruque, Mohammad A Al}, 84 | journal={arXiv preprint arXiv:2009.06435}, 85 | year={2020} 86 | } 87 | ``` 88 | -------------------------------------------------------------------------------- /baseline_risk_assessment.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | from argparse import ArgumentParser 4 | from pathlib import Path 5 | import pandas as pd 6 | from pprint import pprint 7 | import wandb 8 | 9 | from baseline_risk_assessment.dataset import DataSet 10 | from baseline_risk_assessment.models import LSTM_Classifier, CNN_LSTM_Classifier, CNN_Classifier, ResNet50_LSTM_Classifier 11 | from baseline_risk_assessment.train import Trainer 12 | 13 | PROJECT_NAME = "Fill me with wandb id" 14 | 15 | class Config: 16 | 17 | def __init__(self, args): 18 | self.parser = ArgumentParser(description='The parameters for configuring and training the baseline Nagoya model(s)') 19 | self.parser.add_argument('--input_path', type=str, default="./risk-assessment", help="Path to data directory.") 20 | self.parser.add_argument('--pkl_path', type=str, default="./risk-assessment/scene/synthetic/271_dataset.pkl", help="Path to pickled dataset.") 21 | self.parser.add_argument('--load_pkl', type=lambda x: (str(x).lower() == 'true'), default=False, help='Set True to load pkl dataset.') 22 | self.parser.add_argument('--save_pkl_path', type=str, default="./save_dataset.pkl", help="Path to save pickled dataset.") 23 | self.parser.add_argument('--save_pkl', type=lambda x: (str(x).lower() == 'true'), default=False, help='Set True to save pkl dataset.') 24 | 25 | # Training 26 | self.parser.add_argument('--model_name', type=str, default="cnn_lstm", help="Type of model to run, choices include [gru, lstm, cnn, cnn_lstm, resnet]") 27 | self.parser.add_argument('--n_folds', type=int, default=1, help="Number of folds for cross validations") 28 | self.parser.add_argument('--train_ratio', type=float, default=0.7, help="Ratio of dataset used for testing") 29 | self.parser.add_argument('--downsample', type=lambda x: (str(x).lower() == 'true'), default=False, help='Downsample (balance) dataset.') 30 | self.parser.add_argument('--seed', type=int, default=0, help="Seed for splitting the dataset.") 31 | self.parser.add_argument('--test_step', type=int, default=10, help='Number of training epochs before testing the model.') 32 | self.parser.add_argument('--device', type=str, default="cuda", help='The device on which models are run, options: [cuda, cpu].') 33 | 34 | # Hyperparameters 35 | self.parser.add_argument('--epochs', type=int, default=200, help="Number of epochs to train") 36 | self.parser.add_argument('--batch_size', type=int, default=64, help="Batch size per forward") 37 | self.parser.add_argument('--bnorm', type=lambda x: (str(x).lower() == 'true'), default=False, help="Utilize batch normalization.") 38 | self.parser.add_argument('--dropout', type=float, default=0.1, help='Dropout rate (1 - keep probability).') 39 | self.parser.add_argument('--learning_rate', default=3e-5, type=float, help='The initial learning rate.') 40 | self.parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay (L2 loss on parameters).') 41 | 42 | args_parsed = self.parser.parse_args(args) 43 | 44 | self.wandb = wandb.init(project=PROJECT_NAME) 45 | self.wandb_config = self.wandb.config 46 | 47 | for arg_name in vars(args_parsed): 48 | self.__dict__[arg_name] = getattr(args_parsed, arg_name) 49 | self.wandb_config[arg_name] = getattr(args_parsed, arg_name) 50 | 51 | self.input_base_dir = Path(self.input_path).resolve() 52 | 53 | def load_dataset(raw_image_path, config=None): 54 | ''' 55 | This step is for loading the dataset, preprocessing the video clips 56 | and neccessary scaling and normalizing. Also it reads and converts the labeling info. 57 | ''' 58 | image_path = raw_image_path 59 | 60 | dataset = DataSet() 61 | dataset.read_video(image_path, option='fixed frame amount', number_of_frames=5, scaling='scale', scale_x=0.05, scale_y=0.05) 62 | dataset.risk_scores = dataset.read_risk_data(raw_image_path) 63 | dataset.convert_risk_to_one_hot() 64 | 65 | if config != None and config.save_pkl: 66 | parent_path = '/'.join(config.save_pkl_path.split('/')[:-1]) + '/' 67 | fname = config.save_pkl_path.split('/')[-1] 68 | dataset.save(save_dir=parent_path, filename=fname) 69 | print("Saved pickled dataset") 70 | return dataset 71 | 72 | def load_pickle(pkl_path): 73 | ''' 74 | Read dataset from pickle file. 75 | ''' 76 | dataset = DataSet().loader(str(pkl_path)) 77 | return dataset 78 | 79 | def reshape_dataset(dataset): 80 | ''' 81 | input -> (batch, frames, height, width, channels) 82 | output -> (batch, frames, channels, height, width) 83 | ''' 84 | return np.swapaxes(np.swapaxes(dataset, -1, -3), -1, -2) 85 | 86 | def train_model(dataset, config): 87 | dataset.video = reshape_dataset(dataset.video) 88 | video_sequences = dataset.video 89 | labels = dataset.risk_one_hot 90 | clip_names = np.array(['default_all']*len(video_sequences)) 91 | if hasattr(dataset, 'foldernames'): 92 | clip_names = np.concatenate((clip_names, dataset.foldernames), axis=0) 93 | 94 | if config.model_name == 'gru': 95 | model = LSTM_Classifier(video_sequences.shape, 'gru', config) 96 | elif config.model_name == 'lstm': 97 | model = LSTM_Classifier(video_sequences.shape, 'lstm', config) 98 | elif config.model_name == 'cnn': 99 | model = CNN_Classifier(video_sequences.shape, config) 100 | elif config.model_name == 'cnn_lstm': 101 | model = CNN_LSTM_Classifier(video_sequences.shape, config) 102 | elif config.model_name == 'resnet': 103 | model = ResNet50_LSTM_Classifier(video_sequences.shape, config) 104 | 105 | trainer = Trainer(config) 106 | trainer.init_dataset(video_sequences, labels, clip_names) 107 | trainer.build_model(model) 108 | if config.n_folds > 1: 109 | trainer.train_n_fold_cross_val() 110 | else: 111 | trainer.train_model() 112 | 113 | if __name__ == '__main__': 114 | config = Config(sys.argv[1:]) 115 | raw_image_path = config.input_base_dir 116 | 117 | if config.load_pkl: 118 | dataset = load_pickle(Path(config.pkl_path).resolve()) 119 | else: 120 | dataset = load_dataset(raw_image_path, config=config); 121 | 122 | # train model 123 | model = train_model(dataset, config) -------------------------------------------------------------------------------- /sg_risk_assessment/metrics.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_score, recall_score, roc_auc_score, roc_curve, balanced_accuracy_score, matthews_corrcoef 2 | import torch 3 | from sklearn import preprocessing 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import pandas as pd 7 | import wandb 8 | 9 | #this file contains functions for scoring the prediction models. 10 | 11 | ''' 12 | #Expected Inputs: 13 | outputs: (n, 2) FloatTensor 14 | labels: (n,) LongTensor 15 | ''' 16 | def get_metrics(outputs, labels): 17 | labels_tensor = labels.cpu() 18 | outputs_tensor = outputs.cpu() 19 | preds = outputs_tensor.max(1)[1].type_as(labels_tensor).cpu() #binarized version of outputs_tensor. 20 | 21 | metrics = {} 22 | metrics['acc'] = accuracy_score(labels_tensor, preds) 23 | metrics['f1'] = f1_score(labels_tensor, preds, average="binary") 24 | conf = confusion_matrix(labels_tensor, preds) 25 | metrics['fpr'] = conf[0][1] / (conf[0][1] + conf[0][0]) #FPR = FP/(FP+TN) 26 | metrics['tnr'] = conf[0][0] / (conf[0][1] + conf[0][0]) #TNR = TN/(FP+TN) 27 | metrics['fnr'] = conf[1][0] / (conf[1][0] + conf[1][1]) #FNR = FN/(FN+TP) 28 | metrics['confusion'] = str(conf).replace('\n', ',') 29 | metrics['precision'] = precision_score(labels_tensor, preds, average="binary") 30 | metrics['recall'] = recall_score(labels_tensor, preds, average="binary") #recall and TPR are the same. TPR = TP/(TP+FN) 31 | metrics['auc'] = get_auc(outputs_tensor, labels_tensor) 32 | metrics['label_distribution'] = str(np.unique(labels_tensor, return_counts=True)[1]) 33 | metrics['balanced_acc'] = balanced_accuracy_score(labels_tensor, preds) 34 | metrics['mcc'] = matthews_corrcoef(labels_tensor, preds) 35 | 36 | return metrics 37 | 38 | #returns onehot version of labels. can specify n_classes to force onehot size. 39 | def encode_onehot(labels, n_classes=None): 40 | if(n_classes): 41 | classes = set(range(n_classes)) 42 | else: 43 | classes = set(labels) 44 | classes_dict = {c: np.identity(len(classes))[i, :] for i, c in 45 | enumerate(classes)} 46 | labels_onehot = np.array(list(map(classes_dict.get, labels)), 47 | dtype=np.int32) 48 | return labels_onehot 49 | 50 | #log data to to Weights & Biases 51 | def log_wandb(wb, metrics): 52 | wb.log({ 53 | "train_acc": metrics['train']['acc'], 54 | "val_acc": metrics['test']['acc'], 55 | "train_acc_balanced": metrics['train']['balanced_acc'], 56 | "val_acc_balanced": metrics['test']['balanced_acc'], 57 | "train_loss": metrics['train']['loss'], 58 | "val_loss": metrics['test']['loss'], 59 | 'train_auc': metrics['train']['auc'], 60 | 'train_f1': metrics['train']['f1'], 61 | 'val_auc': metrics['test']['auc'], 62 | 'val_f1': metrics['test']['f1'], 63 | 'train_precision': metrics['train']['precision'], 64 | 'train_recall': metrics['train']['recall'], 65 | 'val_precision': metrics['test']['precision'], 66 | 'val_recall': metrics['test']['recall'], 67 | 'train_conf': metrics['train']['confusion'], 68 | 'val_conf': metrics['test']['confusion'], 69 | 'train_fpr': metrics['train']['fpr'], 70 | 'train_tnr': metrics['train']['tnr'], 71 | 'train_fnr': metrics['train']['fnr'], 72 | 'val_fpr': metrics['test']['fpr'], 73 | 'val_tnr': metrics['test']['tnr'], 74 | 'val_fnr': metrics['test']['fnr'], 75 | 'train_avg_seq_len': metrics['train']['avg_seq_len'], 76 | 'val_avg_seq_len': metrics['test']['avg_seq_len'], 77 | 'best_epoch': metrics['best_epoch'], 78 | 'best_val_loss': metrics['best_val_loss'], 79 | 'best_val_acc': metrics['best_val_acc'], 80 | 'best_val_auc': metrics['best_val_auc'], 81 | 'best_val_conf': metrics['best_val_conf'], 82 | 'best_val_mcc': metrics['best_val_mcc'], 83 | 'best_val_acc_balanced': metrics['best_val_acc_balanced'], 84 | 'train_mcc': metrics['train']['mcc'], 85 | 'val_mcc': metrics['test']['mcc'], 86 | 'avg_inf_time': metrics['avg_inf_time'], 87 | }) 88 | 89 | def log_wandb_categories(wb, metrics, id): 90 | wb.log({ 91 | "train_acc"+"_"+id: metrics['train'][id]['acc'], 92 | "val_acc"+"_"+id: metrics['test'][id]['acc'], 93 | "train_acc_balanced"+"_"+id: metrics['train'][id]['balanced_acc'], 94 | "val_acc_balanced"+"_"+id: metrics['test'][id]['balanced_acc'], 95 | 'train_auc'+"_"+id: metrics['train'][id]['auc'], 96 | 'train_f1'+"_"+id: metrics['train'][id]['f1'], 97 | 'val_auc'+"_"+id: metrics['test'][id]['auc'], 98 | 'val_f1'+"_"+id: metrics['test'][id]['f1'], 99 | 'train_precision'+"_"+id: metrics['train'][id]['precision'], 100 | 'train_recall'+"_"+id: metrics['train'][id]['recall'], 101 | 'val_precision'+"_"+id: metrics['test'][id]['precision'], 102 | 'val_recall'+"_"+id: metrics['test'][id]['recall'], 103 | 'train_conf'+"_"+id: metrics['train'][id]['confusion'], 104 | 'val_conf'+"_"+id: metrics['test'][id]['confusion'], 105 | 'train_fpr'+"_"+id: metrics['train'][id]['fpr'], 106 | 'train_tnr'+"_"+id: metrics['train'][id]['tnr'], 107 | 'train_fnr'+"_"+id: metrics['train'][id]['fnr'], 108 | 'val_fpr'+"_"+id: metrics['test'][id]['fpr'], 109 | 'val_tnr'+"_"+id: metrics['test'][id]['tnr'], 110 | 'val_fnr'+"_"+id: metrics['test'][id]['fnr'], 111 | 'train_mcc'+"_"+id: metrics['train'][id]['mcc'], 112 | 'val_mcc'+"_"+id: metrics['test'][id]['mcc'], 113 | }) 114 | 115 | #~~~~~~~~~~Scoring Metrics~~~~~~~~~~ 116 | #note: these scoring metrics only work properly for binary classification use cases (graph classification, dyngraph classification) 117 | def get_auc(outputs, labels): 118 | try: 119 | labels = encode_onehot(labels.numpy().tolist(), 2) #binary labels 120 | auc = roc_auc_score(labels, outputs.numpy(), average="micro") 121 | except ValueError as err: 122 | print("error calculating AUC: ", err) 123 | auc = 0.0 124 | return auc -------------------------------------------------------------------------------- /sg_risk_assessment.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import pandas as pd 3 | from argparse import ArgumentParser 4 | from pathlib import Path 5 | import wandb 6 | 7 | from sg_risk_assessment.dynkg_trainer import * 8 | 9 | PROJECT_NAME = "Fill me with wandb id" 10 | 11 | class Config: 12 | '''Argument Parser for script to train scenegraphs.''' 13 | def __init__(self, args): 14 | self.parser = ArgumentParser(description='The parameters for training the scene graph using GCN.') 15 | self.parser.add_argument('--pkl_path', type=str, default="risk-assessment/scenegraph/synthetic/271_dataset.pkl", help="Path to the cache file.") 16 | self.parser.add_argument('--transfer_path', type=str, default="", help="Path to the transfer file.") 17 | self.parser.add_argument('--model_load_path', type=str, default="./model/model_best_val_loss_.vec.pt", help="Path to load cached model file.") 18 | self.parser.add_argument('--model_save_path', type=str, default="./model/model_best_val_loss_.vec.pt", help="Path to save model file.") 19 | 20 | # Model 21 | self.parser.add_argument('--model', type=str, default="mrgcn", help="Model to be used intrinsically. options: [mrgcn, mrgin]") 22 | self.parser.add_argument('--conv_type', type=str, default="FastRGCNConv", help="type of RGCNConv to use [RGCNConv, FastRGCNConv].") 23 | self.parser.add_argument('--num_layers', type=int, default=3, help="Number of layers in the network.") 24 | self.parser.add_argument('--hidden_dim', type=int, default=32, help="Hidden dimension in RGCN.") 25 | self.parser.add_argument('--layer_spec', type=str, default=None, help="manually specify the size of each layer in format l1,l2,l3 (no spaces).") 26 | self.parser.add_argument('--pooling_type', type=str, default="sagpool", help="Graph pooling type, options: [sagpool, topk, None].") 27 | self.parser.add_argument('--pooling_ratio', type=float, default=0.5, help="Graph pooling ratio.") 28 | self.parser.add_argument('--readout_type', type=str, default="mean", help="Readout type, options: [max, mean, add].") 29 | self.parser.add_argument('--temporal_type', type=str, default="lstm_attn", help="Temporal type, options: [mean, lstm_last, lstm_sum, lstm_attn].") 30 | self.parser.add_argument('--lstm_input_dim', type=int, default=50, help="LSTM input dimensions.") 31 | self.parser.add_argument('--lstm_output_dim', type=int, default=20, help="LSTM output dimensions.") 32 | 33 | # Training 34 | self.parser.add_argument('--device', type=str, default="cuda", help='The device on which models are run, options: [cuda, cpu].') 35 | self.parser.add_argument('--downsample', type=lambda x: (str(x).lower() == 'true'), default=False, help='Set to true to downsample dataset.') 36 | self.parser.add_argument('--nclass', type=int, default=2, help="The number of classes for dynamic graph classification (currently only supports 2).") 37 | self.parser.add_argument('--seed', type=int, default=0, help='Random seed.') 38 | self.parser.add_argument('--split_ratio', type=float, default=0.3, help="Ratio of dataset withheld for testing.") 39 | self.parser.add_argument('--stats_path', type=str, default="best_stats.csv", help="path to save best test statistics.") 40 | self.parser.add_argument('--test_step', type=int, default=10, help='Number of training epochs before testing the model.') 41 | 42 | # Hyperparameters 43 | self.parser.add_argument('--activation', type=str, default='relu', help='Activation function to use, options: [relu, leaky_relu].') 44 | self.parser.add_argument('--batch_size', type=int, default=32, help='Number of graphs in a batch.') 45 | self.parser.add_argument('--dropout', type=float, default=0.25, help='Dropout rate (1 - keep probability).') 46 | self.parser.add_argument('--epochs', type=int, default=200, help='Number of epochs to train.') 47 | self.parser.add_argument('--learning_rate', default=0.0001, type=float, help='The initial learning rate for GCN.') 48 | self.parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay (L2 loss on parameters).') 49 | 50 | 51 | self.args = args 52 | args_parsed = self.parser.parse_args(args) 53 | self.wandb = wandb.init(project=PROJECT_NAME) 54 | self.wandb_config = self.wandb.config 55 | 56 | for arg_name in vars(args_parsed): 57 | self.__dict__[arg_name] = getattr(args_parsed, arg_name) 58 | self.wandb_config[arg_name] = getattr(args_parsed, arg_name) 59 | 60 | self.pkl_path = Path(self.pkl_path).resolve() 61 | if self.transfer_path != "": 62 | self.transfer_path = Path(self.transfer_path).resolve() 63 | else: 64 | self.transfer_path = None 65 | self.stats_path = Path(self.stats_path.strip()).resolve() 66 | 67 | def train_dynamic_kg(config, iterations=1): 68 | ''' Training the dynamic kg algorithm with different attention layer choice.''' 69 | 70 | outputs = [] 71 | labels = [] 72 | metrics = [] 73 | 74 | for i in range(iterations): 75 | trainer = DynKGTrainer(config) 76 | trainer.init_dataset() 77 | trainer.build_model() 78 | trainer.train() 79 | categories_train, categories_test, metric, folder_names_train = trainer.evaluate() 80 | 81 | outputs += categories_train['all']['outputs'] 82 | labels += categories_train['all']['labels'] 83 | metrics.append(metric) 84 | 85 | # Store the prediction results. 86 | store_path = trainer.config.pkl_path.parent 87 | outputs_pd = pd.DataFrame(outputs) 88 | labels_pd = pd.DataFrame(labels) 89 | 90 | labels_pd.to_csv(store_path / "dynkg_training_labels.tsv", sep='\t', header=False, index=False) 91 | outputs_pd.to_csv(store_path / "dynkg_training_outputs.tsv", sep="\t", header=False, index=False) 92 | 93 | # Store the metric results. 94 | metrics_pd = pd.DataFrame(metrics[-1]['test']) 95 | metrics_pd.to_csv(store_path / "dynkg_classification_metrics.csv", header=True) 96 | 97 | 98 | if __name__ == "__main__": 99 | """ the entry of dynkg pipeline training """ 100 | config = Config(sys.argv[1:]) 101 | train_dynamic_kg(config) -------------------------------------------------------------------------------- /baseline_risk_assessment/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import pandas as pd 5 | import wandb 6 | 7 | from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_score, recall_score, roc_auc_score, roc_curve, balanced_accuracy_score, matthews_corrcoef 8 | from sklearn import preprocessing 9 | 10 | #this file contains functions for scoring the prediction models. 11 | 12 | ''' 13 | #Expected Inputs: 14 | outputs: (n, 2) FloatTensor 15 | labels: (n,) LongTensor 16 | ''' 17 | def get_metrics(outputs, labels): 18 | labels_tensor = labels.cpu() 19 | outputs_tensor = outputs.cpu() 20 | preds = outputs_tensor.max(1)[1].type_as(labels_tensor).cpu() #binarized version of outputs_tensor. 21 | 22 | metrics = {} 23 | metrics['acc'] = accuracy_score(labels_tensor, preds) 24 | metrics['f1'] = f1_score(labels_tensor, preds, average="binary") 25 | conf = confusion_matrix(labels_tensor, preds) 26 | metrics['fpr'] = conf[0][1] / (conf[0][1] + conf[0][0]) #FPR = FP/(FP+TN) 27 | metrics['tnr'] = conf[0][0] / (conf[0][1] + conf[0][0]) #TNR = TN/(FP+TN) 28 | metrics['fnr'] = conf[1][0] / (conf[1][0] + conf[1][1]) #FNR = FN/(FN+TP) 29 | metrics['confusion'] = str(conf).replace('\n', ',') 30 | metrics['precision'] = precision_score(labels_tensor, preds, average="binary") 31 | metrics['recall'] = recall_score(labels_tensor, preds, average="binary") #recall and TPR are the same. TPR = TP/(TP+FN) 32 | metrics['auc'] = get_auc(outputs_tensor, labels_tensor) 33 | metrics['label_distribution'] = str(np.unique(labels_tensor, return_counts=True)[1]) 34 | metrics['balanced_acc'] = balanced_accuracy_score(labels_tensor, preds) 35 | metrics['mcc'] = matthews_corrcoef(labels_tensor, preds) 36 | 37 | return metrics 38 | 39 | #returns onehot version of labels. can specify n_classes to force onehot size. 40 | def encode_onehot(labels, n_classes=None): 41 | if(n_classes): 42 | classes = set(range(n_classes)) 43 | else: 44 | classes = set(labels) 45 | classes_dict = {c: np.identity(len(classes))[i, :] for i, c in 46 | enumerate(classes)} 47 | labels_onehot = np.array(list(map(classes_dict.get, labels)), 48 | dtype=np.int32) 49 | return labels_onehot 50 | 51 | #log data to to Weights & Biases 52 | def log_wandb(wb, metrics): 53 | wb.log({ 54 | "train_acc": metrics['train']['acc'], 55 | "val_acc": metrics['test']['acc'], 56 | "train_acc_balanced": metrics['train']['balanced_acc'], 57 | "val_acc_balanced": metrics['test']['balanced_acc'], 58 | "train_loss": metrics['train']['loss'], 59 | "val_loss": metrics['test']['loss'], 60 | 'train_auc': metrics['train']['auc'], 61 | 'train_f1': metrics['train']['f1'], 62 | 'val_auc': metrics['test']['auc'], 63 | 'val_f1': metrics['test']['f1'], 64 | 'train_precision': metrics['train']['precision'], 65 | 'train_recall': metrics['train']['recall'], 66 | 'val_precision': metrics['test']['precision'], 67 | 'val_recall': metrics['test']['recall'], 68 | 'train_conf': metrics['train']['confusion'], 69 | 'val_conf': metrics['test']['confusion'], 70 | 'train_fpr': metrics['train']['fpr'], 71 | 'train_tnr': metrics['train']['tnr'], 72 | 'train_fnr': metrics['train']['fnr'], 73 | 'val_fpr': metrics['test']['fpr'], 74 | 'val_tnr': metrics['test']['tnr'], 75 | 'val_fnr': metrics['test']['fnr'], 76 | 'train_avg_seq_len': metrics['train']['avg_seq_len'], 77 | 'val_avg_seq_len': metrics['test']['avg_seq_len'], 78 | 'best_epoch': metrics['best_epoch'], 79 | 'best_val_loss': metrics['best_val_loss'], 80 | 'best_val_acc': metrics['best_val_acc'], 81 | 'best_val_auc': metrics['best_val_auc'], 82 | 'best_val_conf': metrics['best_val_conf'], 83 | 'best_val_mcc': metrics['best_val_mcc'], 84 | 'best_val_acc_balanced': metrics['best_val_acc_balanced'], 85 | 'train_mcc': metrics['train']['mcc'], 86 | 'val_mcc': metrics['test']['mcc'], 87 | 'avg_inf_time': metrics['avg_inf_time'], 88 | }) 89 | 90 | def log_wandb_categories(wb, metrics, id): 91 | wb.log({ 92 | "train_acc"+"_"+id: metrics['train'][id]['acc'], 93 | "val_acc"+"_"+id: metrics['test'][id]['acc'], 94 | "train_acc_balanced"+"_"+id: metrics['train'][id]['balanced_acc'], 95 | "val_acc_balanced"+"_"+id: metrics['test'][id]['balanced_acc'], 96 | 'train_auc'+"_"+id: metrics['train'][id]['auc'], 97 | 'train_f1'+"_"+id: metrics['train'][id]['f1'], 98 | 'val_auc'+"_"+id: metrics['test'][id]['auc'], 99 | 'val_f1'+"_"+id: metrics['test'][id]['f1'], 100 | 'train_precision'+"_"+id: metrics['train'][id]['precision'], 101 | 'train_recall'+"_"+id: metrics['train'][id]['recall'], 102 | 'val_precision'+"_"+id: metrics['test'][id]['precision'], 103 | 'val_recall'+"_"+id: metrics['test'][id]['recall'], 104 | 'train_conf'+"_"+id: metrics['train'][id]['confusion'], 105 | 'val_conf'+"_"+id: metrics['test'][id]['confusion'], 106 | 'train_fpr'+"_"+id: metrics['train'][id]['fpr'], 107 | 'train_tnr'+"_"+id: metrics['train'][id]['tnr'], 108 | 'train_fnr'+"_"+id: metrics['train'][id]['fnr'], 109 | 'val_fpr'+"_"+id: metrics['test'][id]['fpr'], 110 | 'val_tnr'+"_"+id: metrics['test'][id]['tnr'], 111 | 'val_fnr'+"_"+id: metrics['test'][id]['fnr'], 112 | 'train_mcc'+"_"+id: metrics['train'][id]['mcc'], 113 | 'val_mcc'+"_"+id: metrics['test'][id]['mcc'], 114 | }) 115 | 116 | #~~~~~~~~~~Scoring Metrics~~~~~~~~~~ 117 | #note: these scoring metrics only work properly for binary classification use cases (graph classification, dyngraph classification) 118 | def get_auc(outputs, labels): 119 | try: 120 | labels = encode_onehot(labels.numpy().tolist(), 2) #binary labels 121 | auc = roc_auc_score(labels, outputs.numpy(), average="micro") 122 | except ValueError as err: 123 | print("error calculating AUC: ", err) 124 | auc = 0.0 125 | return auc 126 | 127 | #NOTE: ROC curve is only generated for positive class (risky label) confidence values 128 | #render parameter determines if the figure is actually generated. If false, it saves the values to a csv file. 129 | def get_roc_curve(outputs, labels, render=False): 130 | risk_scores = [] 131 | outputs = preprocessing.normalize(outputs.numpy(), axis=0) 132 | for i in outputs: 133 | risk_scores.append(i[1]) 134 | fpr, tpr, thresholds = roc_curve(labels.numpy(), risk_scores) 135 | roc = pd.DataFrame() 136 | roc['fpr'] = fpr 137 | roc['tpr'] = tpr 138 | roc['thresholds'] = thresholds 139 | roc.to_csv("ROC_data.csv") 140 | 141 | if(render): 142 | plt.figure(figsize=(8,8)) 143 | plt.xlim((0,1)) 144 | plt.ylim((0,1)) 145 | plt.ylabel("TPR") 146 | plt.xlabel("FPR") 147 | plt.title("Receiver Operating Characteristic") 148 | plt.plot([0,1],[0,1], linestyle='dashed') 149 | plt.plot(fpr,tpr, linewidth=2) 150 | plt.savefig("ROC_curve.svg") -------------------------------------------------------------------------------- /baseline_risk_assessment/dataset.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import cv2 3 | import os 4 | import numpy as np 5 | from tqdm import tqdm 6 | import pandas as pd 7 | from pathlib import Path 8 | from collections import Counter 9 | 10 | class DataSet: 11 | 12 | def __init__(self): 13 | self.dataset = {} 14 | self.video = [] 15 | self.foldernames = [] 16 | self.image_seq = [] 17 | self.risk_scores = [] 18 | self.risk_one_hot = [] 19 | self.risk_binary = [] 20 | 21 | @classmethod 22 | def loader(cls, file_path): 23 | with open(file_path, 'rb') as f: 24 | return pickle.load(f) 25 | 26 | def preprocess_clips(self, data_dir): 27 | foldernames = [f for f in os.listdir(data_dir) if os.path.isdir(data_dir / f)] 28 | foldernames = [f for f in foldernames if f.split('_')[0].isnumeric()] 29 | foldernames = sorted(foldernames, key=self.get_clipnumber) 30 | # only grab dataset clips (ignore filtered clips) 31 | self.foldernames = self.get_filtered_clips(data_dir, foldernames) 32 | # for visualizing amount of frames in each clip of the dataset 33 | self.get_max_frames(data_dir) 34 | 35 | def read_video(self, data_dir, option='fixed frame amount', number_of_frames=20, max_number_of_frames=500, 36 | scaling='no scaling', scale_x=0.1, scale_y=0.1): 37 | 38 | self.preprocess_clips(data_dir) 39 | is_valid = self.valid_dataset(data_dir/self.foldernames[0], scaling, scale_x, scale_y) 40 | if is_valid: 41 | # shape: (n_videos, n_frames, im_height, im_width, channel) 42 | im_height, im_width, channel = self.image_seq[0].shape 43 | if option == 'fixed frame amount': 44 | self.video = np.zeros([len(self.foldernames), number_of_frames, im_height, im_width, channel]) 45 | elif option == 'all frames': 46 | self.video = np.zeros([len(self.foldernames), max_number_of_frames, im_height, im_width, channel]) 47 | 48 | # todo convert this to a wrapper 49 | for idx, foldername in tqdm(enumerate(self.foldernames)): 50 | if foldername.isnumeric: 51 | is_valid = self.valid_dataset(str(data_dir/foldername), scaling=scaling, scale_x=scale_x, scale_y=scale_y) 52 | print(foldername) 53 | if is_valid: 54 | if option == 'fixed frame amount': 55 | self.video[idx, :, :, :, :] = self._read_video_helper(number_of_frames=number_of_frames) 56 | elif option == 'all frames': 57 | self.video[idx, 0:len(self.image_seq), :, :, :] = self.image_seq 58 | else: 59 | raise Exception('Error reading first clip! Check path or contents of {}'.format(data_dir)) 60 | 61 | def _read_video_helper(self, number_of_frames=20): 62 | images = [] 63 | index = 0 64 | # length of image sequence must be greater than or equal to number_of_frames 65 | # if number of frames is less than entire length of image sequence, takes every nth frame (n being modulo) 66 | modulo = int(len(self.image_seq) / number_of_frames) 67 | if modulo == 0: 68 | modulo = 1 69 | for counter, img in enumerate(self.image_seq): 70 | if counter % modulo == 0 and index < number_of_frames: 71 | images.append(img) 72 | index += 1 73 | 74 | return images 75 | 76 | def valid_dataset(self, image_path, scaling, scale_x, scale_y): 77 | self.read_image_data(str(image_path), scaling=scaling, scale_x=scale_x, scale_y=scale_y) 78 | if len(self.image_seq) == 0: 79 | print("No image in %s" % (image_path)) 80 | return False 81 | return True 82 | 83 | def read_image_data(self, data_dir, scaling='no scaling', scale_x=0.1, scale_y=0.1): 84 | data_dir += '/raw_images/' 85 | if scaling == 'scale': 86 | self.image_seq = self.load_images_from_folder(data_dir, scaling='scale', scale_x=scale_x, scale_y=scale_y) 87 | else: 88 | self.image_seq = self.load_images_from_folder(data_dir) 89 | 90 | def load_images_from_folder(self, folder, scaling='no scale', scale_x=0.1, scale_y=0.1): 91 | images = [] 92 | filenames = sorted(os.listdir(folder)) 93 | 94 | for filename in filenames: 95 | if self.valid_image(filename): 96 | img = cv2.imread(os.path.join(folder, filename)).astype(np.float32) 97 | img /= 255.0 98 | if img is not None: 99 | if scaling == 'scale': 100 | img = cv2.resize(img, (0, 0), fx=scale_x, fy=scale_y) 101 | images.append(img) 102 | return images 103 | 104 | @staticmethod 105 | def rescale_images(source_dir, save_dir, scaling='scale', scale_x=0.1, scale_y=0.1): 106 | 107 | foldernames = [f for f in os.listdir(source_dir) if f.isnumeric() and not f.startswith('.')] 108 | 109 | for foldername in tqdm(foldernames): 110 | 111 | if foldername.isnumeric: 112 | newpath = save_dir + "/" + foldername 113 | if not os.path.exists(newpath): 114 | os.makedirs(newpath) 115 | 116 | for filename in os.listdir(source_dir + "/" + foldername): 117 | img = cv2.imread(os.path.join(source_dir + "/" + foldername, filename)) 118 | if img is not None: 119 | if scaling == 'scale': 120 | img = cv2.resize(img, (0, 0), fx=scale_x, fy=scale_y) 121 | cv2.imwrite(os.path.join(newpath, filename), img) 122 | 123 | def read_risk_data(self, parent_dir): 124 | risk_scores = [] 125 | for clip in self.foldernames: 126 | path = parent_dir / clip 127 | label_path = path / "label.txt" 128 | if label_path.exists(): 129 | with open(str(path/"label.txt"), 'r') as label_f: 130 | risk_label = int(float(label_f.read().strip().split(",")[0])) 131 | risk_scores.append(risk_label) 132 | else: 133 | raise FileNotFoundError("No label.txt in %s" % path) 134 | return risk_scores 135 | 136 | def convert_risk_to_one_hot(self): 137 | # sorting risk thresholds from least risky to most risky 138 | indexes = [i[0] for i in sorted(enumerate(self.risk_scores), key=lambda x: x[1])] 139 | self.risk_one_hot = np.zeros([len(indexes), 2]) 140 | 141 | for counter, index in enumerate(indexes[::-1]): 142 | if self.risk_scores[index] >= 0: 143 | self.risk_one_hot[index, :] = [0, 1] 144 | else: 145 | self.risk_one_hot[index, :] = [1, 0] 146 | 147 | # Utilities 148 | def get_clipnumber(self, elem): 149 | return int(elem.split('_')[0]) 150 | 151 | def get_filtered_clips(self, clip_dir, foldernames): 152 | filtered_folders = [] 153 | for foldername in foldernames: 154 | clip_path = clip_dir / foldername 155 | if self.ignore_clip(clip_path): continue; 156 | filtered_folders.append(foldername) 157 | 158 | return filtered_folders 159 | 160 | def get_max_frames(self, data_dir): 161 | ''' 162 | Return the longest clip (by amount of frames) 163 | As well as the distribution of frames over all clips 164 | These clips are identifiable with dict clip_lookup 165 | ''' 166 | clip_lookup = {} 167 | num_frames = [] 168 | for foldername in self.foldernames: 169 | foldername = data_dir/foldername/'raw_images' 170 | imgs = [img for img in os.listdir(foldername) if self.valid_image(img)] 171 | num_frames.append(len(imgs)) 172 | clip_lookup[foldername] = len(imgs) 173 | frame_dist = Counter(num_frames).most_common() 174 | # threshold 175 | # for key, val in clip_lookup.items(): 176 | # if val > 150: 177 | # print(key) 178 | return max(num_frames), frame_dist, clip_lookup 179 | 180 | def ignore_clip(self, clip_dir): 181 | ''' 182 | return 1 when ignore.txt is 1 (ignore clip 183 | return 0 when ignore.txt is 0 (do not ignore clip) 184 | ''' 185 | ignore_path = clip_dir / 'ignore.txt' 186 | if ignore_path.exists(): 187 | with open(str(ignore_path), 'r') as label_f: 188 | ignore_label = int(label_f.read()) 189 | return ignore_label 190 | return False # no ignore.txt means to include it in dataset 191 | 192 | def save(self, filename, save_dir): 193 | with open(save_dir + filename, 'wb') as output: 194 | pickle.dump(self, output, pickle.HIGHEST_PROTOCOL) 195 | 196 | def valid_image(self, filename): 197 | return Path(filename).suffix == '.jpg' or Path(filename).suffix == '.png' -------------------------------------------------------------------------------- /baseline_risk_assessment/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision.models as models 6 | 7 | class LSTM_Classifier(nn.Module): 8 | ''' 9 | Recurrent Network binary classifier 10 | Supports 3 models: {GRU, LSTM, LSTM+Dropout} 11 | 12 | To call module provide the input_shape, model_name, and cfg params 13 | input_shape should be a tensor -> (batch_size, frames, channels, height, width) 14 | model_name must be one of these {gru, lstm} 15 | the lstm model can be configured with dropout if cfg.dropout > 0 16 | ''' 17 | def __init__(self, input_shape, model_name, cfg): 18 | super(LSTM_Classifier, self).__init__() 19 | self.cfg = cfg 20 | self.model_name = model_name 21 | self.batch_size, self.frames, self.channels, self.height, self.width = input_shape 22 | self.dropout = nn.Dropout(self.cfg.dropout) 23 | 24 | if self.model_name == 'gru': 25 | self.l1 = nn.GRU(input_size=self.channels*self.height*self.width, hidden_size=100, batch_first=True) 26 | self.l2 = nn.Linear(in_features=100, out_features=2) 27 | 28 | elif self.model_name == 'lstm': 29 | self.l1 = nn.LSTM(input_size=self.channels*self.height*self.width, hidden_size=512, batch_first=True) 30 | self.l2 = nn.LSTM(input_size=512, hidden_size=512, batch_first=True) 31 | self.l3 = nn.Linear(in_features=512, out_features=1000) 32 | self.l4 = nn.Linear(in_features=1000, out_features=200) 33 | self.l5 = nn.Linear(in_features=200, out_features=2) 34 | 35 | def forward(self, x): 36 | # format input for lstm 37 | # x = self.reshape(x) 38 | x = torch.reshape(x, (x.shape[0], x.shape[1], -1)) 39 | if self.model_name == 'gru': 40 | _,l1 = self.l1(x) # return only last sequence 41 | l2 = self.l2(l1) 42 | return l2.squeeze() 43 | elif self.model_name == 'lstm': 44 | dropout = lambda curr_layer: self.dropout(curr_layer) if self.cfg.dropout != 0 else curr_layer 45 | l1,_ = self.l1(x) # return all sequences 46 | _,(l2,_) = self.l2(l1) # return only last sequence 47 | l3 = self.l3(dropout(l2)) 48 | l4 = self.l4(dropout(l3)) 49 | l5 = self.l5(l4) 50 | return l5.squeeze() 51 | else: 52 | raise Exception('Unsupported model! Choose between gru or lstm') 53 | 54 | class CNN_LSTM_Classifier(nn.Module): 55 | ''' 56 | CNN+LSTM binary classifier 57 | 58 | To call module provide the input_shape and cfg params 59 | input_shape should be a tensor -> (batch_size, frames, channels, height, width) 60 | ''' 61 | def __init__(self, input_shape, cfg): 62 | super(CNN_LSTM_Classifier, self).__init__() 63 | self.cfg = cfg 64 | self.batch_size, self.frames, self.channels, self.height, self.width = input_shape 65 | self.dropout = nn.Dropout(self.cfg.dropout) 66 | self.kernel_size = (3, 3) 67 | self.lstm_layers = 1 68 | self.conv_size = lambda i, k, p, s: int((i-k+2*p)/s + 1) 69 | self.pool_size = lambda i, k, p, s, pool : conv_size(i, k, p, s) // pool + 1 70 | self.flat_size = lambda f, h, w : f*h*w 71 | self.TimeDistributed = lambda curr_layer, prev_layer : torch.stack([curr_layer(prev_layer[:,i]) for i in range(self.frames)], dim=1) 72 | 73 | # Note: conv_size and pool_size only work for square 2D matrices, if not a square matrix, run once for height dim and another time for width dim 74 | ''' 75 | conv_size = lambda i, k, p, s: int((i-k+2*p)/s + 1) 76 | pool_size = lambda i, k, p, s, pool : conv_size(i, k, p, s) // pool + 1 77 | flat_size = lambda f, h, w : f*h*w 78 | ''' 79 | self.bn1 = nn.BatchNorm3d(num_features=5) 80 | self.bn2 = nn.BatchNorm3d(num_features=5) 81 | self.bn3 = nn.BatchNorm3d(num_features=5) 82 | self.bn4 = nn.BatchNorm1d(num_features=5) 83 | self.bn5 = nn.BatchNorm1d(num_features=5) 84 | 85 | self.c1 = nn.Conv2d(in_channels=self.channels, out_channels=16, kernel_size=self.kernel_size) 86 | self.c2 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=self.kernel_size) 87 | self.mp1 = nn.MaxPool2d(kernel_size=2) 88 | self.flat = nn.Flatten(start_dim=1) 89 | self.flat_dim = self.get_flat_dim() 90 | self.l1 = nn.Linear(in_features=self.flat_dim, out_features=200) 91 | self.l2 = nn.Linear(in_features=200, out_features=50) 92 | self.lstm1 = nn.LSTM(input_size=50, hidden_size=20, num_layers=self.lstm_layers, batch_first=True) 93 | self.l3 = nn.Linear(in_features=20, out_features=2) 94 | 95 | def get_flat_dim(self): 96 | c1_h = self.conv_size(self.height, self.kernel_size[-1], 0, 1) 97 | c1_w = self.conv_size(self.width, self.kernel_size[-1], 0, 1) 98 | c2_h = self.conv_size(c1_h, self.kernel_size[-1], 0, 1) 99 | c2_w = self.conv_size(c1_w, self.kernel_size[-1], 0, 1) 100 | mp1_h = c2_h // 2 101 | mp1_w = c2_w // 2 102 | return self.flat_size(16, mp1_h, mp1_w) 103 | 104 | def forward(self, x): 105 | # Distribute learnable layers across all frames with shared weights 106 | if self.cfg.bnorm: # can use a larger learning rate w/ bnorm 107 | c1 = F.relu(self.bn1(self.TimeDistributed(self.c1, x))) 108 | c2 = F.relu(self.bn2(self.TimeDistributed(self.c2, c1))) 109 | mp1 = self.dropout(self.bn3(self.TimeDistributed(self.mp1, c2))) 110 | flat = self.TimeDistributed(self.flat, mp1) 111 | l1 = F.relu(self.bn4(self.TimeDistributed(self.l1, flat))) 112 | l2 = F.relu(self.bn5(self.TimeDistributed(self.l2, l1))) 113 | _,(lstm1,_) = self.lstm1(l2) 114 | l3 = self.l3(lstm1) 115 | else: 116 | c1 = F.relu(self.TimeDistributed(self.c1, x)) 117 | c2 = F.relu(self.TimeDistributed(self.c2, c1)) 118 | mp1 = self.dropout(self.TimeDistributed(self.mp1, c2)) 119 | flat = self.TimeDistributed(self.flat, mp1) 120 | l1 = F.relu(self.TimeDistributed(self.l1, flat)) 121 | l2 = F.relu(self.TimeDistributed(self.l2, l1)) 122 | _,(lstm1,_) = self.lstm1(l2) 123 | l3 = self.l3(lstm1) 124 | 125 | self.layer_names = self.ordered_layers = [("c1", self.c1),("c2", self.c2),("mp1", self.mp1),("flat", self.flat), ("l1", self.l1),("l2", self.l2),("lstm1", self.lstm1),("l3", self.l3)] 126 | return l3.squeeze() 127 | 128 | class CNN_Classifier(nn.Module): 129 | ''' 130 | 3D CNN+Linear binary classifier 131 | 132 | To call module provide the input_shape and cfg params 133 | input_shape should be a tensor -> (batch_size, frames, channels, height, width) 134 | ''' 135 | def __init__(self, input_shape, cfg): 136 | super(CNN_Classifier, self).__init__() 137 | self.cfg = cfg 138 | self.batch_size, self.frames, self.channels, self.height, self.width = input_shape 139 | self.kernel_size = (1, 5, 5) 140 | self.conv_size = lambda i, k, p, s: int((i-k+2*p)/s + 1) 141 | self.pool_size = lambda i, k, p, s, pool : conv_size(i, k, p, s) // pool + 1 142 | 143 | self.c1 = nn.Conv3d(in_channels=self.channels, out_channels=32, kernel_size=self.kernel_size) 144 | self.c2 = nn.Conv3d(in_channels=32, out_channels=64, kernel_size=self.kernel_size) 145 | self.mp1 = nn.MaxPool3d(kernel_size=(1,2,2), stride=(1,2,2)) 146 | self.mp2 = nn.MaxPool3d(kernel_size=(1,2,2)) 147 | self.flat = nn.Flatten(start_dim=1) 148 | self.flat_dim = 64*self.frames*self.get_flat_dim() 149 | self.l1 = nn.Linear(in_features=self.flat_dim, out_features=1000) 150 | self.l2 = nn.Linear(in_features=1000, out_features=2) 151 | 152 | def get_flat_dim(self): 153 | c1_h = self.conv_size(self.height, self.kernel_size[-1], 0, 1) 154 | c1_w = self.conv_size(self.width, self.kernel_size[-1], 0, 1) 155 | mp1_h = c1_h // 2 156 | mp1_w = c1_w // 2 157 | c2_h = self.conv_size(mp1_h, self.kernel_size[-1], 0, 1) 158 | c2_w = self.conv_size(mp1_w, self.kernel_size[-1], 0, 1) 159 | mp2_h = c2_h // 2 160 | mp2_w = c2_w // 2 161 | return mp2_h * mp2_w 162 | 163 | def reshape(self, x): 164 | # assumes batch first dim 165 | return x.permute(0, 2, 1, 3, 4) 166 | 167 | def forward(self, x): 168 | # format input for 3d cnn 169 | assert len(x.shape) == 5 170 | x = self.reshape(x) 171 | c1 = F.relu(self.c1(x)) 172 | mp1 = self.mp1(c1) 173 | c2 = F.relu(self.c2(mp1)) 174 | mp2 = self.mp2(c2) 175 | flat1 = self.flat(mp2) 176 | l1 = F.relu(self.l1(flat1)) 177 | l2 = self.l2(l1) 178 | return l2.squeeze() 179 | 180 | class ResNet50_LSTM_Classifier(nn.Module): 181 | ''' 182 | ResNet50+LSTM binary classifier 183 | 184 | To call module provide the input_shape, model_name, and cfg params 185 | input_shape should be a tensor -> (batch_size, frames, channels, height, width) 186 | ''' 187 | def __init__(self, input_shape, cfg): 188 | super(ResNet50_LSTM_Classifier, self).__init__() 189 | self.cfg = cfg 190 | self.batch_size, self.frames, self.channels, self.height, self.width = input_shape 191 | 192 | # “Deep Residual Learning for Image Recognition” 193 | # Using only feature extraction layers shape: (C, H, W) -> (2048, 1, 1) 194 | ''' 195 | self.resent = models.resnet50(pretrained=True, progress=True) 196 | nn.Sequential(*list(self.resnet.children())[:-3])(x[0]).shape 197 | torch.Size([16, 512, 28, 28]) 198 | ''' 199 | self.resnet = nn.Sequential(*list(models.resnet50(pretrained=True, progress=True).children())[:-1]) 200 | 201 | # TODO: verify lstm hidden size with louis 202 | # self.lstm1 = nn.LSTM(input_size=512, hidden_size=20) 203 | self.lstm1 = nn.LSTM(input_size=2048, hidden_size=20, batch_first=True) 204 | self.l1 = nn.Linear(in_features=20, out_features=2) 205 | 206 | def forward(self, x): 207 | TimeDistributed = lambda curr_layer, prev_layer : torch.stack([curr_layer(prev_layer[:,i]) for i in range(self.frames)], dim=1) 208 | resnet = TimeDistributed(self.resnet, x) 209 | _,(lstm1,_) = self.lstm1(torch.squeeze(resnet)) 210 | l1 = self.l1(lstm1) 211 | return l1.squeeze() -------------------------------------------------------------------------------- /sg_risk_assessment/mrgcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchnlp.nn import Attention 5 | from torch.nn import Linear, LSTM 6 | from torch_geometric.nn import RGCNConv, SAGPooling, TopKPooling, FastRGCNConv 7 | from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool 8 | 9 | from torch_geometric.nn import GraphConv 10 | from torch_geometric.nn.pool.topk_pool import topk, filter_adj 11 | from torch_geometric.utils import softmax 12 | 13 | 14 | class RGCNSAGPooling(torch.nn.Module): 15 | def __init__(self, in_channels, num_relations, ratio=0.5, min_score=None, 16 | multiplier=1, nonlinearity=torch.tanh, rgcn_func="FastRGCNConv", **kwargs): 17 | super(RGCNSAGPooling, self).__init__() 18 | 19 | self.in_channels = in_channels 20 | self.ratio = ratio 21 | self.gnn = FastRGCNConv(in_channels, 1, num_relations, **kwargs) if rgcn_func=="FastRGCNConv" else RGCNConv(in_channels, 1, num_relations, **kwargs) 22 | self.min_score = min_score 23 | self.multiplier = multiplier 24 | self.nonlinearity = nonlinearity 25 | 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | self.gnn.reset_parameters() 30 | 31 | 32 | def forward(self, x, edge_index, edge_attr=None, batch=None, attn=None): 33 | """""" 34 | if batch is None: 35 | batch = edge_index.new_zeros(x.size(0)) 36 | 37 | attn = x if attn is None else attn 38 | attn = attn.unsqueeze(-1) if attn.dim() == 1 else attn 39 | score = self.gnn(attn, edge_index, edge_attr).view(-1) 40 | 41 | if self.min_score is None: 42 | score = self.nonlinearity(score) 43 | else: 44 | score = softmax(score, batch) 45 | 46 | perm = topk(score, self.ratio, batch, self.min_score) 47 | x = x[perm] * score[perm].view(-1, 1) 48 | x = self.multiplier * x if self.multiplier != 1 else x 49 | 50 | batch = batch[perm] 51 | edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm, 52 | num_nodes=score.size(0)) 53 | 54 | return x, edge_index, edge_attr, batch, perm, score[perm] 55 | 56 | 57 | def __repr__(self): 58 | return '{}({}, {}, {}={}, multiplier={})'.format( 59 | self.__class__.__name__, self.gnn.__class__.__name__, 60 | self.in_channels, 61 | 'ratio' if self.min_score is None else 'min_score', 62 | self.ratio if self.min_score is None else self.min_score, 63 | self.multiplier) 64 | 65 | class MRGCN(nn.Module): 66 | 67 | def __init__(self, config): 68 | super(MRGCN, self).__init__() 69 | 70 | self.num_features = config.num_features 71 | self.num_relations = config.num_relations 72 | self.num_classes = config.nclass 73 | self.num_layers = config.num_layers #defines number of RGCN conv layers. 74 | self.hidden_dim = config.hidden_dim 75 | self.layer_spec = None if config.layer_spec == None else list(map(int, config.layer_spec.split(','))) 76 | self.lstm_dim1 = config.lstm_input_dim 77 | self.lstm_dim2 = config.lstm_output_dim 78 | self.rgcn_func = FastRGCNConv if config.conv_type == "FastRGCNConv" else RGCNConv 79 | self.activation = F.relu if config.activation == 'relu' else F.leaky_relu 80 | self.pooling_type = config.pooling_type 81 | self.readout_type = config.readout_type 82 | self.temporal_type = config.temporal_type 83 | 84 | self.dropout = config.dropout 85 | self.conv = [] 86 | total_dim = 0 87 | 88 | if self.layer_spec == None: 89 | if self.num_layers > 0: 90 | self.conv.append(self.rgcn_func(self.num_features, self.hidden_dim, self.num_relations).to(config.device)) 91 | total_dim += self.hidden_dim 92 | for i in range(1, self.num_layers): 93 | self.conv.append(self.rgcn_func(self.hidden_dim, self.hidden_dim, self.num_relations).to(config.device)) 94 | total_dim += self.hidden_dim 95 | else: 96 | self.fc0_5 = Linear(self.num_features, self.hidden_dim) 97 | else: 98 | if self.num_layers > 0: 99 | print("using layer specification and ignoring hidden_dim parameter.") 100 | print("layer_spec: " + str(self.layer_spec)) 101 | self.conv.append(self.rgcn_func(self.num_features, self.layer_spec[0], self.num_relations).to(config.device)) 102 | total_dim += self.layer_spec[0] 103 | for i in range(1, self.num_layers): 104 | self.conv.append(self.rgcn_func(self.layer_spec[i-1], self.layer_spec[i], self.num_relations).to(config.device)) 105 | total_dim += self.layer_spec[i] 106 | 107 | else: 108 | self.fc0_5 = Linear(self.num_features, self.hidden_dim) 109 | total_dim += self.hidden_dim 110 | 111 | if self.pooling_type == "sagpool": 112 | self.pool1 = RGCNSAGPooling(total_dim, self.num_relations, ratio=config.pooling_ratio, rgcn_func=config.conv_type) 113 | elif self.pooling_type == "topk": 114 | self.pool1 = TopKPooling(total_dim, ratio=config.pooling_ratio) 115 | 116 | self.fc1 = Linear(total_dim, self.lstm_dim1) 117 | 118 | if "lstm" in self.temporal_type: 119 | self.lstm = LSTM(self.lstm_dim1, self.lstm_dim2, batch_first=True) 120 | self.attn = Attention(self.lstm_dim2) 121 | self.lstm_decoder = LSTM(self.lstm_dim2, self.lstm_dim2, batch_first=True) 122 | else: 123 | self.fc1_5 = Linear(self.lstm_dim1, self.lstm_dim2) 124 | 125 | self.fc2 = Linear(self.lstm_dim2, self.num_classes) 126 | 127 | 128 | def forward(self, x, edge_index, edge_attr, batch=None): 129 | attn_weights = dict() 130 | outputs = [] 131 | if self.num_layers > 0: 132 | for i in range(self.num_layers): 133 | x = self.activation(self.conv[i](x, edge_index, edge_attr)) 134 | x = F.dropout(x, self.dropout, training=self.training) 135 | outputs.append(x) 136 | x = torch.cat(outputs, dim=-1) 137 | else: 138 | x = self.activation(self.fc0_5(x)) 139 | 140 | if self.pooling_type == "sagpool": 141 | x, edge_index, _, attn_weights['batch'], attn_weights['pool_perm'], attn_weights['pool_score'] = self.pool1(x, edge_index, edge_attr=edge_attr, batch=batch) 142 | elif self.pooling_type == "topk": 143 | x, edge_index, _, attn_weights['batch'], attn_weights['pool_perm'], attn_weights['pool_score'] = self.pool1(x, edge_index, edge_attr=edge_attr, batch=batch) 144 | else: 145 | attn_weights['batch'] = batch 146 | 147 | if self.readout_type == "add": 148 | x = global_add_pool(x, attn_weights['batch']) 149 | elif self.readout_type == "mean": 150 | x = global_mean_pool(x, attn_weights['batch']) 151 | elif self.readout_type == "max": 152 | x = global_max_pool(x, attn_weights['batch']) 153 | else: 154 | pass 155 | 156 | x = self.activation(self.fc1(x)) 157 | 158 | if self.temporal_type == "mean": 159 | x = self.activation(self.fc1_5(x.mean(axis=0))) 160 | elif self.temporal_type == "lstm_last": 161 | x_predicted, (h, c) = self.lstm(x.unsqueeze(0)) 162 | x = h.flatten() 163 | elif self.temporal_type == "lstm_sum": 164 | x_predicted, (h, c) = self.lstm(x.unsqueeze(0)) 165 | x = x_predicted.sum(dim=1).flatten() 166 | elif self.temporal_type == "lstm_attn": 167 | x_predicted, (h, c) = self.lstm(x.unsqueeze(0)) 168 | x, attn_weights['lstm_attn_weights'] = self.attn(h.view(1,1,-1), x_predicted) 169 | x, (h_decoder, c_decoder) = self.lstm_decoder(x, (h, c)) 170 | x = x.flatten() 171 | else: 172 | pass 173 | 174 | return F.log_softmax(self.fc2(x), dim=-1), attn_weights 175 | 176 | 177 | #implementation of MRGCN using a GIN style readout. 178 | class MRGIN(nn.Module): 179 | def __init__(self, config): 180 | super(MRGIN, self).__init__() 181 | self.num_features = config.num_features 182 | self.num_relations = config.num_relations 183 | self.num_classes = config.nclass 184 | self.num_layers = config.num_layers #defines number of RGCN conv layers. 185 | self.hidden_dim = config.hidden_dim 186 | self.layer_spec = None if config.layer_spec == None else list(map(int, config.layer_spec.split(','))) 187 | self.lstm_dim1 = config.lstm_input_dim 188 | self.lstm_dim2 = config.lstm_output_dim 189 | self.rgcn_func = FastRGCNConv if config.conv_type == "FastRGCNConv" else RGCNConv 190 | self.activation = F.relu if config.activation == 'relu' else F.leaky_relu 191 | self.pooling_type = config.pooling_type 192 | self.readout_type = config.readout_type 193 | self.temporal_type = config.temporal_type 194 | self.dropout = config.dropout 195 | self.conv = [] 196 | self.pool = [] 197 | total_dim = 0 198 | 199 | if self.layer_spec == None: 200 | for i in range(self.num_layers): 201 | if i == 0: 202 | self.conv.append(self.rgcn_func(self.num_features, self.hidden_dim, self.num_relations).to(config.device)) 203 | else: 204 | self.conv.append(self.rgcn_func(self.hidden_dim, self.hidden_dim, self.num_relations).to(config.device)) 205 | if self.pooling_type == "sagpool": 206 | self.pool.append(RGCNSAGPooling(self.hidden_dim, self.num_relations, ratio=config.pooling_ratio, rgcn_func=config.conv_type).to(config.device)) 207 | elif self.pooling_type == "topk": 208 | self.pool.append(TopKPooling(self.hidden_dim, ratio=config.pooling_ratio).to(config.device)) 209 | total_dim += self.hidden_dim 210 | 211 | else: 212 | print("using layer specification and ignoring hidden_dim parameter.") 213 | print("layer_spec: " + str(self.layer_spec)) 214 | for i in range(self.num_layers): 215 | if i == 0: 216 | self.conv.append(self.rgcn_func(self.num_features, self.layer_spec[0], self.num_relations).to(config.device)) 217 | else: 218 | self.conv.append(self.rgcn_func(self.layer_spec[i-1], self.layer_spec[i], self.num_relations).to(config.device)) 219 | if self.pooling_type == "sagpool": 220 | self.pool.append(RGCNSAGPooling(self.layer_spec[i], self.num_relations, ratio=config.pooling_ratio, rgcn_func=config.conv_type).to(config.device)) 221 | elif self.pooling_type == "topk": 222 | self.pool.append(TopKPooling(self.layer_spec[i], ratio=config.pooling_ratio).to(config.device)) 223 | total_dim += self.layer_spec[i] 224 | 225 | self.fc1 = Linear(total_dim, self.lstm_dim1) 226 | 227 | if "lstm" in self.temporal_type: 228 | self.lstm = LSTM(self.lstm_dim1, self.lstm_dim2, batch_first=True) 229 | self.attn = Attention(self.lstm_dim2) 230 | 231 | self.fc2 = Linear(self.lstm_dim2, self.num_classes) 232 | 233 | 234 | 235 | def forward(self, x, edge_index, edge_attr, batch=None): 236 | attn_weights = dict() 237 | outputs = [] 238 | 239 | #readout performed after each layer and concatenated 240 | for i in range(self.num_layers): 241 | x = self.activation(self.conv[i](x, edge_index, edge_attr)) 242 | x = F.dropout(x, self.dropout, training=self.training) 243 | if self.pooling_type == "sagpool": 244 | p, _, _, batch2, attn_weights['pool_perm'], attn_weights['pool_score'] = self.pool[i](x, edge_index, edge_attr=edge_attr, batch=batch) 245 | elif self.pooling_type == "topk": 246 | p, _, _, batch2, attn_weights['pool_perm'], attn_weights['pool_score'] = self.pool[i](x, edge_index, edge_attr=edge_attr, batch=batch) 247 | else: 248 | p = x 249 | batch2 = batch 250 | if self.readout_type == "add": 251 | r = global_add_pool(p, batch2) 252 | elif self.readout_type == "mean": 253 | r = global_mean_pool(p, batch2) 254 | elif self.readout_type == "max": 255 | r = global_max_pool(p, batch2) 256 | else: 257 | r = p 258 | outputs.append(r) 259 | 260 | x = torch.cat(outputs, dim=-1) 261 | x = self.activation(self.fc1(x)) 262 | 263 | if self.temporal_type == "mean": 264 | x = self.activation(x.mean(axis=0)) 265 | elif self.temporal_type == "lstm_last": 266 | x_predicted, (h, c) = self.lstm(x.unsqueeze(0)) 267 | x = h.flatten() 268 | elif self.temporal_type == "lstm_sum": 269 | x_predicted, (h, c) = self.lstm(x.unsqueeze(0)) 270 | x = x_predicted.sum(dim=1).flatten() 271 | elif self.temporal_type == "lstm_attn": 272 | x_predicted, (h, c) = self.lstm(x.unsqueeze(0)) 273 | x, attn_weights['lstm_attn_weights'] = self.attn(h.view(1,1,-1), x_predicted) 274 | x = x.flatten() 275 | else: 276 | pass 277 | 278 | return F.log_softmax(self.fc2(x), dim=-1), attn_weights -------------------------------------------------------------------------------- /sg_risk_assessment/relation_extractor.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import math 3 | 4 | 5 | MOTO_NAMES = ["Harley-Davidson", "Kawasaki", "Yamaha"] 6 | BICYCLE_NAMES = ["Gazelle", "Diamondback", "Bh"] 7 | CAR_NAMES = ["Ford", "Bmw", "Toyota", "Nissan", "Mini", "Tesla", "Seat", "Lincoln", "Audi", "Carlamotors", "Citroen", "Mercedes-Benz", "Chevrolet", "Volkswagen", "Jeep", "Nissan", "Dodge", "Mustang"] 8 | 9 | CAR_PROXIMITY_THRESH_NEAR_COLL = 4 10 | CAR_PROXIMITY_THRESH_SUPER_NEAR = 7 # max number of feet between a car and another entity to build proximity relation 11 | CAR_PROXIMITY_THRESH_VERY_NEAR = 10 12 | CAR_PROXIMITY_THRESH_NEAR = 16 13 | CAR_PROXIMITY_THRESH_VISIBLE = 25 14 | MOTO_PROXIMITY_THRESH = 50 15 | BICYCLE_PROXIMITY_THRESH = 50 16 | PED_PROXIMITY_THRESH = 50 17 | 18 | #defines all types of actors which can exist 19 | #order of enum values is important as this determines which function is called. DO NOT CHANGE ENUM ORDER 20 | class ActorType(Enum): 21 | CAR = 0 #26, 142, 137:truck 22 | MOTO = 1 #80 23 | BICYCLE = 2 #11 24 | PED = 3 #90, 91, 98: "player", 78:man, 79:men, 149:woman, 56: guy, 53: girl 25 | LANE = 4 #124:street, 114:sidewalk 26 | LIGHT = 5 # 99: "pole", 76: light 27 | SIGN = 6 28 | ROAD = 7 29 | 30 | ACTOR_NAMES=['car','moto','bicycle','ped','lane','light','sign', 'road'] 31 | 32 | class Relations(Enum): 33 | isIn = 0 34 | near_coll = 1 35 | super_near = 2 36 | very_near = 3 37 | near = 4 38 | visible = 5 39 | inDFrontOf = 6 40 | inSFrontOf = 7 41 | atDRearOf = 8 42 | atSRearOf = 9 43 | toLeftOf = 10 44 | toRightOf = 11 45 | 46 | RELATION_COLORS = ["black", "red", "orange", "yellow", "green", "purple", "blue", 47 | "sienna", "pink", "pink", "pink", "turquoise", "turquoise", "turquoise", "violet", "violet"] 48 | 49 | #This class extracts relations for every pair of entities in a scene 50 | class RelationExtractor: 51 | def __init__(self, ego_node): 52 | self.ego_node = ego_node 53 | 54 | def get_actor_type(self, actor): 55 | if "curr" in actor.attr.keys(): 56 | return ActorType.LANE 57 | if actor.attr["name"] == "Traffic Light": 58 | return ActorType.LIGHT 59 | if actor.attr["name"].split(" ")[0] == "Pedestrian": 60 | return ActorType.PED 61 | if actor.attr["name"].split(" ")[0] in CAR_NAMES: 62 | return ActorType.CAR 63 | if actor.attr["name"].split(" ")[0] in MOTO_NAMES: 64 | return ActorType.MOTO 65 | if actor.attr["name"].split(" ")[0] in BICYCLE_NAMES: 66 | return ActorType.BICYCLE 67 | if "Sign" in actor.attr["name"]: 68 | return ActorType.SIGN 69 | 70 | # import pdb; pdb.set_trace() 71 | raise NameError("Actor name not found for actor with name: " + actor.attr["name"]) 72 | 73 | #takes in two entities and extracts all relations between those two entities. extracted relations are bidirectional 74 | def extract_relations(self, actor1, actor2): 75 | #import pdb; pdb.set_trace() 76 | type1 = self.get_actor_type(actor1) 77 | type2 = self.get_actor_type(actor2) 78 | 79 | low_type = min(type1.value, type2.value) #the lower of the two enums. 80 | high_type = max(type1.value, type2.value) 81 | 82 | function_call = "self.extract_relations_"+ACTOR_NAMES[low_type]+"_"+ACTOR_NAMES[high_type]+"(actor1, actor2) if type1.value <= type2.value "\ 83 | "else self.extract_relations_"+ACTOR_NAMES[low_type]+"_"+ACTOR_NAMES[high_type]+"(actor2, actor1)" 84 | return eval(function_call) 85 | 86 | 87 | #~~~~~~~~~specific relations for each pair of actors possible~~~~~~~~~~~~ 88 | #actor 1 corresponds to the first actor in the function name and actor2 the second 89 | 90 | def extract_relations_car_car(self, actor1, actor2): 91 | relation_list = [] 92 | # consider the proximity relations with neighboring lanes. 93 | if actor1.name.startswith("ego:") or actor2.name.startswith("ego:"): 94 | if self.euclidean_distance(actor1, actor2) <= CAR_PROXIMITY_THRESH_NEAR: 95 | relation_list += self.create_proximity_relations(actor1, actor2) 96 | relation_list += self.create_proximity_relations(actor2, actor1) 97 | relation_list += self.extract_directional_relation(actor1, actor2) 98 | relation_list += self.extract_directional_relation(actor2, actor1) 99 | return relation_list 100 | 101 | def extract_relations_car_lane(self, actor1, actor2): 102 | relation_list = [] 103 | # if(self.in_lane(actor1,actor2)): 104 | # relation_list.append([actor1, Relations.isIn, actor2]) 105 | 106 | return relation_list 107 | 108 | def extract_relations_car_light(self, actor1, actor2): 109 | relation_list = [] 110 | return relation_list 111 | 112 | def extract_relations_car_sign(self, actor1, actor2): 113 | relation_list = [] 114 | return relation_list 115 | 116 | def extract_relations_car_ped(self, actor1, actor2): 117 | relation_list = [] 118 | return relation_list 119 | 120 | def extract_relations_car_bicycle(self, actor1, actor2): 121 | relation_list = [] 122 | return relation_list 123 | 124 | def extract_relations_car_moto(self, actor1, actor2): 125 | relation_list = [] 126 | return relation_list 127 | 128 | 129 | def extract_relations_moto_moto(self, actor1, actor2): 130 | relation_list = [] 131 | return relation_list 132 | 133 | def extract_relations_moto_bicycle(self, actor1, actor2): 134 | relation_list = [] 135 | return relation_list 136 | 137 | def extract_relations_moto_ped(self, actor1, actor2): 138 | relation_list = [] 139 | return relation_list 140 | 141 | def extract_relations_moto_lane(self, actor1, actor2): 142 | relation_list = [] 143 | # if(self.in_lane(actor1,actor2)): 144 | # relation_list.append([actor1, Relations.isIn, actor2]) 145 | # # relation_list.append([actor2, Relations.isIn, actor1]) 146 | return relation_list 147 | 148 | def extract_relations_moto_light(self, actor1, actor2): 149 | relation_list = [] 150 | return relation_list 151 | 152 | def extract_relations_moto_sign(self, actor1, actor2): 153 | relation_list = [] 154 | return relation_list 155 | 156 | 157 | def extract_relations_bicycle_bicycle(self, actor1, actor2): 158 | relation_list = [] 159 | # if(self.euclidean_distance(actor1, actor2) < BICYCLE_PROXIMITY_THRESH): 160 | # relation_list.append([actor1, Relations.near, actor2]) 161 | # relation_list.append([actor2, Relations.near, actor1]) 162 | # #relation_list.append(self.extract_directional_relation(actor1, actor2)) 163 | # #relation_list.append(self.extract_directional_relation(actor2, actor1)) 164 | return relation_list 165 | 166 | def extract_relations_bicycle_ped(self, actor1, actor2): 167 | relation_list = [] 168 | # if(self.euclidean_distance(actor1, actor2) < BICYCLE_PROXIMITY_THRESH): 169 | # relation_list.append([actor1, Relations.near, actor2]) 170 | # relation_list.append([actor2, Relations.near, actor1]) 171 | # #relation_list.append(self.extract_directional_relation(actor1, actor2)) 172 | # #relation_list.append(self.extract_directional_relation(actor2, actor1)) 173 | return relation_list 174 | 175 | def extract_relations_bicycle_lane(self, actor1, actor2): 176 | relation_list = [] 177 | # if(self.in_lane(actor1,actor2)): 178 | # relation_list.append([actor1, Relations.isIn, actor2]) 179 | return relation_list 180 | 181 | def extract_relations_bicycle_light(self, actor1, actor2): 182 | relation_list = [] 183 | #relation_list.append(self.extract_directional_relation(actor1, actor2)) 184 | #relation_list.append(self.extract_directional_relation(actor2, actor1)) 185 | return relation_list 186 | 187 | def extract_relations_bicycle_sign(self, actor1, actor2): 188 | relation_list = [] 189 | #relation_list.append(self.extract_directional_relation(actor1, actor2)) 190 | #relation_list.append(self.extract_directional_relation(actor2, actor1)) 191 | return relation_list 192 | 193 | def extract_relations_ped_ped(self, actor1, actor2): 194 | relation_list = [] 195 | if(self.euclidean_distance(actor1, actor2) < PED_PROXIMITY_THRESH): 196 | relation_list.append([actor1, Relations.near, actor2]) 197 | relation_list.append([actor2, Relations.near, actor1]) 198 | #relation_list.append(self.extract_directional_relation(actor1, actor2)) 199 | #relation_list.append(self.extract_directional_relation(actor2, actor1)) 200 | return relation_list 201 | 202 | def extract_relations_ped_lane(self, actor1, actor2): 203 | relation_list = [] 204 | # if(self.in_lane(actor1,actor2)): 205 | # relation_list.append([actor1, Relations.isIn, actor2]) 206 | return relation_list 207 | 208 | def extract_relations_ped_light(self, actor1, actor2): 209 | relation_list = [] 210 | #proximity relation could indicate ped waiting for crosswalk at a light 211 | # if(self.euclidean_distance(actor1, actor2) < PED_PROXIMITY_THRESH): 212 | # relation_list.append([actor1, Relations.near, actor2]) 213 | # relation_list.append([actor2, Relations.near, actor1]) 214 | #relation_list.append(self.extract_directional_relation(actor1, actor2)) 215 | #relation_list.append(self.extract_directional_relation(actor2, actor1)) 216 | return relation_list 217 | 218 | def extract_relations_ped_sign(self, actor1, actor2): 219 | relation_list = [] 220 | # relation_list.append(self.extract_directional_relation(actor1, actor2)) 221 | # relation_list.append(self.extract_directional_relation(actor2, actor1)) 222 | return relation_list 223 | 224 | def extract_relations_lane_lane(self, actor1, actor2): 225 | relation_list = [] 226 | return relation_list 227 | 228 | def extract_relations_lane_light(self, actor1, actor2): 229 | relation_list = [] 230 | return relation_list 231 | 232 | def extract_relations_lane_sign(self, actor1, actor2): 233 | relation_list = [] 234 | return relation_list 235 | 236 | def extract_relations_light_light(self, actor1, actor2): 237 | relation_list = [] 238 | return relation_list 239 | 240 | def extract_relations_light_sign(self, actor1, actor2): 241 | relation_list = [] 242 | return relation_list 243 | 244 | def extract_relations_sign_sign(self, actor1, actor2): 245 | relation_list = [] 246 | return relation_list 247 | 248 | 249 | #~~~~~~~~~~~~~~~~~~UTILITY FUNCTIONS~~~~~~~~~~~~~~~~~~~~~~ 250 | #return euclidean distance between actors 251 | def euclidean_distance(self, actor1, actor2): 252 | #import pdb; pdb.set_trace() 253 | l1 = actor1.attr['location'] 254 | l2 = actor2.attr['location'] 255 | return math.sqrt((l1[0] - l2[0])**2 + (l1[1]- l2[1])**2 + (l1[2] - l2[2])**2) 256 | 257 | #check if an actor is in a certain lane 258 | def in_lane(self, actor1, actor2): 259 | if 'lane_idx' in actor1.attr.keys(): 260 | # calculate the distance bewteen actor1 and actor2 261 | # if it is below 3.5 then they have is in relation. 262 | # if actor1 is ego: if actor2 is not equal to the ego_lane's index then it's invading relation. 263 | if actor1.attr['lane_idx'] == actor2.attr['lane_idx']: 264 | return True 265 | if "invading_lane" in actor1.attr: 266 | if actor1.attr['invading_lane'] == actor2.attr['lane_idx']: 267 | return True 268 | if "orig_lane_idx" in actor1.attr: 269 | if actor1.attr['orig_lane_idx'] == actor2.attr['lane_idx']: 270 | return True 271 | else: 272 | return False 273 | 274 | def create_proximity_relations(self, actor1, actor2): 275 | if self.euclidean_distance(actor1, actor2) <= CAR_PROXIMITY_THRESH_NEAR_COLL: 276 | return [[actor1, Relations.near_coll, actor2]] 277 | elif self.euclidean_distance(actor1, actor2) <= CAR_PROXIMITY_THRESH_SUPER_NEAR: 278 | return [[actor1, Relations.super_near, actor2]] 279 | elif self.euclidean_distance(actor1, actor2) <= CAR_PROXIMITY_THRESH_VERY_NEAR: 280 | return [[actor1, Relations.very_near, actor2]] 281 | elif self.euclidean_distance(actor1, actor2) <= CAR_PROXIMITY_THRESH_NEAR: 282 | return [[actor1, Relations.near, actor2]] 283 | elif self.euclidean_distance(actor1, actor2) <= CAR_PROXIMITY_THRESH_VISIBLE: 284 | return [[actor1, Relations.visible, actor2]] 285 | return [] 286 | 287 | def extract_directional_relation(self, actor1, actor2): 288 | relation_list = [] 289 | # gives directional relations between actors based on their 2D absolute positions. 290 | x1, y1 = math.cos(math.radians(actor1.attr['rotation'][0])), math.sin(math.radians(actor1.attr['rotation'][0])) 291 | x2, y2 = actor2.attr['location'][0] - actor1.attr['location'][0], actor2.attr['location'][1] - actor1.attr['location'][1] 292 | x2, y2 = x2 / math.sqrt(x2**2+y2**2), y2 / math.sqrt(x2**2+y2**2) 293 | 294 | degree = math.degrees(math.atan2(y1, x1)) - math.degrees(math.atan2(y2, x2)) 295 | if degree < 0: 296 | degree += 360 297 | 298 | if degree <= 45: # actor2 is in front of actor1 299 | relation_list.append([actor1, Relations.atDRearOf, actor2]) 300 | elif degree >= 45 and degree <= 90: 301 | relation_list.append([actor1, Relations.atSRearOf, actor2]) 302 | elif degree >= 90 and degree <= 135: 303 | relation_list.append([actor1, Relations.inSFrontOf, actor2]) 304 | elif degree >= 135 and degree <= 180: # actor2 is behind actor1 305 | relation_list.append([actor1, Relations.inDFrontOf, actor2]) 306 | elif degree >= 180 and degree <= 225: # actor2 is behind actor1 307 | relation_list.append([actor1, Relations.inDFrontOf, actor2]) 308 | elif degree >= 225 and degree <= 270: 309 | relation_list.append([actor1, Relations.inSFrontOf, actor2]) 310 | elif degree >= 270 and degree <= 315: 311 | relation_list.append([actor1, Relations.atSRearOf, actor2]) 312 | elif degree >= 315 and degree <= 360: 313 | relation_list.append([actor1, Relations.atDRearOf, actor2]) 314 | 315 | if actor2.attr['lane_idx'] < actor1.attr['lane_idx']: # actor2 to the left of actor1 316 | relation_list.append([actor1, Relations.toRightOf, actor2]) 317 | elif actor2.attr['lane_idx'] > actor1.attr['lane_idx']: # actor2 to the right of actor1 318 | relation_list.append([actor1, Relations.toLeftOf, actor2]) 319 | 320 | return relation_list -------------------------------------------------------------------------------- /sg_risk_assessment/image_scenegraph.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from sg_risk_assessment.relation_extractor import ActorType, Relations, RELATION_COLORS 3 | from networkx.drawing.nx_agraph import to_agraph 4 | import matplotlib.pyplot as plt 5 | import networkx as nx 6 | import numpy as np 7 | import sys 8 | import os 9 | import cv2 10 | import itertools 11 | import math 12 | import matplotlib 13 | matplotlib.use("Agg") 14 | 15 | # SELECT ONE OF THE FOLLOWING: 16 | 17 | # #SETTINGS FOR 1280x720 CARLA IMAGES: 18 | # IMAGE_H = 720 19 | # IMAGE_W = 1280 20 | # CROPPED_H = 350 #height of ROI. crops to lane area of carla images 21 | # BIRDS_EYE_IMAGE_H = 850 22 | # BIRDS_EYE_IMAGE_W = 1280 23 | # Y_SCALE = 0.55 #18 pixels = length of lane line (10 feet) 24 | # X_SCALE = 0.54 #22 pixels = width of lane (12 feet) 25 | 26 | 27 | # SETTINGS FOR 1280x720 HONDA IMAGES: 28 | IMAGE_H = 720 29 | IMAGE_W = 1280 30 | CROPPED_H = 390 31 | BIRDS_EYE_IMAGE_H = 620 32 | BIRDS_EYE_IMAGE_W = 1280 33 | Y_SCALE = 0.45 # 22 pixels = length of lane line (10 feet) 34 | X_SCALE = 0.46 # 26 pixels = width of lane (12 feet) 35 | 36 | 37 | H_OFFSET = IMAGE_H - CROPPED_H # offset from top of image to start of ROI 38 | 39 | CAR_PROXIMITY_THRESH_NEAR_COLL = 4 40 | # max number of feet between a car and another entity to build proximity relation 41 | CAR_PROXIMITY_THRESH_SUPER_NEAR = 7 42 | CAR_PROXIMITY_THRESH_VERY_NEAR = 10 43 | CAR_PROXIMITY_THRESH_NEAR = 16 44 | CAR_PROXIMITY_THRESH_VISIBLE = 25 45 | 46 | LANE_THRESHOLD = 6 # feet. if object's center is more than this distance away from ego's center, build left or right lane relation 47 | # feet. if object's center is within this distance of ego's center, build middle lane relation 48 | CENTER_LANE_THRESHOLD = 9 49 | 50 | 51 | class ObjectNode: 52 | def __init__(self, name, attr, label): 53 | self.name = name # Car-1, Car-2. 54 | self.attr = attr # bounding box info 55 | self.label = label # ActorType 56 | 57 | def __repr__(self): 58 | return "%s" % (self.name) 59 | 60 | 61 | class RealSceneGraph: 62 | ''' 63 | scene graph the real images 64 | arguments: 65 | image_path : path to the image for which the scene graph is generated 66 | 67 | ''' 68 | 69 | def __init__(self, image_path, bounding_boxes, coco_class_names=None, platform='image'): 70 | self.g = nx.MultiDiGraph() # initialize scenegraph as networkx graph 71 | 72 | # road and lane settings. 73 | # we need to define the type of node. 74 | self.road_node = ObjectNode("Root Road", {}, ActorType.ROAD) 75 | self.add_node(self.road_node) # adding the road as the root node 76 | 77 | # specify which type of data to load into model (options: image or honda) 78 | self.platfrom = platform 79 | 80 | # set ego location to middle-bottom of image. 81 | self.ego_location = ((BIRDS_EYE_IMAGE_W/2) * 82 | X_SCALE, BIRDS_EYE_IMAGE_H * Y_SCALE) 83 | self.ego_node = ObjectNode("Ego Car", { 84 | "location_x": self.ego_location[0], "location_y": self.ego_location[1]}, ActorType.CAR) 85 | self.add_node(self.ego_node) 86 | self.extract_relative_lanes() # three lane formulation. 87 | 88 | # convert bounding boxes to nodes and build relations. 89 | boxes, labels, image_size = bounding_boxes 90 | self.get_nodes_from_bboxes(boxes, labels, coco_class_names) 91 | 92 | # import pdb; pdb.set_trace() 93 | self.extract_relations() 94 | 95 | def get_nodes_from_bboxes(self, boxes, labels, coco_class_names): 96 | # birds eye view projection 97 | M = get_birds_eye_matrix() 98 | # warped_img = get_birds_eye_warp(image_path, M) 99 | # cv2.imwrite( "./warped.jpg", cv2.cvtColor(warped_img, cv2.COLOR_BGR2RGB)) #plot warped image 100 | 101 | for idx, (box, label) in enumerate(zip(boxes, labels)): 102 | box = box.cpu().numpy().tolist() 103 | class_name = coco_class_names[label] 104 | 105 | if box[1] >= 620: 106 | continue 107 | 108 | if class_name in ['car', 'truck', 'bus']: 109 | actor_type = ActorType.CAR 110 | # elif class_name in ['person']: 111 | # actor_type = ActorType.PED 112 | # elif class_name in ['bicycle']: 113 | # actor_type = ActorType.BICYCLE 114 | # elif class_name in ['motorcycle']: 115 | # actor_type = ActorType.MOTO 116 | # elif class_name in ['traffic light']: 117 | # actor_type = ActorType.LIGHT 118 | # elif class_name in ['stop sign']: 119 | # actor_type = ActorType.SIGN 120 | else: 121 | continue 122 | 123 | attr = {'x1': box[0], 'y1': box[1], 'x2': box[2], 'y2': box[3]} 124 | 125 | # map center-bottom of bounding box to warped image 126 | x_mid = (box[2] + box[0]) / 2 127 | y_bottom = box[3] - H_OFFSET # offset to account for image crop 128 | pt = np.array([[[x_mid, y_bottom]]], dtype='float32') 129 | warp_pt = cv2.perspectiveTransform(pt, M)[0][0] 130 | 131 | #locations/distances in feet 132 | attr['location_x'] = warp_pt[0] * X_SCALE 133 | attr['location_y'] = warp_pt[1] * Y_SCALE 134 | attr['rel_location_x'] = attr['location_x'] - \ 135 | self.ego_node.attr["location_x"] # x position relative to ego 136 | attr['rel_location_y'] = attr['location_y'] - \ 137 | self.ego_node.attr["location_y"] # y position relative to ego 138 | attr['distance_abs'] = math.sqrt( 139 | attr['rel_location_x']**2 + attr['rel_location_y']**2) # absolute distance from ego 140 | node = ObjectNode("%s_%d" % (class_name, idx), attr, actor_type) 141 | self.add_node(node) 142 | self.add_mapping_to_relative_lanes(node) 143 | 144 | # extract relations between all nodes in the graph 145 | # does not build relations with the road node. 146 | # only builds relations between the ego node and other nodes. 147 | # only builds relations if other node is within the distance CAR_PROXIMITY_THRESH_VISIBLE from ego. 148 | 149 | def extract_relations(self): 150 | for node_a, node_b in itertools.combinations(self.g.nodes, 2): 151 | relation_list = [] 152 | if node_a.label == ActorType.ROAD or node_b.label == ActorType.ROAD: 153 | # dont build relations w/ road 154 | continue 155 | if node_a.label == ActorType.CAR and node_b.label == ActorType.CAR: 156 | if node_a.name.startswith("Ego") or node_b.name.startswith("Ego"): 157 | # print(node_a, node_b, self.get_euclidean_distance(node_a, node_b)) 158 | # import pdb; pdb.set_trace() 159 | if self.get_euclidean_distance(node_a, node_b) <= CAR_PROXIMITY_THRESH_VISIBLE: 160 | relation_list += self.extract_proximity_relations( 161 | node_a, node_b) 162 | relation_list += self.extract_directional_relations( 163 | node_a, node_b) 164 | relation_list += self.extract_proximity_relations( 165 | node_b, node_a) 166 | relation_list += self.extract_directional_relations( 167 | node_b, node_a) 168 | self.add_relations(relation_list) 169 | 170 | # returns proximity relations based on the absolute distance between two actors. 171 | 172 | def extract_proximity_relations(self, actor1, actor2): 173 | if self.get_euclidean_distance(actor1, actor2) <= CAR_PROXIMITY_THRESH_NEAR_COLL: 174 | return [[actor1, Relations.near_coll, actor2]] 175 | elif self.get_euclidean_distance(actor1, actor2) <= CAR_PROXIMITY_THRESH_SUPER_NEAR: 176 | return [[actor1, Relations.super_near, actor2]] 177 | elif self.get_euclidean_distance(actor1, actor2) <= CAR_PROXIMITY_THRESH_VERY_NEAR: 178 | return [[actor1, Relations.very_near, actor2]] 179 | elif self.get_euclidean_distance(actor1, actor2) <= CAR_PROXIMITY_THRESH_NEAR: 180 | return [[actor1, Relations.near, actor2]] 181 | elif self.get_euclidean_distance(actor1, actor2) <= CAR_PROXIMITY_THRESH_VISIBLE: 182 | return [[actor1, Relations.visible, actor2]] 183 | return [] 184 | 185 | # calculates absolute distance between two actors 186 | 187 | def get_euclidean_distance(self, actor1, actor2): 188 | l1 = (actor1.attr['location_x'], actor1.attr['location_y']) 189 | l2 = (actor2.attr['location_x'], actor2.attr['location_y']) 190 | return math.sqrt((l1[0] - l2[0])**2 + (l1[1] - l2[1])**2) 191 | 192 | # returns directional relations between entities based on their relative positions to one another in the scene. 193 | 194 | def extract_directional_relations(self, actor1, actor2): 195 | relation_list = [] 196 | x1, y1 = math.cos(math.radians(0)), math.sin(math.radians(0)) 197 | x2, y2 = actor2.attr['location_x'] - \ 198 | actor1.attr['location_x'], actor2.attr['location_y'] - \ 199 | actor1.attr['location_y'] 200 | x2, y2 = x2 / math.sqrt(x2**2+y2**2), y2 / math.sqrt(x2**2+y2**2) 201 | 202 | degree = math.degrees(math.atan2(y1, x1)) - \ 203 | math.degrees(math.atan2(y2, x2)) 204 | if degree < 0: 205 | degree += 360 206 | 207 | if degree <= 45: # actor2 is in front of actor1 208 | relation_list.append([actor1, Relations.atDRearOf, actor2]) 209 | elif degree >= 45 and degree <= 90: 210 | relation_list.append([actor1, Relations.atSRearOf, actor2]) 211 | elif degree >= 90 and degree <= 135: 212 | relation_list.append([actor1, Relations.inSFrontOf, actor2]) 213 | elif degree >= 135 and degree <= 180: # actor2 is behind actor1 214 | relation_list.append([actor1, Relations.inDFrontOf, actor2]) 215 | elif degree >= 180 and degree <= 225: # actor2 is behind actor1 216 | relation_list.append([actor1, Relations.inDFrontOf, actor2]) 217 | elif degree >= 225 and degree <= 270: 218 | relation_list.append([actor1, Relations.inSFrontOf, actor2]) 219 | elif degree >= 270 and degree <= 315: 220 | relation_list.append([actor1, Relations.atSRearOf, actor2]) 221 | elif degree >= 315 and degree <= 360: 222 | relation_list.append([actor1, Relations.atDRearOf, actor2]) 223 | 224 | if abs(actor2.attr['location_x'] - actor1.attr['location_x']) <= CENTER_LANE_THRESHOLD: 225 | pass 226 | # actor2 to the left of actor1 227 | elif actor2.attr['location_x'] < actor1.attr['location_x']: 228 | relation_list.append([actor2, Relations.toLeftOf, actor1]) 229 | # actor2 to the right of actor1 230 | elif actor2.attr['location_x'] > actor1.attr['location_x']: 231 | relation_list.append([actor2, Relations.toRightOf, actor1]) 232 | # disable rear relations help the inference. 233 | return relation_list 234 | 235 | # relative lane mapping method. Each vehicle is assigned to left, middle, or right lane depending on relative position to ego 236 | 237 | def extract_relative_lanes(self): 238 | self.left_lane = ObjectNode("Left Lane", {}, ActorType.LANE) 239 | self.right_lane = ObjectNode("Right Lane", {}, ActorType.LANE) 240 | self.middle_lane = ObjectNode("Middle Lane", {}, ActorType.LANE) 241 | self.add_node(self.left_lane) 242 | self.add_node(self.right_lane) 243 | self.add_node(self.middle_lane) 244 | self.add_relation([self.left_lane, Relations.isIn, self.road_node]) 245 | self.add_relation([self.right_lane, Relations.isIn, self.road_node]) 246 | self.add_relation([self.middle_lane, Relations.isIn, self.road_node]) 247 | self.add_relation([self.ego_node, Relations.isIn, self.middle_lane]) 248 | 249 | # builds isIn relation between object and lane depending on x-displacement relative to ego 250 | # left/middle and right/middle relations have an overlap area determined by the size of CENTER_LANE_THRESHOLD and LANE_THRESHOLD. 251 | # TODO: move to relation_extractor in replacement of current lane-vehicle relation code 252 | 253 | def add_mapping_to_relative_lanes(self, object_node): 254 | # don't build lane relations with static objects 255 | if object_node.label in [ActorType.LANE, ActorType.LIGHT, ActorType.SIGN, ActorType.ROAD]: 256 | return 257 | if object_node.attr['rel_location_x'] < -LANE_THRESHOLD: 258 | self.add_relation([object_node, Relations.isIn, self.left_lane]) 259 | elif object_node.attr['rel_location_x'] > LANE_THRESHOLD: 260 | self.add_relation([object_node, Relations.isIn, self.right_lane]) 261 | if abs(object_node.attr['rel_location_x']) <= CENTER_LANE_THRESHOLD: 262 | self.add_relation([object_node, Relations.isIn, self.middle_lane]) 263 | 264 | # add single node to graph. node can be any hashable datatype including objects. 265 | 266 | def add_node(self, node): 267 | color = "white" 268 | if "ego" in node.name.lower(): 269 | color = "red" 270 | elif "car" in node.name.lower(): 271 | color = "green" 272 | elif "lane" in node.name.lower(): 273 | color = "yellow" 274 | self.g.add_node(node, attr=node.attr, label=node.name, 275 | style='filled', fillcolor=color) 276 | 277 | # add relation (edge) between nodes on graph. relation is a list containing [subject, relation, object] 278 | 279 | def add_relation(self, relation): 280 | if relation != []: 281 | if relation[0] in self.g.nodes and relation[2] in self.g.nodes: 282 | self.g.add_edge(relation[0], relation[2], object=relation[1], 283 | label=relation[1].name, color=RELATION_COLORS[int(relation[1].value)]) 284 | else: 285 | raise NameError( 286 | "One or both nodes in relation do not exist in graph. Relation: " + str(relation)) 287 | 288 | def add_relations(self, relations_list): 289 | for relation in relations_list: 290 | self.add_relation(relation) 291 | 292 | def visualize(self, to_filename): 293 | A = to_agraph(self.g) 294 | A.layout('dot') 295 | A.draw(to_filename) 296 | 297 | 298 | # ROI: Region of Interest 299 | # returns transformation matrix for warping image to birds eye projection 300 | # birds eye matrix fixed for all images using the assumption that camera perspective does not change over time. 301 | def get_birds_eye_matrix(): 302 | # original dimensions (cropped to ROI) 303 | src = np.float32( 304 | [[0, CROPPED_H], [IMAGE_W, CROPPED_H], [0, 0], [IMAGE_W, 0]]) 305 | dst = np.float32([[int(BIRDS_EYE_IMAGE_W*16/33), BIRDS_EYE_IMAGE_H], [int(BIRDS_EYE_IMAGE_W * 306 | 17/33), BIRDS_EYE_IMAGE_H], [0, 0], [BIRDS_EYE_IMAGE_W, 0]]) # warped dimensions 307 | M = cv2.getPerspectiveTransform(src, dst) # The transformation matrix 308 | # Minv = cv2.getPerspectiveTransform(dst, src) # Inverse transformation (if needed) 309 | return M 310 | 311 | 312 | # returns image warped to birds eye projection using M 313 | # returned image is vertically cropped to the ROI (lane area) 314 | def get_birds_eye_warp(image_path, M): 315 | img = cv2.imread(image_path) 316 | img = img[H_OFFSET:IMAGE_H, 0:IMAGE_W] # Apply np slicing for ROI crop 317 | warped_img = cv2.warpPerspective( 318 | img, M, (BIRDS_EYE_IMAGE_W, BIRDS_EYE_IMAGE_H)) # Image warping 319 | warped_img = cv2.cvtColor(warped_img, cv2.COLOR_BGR2RGB) # set to RGB 320 | return warped_img 321 | -------------------------------------------------------------------------------- /sg_risk_assessment/dynkg_trainer.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import pandas as pd 4 | import random 5 | import pickle as pkl 6 | from tqdm import tqdm 7 | from pathlib import Path 8 | 9 | import torch 10 | import torch.optim as optim 11 | from torch_geometric.data import Data, DataLoader, DataListLoader 12 | from torch.utils.tensorboard import SummaryWriter 13 | 14 | from sklearn import preprocessing 15 | from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_score, recall_score, roc_auc_score, roc_curve 16 | from sklearn.utils import resample 17 | from sklearn.utils.class_weight import compute_class_weight 18 | from sklearn.model_selection import train_test_split 19 | from matplotlib import pyplot as plt 20 | 21 | from sg_risk_assessment.relation_extractor import Relations 22 | from sg_risk_assessment.mrgcn import * 23 | from sg_risk_assessment.metrics import * 24 | import warnings 25 | warnings.simplefilter(action='ignore', category=FutureWarning) 26 | 27 | 28 | class DynKGTrainer: 29 | 30 | def __init__(self, config): 31 | self.config = config 32 | self.args = config.args 33 | np.random.seed(self.config.seed) 34 | torch.manual_seed(self.config.seed) 35 | 36 | self.summary_writer = SummaryWriter() 37 | 38 | self.best_val_loss = 99999 39 | self.best_epoch = 0 40 | self.best_val_acc = 0 41 | self.best_val_auc = 0 42 | self.best_val_confusion = [] 43 | self.best_val_f1 = 0 44 | self.best_val_mcc = -1.0 45 | self.best_val_acc_balanced = 0 46 | self.unique_clips = {} 47 | self.log = False 48 | 49 | if not self.config.pkl_path.exists(): 50 | raise Exception("The cache file does not exist.") 51 | 52 | def init_dataset(self): 53 | self.training_data, self.testing_data, self.feature_list = self.build_scenegraph_dataset(self.config.pkl_path, self.config.split_ratio, downsample=self.config.downsample, seed=self.config.seed, transfer_path=self.config.transfer_path) 54 | self.training_labels = [data['label'] for data in self.training_data] 55 | self.testing_labels = [data['label'] for data in self.testing_data] 56 | self.class_weights = torch.from_numpy(compute_class_weight('balanced', np.unique(self.training_labels), self.training_labels)) 57 | print("Number of Sequences Included: ", len(self.training_data)) 58 | print("Num Labels in Each Class: " + str(np.unique(self.training_labels, return_counts=True)[1]) + ", Class Weights: " + str(self.class_weights)) 59 | 60 | def build_scenegraph_dataset(self, pkl_path, train_to_test_ratio=0.3, downsample=False, seed=0, transfer_path=None): 61 | ''' 62 | scenegraphs_sequence (gnn dataset): 63 | List of scenegraph data structures for evey clip 64 | Keys: {'sequence', 'label', 'folder_name', 'category'} 65 | feature_list: 66 | FILL IN DESCRIPTION HERE! 67 | ''' 68 | dataset_file = open(pkl_path, "rb") 69 | scenegraphs_sequence, feature_list = pkl.load(dataset_file) 70 | 71 | # Store driving categories and their frequencies 72 | self.unique_clips['all'] = 0 73 | for scenegraph in scenegraphs_sequence: 74 | self.unique_clips['all'] += 1 75 | if 'category' in scenegraph: 76 | category = scenegraph['category'] 77 | if category in self.unique_clips: 78 | self.unique_clips[category] += 1 79 | else: 80 | self.unique_clips[category] = 1 81 | else: 82 | scenegraph['category'] = 'all' 83 | print('no category') 84 | print('Total dataset breakdown: {}'.format(self.unique_clips)) 85 | 86 | if transfer_path == None: 87 | class_0 = [] 88 | class_1 = [] 89 | 90 | for g in scenegraphs_sequence: 91 | if g['label'] == 0: 92 | class_0.append(g) 93 | elif g['label'] == 1: 94 | class_1.append(g) 95 | 96 | y_0 = [0]*len(class_0) 97 | y_1 = [1]*len(class_1) 98 | 99 | min_number = min(len(class_0), len(class_1)) 100 | if downsample: 101 | modified_class_0, modified_y_0 = resample(class_0, y_0, n_samples=min_number) 102 | else: 103 | modified_class_0, modified_y_0 = class_0, y_0 104 | 105 | train, test, train_y, test_y = train_test_split(modified_class_0+class_1, modified_y_0+y_1, test_size=train_to_test_ratio, shuffle=True, stratify=modified_y_0+y_1, random_state=seed) 106 | return train, test, feature_list 107 | else: 108 | test, _ = pkl.load(open(transfer_path, "rb")) 109 | return scenegraphs_sequence, test, feature_list 110 | 111 | def build_model(self): 112 | self.config.num_features = len(self.feature_list) 113 | self.config.num_relations = max([r.value for r in Relations])+1 114 | if self.config.model == "mrgcn": 115 | self.model = MRGCN(self.config).to(self.config.device) 116 | elif self.config.model == "mrgin": 117 | self.model = MRGIN(self.config).to(self.config.device) 118 | else: 119 | raise Exception("model selection is invalid: " + self.config.model) 120 | 121 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.learning_rate, weight_decay=self.config.weight_decay) 122 | if self.class_weights.shape[0] < 2: 123 | self.loss_func = nn.CrossEntropyLoss() 124 | else: 125 | self.loss_func = nn.CrossEntropyLoss(weight=self.class_weights.float().to(self.config.device)) 126 | 127 | self.config.wandb.watch(self.model, log="all") 128 | 129 | def train(self): 130 | tqdm_bar = tqdm(range(self.config.epochs)) 131 | for epoch_idx in tqdm_bar: # iterate through epoch 132 | acc_loss_train = 0 133 | self.sequence_loader = DataListLoader(self.training_data, batch_size=self.config.batch_size, shuffle=True) 134 | # TODO: Condense into one for loop 135 | for data_list in self.sequence_loader: # iterate through scenegraphs 136 | labels = torch.empty(0).long().to(self.config.device) 137 | outputs = torch.empty(0,2).to(self.config.device) 138 | self.model.train() 139 | self.optimizer.zero_grad() 140 | 141 | for sequence in data_list: # iterate through sequences 142 | data, label = sequence['sequence'], sequence['label'] 143 | graph_list = [Data(x=g['node_features'], edge_index=g['edge_index'], edge_attr=g['edge_attr']) for g in data] 144 | # data is a sequence that consists of serveral graphs 145 | self.train_loader = DataLoader(graph_list, batch_size=len(graph_list)) 146 | sequence = next(iter(self.train_loader)).to(self.config.device) 147 | output, _ = self.model.forward(sequence.x, sequence.edge_index, sequence.edge_attr, sequence.batch) 148 | outputs = torch.cat([outputs, output.view(-1, 2)], dim=0) 149 | labels = torch.cat([labels, torch.LongTensor([label]).to(self.config.device)], dim=0) 150 | # import pdb; pdb.set_trace() 151 | loss_train = self.loss_func(outputs, labels) 152 | loss_train.backward() 153 | acc_loss_train += loss_train.detach().cpu().item() * len(data_list) 154 | self.optimizer.step() 155 | 156 | acc_loss_train /= len(self.training_data) 157 | tqdm_bar.set_description('Epoch: {:04d}, loss_train: {:.4f}'.format(epoch_idx, acc_loss_train)) 158 | 159 | if epoch_idx % self.config.test_step == 0: 160 | _, _, metrics, _ = self.evaluate(epoch_idx) 161 | self.summary_writer.add_scalar('Acc_Loss/train', metrics['train']['loss'], epoch_idx) 162 | self.summary_writer.add_scalar('Acc_Loss/train_acc', metrics['train']['acc'], epoch_idx) 163 | self.summary_writer.add_scalar('F1/train', metrics['train']['f1'], epoch_idx) 164 | # self.summary_writer.add_scalar('Confusion/train', metrics['train']['confusion'], epoch_idx) 165 | self.summary_writer.add_scalar('Precision/train', metrics['train']['precision'], epoch_idx) 166 | self.summary_writer.add_scalar('Recall/train', metrics['train']['recall'], epoch_idx) 167 | self.summary_writer.add_scalar('Auc/train', metrics['train']['auc'], epoch_idx) 168 | 169 | self.summary_writer.add_scalar('Acc_Loss/test', metrics['test']['loss'], epoch_idx) 170 | self.summary_writer.add_scalar('Acc_Loss/test_acc', metrics['test']['acc'], epoch_idx) 171 | self.summary_writer.add_scalar('F1/test', metrics['test']['f1'], epoch_idx) 172 | # self.summary_writer.add_scalar('Confusion/test', metrics['test']['confusion'], epoch_idx) 173 | self.summary_writer.add_scalar('Precision/test', metrics['test']['precision'], epoch_idx) 174 | self.summary_writer.add_scalar('Recall/test', metrics['test']['recall'], epoch_idx) 175 | self.summary_writer.add_scalar('Auc/test', metrics['test']['auc'], epoch_idx) 176 | 177 | def inference(self, X, y): 178 | labels = torch.LongTensor().to(self.config.device) 179 | outputs = torch.FloatTensor().to(self.config.device) 180 | # Dictionary storing (output, label) pair for all driving categories 181 | categories = dict.fromkeys(self.unique_clips) 182 | for key, val in categories.items(): 183 | categories[key] = {'outputs': outputs, 'labels': labels} 184 | acc_loss_test = 0 185 | folder_names = [] 186 | attns_weights = [] 187 | node_attns = [] 188 | inference_time = 0 189 | 190 | with torch.no_grad(): 191 | for i in range(len(X)): # iterate through scenegraphs 192 | data, label, category = X[i]['sequence'], y[i], X[i]['category'] 193 | data_list = [Data(x=g['node_features'], edge_index=g['edge_index'], edge_attr=g['edge_attr']) for g in data] 194 | self.test_loader = DataLoader(data_list, batch_size=len(data_list)) 195 | sequence = next(iter(self.test_loader)).to(self.config.device) 196 | self.model.eval() 197 | 198 | #start = torch.cuda.Event(enable_timing=True) 199 | #end = torch.cuda.Event(enable_timing=True) 200 | #start.record() 201 | output, attns = self.model.forward(sequence.x, sequence.edge_index, sequence.edge_attr, sequence.batch) 202 | #end.record() 203 | #torch.cuda.synchronize() 204 | inference_time += 0#start.elapsed_time(end) 205 | loss_test = self.loss_func(output.view(-1, 2), torch.LongTensor([label]).to(self.config.device)) 206 | acc_loss_test += loss_test.detach().cpu().item() 207 | label = torch.tensor(label, dtype=torch.long).to(self.config.device) 208 | # store output, label statistics 209 | self.update_categorical_outputs(categories, output, label, category) 210 | 211 | folder_names.append(X[i]['folder_name']) 212 | if 'lstm_attn_weights' in attns: 213 | attns_weights.append(attns['lstm_attn_weights'].squeeze().detach().cpu().numpy().tolist()) 214 | if 'pool_score' in attns: 215 | node_attn = {} 216 | node_attn["original_batch"] = sequence.batch.detach().cpu().numpy().tolist() 217 | node_attn["pool_perm"] = attns['pool_perm'].detach().cpu().numpy().tolist() 218 | node_attn["pool_batch"] = attns['batch'].detach().cpu().numpy().tolist() 219 | node_attn["pool_score"] = attns['pool_score'].detach().cpu().numpy().tolist() 220 | node_attns.append(node_attn) 221 | 222 | sum_seq_len = 0 223 | num_risky_sequences = 0 224 | sequences = len(categories['all']['labels']) 225 | for indices in range(sequences): 226 | seq_output = categories['all']['outputs'][indices] 227 | label = categories['all']['labels'][indices] 228 | pred = torch.argmax(seq_output) 229 | # risky clip 230 | if label == 1: 231 | num_risky_sequences += 1 232 | sum_seq_len += seq_output.shape[0] 233 | 234 | avg_risky_seq_len = sum_seq_len / num_risky_sequences 235 | 236 | return categories, \ 237 | folder_names, \ 238 | acc_loss_test/len(X), \ 239 | avg_risky_seq_len, \ 240 | inference_time, \ 241 | attns_weights, \ 242 | node_attns 243 | 244 | def evaluate(self, current_epoch=None): 245 | metrics = {} 246 | categories_train, \ 247 | folder_names_train, \ 248 | acc_loss_train, \ 249 | train_avg_seq_len, \ 250 | train_inference_time, \ 251 | attns_train, \ 252 | node_attns_train = self.inference(self.training_data, self.training_labels) 253 | 254 | # Collect metrics from all driving categories 255 | for category in self.unique_clips.keys(): 256 | if category == 'all': 257 | metrics['train'] = get_metrics(categories_train['all']['outputs'], categories_train['all']['labels']) 258 | metrics['train']['loss'] = acc_loss_train 259 | metrics['train']['avg_seq_len'] = train_avg_seq_len 260 | else: 261 | metrics['train'][category] = get_metrics(categories_train[category]['outputs'], categories_train[category]['labels']) 262 | 263 | categories_test, \ 264 | folder_names_test, \ 265 | acc_loss_test, \ 266 | val_avg_seq_len, \ 267 | test_inference_time, \ 268 | attns_test, \ 269 | node_attns_test = self.inference(self.testing_data, self.testing_labels) 270 | 271 | # Collect metrics from all driving categories 272 | for category in self.unique_clips.keys(): 273 | if category == 'all': 274 | metrics['test'] = get_metrics(categories_test['all']['outputs'], categories_test['all']['labels']) 275 | metrics['test']['loss'] = acc_loss_test 276 | metrics['test']['avg_seq_len'] = val_avg_seq_len 277 | metrics['avg_inf_time'] = (train_inference_time + test_inference_time) / ((len(self.training_labels) + len(self.testing_labels))) 278 | else: 279 | metrics['test'][category] = get_metrics(categories_test[category]['outputs'], categories_test[category]['labels']) 280 | 281 | print("\ntrain loss: " + str(acc_loss_train) + ", acc:", metrics['train']['acc'], metrics['train']['confusion'], "mcc:", metrics['train']['mcc'], \ 282 | "\ntest loss: " + str(acc_loss_test) + ", acc:", metrics['test']['acc'], metrics['test']['confusion'], "mcc:", metrics['test']['mcc']) 283 | 284 | #automatically save the model and metrics with the lowest validation loss 285 | self.update_best_metrics(metrics, current_epoch) 286 | metrics['best_epoch'] = self.best_epoch 287 | metrics['best_val_loss'] = self.best_val_loss 288 | metrics['best_val_acc'] = self.best_val_acc 289 | metrics['best_val_auc'] = self.best_val_auc 290 | metrics['best_val_conf'] = self.best_val_confusion 291 | metrics['best_val_f1'] = self.best_val_f1 292 | metrics['best_val_mcc'] = self.best_val_mcc 293 | metrics['best_val_acc_balanced'] = self.best_val_acc_balanced 294 | 295 | self.log2wandb(metrics) 296 | 297 | return categories_train, categories_test, metrics, folder_names_train 298 | 299 | # Utilities 300 | def update_categorical_outputs(self, categories, outputs, labels, category): 301 | ''' 302 | Aggregates output, label pairs for every driving category 303 | Based on inference setup, only one scenegraph_sequence is updated per call 304 | ''' 305 | if category in categories: 306 | categories[category]['outputs'] = torch.cat([categories[category]['outputs'], torch.unsqueeze(outputs, dim=0)], dim=0) 307 | categories[category]['labels'] = torch.cat([categories[category]['labels'], torch.unsqueeze(labels, dim=0)], dim=0) 308 | # multi category 309 | if category != 'all': 310 | category = 'all' 311 | categories[category]['outputs'] = torch.cat([categories[category]['outputs'], torch.unsqueeze(outputs, dim=0)], dim=0) 312 | categories[category]['labels'] = torch.cat([categories[category]['labels'], torch.unsqueeze(labels, dim=0)], dim=0) 313 | 314 | # reshape outputs 315 | for k, v in categories.items(): 316 | categories[k]['outputs'] = categories[k]['outputs'].reshape(-1, 2) 317 | 318 | def update_best_metrics(self, metrics, current_epoch): 319 | if metrics['test']['loss'] < self.best_val_loss: 320 | self.best_val_loss = metrics['test']['loss'] 321 | self.best_epoch = current_epoch if current_epoch != None else self.config.epochs 322 | self.best_val_acc = metrics['test']['acc'] 323 | self.best_val_auc = metrics['test']['auc'] 324 | self.best_val_confusion = metrics['test']['confusion'] 325 | self.best_val_f1 = metrics['test']['f1'] 326 | self.best_val_mcc = metrics['test']['mcc'] 327 | self.best_val_acc_balanced = metrics['test']['balanced_acc'] 328 | #self.save_model() 329 | 330 | def save_model(self): 331 | """Function to save the model.""" 332 | saved_path = Path(self.config.model_save_path).resolve() 333 | os.makedirs(os.path.dirname(saved_path), exist_ok=True) 334 | torch.save(self.model.state_dict(), str(saved_path)) 335 | with open(os.path.dirname(saved_path) + "/model_parameters.txt", "w+") as f: 336 | f.write(str(self.config)) 337 | f.write('\n') 338 | f.write(str(' '.join(sys.argv))) 339 | 340 | def load_model(self): 341 | """Function to load the model.""" 342 | saved_path = Path(self.config.model_load_path).resolve() 343 | if saved_path.exists(): 344 | self.build_model() 345 | self.model.load_state_dict(torch.load(str(saved_path))) 346 | self.model.eval() 347 | 348 | def log2wandb(self, metrics): 349 | ''' 350 | Log metrics from all driving categories 351 | ''' 352 | for category in self.unique_clips.keys(): 353 | if category == 'all': 354 | log_wandb(self.config.wandb, metrics) 355 | else: 356 | log_wandb_categories(self.config.wandb, metrics, id=category) -------------------------------------------------------------------------------- /baseline_risk_assessment/train.py: -------------------------------------------------------------------------------- 1 | import random, cv2 2 | from pathlib import Path 3 | import numpy as np 4 | from tqdm import tqdm 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from sklearn.utils import resample 9 | from sklearn.model_selection import train_test_split, StratifiedKFold 10 | from sklearn.utils.class_weight import compute_class_weight 11 | from baseline_risk_assessment.metrics import * 12 | 13 | 14 | class Trainer: 15 | def __init__(self, config): 16 | self.config = config 17 | self.toGPU = lambda x, dtype: torch.as_tensor(x, dtype=dtype, device=self.config.device) 18 | np.random.seed(self.config.seed) 19 | torch.manual_seed(self.config.seed) 20 | self.best_val_loss = 99999 21 | self.best_epoch = 0 22 | self.best_val_acc = 0 23 | self.best_val_auc = 0 24 | self.best_val_confusion = [] 25 | self.best_val_f1 = 0 26 | self.best_val_mcc = -1.0 27 | self.best_val_acc_balanced = 0 28 | self.unique_clips = {} 29 | self.log = False # for logging to wandb 30 | 31 | def build_model(self, model): 32 | self.model = model.to(self.config.device) 33 | self.loss_fn = nn.CrossEntropyLoss() 34 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.learning_rate, weight_decay=self.config.weight_decay) 35 | if self.class_weights.shape[0] < 2: 36 | self.loss_func = torch.nn.CrossEntropyLoss() 37 | else: 38 | self.loss_func = torch.nn.CrossEntropyLoss(weight=self.class_weights.float().to(self.config.device)) 39 | 40 | def reset_weights(self, model): 41 | for layer in model.children(): 42 | if hasattr(layer, 'reset_parameters'): 43 | layer.reset_parameters() 44 | 45 | def init_dataset(self, data, label, clip_name): 46 | nb_samples = data.shape[0] # 574 samples 47 | class_0 = [] 48 | class_1 = [] 49 | class_0_clip_name = [] 50 | class_1_clip_name = [] 51 | 52 | # Store driving categories and their frequencies 53 | for clip in clip_name: 54 | category = clip.split('_')[-1] 55 | if category in self.unique_clips: 56 | self.unique_clips[category] += 1 57 | else: 58 | self.unique_clips[category] = 1 59 | print('Total dataset breakdown: {}'.format(self.unique_clips)) 60 | 61 | # Remove default_all category if more than one category 62 | if len(self.unique_clips.keys()) > 1: 63 | index = np.argwhere(clip_name == 'default_all') 64 | clip_name = np.delete(clip_name, index) 65 | 66 | for idx, l in enumerate(label): 67 | # (data, label) -> (class_0, y_0) non-risky, (class_1, y_1) risky 68 | if (l == [1.0, 0.0]).all(): 69 | class_0.append(data[idx]) 70 | class_0_clip_name.append(clip_name[idx]) 71 | elif (l == [0.0, 1.0]).all(): 72 | class_1.append(data[idx]) 73 | class_1_clip_name.append(clip_name[idx]) 74 | 75 | y_0 = [0] * len(class_0) 76 | y_1 = [1] * len(class_1) 77 | 78 | self.class_0 = np.array(class_0) 79 | self.class_1 = np.array(class_1) 80 | self.class_0_clip_name = np.array(class_0_clip_name) 81 | self.class_1_clip_name = np.array(class_1_clip_name) 82 | self.y_0 = np.array(y_0, dtype=np.float64) 83 | self.y_1 = np.array(y_1, dtype=np.float64) 84 | 85 | # balance the dataset 86 | min_number = min(len(class_0), len(class_1)) 87 | if self.config.downsample: 88 | if len(class_0) > len(class_1): 89 | self.class_0, self.class_0_clip_name, self.y_0 = resample(class_0, class_0_clip_name, y_0, n_samples=min_number, random_state=self.config.seed); 90 | else: 91 | self.class_1, self.class_1_clip_name, self.y_1 = resample(class_1, class_1_clip_name, y_1, n_samples=min_number, random_state=self.config.seed); 92 | self.split_dataset() 93 | 94 | def split_dataset(self): 95 | self.training_x, self.testing_x, self.training_clip_name, self.testing_clip_name, self.training_y, self.testing_y = train_test_split( 96 | np.concatenate([self.class_0, self.class_1], axis=0), 97 | np.concatenate([self.class_0_clip_name, self.class_1_clip_name], axis=0), 98 | np.concatenate([self.y_0, self.y_1], axis=0), 99 | test_size=1-self.config.train_ratio, 100 | shuffle=True, 101 | stratify=np.concatenate([self.y_0, self.y_1], axis=0), 102 | random_state=self.config.seed, 103 | ) 104 | self.class_weights = torch.from_numpy(compute_class_weight('balanced', np.unique(self.training_y), self.training_y)) 105 | if self.config.n_folds <= 1: 106 | print("Number of Training Sequences Included: ", len(self.training_x)) 107 | print("Number of Testing Sequences Included: ", len(self.testing_x)) 108 | print("Num of Training Labels in Each Class: " + str(np.unique(self.training_y, return_counts=True)[1]) + ", Class Weights: " + str(self.class_weights)) 109 | print("Num of Testing Labels in Each Class: " + str(np.unique(self.testing_y, return_counts=True)[1]) + ", Class Weights: " + str(self.class_weights)) 110 | 111 | def train_n_fold_cross_val(self): 112 | # KFold cross validation with similar class distribution in each fold 113 | skf = StratifiedKFold(n_splits=self.config.n_folds) 114 | X = np.append(self.training_x, self.testing_x, axis=0) 115 | clip_name = np.append(self.training_clip_name, self.testing_clip_name, axis=0) 116 | y = np.append(self.training_y, self.testing_y, axis=0) 117 | 118 | # self.results stores average metrics for the the n_folds 119 | self.results = {} 120 | self.fold = 1 121 | 122 | # Split training and testing data based on n_splits (Folds) 123 | for train_index, test_index in skf.split(X, y): 124 | self.training_x, self.testing_x, self.training_clip_name, self.testing_clip_name, self.training_y, self.testing_y = None, None, None, None, None, None #clear vars to save memory 125 | X_train, X_test = X[train_index], X[test_index] 126 | clip_train, clip_test = clip_name[train_index], clip_name[test_index] 127 | y_train, y_test = y[train_index], y[test_index] 128 | self.class_weights = torch.from_numpy(compute_class_weight('balanced', np.unique(y_train), y_train)) 129 | 130 | # Update dataset 131 | self.training_x = X_train 132 | self.testing_x = X_test 133 | self.training_clip_name = clip_train 134 | self.testing_clip_name = clip_test 135 | self.training_y = y_train 136 | self.testing_y = y_test 137 | 138 | print('\nFold {}'.format(self.fold)) 139 | print("Number of Training Sequences Included: ", len(X_train)) 140 | print("Number of Testing Sequences Included: ", len(X_test)) 141 | print("Num of Training Labels in Each Class: " + str(np.unique(self.training_y, return_counts=True)[1]) + ", Class Weights: " + str(self.class_weights)) 142 | print("Num of Testing Labels in Each Class: " + str(np.unique(self.testing_y, return_counts=True)[1]) + ", Class Weights: " + str(self.class_weights)) 143 | 144 | self.best_val_loss = 99999 145 | self.train_model() 146 | self.log = True 147 | categories_train, categories_test, metrics = self.eval_model(self.fold) 148 | self.update_cross_valid_metrics(categories_train, categories_test, metrics) 149 | self.log = False 150 | 151 | if self.fold != self.config.n_folds: 152 | self.reset_weights(self.model) 153 | del self.optimizer 154 | self.build_model(self.model) 155 | 156 | self.fold += 1 157 | del self.results 158 | 159 | def train_model(self): 160 | tqdm_bar = tqdm(range(self.config.epochs)) 161 | for epoch_idx in tqdm_bar: # iterate through epoch 162 | acc_loss_train = 0 163 | permutation = np.random.permutation(len(self.training_x)) # shuffle dataset before each epoch 164 | self.model.train() 165 | 166 | for i in range(0, len(self.training_x), self.config.batch_size): # iterate through batches of the dataset 167 | batch_index = i + self.config.batch_size if i + self.config.batch_size <= len(self.training_x) else len(self.training_x) 168 | indices = permutation[i:batch_index] 169 | batch_x, batch_y = self.training_x[indices], self.training_y[indices] 170 | batch_x, batch_y = self.toGPU(batch_x, torch.float32), self.toGPU(batch_y, torch.long) 171 | output = self.model.forward(batch_x).view(-1, 2) 172 | loss_train = self.loss_func(output, batch_y) 173 | loss_train.backward() 174 | acc_loss_train += loss_train.detach().cpu().item() * len(indices) 175 | self.optimizer.step() 176 | del loss_train 177 | 178 | acc_loss_train /= len(self.training_x) 179 | tqdm_bar.set_description('Epoch: {:04d}, loss_train: {:.4f}'.format(epoch_idx, acc_loss_train)) 180 | 181 | # no cross validation 182 | if epoch_idx % self.config.test_step == 0: 183 | self.eval_model(epoch_idx) 184 | 185 | def model_inference(self, X, y, clip_name): 186 | labels = torch.LongTensor().to(self.config.device) 187 | outputs = torch.FloatTensor().to(self.config.device) 188 | # Dictionary storing (output, label) pair for all driving categories 189 | categories = dict.fromkeys(self.unique_clips) 190 | for key, val in categories.items(): 191 | categories[key] = {'outputs': outputs, 'labels': labels} 192 | batch_size = self.config.batch_size # NOTE: set to 1 when profiling or calculating inference time. 193 | acc_loss = 0 194 | inference_time = 0 195 | prof_result = "" 196 | 197 | with torch.autograd.profiler.profile(enabled=False, use_cuda=True) as prof: 198 | with torch.no_grad(): 199 | self.model.eval() 200 | 201 | for i in range(0, len(X), batch_size): # iterate through subsequences 202 | batch_index = i + batch_size if i + batch_size <= len(X) else len(X) 203 | batch_x, batch_y, batch_clip_name = X[i:batch_index], y[i:batch_index], clip_name[i:batch_index] 204 | batch_x, batch_y = self.toGPU(batch_x, torch.float32), self.toGPU(batch_y, torch.long) 205 | #start = torch.cuda.Event(enable_timing=True) 206 | #end = torch.cuda.Event(enable_timing=True) 207 | #start.record() 208 | output = self.model.forward(batch_x).view(-1, 2) 209 | #end.record() 210 | #torch.cuda.synchronize() 211 | inference_time += 0#start.elapsed_time(end) 212 | loss_test = self.loss_func(output, batch_y) 213 | acc_loss += loss_test.detach().cpu().item() * len(batch_y) 214 | # store output, label statistics 215 | self.update_categorical_outputs(categories, output, batch_y, batch_clip_name) 216 | 217 | # calculate one risk score per sequence (this is not implemented for each category) 218 | sum_seq_len = 0 219 | num_risky_sequences = 0 220 | num_safe_sequences = 0 221 | correct_risky_seq = 0 222 | correct_safe_seq = 0 223 | incorrect_risky_seq = 0 224 | incorrect_safe_seq = 0 225 | sequences = len(categories['all']['labels']) 226 | for indices in range(sequences): 227 | seq_output = categories['all']['outputs'][indices] 228 | label = categories['all']['labels'][indices] 229 | pred = torch.argmax(seq_output) 230 | 231 | # risky clip 232 | if label == 1: 233 | num_risky_sequences += 1 234 | sum_seq_len += seq_output.shape[0] 235 | correct_risky_seq += self.correctness(label, pred) 236 | incorrect_risky_seq += self.correctness(label, pred) 237 | # non-risky clip 238 | elif label == 0: 239 | num_safe_sequences += 1 240 | incorrect_safe_seq += self.correctness(label, pred) 241 | correct_safe_seq += self.correctness(label, pred) 242 | 243 | avg_risky_seq_len = sum_seq_len / num_risky_sequences # sequence length for comparison with the prediction frame metric. 244 | seq_tpr = correct_risky_seq / num_risky_sequences 245 | seq_fpr = incorrect_safe_seq / num_safe_sequences 246 | seq_tnr = correct_safe_seq / num_safe_sequences 247 | seq_fnr = incorrect_risky_seq / num_risky_sequences 248 | if prof != None: 249 | prof_result = prof.key_averages().table(sort_by="cuda_time_total") 250 | 251 | return categories, \ 252 | acc_loss/len(X), \ 253 | avg_risky_seq_len, \ 254 | inference_time, \ 255 | prof_result, \ 256 | seq_tpr, \ 257 | seq_fpr, \ 258 | seq_tnr, \ 259 | seq_fnr 260 | 261 | def eval_model(self, current_epoch=None): 262 | metrics = {} 263 | categories_train, \ 264 | acc_loss_train, \ 265 | train_avg_seq_len, \ 266 | train_inference_time, \ 267 | train_profiler_result, \ 268 | seq_tpr, seq_fpr, seq_tnr, seq_fnr = self.model_inference(self.training_x, self.training_y, self.training_clip_name) 269 | 270 | # Collect metrics from all driving categories 271 | for category in self.unique_clips.keys(): 272 | if category == 'all': 273 | metrics['train'] = get_metrics(categories_train['all']['outputs'], categories_train['all']['labels']) 274 | metrics['train']['loss'] = acc_loss_train 275 | metrics['train']['avg_seq_len'] = train_avg_seq_len 276 | metrics['train']['seq_tpr'] = seq_tpr 277 | metrics['train']['seq_tnr'] = seq_tnr 278 | metrics['train']['seq_fpr'] = seq_fpr 279 | metrics['train']['seq_fnr'] = seq_fnr 280 | else: 281 | metrics['train'][category] = get_metrics(categories_train[category]['outputs'], categories_train[category]['labels']) 282 | 283 | categories_test, \ 284 | acc_loss_test, \ 285 | val_avg_seq_len, \ 286 | test_inference_time, \ 287 | test_profiler_result, \ 288 | seq_tpr, seq_fpr, seq_tnr, seq_fnr = self.model_inference(self.testing_x, self.testing_y, self.testing_clip_name) 289 | 290 | # Collect metrics from all driving categories 291 | for category in self.unique_clips.keys(): 292 | if category == 'all': 293 | metrics['test'] = get_metrics(categories_test['all']['outputs'], categories_test['all']['labels']) 294 | metrics['test']['loss'] = acc_loss_test 295 | metrics['test']['avg_seq_len'] = val_avg_seq_len 296 | metrics['test']['seq_tpr'] = seq_tpr 297 | metrics['test']['seq_tnr'] = seq_tnr 298 | metrics['test']['seq_fpr'] = seq_fpr 299 | metrics['test']['seq_fnr'] = seq_fnr 300 | metrics['avg_inf_time'] = (train_inference_time + test_inference_time) / ((len(self.training_y) + len(self.testing_y))*5) 301 | else: 302 | metrics['test'][category] = get_metrics(categories_test[category]['outputs'], categories_test[category]['labels']) 303 | 304 | 305 | print("\ntrain loss: " + str(acc_loss_train) + ", acc:", metrics['train']['acc'], metrics['train']['confusion'], "mcc:", metrics['train']['mcc'], \ 306 | "\ntest loss: " + str(acc_loss_test) + ", acc:", metrics['test']['acc'], metrics['test']['confusion'], "mcc:", metrics['test']['mcc']) 307 | 308 | self.update_best_metrics(metrics, current_epoch) 309 | metrics['best_epoch'] = self.best_epoch 310 | metrics['best_val_loss'] = self.best_val_loss 311 | metrics['best_val_acc'] = self.best_val_acc 312 | metrics['best_val_auc'] = self.best_val_auc 313 | metrics['best_val_conf'] = self.best_val_confusion 314 | metrics['best_val_f1'] = self.best_val_f1 315 | metrics['best_val_mcc'] = self.best_val_mcc 316 | metrics['best_val_acc_balanced'] = self.best_val_acc_balanced 317 | 318 | if self.config.n_folds <= 1 or self.log: 319 | self.log2wandb(metrics) 320 | 321 | return categories_train, categories_test, metrics 322 | 323 | 324 | #automatically save the model and metrics with the lowest validation loss 325 | def update_best_metrics(self, metrics, current_epoch): 326 | if metrics['test']['loss'] < self.best_val_loss: 327 | self.best_val_loss = metrics['test']['loss'] 328 | self.best_epoch = current_epoch if current_epoch != None else self.config.epochs 329 | self.best_val_acc = metrics['test']['acc'] 330 | self.best_val_auc = metrics['test']['auc'] 331 | self.best_val_confusion = metrics['test']['confusion'] 332 | self.best_val_f1 = metrics['test']['f1'] 333 | self.best_val_mcc = metrics['test']['mcc'] 334 | self.best_val_acc_balanced = metrics['test']['balanced_acc'] 335 | #self.save_model() 336 | 337 | def update_cross_valid_metrics(self, categories_train, categories_test, metrics): 338 | ''' 339 | Stores cross-validation metrics for all driving categories 340 | ''' 341 | datasets = ['train', 'test'] 342 | if self.fold == 1: 343 | for dataset in datasets: 344 | categories = categories_train if dataset == 'train' else categories_test 345 | for category in self.unique_clips.keys(): 346 | if category == 'all': 347 | self.results['outputs'+'_'+dataset] = categories['all']['outputs'] 348 | self.results['labels'+'_'+dataset] = categories['all']['labels'] 349 | self.results[dataset] = metrics[dataset] 350 | self.results[dataset]['loss'] = metrics[dataset]['loss'] 351 | self.results[dataset]['avg_seq_len'] = metrics[dataset]['avg_seq_len'] 352 | 353 | # Best results 354 | self.results['avg_inf_time'] = metrics['avg_inf_time'] 355 | self.results['best_epoch'] = metrics['best_epoch'] 356 | self.results['best_val_loss'] = metrics['best_val_loss'] 357 | self.results['best_val_acc'] = metrics['best_val_acc'] 358 | self.results['best_val_auc'] = metrics['best_val_auc'] 359 | self.results['best_val_conf'] = metrics['best_val_conf'] 360 | self.results['best_val_f1'] = metrics['best_val_f1'] 361 | self.results['best_val_mcc'] = metrics['best_val_mcc'] 362 | self.results['best_val_acc_balanced'] = metrics['best_val_acc_balanced'] 363 | else: 364 | self.results[dataset][category]['outputs'] = categories[category]['outputs'] 365 | self.results[dataset][category]['labels'] = categories[category]['labels'] 366 | 367 | else: 368 | for dataset in datasets: 369 | categories = categories_train if dataset == 'train' else categories_test 370 | for category in self.unique_clips.keys(): 371 | if category == 'all': 372 | self.results['outputs'+'_'+dataset] = torch.cat((self.results['outputs'+'_'+dataset], categories['all']['outputs']), dim=0) 373 | self.results['labels'+'_'+dataset] = torch.cat((self.results['labels'+'_'+dataset], categories['all']['labels']), dim=0) 374 | self.results[dataset]['loss'] = np.append(self.results[dataset]['loss'], metrics[dataset]['loss']) 375 | self.results[dataset]['avg_seq_len'] = np.append(self.results[dataset]['avg_seq_len'], metrics[dataset]['avg_seq_len']) 376 | 377 | # Best results 378 | self.results['avg_inf_time'] = np.append(self.results['avg_inf_time'], metrics['avg_inf_time']) 379 | self.results['best_epoch'] = np.append(self.results['best_epoch'], metrics['best_epoch']) 380 | self.results['best_val_loss'] = np.append(self.results['best_val_loss'], metrics['best_val_loss']) 381 | self.results['best_val_acc'] = np.append(self.results['best_val_acc'], metrics['best_val_acc']) 382 | self.results['best_val_auc'] = np.append(self.results['best_val_auc'], metrics['best_val_auc']) 383 | self.results['best_val_conf'] = np.append(self.results['best_val_conf'], metrics['best_val_conf']) 384 | self.results['best_val_f1'] = np.append(self.results['best_val_f1'], metrics['best_val_f1']) 385 | self.results['best_val_mcc'] = np.append(self.results['best_val_mcc'], metrics['best_val_mcc']) 386 | self.results['best_val_acc_balanced'] = np.append(self.results['best_val_acc_balanced'], metrics['best_val_acc_balanced']) 387 | else: 388 | self.results[dataset][category]['outputs'] = torch.cat((self.results[dataset][category]['outputs'], categories[category]['outputs']), dim=0) 389 | self.results[dataset][category]['labels'] = torch.cat((self.results[dataset][category]['labels'], categories[category]['labels']), dim=0) 390 | 391 | # Log final averaged results 392 | if self.fold == self.config.n_folds: 393 | final_metrics = {} 394 | for dataset in datasets: 395 | for category in self.unique_clips.keys(): 396 | if category == 'all': 397 | final_metrics[dataset] = get_metrics(self.results['outputs'+'_'+dataset], self.results['labels'+'_'+dataset]) 398 | final_metrics[dataset]['loss'] = np.average(self.results[dataset]['loss']) 399 | final_metrics[dataset]['avg_seq_len'] = np.average(self.results[dataset]['avg_seq_len']) 400 | 401 | # Best results 402 | final_metrics['avg_inf_time'] = np.average(self.results['avg_inf_time']) 403 | final_metrics['best_epoch'] = np.average(self.results['best_epoch']) 404 | final_metrics['best_val_loss'] = np.average(self.results['best_val_loss']) 405 | final_metrics['best_val_acc'] = np.average(self.results['best_val_acc']) 406 | final_metrics['best_val_auc'] = np.average(self.results['best_val_auc']) 407 | final_metrics['best_val_conf'] = self.results['best_val_conf'] 408 | final_metrics['best_val_f1'] = np.average(self.results['best_val_f1']) 409 | final_metrics['best_val_mcc'] = np.average(self.results['best_val_mcc']) 410 | final_metrics['best_val_acc_balanced'] = np.average(self.results['best_val_acc_balanced']) 411 | else: 412 | final_metrics[dataset][category] = get_metrics(self.results[dataset][category]['outputs'], self.results[dataset][category]['labels']) 413 | 414 | print('\nFinal Averaged Results') 415 | print("\naverage train loss: " + str(final_metrics['train']['loss']) + ", average acc:", final_metrics['train']['acc'], final_metrics['train']['confusion'], final_metrics['train']['auc'], \ 416 | "\naverage test loss: " + str(final_metrics['test']['loss']) + ", average acc:", final_metrics['test']['acc'], final_metrics['test']['confusion'], final_metrics['test']['auc']) 417 | 418 | self.log2wandb(final_metrics) 419 | 420 | # final combined results and metrics 421 | return self.results['outputs_train'], self.results['labels_train'], self.results['outputs_test'], self.results['labels_test'], final_metrics 422 | 423 | # Utilities 424 | def update_categorical_outputs(self, categories, outputs, labels, clip_name): 425 | ''' 426 | Aggregates output, label pairs for every driving category 427 | ''' 428 | n = len(clip_name) 429 | for i in range(n): 430 | category = clip_name[i].split('_')[-1] 431 | # FIXME: probably better way to do this 432 | if category in categories: 433 | categories[category]['outputs'] = torch.cat([categories[category]['outputs'], torch.unsqueeze(outputs[i], dim=0)], dim=0) 434 | categories[category]['labels'] = torch.cat([categories[category]['labels'], torch.unsqueeze(labels[i], dim=0)], dim=0) 435 | # multi category 436 | if category != 'all': 437 | category = 'all' 438 | categories[category]['outputs'] = torch.cat([categories[category]['outputs'], torch.unsqueeze(outputs[i], dim=0)], dim=0) 439 | categories[category]['labels'] = torch.cat([categories[category]['labels'], torch.unsqueeze(labels[i], dim=0)], dim=0) 440 | 441 | # reshape outputs 442 | for k, v in categories.items(): 443 | categories[k]['outputs'] = categories[k]['outputs'].reshape(-1, 2) 444 | 445 | def preprocess_batch(self, x): 446 | ''' 447 | Apply normalization preprocess to all data 448 | ''' 449 | b = [] 450 | for batch in x: 451 | d = [] 452 | for data in batch: 453 | data = np.moveaxis(data, 0, -1) # move channels to last_dim 454 | d.append(preprocess_image(data)) 455 | b.append(torch.cat(d, axis=0)) 456 | return torch.stack(b, dim=0).type(torch.float32) 457 | 458 | def correctness(self, output, pred): 459 | return 1 if output == pred else 0 460 | 461 | def log2wandb(self, metrics): 462 | ''' 463 | Log metrics from all driving categories 464 | ''' 465 | for category in self.unique_clips.keys(): 466 | if category == 'all': 467 | log_wandb(self.config.wandb, metrics) 468 | else: 469 | log_wandb_categories(self.config.wandb, metrics, id=category) --------------------------------------------------------------------------------