├── README.md ├── configs └── config.py ├── figures ├── archi.png ├── em.png ├── image52.gif └── pytorch-logo-dark.png ├── msp ├── __init__.py ├── datasets │ ├── __init__.py │ ├── builder.py │ ├── dashcams.py │ ├── pipelines │ │ ├── __init__.py │ │ └── pipelines.py │ └── utils │ │ ├── __init__.py │ │ └── metrics.py ├── lightning_model.py ├── models │ ├── __init__.py │ ├── accidentblocks │ │ ├── __init__.py │ │ └── accident_lstm.py │ ├── builder.py │ ├── gates │ │ ├── __init__.py │ │ └── s_t_masker.py │ ├── losses │ │ ├── __init__.py │ │ └── mse_loss.py │ ├── nearfuture │ │ ├── __init__.py │ │ └── near_future.py │ ├── predictors │ │ ├── __init__.py │ │ └── all_batch_version.py │ ├── spatialgraphs │ │ ├── __init__.py │ │ └── spatial_graph_batch.py │ ├── temporalattns │ │ ├── __init__.py │ │ └── s_t_attn.py │ ├── temporalgraphs │ │ ├── __init__.py │ │ └── temporal_graph_batch.py │ └── utils │ │ ├── __init__.py │ │ └── utils.py └── utils │ ├── __init__.py │ └── train_utils.py ├── requirements.txt ├── run.py └── scripts ├── preprecess_data_MASKER_MD.py ├── preprecess_data_MD-SG.py └── preprocess_data_MASKER.py /README.md: -------------------------------------------------------------------------------- 1 | # GSC: A Graph and Spatio-temporal Continuity Based Framework for Accident Anticipation 2 | 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 4 | 5 | The repository contains the source code and pred-trained models of our paper: `GSC: A Graph and Spatio-temporal Continuity Based Framework for Accident Anticipation` 6 | 7 |

8 | 9 |

10 | 11 | ## Architecture 12 | The overview of the network is showed below; 13 | 14 |

15 | 16 |

17 | 18 | ## Prerequisites 19 | 20 | - Python 3.6 21 | - Pytorch 1.7.0 22 | - Pytorch-Lightning 0.9.0 23 | - Other required packages in `requirements.txt` 24 | 25 | ## Getting Started 26 | 27 | ### Create conda environment 28 | 29 | ```bash 30 | conda create -n sspm python=3.6 31 | source activate sspm 32 | ``` 33 | 34 | ### Install the required packages 35 | 36 | ```bash 37 | pip install -r requirements.txt 38 | ``` 39 | 40 | ### Downloading MASKER_MD dataset and unzip it 41 | 42 | - Access the datasets by [BaiduYun](https://pan.baidu.com/s/1TT3AZBBuE-u_zovl6i44iQ)[Passwards: `qo81`], and unzip it. 43 | 44 | - Change the `data_root` 45 | of `configs/config.py` to your unzip path; 46 | 47 | ## Train the Model 48 | 49 | - Run the following command in Terminal: 50 | ```bash 51 | python run.py --train ./configs/config.py 52 | ``` 53 | 54 | ## Test the Model 55 | 56 | - Change the `test_checkpoint` of `configs/config.py` to your model 57 | 58 | - Run the following command in Terminal 59 | ```bash 60 | python run.py --test ./configs/config.py 61 | ``` 62 | 63 | ## Visualize 64 | 65 |

66 | 67 |

68 | -------------------------------------------------------------------------------- /configs/config.py: -------------------------------------------------------------------------------- 1 | name = "STAttnsGraphDashcam" 2 | version = "2.0.0" 3 | model = dict( 4 | type='STAttnsGraphSimAllBatch', 5 | spatial_graph=dict( 6 | type='SpatialGraphBatch', 7 | node_feature=2048, 8 | hidden_feature=1024, 9 | out_feature=512, 10 | ), 11 | temporal_graph=dict( 12 | type='TemporalGraphBatch', 13 | past_num=5, 14 | input_feature=512, 15 | out_feature=512, 16 | ), 17 | gate=dict( 18 | type='MaskGate', 19 | past_num = 10, 20 | ), 21 | temporal_attn=dict( 22 | type='STAttn' 23 | ), 24 | near_future=dict( 25 | type="NearFuture" 26 | ), 27 | accident_block=dict( 28 | type='AccidentLSTM', 29 | temporal_feature=256, 30 | hidden_feature=64, 31 | num_layers=2 32 | ), 33 | loss=dict( 34 | type="LogLoss" 35 | )) 36 | dataset_type = "DashCam" 37 | data_root = "/media/group1/data/tianhang/MASKER_MD" 38 | data = dict( 39 | batch_size=32, 40 | num_workers=1, 41 | train=dict( 42 | type=dataset_type, 43 | root_dir=data_root, 44 | #pipelines=pipelines, 45 | video_list_file=f"{data_root}/train_video_list.txt"), 46 | val=dict( 47 | type=dataset_type, 48 | root_dir=data_root, 49 | #pipelines=pipelines, 50 | video_list_file=f"{data_root}/valid_video_list.txt")) 51 | # training and testing settings 52 | optimizer_cfg = dict( 53 | type='Adam', 54 | lr=0.01, 55 | betas=(0.9, 0.999), 56 | ) 57 | lr_cfg = dict( 58 | type="StepLR", 59 | step_size=50, 60 | gamma=.7) 61 | warm_up_cfg = dict( 62 | type="Exponential", 63 | step_size=5000) 64 | random_seed = 1234 65 | # GPU 66 | num_gpus = [2] 67 | max_epochs = 200 68 | checkpoint_path = "work_dirs/checkpoints" 69 | log_path = "work_dirs/logs" 70 | result_path = "work_dirs/results" 71 | load_from_checkpoint = None 72 | resume_from_checkpoint = None 73 | test_checkpoint = None 74 | batch_accumulate_size = 1 75 | simple_profiler = True -------------------------------------------------------------------------------- /figures/archi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispc-lab/GSC/09fa815f1d092bd5a42a5d45b9266a210677dee7/figures/archi.png -------------------------------------------------------------------------------- /figures/em.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispc-lab/GSC/09fa815f1d092bd5a42a5d45b9266a210677dee7/figures/em.png -------------------------------------------------------------------------------- /figures/image52.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispc-lab/GSC/09fa815f1d092bd5a42a5d45b9266a210677dee7/figures/image52.gif -------------------------------------------------------------------------------- /figures/pytorch-logo-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispc-lab/GSC/09fa815f1d092bd5a42a5d45b9266a210677dee7/figures/pytorch-logo-dark.png -------------------------------------------------------------------------------- /msp/__init__.py: -------------------------------------------------------------------------------- 1 | from .lightning_model import LightningModel 2 | 3 | 4 | __all__ = ['LightningModel'] -------------------------------------------------------------------------------- /msp/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import DATASETS, build_dataset 2 | from .dashcams import * 3 | from .pipelines import * 4 | 5 | __all__ = ["DATASETS", "build_dataset"] 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /msp/datasets/builder.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from mmcv.utils import Registry, build_from_cfg 4 | 5 | DATASETS = Registry('datasets') 6 | PIPELINES = Registry('pipelines') 7 | 8 | 9 | def build_dataset(cfg: dict, default_args: Optional[dict] = None): 10 | return build_from_cfg(cfg, DATASETS, default_args) 11 | 12 | 13 | def build_pipelines(cfg: dict, default_args: Optional[dict] = None): 14 | return build_from_cfg(cfg, PIPELINES, default_args) -------------------------------------------------------------------------------- /msp/datasets/dashcams.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from tqdm import tqdm 4 | import torch 5 | from torch.utils.data.dataset import Dataset 6 | from .builder import DATASETS 7 | from .pipelines import Compose 8 | 9 | 10 | @DATASETS.register_module() 11 | class DashCam(Dataset): 12 | 13 | def __init__(self, 14 | root_dir: str, 15 | video_list_file: str = None, 16 | pipelines: dict = None): 17 | 18 | self.root_dir = root_dir 19 | 20 | self.pipelines = Compose(pipelines) if pipelines is not None else None 21 | 22 | self.__sample__ = open(video_list_file, "r").read().splitlines() 23 | 24 | def __getitem__(self, idx: int) -> dict: 25 | data = np.load(os.path.join(self.root_dir, self.__sample__[idx]), allow_pickle=True).item() 26 | 27 | data['location'][:, :, [1, 3]] /= 1280 28 | data['location'][:, :, [2, 4]] /= 720 29 | 30 | for phase in data: 31 | if (phase != "accident") & (phase != "graph_index"): 32 | data[phase] = torch.from_numpy(data[phase]) 33 | elif phase == "graph_index": 34 | data[phase] = torch.tensor(data[phase]).t().contiguous() 35 | 36 | data['video_id'] = self.__sample__[idx] 37 | 38 | return self.pipelines(data) if self.pipelines is not None else data 39 | 40 | def __len__(self) -> int: 41 | return len(self.__sample__) -------------------------------------------------------------------------------- /msp/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipelines import * 2 | 3 | 4 | #__all__ = ['Compose', 'ToTensor', 'Normalize'] 5 | __all__ = ['Compose'] -------------------------------------------------------------------------------- /msp/datasets/pipelines/pipelines.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Sequence, Dict, Union, Callable 2 | import numpy as np 3 | import torch 4 | from torch import Tensor 5 | from ..builder import PIPELINES, build_pipelines 6 | 7 | 8 | class Compose(object): 9 | 10 | def __init__(self, transforms: Sequence[Union[dict, Callable]]): 11 | self.transforms = [] 12 | for transform in transforms: 13 | if isinstance(transform, dict): 14 | transform = build_pipelines(transform) 15 | self.transforms.append(transform) 16 | elif callable(transform): 17 | self.transforms.append(transform) 18 | else: 19 | raise TypeError('transform must be callable or a dict') 20 | 21 | def __call__(self, data: Dict[str, Dict[str, Union[Tensor, np.ndarray]]]) \ 22 | -> Dict[str, Dict[str, Union[Tensor, np.ndarray]]]: 23 | 24 | for t in self.transforms: 25 | data = t(data) 26 | return data 27 | 28 | def __repr__(self): 29 | format_string = self.__class__.__name__ + '(' 30 | for t in self.transforms: 31 | format_string += '\n' 32 | format_string += f' {t}' 33 | format_string += '\n)' 34 | return format_string -------------------------------------------------------------------------------- /msp/datasets/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import evaluate_metric 2 | 3 | __all__ = ['evaluate_metric'] -------------------------------------------------------------------------------- /msp/datasets/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def evaluate_metric(pred_np, label, fps=25): 5 | all_pred = pred_np[:,:,1] # shape [batch_size, frames] 6 | accident_status = label['accident'].cpu().numpy() # shape [batch_size, ] 7 | accident_time = 76 8 | pred_eval = [] 9 | min_pred = np.inf 10 | n_frames = 0 11 | 12 | 13 | # access the frames before accident 14 | for idx, toa in enumerate(accident_status): 15 | if toa == True: 16 | pred = all_pred[idx, :int(accident_time)] # positive video 17 | else: 18 | pred = all_pred[idx, :] # negtive video 19 | # find the minimum prediction 20 | min_pred = np.min(pred) if min_pred > np.min(pred) else min_pred 21 | pred_eval.append(pred) 22 | n_frames += len(pred) 23 | total_seconds = all_pred.shape[1] / fps 24 | # iterate a set of thresholds from the minimum predications 25 | threholds = np.arange(max(min_pred,0), 1.0, 0.001) 26 | threholds_num = threholds.shape[0] 27 | Precision = np.zeros((threholds_num)) 28 | Recall = np.zeros((threholds_num)) 29 | Time = np.zeros((threholds_num)) 30 | cnt = 0 31 | for Th in threholds: 32 | Tp = 0.0 33 | Tp_Fp = 0.0 34 | Tp_Tn = 0.0 35 | time = 0.0 36 | counter = 0.0 # number of TP videos 37 | # iterate each video sample 38 | for i in range(len(pred_eval)): 39 | # ture positive frames: (pred->1) & (gt->True) 40 | tp = np.where(pred_eval[i] * accident_status[i] >=Th) 41 | Tp += float(len(tp[0])>0) 42 | if float(len(tp[0])>0) > 0: 43 | time += tp[0][0] / float(accident_time) 44 | counter = counter + 1 45 | Tp_Fp += float(len(np.where(pred_eval[i]>=Th)[0])>0) 46 | 47 | if Tp_Fp == 0: 48 | continue 49 | else: 50 | Precision[cnt] = Tp/Tp_Fp 51 | if np.sum(accident_status)==0: 52 | continue 53 | else: 54 | Recall[cnt] = Tp/np.sum(accident_status) 55 | if counter == 0: 56 | continue 57 | else: 58 | Time[cnt] = (1-time/counter) 59 | cnt += 1 60 | 61 | new_index = np.argsort(Recall) 62 | Precision = Precision[new_index] 63 | Recall = Recall[new_index] 64 | 65 | p_r_plot = {'precision':Precision, 66 | 'recall':Recall} 67 | 68 | Time = Time[new_index] 69 | # Unique the recall 70 | _, rep_index = np.unique(Recall, return_index=1) 71 | rep_index = rep_index[1:] 72 | new_Time = np.zeros(len(rep_index)) 73 | new_Precision = np.zeros(len(rep_index)) 74 | for i in range(len(rep_index)-1): 75 | new_Time[i] = np.max(Time[rep_index[i]:rep_index[i+1]]) 76 | new_Precision[i] = np.max(Precision[rep_index[i]:rep_index[i+1]]) 77 | # sort by descending order 78 | if rep_index.size != 0: 79 | new_Time[-1] = Time[rep_index[-1]] 80 | new_Precision[-1] = Precision[rep_index[-1]] 81 | new_Recall = Recall[rep_index] 82 | else: 83 | new_Recall = Recall 84 | # compute AP (area under P-R curve) 85 | AP = 0.0 86 | if new_Recall[0] != 0: 87 | AP += new_Precision[0]*(new_Recall[0]-0) 88 | for i in range(1,len(new_Precision)): 89 | AP += (new_Precision[i-1]+new_Precision[i])*(new_Recall[i]-new_Recall[i-1])/2 90 | 91 | # transform the relative mTTA to seconds 92 | mTTA = np.mean(new_Time) * total_seconds 93 | #print("Average Precision= %.4f, mean Time to accident= %.4f"%(AP, mTTA)) 94 | sort_time = new_Time[np.argsort(new_Recall)] 95 | sort_recall = np.sort(new_Recall) 96 | TTA_R80 = sort_time[np.argmin(np.abs(sort_recall-0.8))] * total_seconds 97 | #print("Recall@80%, Time to accident= " +"{:.4}".format(TTA_R80)) 98 | 99 | return AP, mTTA, TTA_R80, p_r_plot 100 | 101 | 102 | -------------------------------------------------------------------------------- /msp/lightning_model.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import pytorch_lightning as pl 4 | import torch.optim as optim 5 | import torch 6 | from mmcv.utils import Config 7 | from torch.utils.data import DataLoader 8 | from .models import build_predictor 9 | from .datasets.utils.metrics import evaluate_metric 10 | from .datasets import build_dataset 11 | from matplotlib import pyplot as plt 12 | 13 | import os 14 | from .utils.train_utils import summarize_metric 15 | 16 | from torch.utils.tensorboard import SummaryWriter 17 | 18 | class LightningModel(pl.LightningModule): 19 | 20 | def __init__(self, cfg: Config): 21 | super(LightningModel, self).__init__() 22 | 23 | self.cfg = cfg 24 | self.model_cfg = cfg.model 25 | self.data_cfg = cfg.data 26 | self.optim_cfg = cfg.optimizer_cfg 27 | self.lr_cfg = cfg.lr_cfg 28 | 29 | self.model = build_predictor(self.model_cfg) 30 | self.hparams = dict(lr=self.optim_cfg.lr, 31 | batch_size=self.data_cfg.batch_size * cfg.batch_accumulate_size) 32 | 33 | self.whether_save = True 34 | if self.whether_save: 35 | self.writer = SummaryWriter(comment="Training") 36 | self.writer.add_hparams({'lr':self.optim_cfg.lr, 'bsize':self.data_cfg.batch_size * cfg.batch_accumulate_size}, {}) 37 | self.train_step = 1 38 | self.val_step = 1 39 | self.tests_step = 1 40 | 41 | def forward(self, data, **kwargs): 42 | pred, masker_pred = self.model(data, **kwargs) 43 | return pred, masker_pred 44 | 45 | def training_step(self, batch, batch_idx): 46 | data = batch 47 | pred, masker_pred = self(data, module="Train") 48 | acdt_loss, masker_loss = self.model.loss(pred, masker_pred, data) 49 | loss = acdt_loss + masker_loss 50 | 51 | with torch.no_grad(): 52 | masker_pred_np = masker_pred.cpu().numpy() 53 | 54 | shelter_precision = (masker_pred_np[masker_pred_np[:,-1] == 1., 1] >= 0.5).sum() / (masker_pred_np[:,-1] == 1.).sum() 55 | go_away_precision = (masker_pred_np[masker_pred_np[:,-1] == 0., 0] >= 0.5).sum() / (masker_pred_np[:,-1] == 0.).sum() 56 | precision = (np.sum((np.where(masker_pred_np[:, :2] >= 0.50)[1] == masker_pred_np[:, -1])!=0)) / masker_pred_np.shape[0] 57 | 58 | if self.whether_save: 59 | self.writer.add_scalar('Training/acdt_loss', acdt_loss, self.train_step) 60 | self.writer.add_scalar('Training/masker_loss', masker_loss, self.train_step) 61 | self.writer.add_scalar('Training/shelter_precision', shelter_precision, self.train_step) 62 | self.writer.add_scalar('Training/go_away_precision', go_away_precision, self.train_step) 63 | self.writer.add_scalar('Training/precision', precision, self.train_step) 64 | 65 | 66 | logs = {"loss": loss} 67 | self.train_step += 1 68 | 69 | return {"loss": loss, 70 | "log": logs} 71 | 72 | def validation_step(self, batch, batch_idx): 73 | data = batch 74 | pred_np, pred_gate_np = self.model.predict(data, module="Test") 75 | ap, mtta, tta_r80, p_r_plot = evaluate_metric(pred_np, data) 76 | 77 | self.val_step += 1 78 | 79 | ap_list = [] 80 | mtta_list = [] 81 | tta_r80_list = [] 82 | 83 | ap_list.append(ap) 84 | mtta_list.append(mtta) 85 | tta_r80_list.append(tta_r80) 86 | 87 | logs = {"ap":ap_list, "mtta":mtta_list, "tta_r80":tta_r80_list} 88 | return logs 89 | 90 | def validation_epoch_end(self, output): 91 | average_ap, average_mtta, average_tta_r80 = summarize_metric(output) 92 | if self.whether_save: 93 | self.writer.add_scalar('Metrics/AP', average_ap, self.tests_step) 94 | self.writer.add_scalar('Metrics/mtta', average_mtta, self.tests_step) 95 | self.writer.add_scalar('Metrics/TTA_R80', average_tta_r80, self.tests_step) 96 | self.tests_step += 1 97 | return {"output": output} 98 | 99 | def test_step(self, batch, batch_idx): 100 | data = batch 101 | pred_np = self.model.predict(data, module="Test") 102 | for i in range(data['location'].shape[0]): 103 | root_fill = './results/' + str(149) + "/" 104 | if not os.path.exists(root_fill): 105 | os.makedirs(root_fill) 106 | file_name = str(data["video_id"][i]) + ".csv" 107 | np.savetxt(root_fill + file_name, pred_np[i][:,1], delimiter=',') 108 | 109 | return None 110 | 111 | def prepare_data(self): 112 | self.train_dataset = build_dataset(self.data_cfg.train) 113 | self.val_dataset = build_dataset(self.data_cfg.val) 114 | 115 | def train_dataloader(self): 116 | return DataLoader(self.train_dataset, 117 | batch_size=self.data_cfg.batch_size, 118 | shuffle=False, 119 | num_workers=self.data_cfg.num_workers) 120 | 121 | def val_dataloader(self): 122 | return DataLoader(self.val_dataset, 123 | batch_size=self.data_cfg.batch_size, 124 | shuffle=False, 125 | num_workers=self.data_cfg.num_workers) 126 | 127 | def test_dataloader(self): 128 | return DataLoader(self.val_dataset, 129 | batch_size=1, 130 | shuffle=False, 131 | num_workers=self.data_cfg.num_workers) 132 | 133 | def configure_optimizers(self): 134 | 135 | optim_cfg = self.optim_cfg.copy() 136 | optim_class = getattr(optim, optim_cfg.pop("type")) 137 | optimizer = optim_class(self.parameters(), **optim_cfg) 138 | 139 | lr_cfg = self.lr_cfg.copy() 140 | lr_sheduler_class = getattr(optim.lr_scheduler, lr_cfg.pop("type")) 141 | scheduler = { 142 | "scheduler": lr_sheduler_class(optimizer, **lr_cfg), 143 | "monitor": 'avg_val_loss', 144 | "interval": "epoch", 145 | "frequency": 1 146 | } 147 | 148 | return [optimizer], [scheduler] 149 | 150 | def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure = None, 151 | on_tpu: bool = False, using_native_amp=False, using_lbfgs: bool = False): 152 | 153 | warm_up_type, warm_up_step = self.cfg.warm_up_cfg.type, self.cfg.warm_up_cfg.step_size 154 | if warm_up_type == 'Exponential': 155 | lr_scale = self.model_cfg.spatial_graph.hidden_feature ** -0.5 156 | lr_scale *= min((self.trainer.global_step + 1) ** (-0.5), 157 | (self.trainer.global_step + 1) * warm_up_step ** (-1.5)) 158 | elif warm_up_type == "Linear": 159 | lr_scale = min(1., float(self.trainer.global_step + 1) / warm_up_step) 160 | else: 161 | raise NotImplementedError 162 | 163 | for pg in optimizer.param_groups: 164 | # import pdb; pdb.set_trace() 165 | if self.whether_save: 166 | self.writer.add_scalar('Training/lr', pg['lr'], self.train_step) 167 | pg['lr'] = lr_scale * self.hparams.lr 168 | 169 | optimizer.step() 170 | optimizer.zero_grad() 171 | 172 | 173 | 174 | 175 | 176 | 177 | -------------------------------------------------------------------------------- /msp/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import (SGMODEL, TGMODEL, ATTNMODEL, ACCIDENTMODEL, PREDICTORS, LOSSES, GATEMODEL, NEARFUTUREMODEL, 2 | build_spatial, build_temporal, build_attn, build_accident_block, build_predictor, build_loss, build_gate, build_nearfuture) 3 | 4 | from .spatialgraphs import * 5 | from .temporalgraphs import * 6 | from .temporalattns import * 7 | from .nearfuture import * 8 | from .accidentblocks import * 9 | from .gates import * 10 | from .predictors import * 11 | from .losses import * 12 | 13 | 14 | __all__ = ["SGMODEL", "TGMODEL", "PREDICTORS", "ATTNMODEL", "ACCIDENTMODEL", "LOSSES", "GATEMODEL", "NEARFUTUREMODEL", 15 | "build_spatial", "build_temporal", "build_predictor", "build_attn", "build_accident_block", "build_loss", "build_gate", "build_nearfuture"] 16 | 17 | 18 | -------------------------------------------------------------------------------- /msp/models/accidentblocks/__init__.py: -------------------------------------------------------------------------------- 1 | from .accident_lstm import AccidentLSTM 2 | 3 | __all__ = ['AccidentLSTM'] -------------------------------------------------------------------------------- /msp/models/accidentblocks/accident_lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from ..builder import ACCIDENTMODEL 5 | 6 | 7 | @ACCIDENTMODEL.register_module() 8 | class AccidentLSTM(nn.Module): 9 | def __init__(self, 10 | temporal_feature: int, 11 | hidden_feature: int, 12 | num_layers: int): 13 | super(AccidentLSTM, self).__init__() 14 | 15 | self.acc_lstm = torch.nn.LSTM(temporal_feature, hidden_feature, num_layers=num_layers) 16 | self.hidden_feature = hidden_feature 17 | self.pred = nn.Sequential( 18 | nn.Linear(hidden_feature, 64), 19 | nn.ReLU(), 20 | nn.Linear(64, 2), 21 | nn.Softmax(dim=-1)) 22 | 23 | def forward(self, 24 | inputs, 25 | h0, 26 | c0): 27 | 28 | output, (hn, cn) = self.acc_lstm(inputs, (h0, c0)) 29 | 30 | return self.pred(output) 31 | 32 | def _reset_parameters(self): 33 | for p in self.parameters(): 34 | if p.dim() > 1: 35 | nn.init.xavier_uniform_(p) 36 | 37 | -------------------------------------------------------------------------------- /msp/models/builder.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Optional, Any, Dict, List 2 | import torch.nn as nn 3 | from mmcv.utils import Registry, Config, build_from_cfg 4 | 5 | 6 | SGMODEL = Registry("spatialgraphs") 7 | TGMODEL = Registry("temporalgraphs") 8 | ATTNMODEL = Registry("temporalattns") 9 | NEARFUTUREMODEL = Registry("nearfuture") 10 | ACCIDENTMODEL = Registry("accidentblocks") 11 | GATEMODEL = Registry("gates") 12 | LOSSES = Registry("losses") 13 | PREDICTORS = Registry("predictors") 14 | 15 | 16 | def build(cfg: Union[Dict, List[Dict]], 17 | registry: Registry, 18 | default_args: Optional[Dict] = None) -> Any: 19 | """Build a module. 20 | 21 | Args: 22 | cfg: The config of modules, is is either a dict or a list of configs. 23 | registry: A registry the module belongs to. 24 | default_args: Default arguments to build the module. Defaults to None. 25 | 26 | Returns: 27 | nn.Module: A built nn module. 28 | """ 29 | if isinstance(cfg, list): 30 | modules = [build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg] 31 | return nn.Sequential(*modules) 32 | else: 33 | return build_from_cfg(cfg, registry, default_args) 34 | 35 | 36 | def build_spatial(cfg: Union[Dict, List[Dict]]) -> Any: 37 | """Build spatial graph""" 38 | return build(cfg, SGMODEL) 39 | 40 | 41 | def build_temporal(cfg: Union[Dict, List[Dict]]) -> Any: 42 | """Build temporal graph""" 43 | return build(cfg, TGMODEL) 44 | 45 | 46 | def build_attn(cfg: Union[Dict, List[Dict]]) -> Any: 47 | """Build temporal attn""" 48 | return build(cfg, ATTNMODEL) 49 | 50 | def build_nearfuture(cfg: Union[Dict, List[Dict]]) -> Any: 51 | """Build temporal attn""" 52 | return build(cfg, NEARFUTUREMODEL) 53 | 54 | 55 | def build_accident_block(cfg: Union[Dict, List[Dict]]) -> Any: 56 | """Build accident block""" 57 | return build(cfg, ACCIDENTMODEL) 58 | 59 | 60 | def build_gate(cfg: Union[Dict, List[Dict]]) -> Any: 61 | """Build gate""" 62 | return build(cfg, GATEMODEL) 63 | 64 | 65 | def build_loss(cfg: Union[Dict, List[Dict]]) -> Any: 66 | """Build loss""" 67 | return build(cfg, LOSSES) 68 | 69 | 70 | def build_predictor(cfg: Union[Dict, List[Dict]]) -> Any: 71 | """Build model""" 72 | return build(cfg, PREDICTORS) 73 | 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /msp/models/gates/__init__.py: -------------------------------------------------------------------------------- 1 | from .s_t_masker import MaskGate 2 | 3 | __all__ = ['MaskGate'] -------------------------------------------------------------------------------- /msp/models/gates/s_t_masker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ..builder import GATEMODEL 4 | 5 | 6 | @GATEMODEL.register_module() 7 | class MaskGate(torch.nn.Module): 8 | def __init__(self, past_num): 9 | 10 | super(MaskGate, self).__init__() 11 | self.gate_cnn = nn.Sequential( 12 | nn.Conv2d(512, 256, (1,1), stride=1), 13 | nn.ReLU(), 14 | nn.Conv2d(256, 128, (1,1), stride=1), 15 | nn.ReLU() 16 | ) 17 | self.gate_fc = nn.Sequential( 18 | nn.Linear(128, 64), 19 | nn.ReLU(), 20 | nn.Linear(64, 2), 21 | nn.Softmax(dim=-1) 22 | ) 23 | self.past_num = past_num 24 | 25 | def forward(self, 26 | param_dict, 27 | module): 28 | 29 | if module == "Train": 30 | pred = self.masker_train(param_dict["select_feature"]) 31 | return pred 32 | 33 | elif module == "Test": 34 | m_t, m_d, fill = self.masker_test( 35 | param_dict["all_mask"], 36 | param_dict["present_mask"], 37 | param_dict["feature_bank"], 38 | param_dict["time"], 39 | param_dict["location_all"], 40 | param_dict["masker_dict"]) 41 | return m_t, m_d, fill 42 | 43 | def masker_train(self, 44 | select_feature): 45 | 46 | select_feature_un = select_feature.unsqueeze(-1).unsqueeze(-1) 47 | temp = self.gate_cnn(select_feature_un) 48 | temp_sq = temp.squeeze(-1).squeeze(-1) 49 | pred_status = self.gate_fc(temp_sq) 50 | 51 | return pred_status 52 | 53 | 54 | def masker_test(self, 55 | all_mask, 56 | present_mask, 57 | temporal_feature_bank, 58 | t, 59 | location_all, 60 | masker_dict): 61 | 62 | if len(all_mask) == 0: 63 | 64 | return present_mask, masker_dict, (torch.empty(([0,])).type_as(location_all), 65 | torch.empty(([0,])).type_as(location_all)) 66 | else: 67 | # mask before 68 | before_mask = all_mask[-1] # shape: batch_size x objects 69 | # compare from True to False 70 | index = torch.where((before_mask==True) & (present_mask==False)) 71 | 72 | if min(index[0].shape) != 0: 73 | un_index, indexed = self.get_index(t, index, masker_dict, present_mask) 74 | present_mask[indexed] = True 75 | temporal_feature_bank_tensor = torch.stack(temporal_feature_bank)[-1,:,:,:] 76 | 77 | tfbt_un = temporal_feature_bank_tensor[un_index].unsqueeze(-1).unsqueeze(-1) 78 | temp = self.gate_cnn(tfbt_un) 79 | temp_sq = temp.squeeze(-1).squeeze(-1) 80 | gate_status = self.gate_fc(temp_sq) 81 | 82 | near_feature_infor = location_all[:,t+1:t+self.past_num+1,:,:] 83 | n_index = near_feature_infor[un_index[0],:,un_index[1], :] 84 | test_label = torch.sum(n_index[:,:,1:5].flatten(start_dim=1), dim=1, keepdim=True).ge(0.01) 85 | 86 | shelter_objects = torch.where(gate_status[:, 1] >= 0.50) 87 | 88 | shelter_index = (un_index[0][shelter_objects], un_index[1][shelter_objects]) 89 | present_mask[shelter_index] = True 90 | fill_position_index = (torch.cat((indexed[0], shelter_index[0])), 91 | torch.cat((indexed[1], shelter_index[1]))) 92 | 93 | training_test_samples = torch.hstack((gate_status, test_label)) 94 | 95 | masker_dict['index'][0] = torch.cat((masker_dict['index'][0], un_index[0])) 96 | masker_dict['index'][1] = torch.cat((masker_dict['index'][1], un_index[1])) 97 | masker_dict['pred_label'] = torch.cat((masker_dict['pred_label'], training_test_samples)) 98 | 99 | for i in range(un_index[0].shape[0]): 100 | masker_dict['timestamp'].append(t) 101 | 102 | return present_mask, masker_dict, fill_position_index 103 | 104 | else: 105 | 106 | return present_mask, masker_dict, (torch.empty(([0,])).type_as(location_all), 107 | torch.empty(([0,])).type_as(location_all)) 108 | 109 | def get_index(self, 110 | t, 111 | present_index, 112 | index_dict, 113 | present_mask): 114 | 115 | un_index = [] 116 | indexed = [] 117 | for i in range(present_index[0].shape[0]): 118 | batch_index = torch.where(index_dict['index'][0]==present_index[0][i]) 119 | if min(batch_index[0].shape) != 0: 120 | object_index = torch.where(index_dict['index'][1][batch_index]==present_index[1][i]) 121 | if min(object_index[0].shape) != 0 : 122 | # check time 123 | nearest_timestamps_index = batch_index[0][object_index[0]][-1] 124 | if (t - index_dict['timestamp'][nearest_timestamps_index] <= self.past_num) & \ 125 | (present_mask[present_index[0][i].item(), present_index[1][i].item()].item() == False): 126 | indexed.append(i) 127 | else: 128 | un_index.append(i) 129 | else: 130 | un_index.append(i) 131 | else: 132 | un_index.append(i) 133 | 134 | return (present_index[0][un_index], present_index[1][un_index]), \ 135 | (present_index[0][indexed], present_index[1][indexed]) 136 | 137 | def _reset_parameters(self): 138 | for p in self.parameters(): 139 | if p.dim() > 1: 140 | nn.init.xavier_uniform_(p) -------------------------------------------------------------------------------- /msp/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .mse_loss import LogLoss 2 | 3 | 4 | __all__ = ['LogLoss'] 5 | 6 | -------------------------------------------------------------------------------- /msp/models/losses/mse_loss.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | import math 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | 8 | from ..builder import LOSSES 9 | 10 | 11 | @LOSSES.register_module() 12 | class LogLoss(nn.Module): 13 | 14 | def __init__(self): 15 | super(LogLoss, self).__init__() 16 | self.accident = True 17 | 18 | def forward(self, frame, labels): 19 | if self.accident: 20 | loss = (- math.exp(-max(0, 76 - frame)) * torch.log(labels)) 21 | else: 22 | loss = - torch.log(1 - labels) 23 | 24 | return loss 25 | 26 | -------------------------------------------------------------------------------- /msp/models/nearfuture/__init__.py: -------------------------------------------------------------------------------- 1 | from .near_future import NearFuture 2 | 3 | 4 | __all__ = ['NearFuture'] 5 | 6 | -------------------------------------------------------------------------------- /msp/models/nearfuture/near_future.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ..builder import NEARFUTUREMODEL 5 | 6 | 7 | @NEARFUTUREMODEL.register_module() 8 | class NearFuture(torch.nn.Module): 9 | def __init__(self) -> None: 10 | super().__init__() 11 | 12 | def forward(self, 13 | location_all, 14 | feature_all, 15 | t, 16 | fill_position_index): 17 | if t==0: 18 | return location_all[:, t], feature_all[:, t] 19 | else: 20 | # process for location prediction 21 | if min(fill_position_index[0].shape) != 0: 22 | if t <= 5: 23 | past_movement_bank = location_all[fill_position_index[0], :t, fill_position_index[1], :].data 24 | else: 25 | past_movement_bank = location_all[fill_position_index[0], (t-5):t, fill_position_index[1], :].data 26 | 27 | real_timestamp = past_movement_bank.shape[1] 28 | delta_movement = (past_movement_bank[:,-1] - past_movement_bank[:, 0])/(real_timestamp -1) 29 | pred_location = past_movement_bank[:, -1, 1:5] - delta_movement[:, 1:5] 30 | location_all[fill_position_index[0], t, fill_position_index[1], 1:5].data = pred_location 31 | 32 | # fill the feature 33 | past_feature = feature_all[fill_position_index[0], t-1, fill_position_index[1], :].data 34 | feature_all[fill_position_index[0], t, fill_position_index[1], :].data = past_feature 35 | return location_all[:, t], feature_all[:, t] 36 | 37 | else: 38 | return location_all[:, t], feature_all[:, t] 39 | 40 | def _reset_parameters(self): 41 | for p in self.parameters(): 42 | if p.dim() > 1: 43 | nn.init.xavier_uniform_(p) -------------------------------------------------------------------------------- /msp/models/predictors/__init__.py: -------------------------------------------------------------------------------- 1 | from .all_batch_version import STAttnsGraphSimAllBatch 2 | 3 | __all__ = ['STAttnsGraphSimAllBatch'] -------------------------------------------------------------------------------- /msp/models/predictors/all_batch_version.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from mmcv.utils.config import ConfigDict 4 | from ..utils import * 5 | from ..builder import PREDICTORS, build_spatial, build_temporal, build_attn, \ 6 | build_accident_block, build_gate, build_loss, build_nearfuture 7 | 8 | 9 | @PREDICTORS.register_module() 10 | class STAttnsGraphSimAllBatch(nn.Module): 11 | 12 | def __init__(self, 13 | spatial_graph: ConfigDict, 14 | temporal_graph:ConfigDict, 15 | gate: ConfigDict, 16 | temporal_attn: ConfigDict, 17 | near_future: ConfigDict, 18 | accident_block: ConfigDict, 19 | loss: ConfigDict): 20 | 21 | super(STAttnsGraphSimAllBatch, self).__init__() 22 | 23 | self.spatial_graph = build_spatial(spatial_graph) 24 | self.temporal_graph = build_temporal(temporal_graph) 25 | self.mask_gate = build_gate(gate) 26 | self.near_future = build_nearfuture(near_future) 27 | self.temporal_attn = build_attn(temporal_attn) 28 | self.accident_block = build_accident_block(accident_block) 29 | self.loss_func = build_loss(loss) 30 | # self.phi_x = nn.Sequential( 31 | # nn.Linear(2048, 1024), 32 | # nn.LeakyReLU(0.1) 33 | # ) 34 | 35 | def forward(self, data, module="Train"): 36 | location = data['location'] # batch x frames x 19 x 6 37 | maskers = data["maskers"] # batch x frames x 19 38 | missing_status = data["missing_status"] 39 | feature = data['feature'][:,:,:,1:] 40 | batch_size = location.shape[0] 41 | h0 = torch.randn(2, batch_size, self.accident_block.hidden_feature).type_as(location) 42 | c0 = torch.randn(2, batch_size, self.accident_block.hidden_feature).type_as(location) 43 | location_past_index = torch.stack([location[:, 0] for x in range(5)], dim=1) 44 | location_past = torch.cat([location_past_index, location], dim=1) # batch_size x (frames+5) x 19 x 6 45 | 46 | 47 | if module == "Train": 48 | graph_index = data["graph_index"] 49 | graph_weight = data["graph_weight"] 50 | train_dict = { 51 | "feature_all":feature, 52 | "graph_index":graph_index, 53 | "graph_weight":graph_weight} 54 | spatial_feature_bank = self.spatial_graph(train_dict, module) 55 | 56 | train_dict_t = { 57 | "spatial_feature": spatial_feature_bank 58 | } 59 | temporal_feature_bank = self.temporal_graph(train_dict_t, module) 60 | 61 | select_feature, label = select_gate_feature(temporal_feature_bank, missing_status) 62 | train_gate_dict = { 63 | "select_feature":select_feature} 64 | pred_status = self.mask_gate(train_gate_dict, module) 65 | pred_label = torch.hstack((pred_status, label.unsqueeze(-1))) 66 | # weighted feature 67 | attn_feature = self.temporal_attn(temporal_feature_bank, module) 68 | # LSTM 69 | pred_accident = self.accident_block(attn_feature, h0, c0).transpose(1,0) 70 | 71 | return pred_accident, pred_label 72 | 73 | elif module == "Test": 74 | spatial_feature_bank = [] 75 | temporal_feature_bank = [] 76 | temporal_feature_single_bank = [] 77 | masks = [] 78 | masker_dict = { 79 | 'index':[torch.empty((0,)).type_as(location), torch.empty((0,)).type_as(location)], 80 | 'pred_label':torch.empty((0, 3)).type_as(location), 81 | 'timestamp': [] # add timestamps 82 | } 83 | 84 | for t in range(location.size(1)):# pass throught time 85 | 86 | location_t = location[:, t] 87 | past_location_t = location_past[:, t:(t+5)] 88 | # missing detect and fill the position 89 | mask_t = maskers[:, t] 90 | masker_param_dict={ 91 | "all_mask": masks, 92 | "present_mask": mask_t, 93 | "feature_bank": temporal_feature_single_bank, 94 | "time":t, 95 | "location_all":location, 96 | "masker_dict":masker_dict} 97 | 98 | mask_t_filled, masker_dict, fill_position_index = self.mask_gate(masker_param_dict, module) 99 | 100 | masks.append(mask_t_filled) 101 | 102 | location_t_filled, feature_t_filled = self.near_future(location, 103 | feature, 104 | t, 105 | fill_position_index) 106 | 107 | location[:, t].data = location_t_filled 108 | feature[:, t].data = feature_t_filled 109 | 110 | # GCN encoder 111 | test_param_dict = { 112 | "location_t":location_t_filled, 113 | "feature_t":feature_t_filled, 114 | "mask_t":mask_t_filled, 115 | "past_location_t":past_location_t 116 | } 117 | 118 | enc = self.spatial_graph(module="Test", param_dict=test_param_dict) 119 | 120 | # save the spatial_bank 121 | spatial_feature_bank.append(enc) 122 | # temporal graph 123 | test_dict_t = { 124 | "spatial_feature":spatial_feature_bank, 125 | "mask":mask_t_filled} 126 | temporal_feature_t = self.temporal_graph(test_dict_t, module) 127 | 128 | temporal_feature_single_bank.append(temporal_feature_t) 129 | attn_feature_t = self.temporal_attn(temporal_feature_t, module) 130 | 131 | temporal_feature_bank.append(attn_feature_t) 132 | 133 | temporal_feature_bank = torch.stack(temporal_feature_bank, dim=0) 134 | 135 | # LSTM 136 | 137 | pred = self.accident_block(temporal_feature_bank, h0, c0).transpose(1,0) 138 | 139 | return pred, masker_dict['pred_label'] 140 | 141 | def loss(self, pred, masker_pred, data): 142 | 143 | accident = data["accident"] 144 | pred_loss = caculate_softmax_e_loss(pred, accident) 145 | mask_gate_loss = caculate_masker_gate_loss(masker_pred) 146 | 147 | return pred_loss, mask_gate_loss 148 | 149 | def predict(self, data, module): 150 | with torch.no_grad(): 151 | pred, pred_gate = self(data, module) 152 | pred_np = pred.cpu().numpy() 153 | pred_gate_np = pred_gate.cpu().numpy() 154 | 155 | return pred_np 156 | 157 | -------------------------------------------------------------------------------- /msp/models/spatialgraphs/__init__.py: -------------------------------------------------------------------------------- 1 | from .spatial_graph_batch import SpatialGraphBatch 2 | 3 | 4 | __all__ = ['SpatialGraphBatch'] -------------------------------------------------------------------------------- /msp/models/spatialgraphs/spatial_graph_batch.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from torch_geometric.data import Data 3 | from torch_geometric.loader import DataLoader 4 | from torch_geometric.nn import GCNConv 5 | from torch_geometric.data import Batch, batch 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | from collections import deque 10 | import itertools 11 | import math 12 | import networkx 13 | 14 | 15 | from ..builder import SGMODEL 16 | 17 | 18 | @SGMODEL.register_module() 19 | class SpatialGraphBatch(nn.Module): 20 | 21 | def __init__(self, 22 | node_feature: int, 23 | hidden_feature: int, 24 | out_feature: int, 25 | ): 26 | 27 | super(SpatialGraphBatch, self).__init__() 28 | 29 | self.g_conv1 = GCNConv(node_feature, hidden_feature) 30 | self.g_conv2 = GCNConv(hidden_feature, out_feature) 31 | 32 | def forward(self, 33 | param_dict, 34 | module): 35 | 36 | if module == "Train": 37 | enc = self.gate_training(param_dict["feature_all"], 38 | param_dict["graph_index"], 39 | param_dict["graph_weight"]) 40 | 41 | elif module == "Test": 42 | enc = self.module_testing(param_dict["location_t"], 43 | param_dict["feature_t"], 44 | param_dict["mask_t"], 45 | param_dict["past_location_t"]) 46 | return enc 47 | 48 | def gate_training(self, 49 | feature_all, 50 | graph_index, 51 | graph_weight): 52 | 53 | 54 | feature_flatten = torch.flatten(feature_all, start_dim=0, end_dim=1) 55 | graph_weight_flatten = torch.flatten(graph_weight, start_dim=0, end_dim=1) 56 | single_graph = graph_index[0] 57 | train_loader = self.generate_gcn_batch(feature_flatten, single_graph, graph_weight_flatten) 58 | for batch in train_loader: 59 | all_spatial_feature = F.sigmoid(self.g_conv1(batch.x, batch.edge_index, batch.edge_attr)) 60 | env_feature = F.sigmoid(self.g_conv2(all_spatial_feature, batch.edge_index, batch.edge_attr)) 61 | env_feature_re = env_feature.reshape(feature_all.shape[0], feature_all.shape[1], 19, 512) 62 | 63 | return env_feature_re 64 | 65 | 66 | def generate_gcn_batch(self, 67 | x, 68 | edge_index, 69 | edge_weight): 70 | data_list = [] 71 | 72 | for i in range(x.shape[0]): 73 | 74 | temp_data = Data(x=x[i], edge_index=edge_index, edge_attr=edge_weight[i]) 75 | data_list.append(temp_data) 76 | 77 | loader = DataLoader(data_list, batch_size=x.shape[0], shuffle=False) 78 | 79 | return loader 80 | 81 | def module_testing(self, 82 | location_t, 83 | feature_t, 84 | mask, 85 | past_location_t): 86 | 87 | num_boxes = location_t.shape[1] 88 | graph_edges = self.generate_graph_from_list(range(num_boxes)) 89 | graph_weight = self.generate_weight(location_t, graph_edges, mask, past_location_t) 90 | 91 | graph_edges = torch.tensor(graph_edges).type_as(location_t).t().contiguous().type(torch.long) 92 | graph_edges_b = torch.stack([graph_edges for x in range(location_t.shape[0])], dim=0) 93 | graph_weight = graph_weight.type_as(location_t) 94 | 95 | out_batch = [] 96 | for i in range(graph_edges_b.size(0)): 97 | hidden_enc_feature = F.sigmoid(self.g_conv1(feature_t[i], graph_edges_b[i], graph_weight[i])) 98 | enc_feature = F.sigmoid(self.g_conv2(hidden_enc_feature, graph_edges_b[i], graph_weight[i])) 99 | out_batch.append(enc_feature) 100 | 101 | out_batch = torch.stack(out_batch) 102 | 103 | return out_batch 104 | 105 | 106 | def generate_weight(self, location_t, graph_edges, mask, past_location_t): 107 | 108 | weights = torch.zeros((location_t.shape[0], len(graph_edges), ), dtype=torch.float32).type_as(location_t) 109 | 110 | for i, edge in enumerate(graph_edges): 111 | 112 | c1 = [0.5 * (location_t[:, edge[0], 1] + location_t[:, edge[0], 3]), 113 | 0.5 * (location_t[:, edge[0], 2] + location_t[:, edge[0], 4])] 114 | c2 = [0.5 * (location_t[:, edge[1], 1] + location_t[:, edge[1], 3]), 115 | 0.5 * (location_t[:, edge[1], 2] + location_t[:, edge[1], 4])] 116 | d = (c1[0] - c2[0])**2 + (c1[1] - c2[1])**2 117 | # md 118 | md = self.get_motion_distance(edge[0], edge[1], past_location_t) 119 | weights[:, i] = torch.exp(-(0.7*d + 0.3*md)) 120 | # mask empty position 121 | batch_index = ((mask[:,edge[0]]==True) & (mask[:,edge[1]]==True)) 122 | weights[batch_index==False, i] = 0 123 | 124 | weights_nor = (weights - torch.min(weights, dim=1, keepdim=True)[0]) / (torch.max(weights, dim=1, keepdim=True)[0] - torch.min(weights, dim=1, keepdim=True)[0]) 125 | 126 | return weights_nor 127 | 128 | def get_motion_distance(self, id_0, id_1, past_location): 129 | 130 | center_point = past_location[:,:,:,3:-1] - past_location[:,:,:,1:3] 131 | center_point_n_x = (center_point[:,:,:,0] * 1280).unsqueeze(dim=-1) 132 | center_point_n_y = (center_point[:,:,:,1] * 720).unsqueeze(dim=-1) 133 | center_point_n = torch.cat((center_point_n_x, center_point_n_y), dim=-1) 134 | 135 | delta_0 = (center_point_n[:, -1, id_0, 0] - center_point_n[:, 0, id_0, 0], 136 | center_point_n[:, -1, id_0, 1] - center_point_n[:, 0, id_0, 1]) 137 | delta_1 = (center_point_n[:, -1, id_1, 0] - center_point_n[:, 0, id_1, 0], 138 | center_point_n[:, -1, id_1, 1] - center_point_n[:, 0, id_1, 1]) 139 | 140 | length_0 = (delta_0[0]**2 + delta_0[1]**2)**0.5 141 | length_1 = (delta_1[0]**2 + delta_1[1]**2)**0.5 142 | 143 | with torch.no_grad(): 144 | length_0[length_0==0.] = 1.0 145 | length_1[length_1==0.] = 1.0 146 | 147 | nor_delta_x_0 = delta_0[0] / length_0 148 | nor_delta_y_0 = delta_0[1] / length_0 149 | 150 | nor_delta_x_1 = delta_1[0] / length_1 151 | nor_delta_y_1 = delta_1[1] / length_1 152 | 153 | d1 = ((center_point_n[:, 0, id_0, 0] + nor_delta_x_0 - center_point_n[:, 0, id_1, 0] - nor_delta_x_1)**2 + \ 154 | (center_point_n[:, 0, id_0, 1] + nor_delta_y_0 - center_point_n[:, 0, id_1, 1] - nor_delta_y_1)**2)**0.5 155 | 156 | d2 = ((center_point_n[:, 0, id_0, 0] - center_point_n[:, 0, id_1, 0])**2 + \ 157 | (center_point_n[:, 0, id_0, 1] - center_point_n[:, 0, id_1, 1])**2)**0.5 158 | 159 | # motion distance 160 | md = d1 - d2 161 | 162 | return md 163 | 164 | 165 | def generate_graph_from_list(self, L, create_using=None): 166 | G = networkx.empty_graph(len(L), create_using) 167 | if len(L) > 1: 168 | if G.is_directed(): 169 | edges = itertools.permutations(L,2) 170 | else: 171 | edges = itertools.combinations(L,2) 172 | 173 | G.add_edges_from(edges) 174 | 175 | graph_edges = list(G.edges()) 176 | 177 | return graph_edges 178 | -------------------------------------------------------------------------------- /msp/models/temporalattns/__init__.py: -------------------------------------------------------------------------------- 1 | from .s_t_attn import STAttn 2 | 3 | 4 | __all__ = ['STAttn'] -------------------------------------------------------------------------------- /msp/models/temporalattns/s_t_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | from ..builder import ATTNMODEL 7 | 8 | 9 | @ATTNMODEL.register_module() 10 | class STAttn(nn.Module): 11 | 12 | def __init__(self): 13 | 14 | super(STAttn, self).__init__() 15 | 16 | self.ue = nn.Linear(512, 64) 17 | self.be = nn.Parameter(torch.zeros(64)) 18 | self.w = nn.Linear(64, 1) 19 | 20 | self.fc1 = nn.Linear(512, 256) 21 | 22 | self.softmax = nn.Softmax(dim=-2) 23 | 24 | def forward(self, 25 | inputs, 26 | module): 27 | 28 | if module=="Train": 29 | fc_attr = self.attn_train(inputs) 30 | elif module=="Test": 31 | fc_attr = self.attn_test(inputs) 32 | 33 | return fc_attr 34 | 35 | def attn_train(self, inputs): 36 | 37 | inputs_flatten = torch.flatten(inputs, start_dim=0, end_dim=1) 38 | # weight the feature 39 | e_j = self.w(F.leaky_relu((self.ue(inputs_flatten) + self.be), negative_slope=0.2)) 40 | a_j = self.softmax(e_j) 41 | attr = torch.mul(a_j, inputs_flatten).sum(dim=1) 42 | fc_attr = self.fc1(attr) 43 | fc_attr_re = fc_attr.reshape(inputs.shape[0], inputs.shape[1], 256) 44 | 45 | fc_attr_re_lstm = fc_attr_re.permute(1,0,2) 46 | 47 | return fc_attr_re_lstm 48 | 49 | def attn_test(self, inputs): 50 | 51 | # mean the spatial_feature 52 | e_j = self.w(F.leaky_relu((self.ue(inputs) + self.be), negative_slope=0.2)) 53 | a_j = self.softmax(e_j) 54 | attr = torch.mul(a_j, inputs).sum(dim=1) 55 | # counting down the dementions 56 | fc_attr = self.fc1(attr) 57 | 58 | return fc_attr 59 | 60 | def _reset_parameters(self): 61 | for p in self.parameters(): 62 | if p.dim() > 1: 63 | nn.init.xavier_uniform_(p) 64 | -------------------------------------------------------------------------------- /msp/models/temporalgraphs/__init__.py: -------------------------------------------------------------------------------- 1 | from .temporal_graph_batch import TemporalGraphBatch 2 | 3 | __all__ = ['TemporalGraphBatch'] -------------------------------------------------------------------------------- /msp/models/temporalgraphs/temporal_graph_batch.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.nn import GCNConv 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | 6 | 7 | from ..builder import TGMODEL 8 | 9 | 10 | @TGMODEL.register_module() 11 | class TemporalGraphBatch(nn.Module): 12 | def __init__(self, 13 | past_num: int, 14 | input_feature: int, 15 | out_feature: int): 16 | 17 | super(TemporalGraphBatch, self).__init__() 18 | 19 | self.past_num = past_num 20 | self.out_feature = out_feature 21 | self.in_feature = input_feature 22 | 23 | edge_index, edge_weight = self.generate_graph(self.past_num) 24 | 25 | self.register_buffer("t_edge_index", edge_index) 26 | self.register_buffer("t_edge_weight", edge_weight) 27 | 28 | self.conv1 = GCNConv(input_feature, out_feature) 29 | 30 | def forward(self, 31 | param_dict, 32 | module): 33 | if module == "Train": 34 | enc = self.ts_train(param_dict["spatial_feature"]) 35 | elif module == "Test": 36 | enc = self.ts_test(param_dict["spatial_feature"], 37 | param_dict["mask"]) 38 | 39 | return enc 40 | 41 | def ts_train(self, 42 | spatial_feature_bank 43 | ): 44 | batch_size = spatial_feature_bank.shape[0] 45 | filled = spatial_feature_bank[:,0,:,:] 46 | pred_filled = torch.stack([filled for i in range(self.past_num-1)], dim=1) 47 | s_bank_fill = torch.cat((spatial_feature_bank, pred_filled), dim=1) 48 | s_bank_per = s_bank_fill.permute(0,2,1,3) 49 | s_bank_re = torch.flatten(s_bank_per, start_dim=0, end_dim=1) 50 | througth_time = torch.empty(size=[batch_size * 19, 0, self.past_num, self.in_feature]).type_as(spatial_feature_bank) 51 | 52 | for time_step in range(self.past_num-1,100+self.past_num-1): 53 | singe_step = s_bank_re[:,time_step-self.past_num+1:time_step+1,:].unsqueeze(dim=1) 54 | througth_time = torch.cat((througth_time, singe_step), dim=1) 55 | 56 | t_t_re = torch.flatten(througth_time, start_dim=0, end_dim=1) 57 | enc = self.conv1(t_t_re, self.t_edge_index, self.t_edge_weight) 58 | shape_1 = enc.reshape(batch_size*19, 100, self.past_num, self.out_feature)[:,:,-1,:] 59 | shape_2 = shape_1.reshape(batch_size, 19, 100, self.out_feature) 60 | final = shape_2.permute(0,2,1,3) 61 | 62 | return final 63 | 64 | def ts_test(self, 65 | spatial_feature_bank, 66 | mask): 67 | 68 | spatial_feature_bank_tensor = torch.stack(spatial_feature_bank, dim=0).permute(1,0,2,3) 69 | temporal_feature_t = torch.zeros((mask.shape[0], 19, 512,), dtype=torch.float32).type_as(spatial_feature_bank_tensor) 70 | 71 | if spatial_feature_bank_tensor.shape[1] < self.past_num: 72 | missing_num = self.past_num - spatial_feature_bank_tensor.shape[1] 73 | unpack = torch.stack([spatial_feature_bank_tensor[:,0,:,:] for x in range(missing_num)], dim=1) 74 | spatial_feature_bank_pack = torch.cat([unpack, spatial_feature_bank_tensor], dim=1) 75 | else: 76 | spatial_feature_bank_pack = spatial_feature_bank_tensor[:,-5:] 77 | 78 | sf_pack = spatial_feature_bank_pack.permute(0,2,1,3) 79 | 80 | sf_pack_f = torch.flatten(sf_pack, start_dim=0, end_dim=1) 81 | t_f_t_all = self.conv1(sf_pack_f, self.t_edge_index, self.t_edge_weight) 82 | t_f_t_all_re = t_f_t_all.reshape(mask.shape[0],19,5,512).permute(0,2,1,3)[:,-1,:,:] 83 | temporal_feature_t[torch.where(mask==True)] = t_f_t_all_re[torch.where(mask==True)] 84 | 85 | return temporal_feature_t 86 | 87 | def generate_graph(self, 88 | past_num: int): 89 | 90 | edge_index = [] 91 | edge_weight = [] 92 | 93 | for i in range(past_num - 1): 94 | edge_index.append([i, past_num - 1]) 95 | edge_weight.append(math.exp(-(past_num - 1 -i))) 96 | return torch.tensor(edge_index, dtype=torch.long).t().contiguous(), torch.tensor(edge_weight, dtype=torch.float32) -------------------------------------------------------------------------------- /msp/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import get_non_zero, generate_label_matrix, caculate_loss, caculate_softmax_e_loss, caculate_masker_gate_loss, select_gate_feature 2 | 3 | __all__ = ['get_non_zero', 'generate_label_matrix', 'caculate_loss', 'caculate_softmax_e_loss', 'caculate_masker_gate_loss','select_gate_feature'] -------------------------------------------------------------------------------- /msp/models/utils/utils.py: -------------------------------------------------------------------------------- 1 | from networkx.algorithms.shortest_paths.weighted import negative_edge_cycle 2 | from networkx.utils import decorators 3 | from networkx.utils.misc import flatten 4 | import numpy as np 5 | import torch 6 | import torch.nn 7 | import math 8 | import torch.nn.functional as F 9 | 10 | def get_non_zero(id_with_zero): 11 | mask = (id_with_zero != 0) 12 | id_without_zero = id_with_zero[mask] 13 | 14 | return id_without_zero 15 | 16 | def generate_label_matrix(accident): 17 | label_matrix = torch.zeros((accident.shape[0], 40), device=torch.cuda.current_device()) 18 | pos = torch.tensor([-math.exp(-max(0, (36-t) / 25)) for t in range(40)], device=torch.cuda.current_device()) 19 | neg = torch.tensor([-1. for t in range(40)], device=torch.cuda.current_device()) 20 | label_matrix[accident==True] = pos 21 | label_matrix[accident==False] = neg 22 | 23 | return label_matrix 24 | 25 | def caculate_loss(pred, labels): 26 | 27 | weight_loss = torch.mul(torch.log(pred), labels).sum(dim=1).mean() 28 | 29 | return weight_loss 30 | 31 | def caculate_softmax_e_loss(pred, accident): 32 | 33 | crterion = torch.nn.CrossEntropyLoss(reduction='none') 34 | #positive sample 35 | frames = pred.shape[1] 36 | pos = pred[accident==True] 37 | pos_target = torch.tensor([1. for x in range(frames)], dtype=torch.long, device=torch.cuda.current_device()) 38 | pos_penalty = torch.tensor([math.exp(-max(0, (76-t)/ 25)) for t in range(frames)], device=torch.cuda.current_device(), dtype=torch.long) 39 | all_positive_loss = [] 40 | for batch in range(pos.shape[0]): 41 | loss = -crterion(pos[batch], pos_target) 42 | positive_loss = -torch.mul(loss, pos_penalty).sum() 43 | all_positive_loss.append(positive_loss) 44 | if len(all_positive_loss)!=0: 45 | all_positive_loss_t = torch.stack(all_positive_loss).mean() 46 | else: 47 | all_positive_loss_t = torch.tensor(0., dtype=torch.float32).type_as(pred) 48 | 49 | #neg 50 | neg = pred[accident==False] 51 | neg_target = torch.tensor([0. for x in range(frames)], dtype=torch.long, device=torch.cuda.current_device()) 52 | all_neg_loss = [] 53 | 54 | for batch in range(neg.shape[0]): 55 | loss_neg = crterion(neg[batch], neg_target).sum() 56 | all_neg_loss.append(loss_neg) 57 | if len(all_neg_loss) != 0: 58 | all_neg_loss_t = torch.stack(all_neg_loss).mean() 59 | else: 60 | all_neg_loss_t = torch.tensor(0., dtype=torch.float32).type_as(pred) 61 | 62 | pred_loss = all_positive_loss_t + 0.2 * all_neg_loss_t 63 | return pred_loss 64 | 65 | def caculate_masker_gate_loss(masker_pred): 66 | 67 | pred_results = masker_pred[:, :2] 68 | label = masker_pred[:, -1] 69 | 70 | crterion = torch.nn.CrossEntropyLoss(reduction='none') 71 | go_away_pred = pred_results 72 | go_away_label = label 73 | go_away_loss = crterion(go_away_pred, 1 - go_away_label.long()).sum() 74 | 75 | shelter_pred = pred_results 76 | shelter_label = label 77 | shelter_loss = crterion(shelter_pred, shelter_label.long()).sum() 78 | 79 | loss_masker = 0.02 * go_away_loss + 0.02 * shelter_loss 80 | 81 | return loss_masker 82 | 83 | 84 | def select_gate_feature(feature_bank, missing_status): 85 | 86 | go_away_index = torch.where(missing_status==1.0) 87 | go_away_feature = feature_bank[go_away_index] 88 | go_away_label = torch.zeros((go_away_index[0].shape[0], )).type_as(feature_bank) 89 | # 处理遮挡状态的物体,给1的标签 90 | shelter_index = torch.where(missing_status==2.0) 91 | shelter_feature = feature_bank[shelter_index] 92 | shelter_label = torch.ones((shelter_index[0].shape[0], )).type_as(feature_bank) 93 | # 合并特征和标签 94 | missing_feature = torch.cat((go_away_feature, shelter_feature), dim=0) 95 | labels = torch.cat((go_away_label, shelter_label), dim=0) 96 | 97 | return missing_feature, labels -------------------------------------------------------------------------------- /msp/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .train_utils import setup_seed, partial_state_dict, summarize_metric 2 | 3 | 4 | __all__ = ['setup_seed', 'partial_state_dict', 'summarize_metric'] 5 | -------------------------------------------------------------------------------- /msp/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def setup_seed(seed: int): 8 | torch.manual_seed(seed) 9 | torch.cuda.manual_seed(seed) 10 | torch.cuda.manual_seed_all(seed) 11 | np.random.seed(seed) 12 | random.seed(seed) 13 | torch.backends.cudnn.deterministic = True 14 | 15 | 16 | def partial_state_dict(model: torch.nn.Module, ckpt_path: str): 17 | pre_ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['state_dict'] 18 | model_ckpt = model.state_dict() 19 | pre_ckpt = {k: v for k, v in pre_ckpt.items() if k in model_ckpt} 20 | model_ckpt.update(pre_ckpt) 21 | 22 | return model_ckpt 23 | 24 | def summarize_metric(output): 25 | # the output is list 26 | average_ap = [] 27 | average_mtta = [] 28 | average_tta_r80 = [] 29 | 30 | for item in range(len(output)): 31 | average_ap.append(output[item]['ap']) 32 | average_mtta.append(output[item]['mtta']) 33 | average_tta_r80.append(output[item]['tta_r80']) 34 | 35 | return np.array(average_ap).mean(), np.array(average_mtta).mean(), np.array(average_tta_r80).mean() 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.14.1 2 | addict==2.4.0 3 | cachetools==4.2.4 4 | certifi==2020.12.5 5 | charset-normalizer==2.0.7 6 | cycler==0.10.0 7 | decorator==4.4.2 8 | future==0.18.2 9 | google-auth==1.35.0 10 | google-auth-oauthlib==0.4.6 11 | googledrivedownloader==0.4 12 | grpcio==1.41.0 13 | idna==3.3 14 | importlib-metadata==4.8.1 15 | isodate==0.6.0 16 | Jinja2==3.0.2 17 | joblib==1.1.0 18 | kiwisolver==1.3.1 19 | Markdown==3.3.4 20 | MarkupSafe==2.0.1 21 | matplotlib==3.3.4 22 | mmcv==1.3.0 23 | networkx==2.5.1 24 | numpy 25 | oauthlib==3.1.1 26 | olefile==0.46 27 | opencv-python==4.5.3.56 28 | packaging==21.0 29 | pandas==1.1.5 30 | Pillow 31 | protobuf==3.18.1 32 | pyasn1==0.4.8 33 | pyasn1-modules==0.2.8 34 | pyparsing==2.4.7 35 | python-dateutil==2.8.2 36 | pytorch-lightning==0.9.0 37 | pytz==2021.3 38 | PyYAML==6.0 39 | rdflib==5.0.0 40 | requests==2.26.0 41 | requests-oauthlib==1.3.0 42 | rsa==4.7.2 43 | scikit-learn==0.24.2 44 | scipy==1.5.4 45 | six==1.16.0 46 | tensorboard==2.2.0 47 | tensorboard-plugin-wit==1.8.0 48 | threadpoolctl==3.0.0 49 | torch==1.7.0 50 | torch-cluster==1.5.9 51 | torch-geometric==2.0.1 52 | torch-scatter==2.0.7 53 | torch-sparse==0.6.9 54 | torchaudio==0.7.0a0+ac17b64 55 | torchvision==0.8.0 56 | tqdm==4.62.3 57 | typing==3.7.4.3 58 | typing-extensions 59 | urllib3==1.26.7 60 | Werkzeug==2.0.2 61 | yacs==0.1.8 62 | yapf==0.31.0 63 | zipp==3.6.0 64 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | from mmcv import Config 6 | import pytorch_lightning as pl 7 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateLogger 8 | from pytorch_lightning.loggers import TensorBoardLogger 9 | from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler 10 | 11 | from msp import LightningModel 12 | from msp.utils import setup_seed, partial_state_dict 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description="Train or test a detector.") 17 | parser.add_argument("config", help="Train config file path.") 18 | parser.add_argument("--train", action="store_true") 19 | parser.add_argument("--test", action="store_true") 20 | 21 | args = parser.parse_args() 22 | return args 23 | 24 | 25 | def main(): 26 | args = parse_args() 27 | cfg = Config.fromfile(args.config) 28 | setup_seed(cfg.random_seed) 29 | 30 | model = LightningModel(cfg) 31 | 32 | checkpoint_callback = ModelCheckpoint( 33 | filepath=f"{cfg.checkpoint_path}/{cfg.name}/{cfg.version}/" 34 | f"{cfg.name}_{cfg.version}_{{epoch}}_{{avg_val_loss:.3f}}_{{ade:.3f}}_{{fde:.3f}}_{{fiou:.3f}}", 35 | save_last=None, 36 | save_top_k=-1, 37 | verbose=True, 38 | monitor='fiou', 39 | mode='max', 40 | prefix='' 41 | ) 42 | 43 | lr_logger_callback = LearningRateLogger(logging_interval='step') 44 | profiler = SimpleProfiler() if cfg.simple_profiler else AdvancedProfiler() 45 | logger = TensorBoardLogger(save_dir=cfg.log_path, name=cfg.name, version=cfg.version) 46 | 47 | trainer = pl.Trainer( 48 | gpus=cfg.num_gpus, 49 | #distributed_backend='dp', 50 | max_epochs=cfg.max_epochs, 51 | logger=logger, 52 | profiler=profiler, 53 | callbacks=[lr_logger_callback], 54 | #gradient_clip_val=cfg.gradient_clip_val,\ 55 | checkpoint_callback=checkpoint_callback, 56 | check_val_every_n_epoch=10, 57 | resume_from_checkpoint=cfg.resume_from_checkpoint, 58 | accumulate_grad_batches=cfg.batch_accumulate_size) # 由于每个batch内只有一个样本,所以采用这样的处理方法 59 | 60 | if (not (args.train or args.test)) or args.train: 61 | 62 | shutil.copy(args.config, os.path.join(cfg.log_path, cfg.name, cfg.version, args.config.split('/')[-1])) 63 | 64 | if cfg.load_from_checkpoint is not None: 65 | model_ckpt = partial_state_dict(model, cfg.load_from_checkpoint) 66 | model.load_state_dict(model_ckpt) 67 | 68 | trainer.fit(model) 69 | 70 | if args.test: 71 | if cfg.test_checkpoint is not None: 72 | model_ckpt = partial_state_dict(model, cfg.test_checkpoint) 73 | model.load_state_dict(model_ckpt) 74 | 75 | trainer.test(model) 76 | 77 | 78 | if __name__ == '__main__': 79 | main() 80 | -------------------------------------------------------------------------------- /scripts/preprecess_data_MASKER_MD.py: -------------------------------------------------------------------------------- 1 | from networkx.algorithms import mis 2 | from networkx.algorithms.distance_measures import center 3 | import numpy as np 4 | import os 5 | import random 6 | import pandas as pd 7 | import networkx 8 | import itertools 9 | random.seed(1234) 10 | 11 | data_root = '/media/group1/data/tianhang/sorted_samples_pn_2048/' 12 | train_root = '/media/group1/data/tianhang/MASKER_MD/' 13 | test_root = '/media/group1/data/tianhang/MASKER_MD/' 14 | 15 | global miss_wait 16 | miss_wait = 10 17 | 18 | def proprecess_data(data_root, train_root, test_root): 19 | 20 | if not os.path.exists(train_root): 21 | os.makedirs(train_root) 22 | if not os.path.exists(test_root): 23 | os.makedirs(test_root) 24 | 25 | all_file_list = os.listdir(data_root) 26 | random.shuffle(all_file_list) 27 | 28 | train_number = int(len(all_file_list) * 0.7) 29 | train_file_list = all_file_list[:train_number] 30 | val_file_list = all_file_list[train_number:] 31 | generate_corresponding_txt(train_file_list, train_root, module="train") 32 | generate_corresponding_txt(val_file_list, test_root, module="valid") 33 | 34 | print("Start to process training set!") 35 | graph_edges = generate_graph_from_list(range(19)) 36 | 37 | for number, i in enumerate(train_file_list): 38 | 39 | filled_before = np.load(os.path.join(data_root, i), allow_pickle=True).item() 40 | frames = filled_before["location"].shape[0] 41 | 42 | maskers = [] 43 | mask_dict = np.zeros((frames, 19), dtype=np.float32) 44 | 45 | filled_before_location = filled_before['location'] # frames x 19 x 6 46 | filled_before_feature = filled_before['feature'] # frames x 19 x 2048 47 | for t in range(frames): 48 | #print("time", t) 49 | #print("location_present",filled_before_location[t,:,1:5]) 50 | mask_t = filled_before_location[t,:,1:5].sum(-1) != 0. 51 | #print("mask_t", mask_t) 52 | maskers, mask_dict, filled_after_location, filled_after_feature = fill_during_training(mask_t, 53 | maskers, 54 | mask_dict, 55 | t, 56 | location_all=filled_before_location, 57 | feature_all=filled_before_feature) 58 | 59 | maskers_np = np.array(maskers) 60 | #print("maskers_np", maskers_np) 61 | graph_weight = generate_weight(filled_after_location, graph_edges, maskers_np) 62 | filled_before["graph_index"] = graph_edges 63 | filled_before["graph_weight"] = graph_weight 64 | filled_before["maskers"] = maskers_np 65 | filled_before["missing_status"] = mask_dict 66 | filled_before["location"] = filled_after_location 67 | filled_before["feature"] = filled_after_feature 68 | 69 | np.save(os.path.join(train_root, i), filled_before) 70 | print("The Training process has been {:.2f} % done!".format((number + 1) / len(train_file_list) * 100)) 71 | 72 | print("#" * 30) 73 | print("Start to process Testing set!") 74 | for val_number, j in enumerate(val_file_list): 75 | 76 | val_infor = np.load(os.path.join(data_root, j), allow_pickle=True).item() 77 | frames = val_infor["location"].shape[0] 78 | val_maskers = [] 79 | val_mask_dict = np.zeros((frames, 19), dtype=np.float32) 80 | val_location = val_infor["location"] # frames x 19 x 6 81 | for t in range(frames): 82 | 83 | val_mask_t = val_location[t,:,1:5].sum(-1) != 0. 84 | #val_maskers.append(val_mask_t) 85 | val_maskers, val_mask_dict = spect_during_testing(val_mask_t, val_maskers, val_mask_dict, t, val_location) 86 | 87 | val_maskers_np = np.array(val_maskers) 88 | val_infor["maskers"] = val_maskers_np 89 | val_infor["missing_status"] = val_mask_dict 90 | np.save(os.path.join(test_root, j), val_infor) 91 | print("The Testing process has been {:.2f} % done!".format((val_number + 1) / len(val_file_list) * 100)) 92 | 93 | def generate_weight(location, graph_edges, mask): 94 | # location: frames x 19 x 16 95 | # graph_weight frames x num_edge 96 | 97 | weights = np.zeros((location.shape[0], len(graph_edges), ), dtype=np.float32) 98 | location_n = location.copy() 99 | location_n[:, :, [1,3]] /= 1280 100 | location_n[:, :, [2,4]] /= 720 101 | for i, edge in enumerate(graph_edges): 102 | 103 | c1 = [0.5 * (location_n[:, edge[0], 1] + location_n[:, edge[0], 3]), 104 | 0.5 * (location_n[:, edge[0], 2] + location_n[:, edge[0], 4])] 105 | c2 = [0.5 * (location_n[:, edge[1], 1] + location_n[:, edge[1], 3]), 106 | 0.5 * (location_n[:, edge[1], 2] + location_n[:, edge[1], 4])] 107 | 108 | d = (c1[0] - c2[0])**2 + (c1[1] - c2[1])**2 109 | 110 | md = get_motion_distance(edge[0], edge[1], location) 111 | 112 | weights[:, i] = np.exp(-(0.7 * d + 0.3 * md)) 113 | frame_index = ((mask[:, edge[0]]==True) & (mask[:, edge[1]] == True)) 114 | weights[frame_index==False, i] = 0 115 | 116 | # normalize 117 | weights_nor = (weights - np.min(weights, axis=1, keepdims=True)) / (np.max(weights, axis=1, keepdims=True) - np.min(weights, axis=1, keepdims=True)) 118 | # account for NaN 119 | weights_nor = np.nan_to_num(weights_nor) 120 | 121 | return weights_nor 122 | 123 | def get_motion_distance(id_0, id_1, location_all): 124 | md = [] 125 | center_point = location_all[:,:,3:-1] - location_all[:,:,1:3] 126 | for t in range(100): 127 | if t <= 4: 128 | d = get_normolize_distance(id_0, id_1, 0, t, center_point) 129 | else: 130 | d = get_normolize_distance(id_0, id_1, t-4, t, center_point) 131 | md.append(d) 132 | 133 | md = np.array(md) 134 | return md 135 | 136 | def get_normolize_distance(id_0, id_1, t1, t2, center_point): 137 | 138 | v_id_0 = normalize(((center_point[t1, id_0, 0], center_point[t1, id_0, 1]),(center_point[t2, id_0, 0], center_point[t2, id_0, 1]))) 139 | v_id_1 = normalize(((center_point[t1, id_1, 0], center_point[t1, id_1, 1]),(center_point[t2, id_1, 0], center_point[t2, id_1, 1]))) 140 | 141 | # Eud_distance 142 | d_1 = ((v_id_0[1][0] - v_id_1[1][0])**2 + (v_id_0[1][1] - v_id_1[1][1])**2)**0.5 #箭头距离 143 | d_2 = ((v_id_0[0][0] - v_id_1[0][0])**2 + (v_id_0[0][1] - v_id_1[0][1])**2)**0.5 #箭尾距离 144 | 145 | md = d_1 - d_2 146 | 147 | 148 | return md 149 | 150 | def normalize(v): 151 | delta_x = v[1][0] - v[0][0] 152 | delta_y = v[1][1] - v[0][1] 153 | length = (delta_x**2 + delta_y**2)**0.5 154 | if length != 0: 155 | nor_delta_x = delta_x/length 156 | nor_delta_y = delta_y/length 157 | return ((v[0][0], v[0][1]),(v[0][0]+nor_delta_x, v[0][1]+nor_delta_y)) 158 | else: 159 | return v 160 | 161 | def generate_graph_from_list(L, create_using=None): 162 | G = networkx.empty_graph(len(L), create_using) 163 | if len(L) > 1: 164 | if G.is_directed(): 165 | edges = itertools.permutations(L,2) 166 | else: 167 | edges = itertools.combinations(L,2) 168 | G.add_edges_from(edges) 169 | graph_edges = list(G.edges()) 170 | 171 | return graph_edges 172 | 173 | 174 | def spect_during_testing(mask_present, 175 | all_mask, 176 | mask_dict, 177 | t, 178 | location_all): 179 | 180 | if len(all_mask) == 0: 181 | all_mask.append(mask_present) 182 | 183 | return all_mask, mask_dict 184 | 185 | else: 186 | before_mask = all_mask[-1] 187 | # compare from True to False 188 | index = np.where((before_mask == True) & (mask_present ==False)) 189 | if min(index[0].shape) != 0: 190 | for missing_object in index[0]: 191 | #print("missing_object", missing_object) 192 | near_future_infor = location_all[(t+1):(t+miss_wait), missing_object, 1:5] 193 | #print("near_future_infor", near_future_infor) 194 | near_future_status = np.where((near_future_infor.sum(-1) != 0) == True) 195 | if min(near_future_status[0].shape) != 0: 196 | mask_dict[t-1, missing_object] = 2. 197 | else: 198 | mask_dict[t-1, missing_object] = 1. 199 | 200 | all_mask.append(mask_present) 201 | return all_mask, mask_dict 202 | 203 | else: 204 | all_mask.append(mask_present) 205 | return all_mask, mask_dict 206 | 207 | def fill_during_training(mask_present, 208 | all_mask, 209 | mask_dict, 210 | t, 211 | location_all, 212 | feature_all 213 | ): 214 | # params for waiting time 215 | 216 | if len(all_mask) == 0: 217 | all_mask.append(mask_present) 218 | 219 | return all_mask, mask_dict, location_all, feature_all 220 | else: 221 | # mask_before 222 | before_mask= all_mask[-1] 223 | # compare from True to False 224 | index = np.where((before_mask==True) & (mask_present==False)) 225 | if min(index[0].shape) != 0: 226 | #print("missing index", index) 227 | for missing_object in index[0]: 228 | #print("Missing objects", missing_object) 229 | near_future_infor = location_all[(t+1):(t+miss_wait), missing_object,1:5] 230 | near_future_status = np.where((near_future_infor.sum(-1) != 0) == True) 231 | if min(near_future_status[0].shape) != 0: 232 | mask_dict[t-1, missing_object] = 2. 233 | nearest_future = near_future_status[0][0] 234 | location_all[t:(t+nearest_future+1),missing_object, 1:5] = fill_functions( 235 | location_all[(t-5 if t>5 else 0):t, missing_object,1:5], 236 | location_all[t+nearest_future+1, missing_object, 1:5], 237 | nearest_future+1) 238 | feature_all[t:(t+nearest_future+1), missing_object, 1:] = feature_all[t-1,missing_object, 1:] 239 | else: 240 | mask_dict[t-1, missing_object] = 1. 241 | 242 | mask_present = location_all[t,:,1:5].sum(-1) != 0. 243 | 244 | all_mask.append(mask_present) 245 | 246 | return all_mask, mask_dict, location_all, feature_all 247 | 248 | else: 249 | 250 | all_mask.append(mask_present) 251 | return all_mask, mask_dict, location_all, feature_all 252 | 253 | def fill_functions(past_movement, near_future, fill_number): 254 | fill_matrix = np.full((fill_number, 4), np.nan) 255 | interplot = np.vstack((past_movement, fill_matrix, near_future)) 256 | df_interplot = pd.DataFrame(interplot) 257 | s = df_interplot.interpolate() 258 | interplot = np.around(np.array(s)) 259 | fill_matrix = interplot[(-(fill_number+1)):-1] 260 | 261 | return fill_matrix 262 | 263 | def generate_corresponding_txt(file_list, root, module="train"): 264 | """ 265 | all valid module is train or valid 266 | """ 267 | str = '\n' 268 | f = open(os.path.join(root, module + "_video_list.txt"), "w") 269 | f.write(str.join(file_list)) 270 | f.close() 271 | 272 | proprecess_data(data_root, train_root, test_root) -------------------------------------------------------------------------------- /scripts/preprecess_data_MD-SG.py: -------------------------------------------------------------------------------- 1 | from networkx.algorithms import mis 2 | import numpy as np 3 | import os 4 | import random 5 | import pandas as pd 6 | import networkx 7 | import itertools 8 | random.seed(1234) 9 | 10 | 11 | data_root = '/media/group1/data/tianhang/sorted_samples_pn_2048/' 12 | train_root = '/media/group1/data/tianhang/MD/' 13 | test_root = '/media/group1/data/tianhang/MD/' 14 | 15 | 16 | def proprecess_data(data_root, train_root, test_root): 17 | if not os.path.exists(train_root): 18 | os.makedirs(train_root) 19 | if not os.path.exists(test_root): 20 | os.makedirs(test_root) 21 | 22 | all_file_list = os.listdir(data_root) 23 | random.shuffle(all_file_list) 24 | train_number = int(len(all_file_list) * 0.7) 25 | train_file_list = all_file_list[:train_number] 26 | val_file_list = all_file_list[train_number:] 27 | generate_corresponding_txt(train_file_list, train_root, module="train") 28 | generate_corresponding_txt(val_file_list, test_root, module="valid") 29 | print("Start to process training set!") 30 | graph_edges = generate_graph_from_list(range(19)) 31 | 32 | for number, i in enumerate(train_file_list): 33 | filled_before = np.load(os.path.join(data_root, i), allow_pickle=True).item() 34 | frames = filled_before["location"].shape[0] 35 | maskers = [] 36 | 37 | filled_before_location = filled_before['location'] 38 | for t in range(frames): 39 | mask_t = filled_before_location[t,:,1:5].sum(-1) != 0. 40 | maskers.append(mask_t) 41 | maskers_np = np.array(maskers) 42 | graph_weight = generate_weight(filled_before_location, graph_edges, maskers_np) 43 | filled_before["graph_index"] = graph_edges 44 | filled_before["graph_weight"] = graph_weight 45 | filled_before["maskers"] = maskers_np 46 | np.save(os.path.join(train_root, i), filled_before) 47 | print("The Training process has been {:.2f} % done!".format((number + 1) / len(train_file_list) * 100)) 48 | 49 | print("#" * 30) 50 | print("Start to process Testing set!") 51 | for val_number, j in enumerate(val_file_list): 52 | 53 | val_infor = np.load(os.path.join(data_root, j), allow_pickle=True).item() 54 | frames = val_infor["location"].shape[0] 55 | val_maskers = [] 56 | val_location = val_infor["location"] 57 | for t in range(frames): 58 | val_mask_t = val_location[t,:,1:5].sum(-1) != 0. 59 | val_maskers.append(val_mask_t) 60 | 61 | val_maskers_np = np.array(val_maskers) 62 | val_infor["maskers"] = val_maskers_np 63 | np.save(os.path.join(test_root, j), val_infor) 64 | print("The Testing process has been {:.2f} % done!".format((val_number + 1) / len(val_file_list) * 100)) 65 | 66 | def generate_weight(location, graph_edges, mask): 67 | # location: frames x 19 x 16 68 | # graph_weight frames x num_edge 69 | 70 | weights = np.zeros((location.shape[0], len(graph_edges), ), dtype=np.float32) 71 | # normolize 72 | location_n = location.copy() 73 | location_n[:, :, [1,3]] /= 1280 74 | location_n[:, :, [2,4]] /= 720 75 | 76 | for i, edge in enumerate(graph_edges): 77 | 78 | c1 = [0.5 * (location_n[:, edge[0], 1] + location_n[:, edge[0], 3]), 79 | 0.5 * (location_n[:, edge[0], 2] + location_n[:, edge[0], 4])] 80 | c2 = [0.5 * (location_n[:, edge[1], 1] + location_n[:, edge[1], 3]), 81 | 0.5 * (location_n[:, edge[1], 2] + location_n[:, edge[1], 4])] 82 | 83 | d = (c1[0] - c2[0])**2 + (c1[1] - c2[1])**2 84 | 85 | md = get_motion_distance(edge[0], edge[1], location) 86 | 87 | weights[:, i] = np.exp(-(0.7*d + 0.3*md)) 88 | frame_index = ((mask[:, edge[0]]==True) & (mask[:, edge[1]] == True)) 89 | weights[frame_index==False, i] = 0 90 | 91 | # normalize 92 | weights_nor = (weights - np.min(weights, axis=1, keepdims=True)) / (np.max(weights, axis=1, keepdims=True) - np.min(weights, axis=1, keepdims=True)) 93 | # account for NaN 94 | weights_nor = np.nan_to_num(weights_nor) 95 | 96 | return weights_nor 97 | 98 | def get_motion_distance(id_0, id_1, location_all): 99 | md = [] 100 | center_point = location_all[:,:,3:-1] - location_all[:,:,1:3] 101 | for t in range(100): 102 | if t <= 4: 103 | d = get_normolize_distance(id_0, id_1, 0, t, center_point) 104 | else: 105 | d = get_normolize_distance(id_0, id_1, t-4, t, center_point) 106 | md.append(d) 107 | 108 | md = np.array(md) 109 | return md 110 | 111 | def get_normolize_distance(id_0, id_1, t1, t2, center_point): 112 | 113 | v_id_0 = normalize(((center_point[t1, id_0, 0], center_point[t1, id_0, 1]),(center_point[t2, id_0, 0], center_point[t2, id_0, 1]))) 114 | v_id_1 = normalize(((center_point[t1, id_1, 0], center_point[t1, id_1, 1]),(center_point[t2, id_1, 0], center_point[t2, id_1, 1]))) 115 | 116 | # Eud_distance 117 | d_1 = ((v_id_0[1][0] - v_id_1[1][0])**2 + (v_id_0[1][1] - v_id_1[1][1])**2)**0.5 #箭头距离 118 | d_2 = ((v_id_0[0][0] - v_id_1[0][0])**2 + (v_id_0[0][1] - v_id_1[0][1])**2)**0.5 #箭尾距离 119 | 120 | md = d_1 - d_2 121 | 122 | return md 123 | 124 | def normalize(v): 125 | delta_x = v[1][0] - v[0][0] 126 | delta_y = v[1][1] - v[0][1] 127 | length = (delta_x**2 + delta_y**2)**0.5 128 | if length != 0: 129 | nor_delta_x = delta_x/length 130 | nor_delta_y = delta_y/length 131 | return ((v[0][0], v[0][1]),(v[0][0]+nor_delta_x, v[0][1]+nor_delta_y)) 132 | else: 133 | return v 134 | 135 | def generate_graph_from_list(L, create_using=None): 136 | G = networkx.empty_graph(len(L), create_using) 137 | if len(L) > 1: 138 | if G.is_directed(): 139 | edges = itertools.permutations(L,2) 140 | else: 141 | edges = itertools.combinations(L,2) 142 | G.add_edges_from(edges) 143 | graph_edges = list(G.edges()) 144 | 145 | return graph_edges 146 | 147 | def generate_corresponding_txt(file_list, root, module="train"): 148 | """ 149 | all valid module is train or valid 150 | """ 151 | str = '\n' 152 | f = open(os.path.join(root, module + "_video_list.txt"), "w") 153 | f.write(str.join(file_list)) 154 | f.close() 155 | 156 | proprecess_data(data_root, train_root, test_root) -------------------------------------------------------------------------------- /scripts/preprocess_data_MASKER.py: -------------------------------------------------------------------------------- 1 | from networkx.algorithms import mis 2 | from networkx.algorithms.distance_measures import center 3 | import numpy as np 4 | import os 5 | import random 6 | import pandas as pd 7 | import networkx 8 | import itertools 9 | random.seed(1234) 10 | 11 | 12 | data_root = '/media/group1/data/tianhang/sorted_samples_pn_2048/' 13 | train_root = '/media/group1/data/tianhang/MASKER/' 14 | test_root = '/media/group1/data/tianhang/MASKER/' 15 | 16 | global miss_wait 17 | miss_wait = 10 18 | 19 | def proprecess_data(data_root, train_root, test_root): 20 | if not os.path.exists(train_root): 21 | os.makedirs(train_root) 22 | if not os.path.exists(test_root): 23 | os.makedirs(test_root) 24 | 25 | all_file_list = os.listdir(data_root) 26 | random.shuffle(all_file_list) 27 | train_number = int(len(all_file_list) * 0.7) 28 | train_file_list = all_file_list[:train_number] 29 | val_file_list = all_file_list[train_number:] 30 | generate_corresponding_txt(train_file_list, train_root, module="train") 31 | generate_corresponding_txt(val_file_list, test_root, module="valid") 32 | print("Start to process training set!") 33 | graph_edges = generate_graph_from_list(range(19)) 34 | 35 | for number, i in enumerate(train_file_list): 36 | 37 | filled_before = np.load(os.path.join(data_root, i), allow_pickle=True).item() 38 | frames = filled_before["location"].shape[0] 39 | maskers = [] 40 | mask_dict = np.zeros((frames, 19), dtype=np.float32) 41 | 42 | filled_before_location = filled_before['location'] 43 | filled_before_feature = filled_before['feature'] 44 | for t in range(frames): 45 | mask_t = filled_before_location[t,:,1:5].sum(-1) != 0. 46 | 47 | maskers, mask_dict, filled_after_location, filled_after_feature = fill_during_training(mask_t, 48 | maskers, 49 | mask_dict, 50 | t, 51 | location_all=filled_before_location, 52 | feature_all=filled_before_feature) 53 | maskers_np = np.array(maskers) 54 | 55 | graph_weight = generate_weight(filled_after_location, graph_edges, maskers_np) 56 | filled_before["graph_index"] = graph_edges 57 | filled_before["graph_weight"] = graph_weight 58 | filled_before["maskers"] = maskers_np 59 | filled_before["missing_status"] = mask_dict 60 | filled_before["location"] = filled_after_location 61 | filled_before["feature"] = filled_after_feature 62 | 63 | np.save(os.path.join(train_root, i), filled_before) 64 | print("The Training process has been {:.2f} % done!".format((number + 1) / len(train_file_list) * 100)) 65 | 66 | print("#" * 30) 67 | print("Start to process Testing set!") 68 | for val_number, j in enumerate(val_file_list): 69 | 70 | val_infor = np.load(os.path.join(data_root, j), allow_pickle=True).item() 71 | frames = val_infor["location"].shape[0] 72 | val_maskers = [] 73 | val_mask_dict = np.zeros((frames, 19), dtype=np.float32) 74 | val_location = val_infor["location"] # frames x 19 x 6 75 | for t in range(frames): 76 | 77 | val_mask_t = val_location[t,:,1:5].sum(-1) != 0. 78 | 79 | val_maskers, val_mask_dict = spect_during_testing(val_mask_t, val_maskers, val_mask_dict, t, val_location) 80 | 81 | val_maskers_np = np.array(val_maskers) 82 | val_infor["maskers"] = val_maskers_np 83 | val_infor["missing_status"] = val_mask_dict 84 | np.save(os.path.join(test_root, j), val_infor) 85 | print("The Testing process has been {:.2f} % done!".format((val_number + 1) / len(val_file_list) * 100)) 86 | 87 | def generate_weight(location, graph_edges, mask): 88 | 89 | 90 | weights = np.zeros((location.shape[0], len(graph_edges), ), dtype=np.float32) 91 | location_n = location.copy() 92 | location_n[:, :, [1,3]] /= 1280 93 | location_n[:, :, [2,4]] /= 720 94 | for i, edge in enumerate(graph_edges): 95 | 96 | c1 = [0.5 * (location_n[:, edge[0], 1] + location_n[:, edge[0], 3]), 97 | 0.5 * (location_n[:, edge[0], 2] + location_n[:, edge[0], 4])] 98 | c2 = [0.5 * (location_n[:, edge[1], 1] + location_n[:, edge[1], 3]), 99 | 0.5 * (location_n[:, edge[1], 2] + location_n[:, edge[1], 4])] 100 | 101 | d = (c1[0] - c2[0])**2 + (c1[1] - c2[1])**2 102 | 103 | weights[:, i] = np.exp(-d) 104 | frame_index = ((mask[:, edge[0]]==True) & (mask[:, edge[1]] == True)) 105 | weights[frame_index==False, i] = 0 106 | 107 | # normalize 108 | weights_nor = (weights - np.min(weights, axis=1, keepdims=True)) / (np.max(weights, axis=1, keepdims=True) - np.min(weights, axis=1, keepdims=True)) 109 | # account for NaN 110 | weights_nor = np.nan_to_num(weights_nor) 111 | 112 | return weights_nor 113 | 114 | def get_motion_distance(id_0, id_1, location_all): 115 | md = [] 116 | center_point = location_all[:,:,3:-1] - location_all[:,:,1:3] 117 | for t in range(100): 118 | if t <= 4: 119 | d = get_normolize_distance(id_0, id_1, 0, t, center_point) 120 | else: 121 | d = get_normolize_distance(id_0, id_1, t-4, t, center_point) 122 | md.append(d) 123 | 124 | md = np.array(md) 125 | return md 126 | 127 | def get_normolize_distance(id_0, id_1, t1, t2, center_point): 128 | v_id_0 = normalize(((center_point[t1, id_0, 0], center_point[t1, id_0, 1]),(center_point[t2, id_0, 0], center_point[t2, id_0, 1]))) 129 | v_id_1 = normalize(((center_point[t1, id_1, 0], center_point[t1, id_1, 1]),(center_point[t2, id_1, 0], center_point[t2, id_1, 1]))) 130 | 131 | # Eud_distance 132 | d_1 = ((v_id_0[1][0] - v_id_1[1][0])**2 + (v_id_0[1][1] - v_id_1[1][1])**2)**0.5 #箭头距离 133 | d_2 = ((v_id_0[0][0] - v_id_1[0][0])**2 + (v_id_0[0][1] - v_id_1[0][1])**2)**0.5 #箭尾距离 134 | 135 | md = d_1 - d_2 136 | 137 | return md 138 | 139 | def normalize(v): 140 | delta_x = v[1][0] - v[0][0] 141 | delta_y = v[1][1] - v[0][1] 142 | length = (delta_x**2 + delta_y**2)**0.5 143 | if length != 0: 144 | nor_delta_x = delta_x/length 145 | nor_delta_y = delta_y/length 146 | return ((v[0][0], v[0][1]),(v[0][0]+nor_delta_x, v[0][1]+nor_delta_y)) 147 | else: 148 | return v 149 | 150 | 151 | def generate_graph_from_list(L, create_using=None): 152 | G = networkx.empty_graph(len(L), create_using) 153 | if len(L) > 1: 154 | if G.is_directed(): 155 | edges = itertools.permutations(L,2) 156 | else: 157 | edges = itertools.combinations(L,2) 158 | G.add_edges_from(edges) 159 | graph_edges = list(G.edges()) 160 | 161 | return graph_edges 162 | 163 | 164 | def spect_during_testing(mask_present, 165 | all_mask, 166 | mask_dict, 167 | t, 168 | location_all): 169 | 170 | if len(all_mask) == 0: 171 | all_mask.append(mask_present) 172 | 173 | return all_mask, mask_dict 174 | 175 | else: 176 | before_mask = all_mask[-1] 177 | # compare from True to False 178 | index = np.where((before_mask == True) & (mask_present ==False)) 179 | if min(index[0].shape) != 0: 180 | for missing_object in index[0]: 181 | #print("missing_object", missing_object) 182 | near_future_infor = location_all[(t+1):(t+miss_wait), missing_object, 1:5] 183 | #print("near_future_infor", near_future_infor) 184 | near_future_status = np.where((near_future_infor.sum(-1) != 0) == True) 185 | if min(near_future_status[0].shape) != 0: 186 | mask_dict[t-1, missing_object] = 2. 187 | else: 188 | mask_dict[t-1, missing_object] = 1. 189 | 190 | all_mask.append(mask_present) 191 | return all_mask, mask_dict 192 | 193 | else: 194 | all_mask.append(mask_present) 195 | return all_mask, mask_dict 196 | 197 | def fill_during_training(mask_present, 198 | all_mask, 199 | mask_dict, 200 | t, 201 | location_all, 202 | feature_all 203 | ): 204 | # params for waiting time 205 | 206 | if len(all_mask) == 0: 207 | all_mask.append(mask_present) 208 | 209 | return all_mask, mask_dict, location_all, feature_all 210 | else: 211 | before_mask= all_mask[-1] 212 | index = np.where((before_mask==True) & (mask_present==False)) 213 | if min(index[0].shape) != 0: 214 | #print("missing index", index) 215 | for missing_object in index[0]: 216 | #print("Missing objects", missing_object) 217 | near_future_infor = location_all[(t+1):(t+miss_wait), missing_object,1:5] 218 | near_future_status = np.where((near_future_infor.sum(-1) != 0) == True) 219 | if min(near_future_status[0].shape) != 0: 220 | 221 | mask_dict[t-1, missing_object] = 2. 222 | nearest_future = near_future_status[0][0] 223 | location_all[t:(t+nearest_future+1),missing_object, 1:5] = fill_functions( 224 | location_all[(t-5 if t>5 else 0):t, missing_object,1:5], 225 | location_all[t+nearest_future+1, missing_object, 1:5], 226 | nearest_future+1) 227 | feature_all[t:(t+nearest_future+1), missing_object, 1:] = feature_all[t-1,missing_object, 1:] 228 | 229 | else: 230 | mask_dict[t-1, missing_object] = 1. 231 | 232 | mask_present = location_all[t,:,1:5].sum(-1) != 0. 233 | 234 | all_mask.append(mask_present) 235 | 236 | return all_mask, mask_dict, location_all, feature_all 237 | 238 | else: 239 | 240 | all_mask.append(mask_present) 241 | return all_mask, mask_dict, location_all, feature_all 242 | 243 | def fill_functions(past_movement, near_future, fill_number): 244 | fill_matrix = np.full((fill_number, 4), np.nan) 245 | interplot = np.vstack((past_movement, fill_matrix, near_future)) 246 | df_interplot = pd.DataFrame(interplot) 247 | s = df_interplot.interpolate() 248 | interplot = np.around(np.array(s)) 249 | fill_matrix = interplot[(-(fill_number+1)):-1] 250 | 251 | return fill_matrix 252 | 253 | def generate_corresponding_txt(file_list, root, module="train"): 254 | """ 255 | all valid module is train or valid 256 | """ 257 | str = '\n' 258 | f = open(os.path.join(root, module + "_video_list.txt"), "w") 259 | f.write(str.join(file_list)) 260 | f.close() 261 | 262 | proprecess_data(data_root, train_root, test_root) --------------------------------------------------------------------------------