├── README.md
├── configs
└── config.py
├── figures
├── archi.png
├── em.png
├── image52.gif
└── pytorch-logo-dark.png
├── msp
├── __init__.py
├── datasets
│ ├── __init__.py
│ ├── builder.py
│ ├── dashcams.py
│ ├── pipelines
│ │ ├── __init__.py
│ │ └── pipelines.py
│ └── utils
│ │ ├── __init__.py
│ │ └── metrics.py
├── lightning_model.py
├── models
│ ├── __init__.py
│ ├── accidentblocks
│ │ ├── __init__.py
│ │ └── accident_lstm.py
│ ├── builder.py
│ ├── gates
│ │ ├── __init__.py
│ │ └── s_t_masker.py
│ ├── losses
│ │ ├── __init__.py
│ │ └── mse_loss.py
│ ├── nearfuture
│ │ ├── __init__.py
│ │ └── near_future.py
│ ├── predictors
│ │ ├── __init__.py
│ │ └── all_batch_version.py
│ ├── spatialgraphs
│ │ ├── __init__.py
│ │ └── spatial_graph_batch.py
│ ├── temporalattns
│ │ ├── __init__.py
│ │ └── s_t_attn.py
│ ├── temporalgraphs
│ │ ├── __init__.py
│ │ └── temporal_graph_batch.py
│ └── utils
│ │ ├── __init__.py
│ │ └── utils.py
└── utils
│ ├── __init__.py
│ └── train_utils.py
├── requirements.txt
├── run.py
└── scripts
├── preprecess_data_MASKER_MD.py
├── preprecess_data_MD-SG.py
└── preprocess_data_MASKER.py
/README.md:
--------------------------------------------------------------------------------
1 | # GSC: A Graph and Spatio-temporal Continuity Based Framework for Accident Anticipation
2 |
3 |
[](https://opensource.org/licenses/MIT)
4 |
5 | The repository contains the source code and pred-trained models of our paper: `GSC: A Graph and Spatio-temporal Continuity Based Framework for Accident Anticipation`
6 |
7 |
8 |
9 |
10 |
11 | ## Architecture
12 | The overview of the network is showed below;
13 |
14 |
15 |
16 |
17 |
18 | ## Prerequisites
19 |
20 | - Python 3.6
21 | - Pytorch 1.7.0
22 | - Pytorch-Lightning 0.9.0
23 | - Other required packages in `requirements.txt`
24 |
25 | ## Getting Started
26 |
27 | ### Create conda environment
28 |
29 | ```bash
30 | conda create -n sspm python=3.6
31 | source activate sspm
32 | ```
33 |
34 | ### Install the required packages
35 |
36 | ```bash
37 | pip install -r requirements.txt
38 | ```
39 |
40 | ### Downloading MASKER_MD dataset and unzip it
41 |
42 | - Access the datasets by [BaiduYun](https://pan.baidu.com/s/1TT3AZBBuE-u_zovl6i44iQ)[Passwards: `qo81`], and unzip it.
43 |
44 | - Change the `data_root`
45 | of `configs/config.py` to your unzip path;
46 |
47 | ## Train the Model
48 |
49 | - Run the following command in Terminal:
50 | ```bash
51 | python run.py --train ./configs/config.py
52 | ```
53 |
54 | ## Test the Model
55 |
56 | - Change the `test_checkpoint` of `configs/config.py` to your model
57 |
58 | - Run the following command in Terminal
59 | ```bash
60 | python run.py --test ./configs/config.py
61 | ```
62 |
63 | ## Visualize
64 |
65 |
66 |
67 |
68 |
--------------------------------------------------------------------------------
/configs/config.py:
--------------------------------------------------------------------------------
1 | name = "STAttnsGraphDashcam"
2 | version = "2.0.0"
3 | model = dict(
4 | type='STAttnsGraphSimAllBatch',
5 | spatial_graph=dict(
6 | type='SpatialGraphBatch',
7 | node_feature=2048,
8 | hidden_feature=1024,
9 | out_feature=512,
10 | ),
11 | temporal_graph=dict(
12 | type='TemporalGraphBatch',
13 | past_num=5,
14 | input_feature=512,
15 | out_feature=512,
16 | ),
17 | gate=dict(
18 | type='MaskGate',
19 | past_num = 10,
20 | ),
21 | temporal_attn=dict(
22 | type='STAttn'
23 | ),
24 | near_future=dict(
25 | type="NearFuture"
26 | ),
27 | accident_block=dict(
28 | type='AccidentLSTM',
29 | temporal_feature=256,
30 | hidden_feature=64,
31 | num_layers=2
32 | ),
33 | loss=dict(
34 | type="LogLoss"
35 | ))
36 | dataset_type = "DashCam"
37 | data_root = "/media/group1/data/tianhang/MASKER_MD"
38 | data = dict(
39 | batch_size=32,
40 | num_workers=1,
41 | train=dict(
42 | type=dataset_type,
43 | root_dir=data_root,
44 | #pipelines=pipelines,
45 | video_list_file=f"{data_root}/train_video_list.txt"),
46 | val=dict(
47 | type=dataset_type,
48 | root_dir=data_root,
49 | #pipelines=pipelines,
50 | video_list_file=f"{data_root}/valid_video_list.txt"))
51 | # training and testing settings
52 | optimizer_cfg = dict(
53 | type='Adam',
54 | lr=0.01,
55 | betas=(0.9, 0.999),
56 | )
57 | lr_cfg = dict(
58 | type="StepLR",
59 | step_size=50,
60 | gamma=.7)
61 | warm_up_cfg = dict(
62 | type="Exponential",
63 | step_size=5000)
64 | random_seed = 1234
65 | # GPU
66 | num_gpus = [2]
67 | max_epochs = 200
68 | checkpoint_path = "work_dirs/checkpoints"
69 | log_path = "work_dirs/logs"
70 | result_path = "work_dirs/results"
71 | load_from_checkpoint = None
72 | resume_from_checkpoint = None
73 | test_checkpoint = None
74 | batch_accumulate_size = 1
75 | simple_profiler = True
--------------------------------------------------------------------------------
/figures/archi.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispc-lab/GSC/09fa815f1d092bd5a42a5d45b9266a210677dee7/figures/archi.png
--------------------------------------------------------------------------------
/figures/em.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispc-lab/GSC/09fa815f1d092bd5a42a5d45b9266a210677dee7/figures/em.png
--------------------------------------------------------------------------------
/figures/image52.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispc-lab/GSC/09fa815f1d092bd5a42a5d45b9266a210677dee7/figures/image52.gif
--------------------------------------------------------------------------------
/figures/pytorch-logo-dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispc-lab/GSC/09fa815f1d092bd5a42a5d45b9266a210677dee7/figures/pytorch-logo-dark.png
--------------------------------------------------------------------------------
/msp/__init__.py:
--------------------------------------------------------------------------------
1 | from .lightning_model import LightningModel
2 |
3 |
4 | __all__ = ['LightningModel']
--------------------------------------------------------------------------------
/msp/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import DATASETS, build_dataset
2 | from .dashcams import *
3 | from .pipelines import *
4 |
5 | __all__ = ["DATASETS", "build_dataset"]
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/msp/datasets/builder.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from mmcv.utils import Registry, build_from_cfg
4 |
5 | DATASETS = Registry('datasets')
6 | PIPELINES = Registry('pipelines')
7 |
8 |
9 | def build_dataset(cfg: dict, default_args: Optional[dict] = None):
10 | return build_from_cfg(cfg, DATASETS, default_args)
11 |
12 |
13 | def build_pipelines(cfg: dict, default_args: Optional[dict] = None):
14 | return build_from_cfg(cfg, PIPELINES, default_args)
--------------------------------------------------------------------------------
/msp/datasets/dashcams.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from tqdm import tqdm
4 | import torch
5 | from torch.utils.data.dataset import Dataset
6 | from .builder import DATASETS
7 | from .pipelines import Compose
8 |
9 |
10 | @DATASETS.register_module()
11 | class DashCam(Dataset):
12 |
13 | def __init__(self,
14 | root_dir: str,
15 | video_list_file: str = None,
16 | pipelines: dict = None):
17 |
18 | self.root_dir = root_dir
19 |
20 | self.pipelines = Compose(pipelines) if pipelines is not None else None
21 |
22 | self.__sample__ = open(video_list_file, "r").read().splitlines()
23 |
24 | def __getitem__(self, idx: int) -> dict:
25 | data = np.load(os.path.join(self.root_dir, self.__sample__[idx]), allow_pickle=True).item()
26 |
27 | data['location'][:, :, [1, 3]] /= 1280
28 | data['location'][:, :, [2, 4]] /= 720
29 |
30 | for phase in data:
31 | if (phase != "accident") & (phase != "graph_index"):
32 | data[phase] = torch.from_numpy(data[phase])
33 | elif phase == "graph_index":
34 | data[phase] = torch.tensor(data[phase]).t().contiguous()
35 |
36 | data['video_id'] = self.__sample__[idx]
37 |
38 | return self.pipelines(data) if self.pipelines is not None else data
39 |
40 | def __len__(self) -> int:
41 | return len(self.__sample__)
--------------------------------------------------------------------------------
/msp/datasets/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipelines import *
2 |
3 |
4 | #__all__ = ['Compose', 'ToTensor', 'Normalize']
5 | __all__ = ['Compose']
--------------------------------------------------------------------------------
/msp/datasets/pipelines/pipelines.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Sequence, Dict, Union, Callable
2 | import numpy as np
3 | import torch
4 | from torch import Tensor
5 | from ..builder import PIPELINES, build_pipelines
6 |
7 |
8 | class Compose(object):
9 |
10 | def __init__(self, transforms: Sequence[Union[dict, Callable]]):
11 | self.transforms = []
12 | for transform in transforms:
13 | if isinstance(transform, dict):
14 | transform = build_pipelines(transform)
15 | self.transforms.append(transform)
16 | elif callable(transform):
17 | self.transforms.append(transform)
18 | else:
19 | raise TypeError('transform must be callable or a dict')
20 |
21 | def __call__(self, data: Dict[str, Dict[str, Union[Tensor, np.ndarray]]]) \
22 | -> Dict[str, Dict[str, Union[Tensor, np.ndarray]]]:
23 |
24 | for t in self.transforms:
25 | data = t(data)
26 | return data
27 |
28 | def __repr__(self):
29 | format_string = self.__class__.__name__ + '('
30 | for t in self.transforms:
31 | format_string += '\n'
32 | format_string += f' {t}'
33 | format_string += '\n)'
34 | return format_string
--------------------------------------------------------------------------------
/msp/datasets/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .metrics import evaluate_metric
2 |
3 | __all__ = ['evaluate_metric']
--------------------------------------------------------------------------------
/msp/datasets/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def evaluate_metric(pred_np, label, fps=25):
5 | all_pred = pred_np[:,:,1] # shape [batch_size, frames]
6 | accident_status = label['accident'].cpu().numpy() # shape [batch_size, ]
7 | accident_time = 76
8 | pred_eval = []
9 | min_pred = np.inf
10 | n_frames = 0
11 |
12 |
13 | # access the frames before accident
14 | for idx, toa in enumerate(accident_status):
15 | if toa == True:
16 | pred = all_pred[idx, :int(accident_time)] # positive video
17 | else:
18 | pred = all_pred[idx, :] # negtive video
19 | # find the minimum prediction
20 | min_pred = np.min(pred) if min_pred > np.min(pred) else min_pred
21 | pred_eval.append(pred)
22 | n_frames += len(pred)
23 | total_seconds = all_pred.shape[1] / fps
24 | # iterate a set of thresholds from the minimum predications
25 | threholds = np.arange(max(min_pred,0), 1.0, 0.001)
26 | threholds_num = threholds.shape[0]
27 | Precision = np.zeros((threholds_num))
28 | Recall = np.zeros((threholds_num))
29 | Time = np.zeros((threholds_num))
30 | cnt = 0
31 | for Th in threholds:
32 | Tp = 0.0
33 | Tp_Fp = 0.0
34 | Tp_Tn = 0.0
35 | time = 0.0
36 | counter = 0.0 # number of TP videos
37 | # iterate each video sample
38 | for i in range(len(pred_eval)):
39 | # ture positive frames: (pred->1) & (gt->True)
40 | tp = np.where(pred_eval[i] * accident_status[i] >=Th)
41 | Tp += float(len(tp[0])>0)
42 | if float(len(tp[0])>0) > 0:
43 | time += tp[0][0] / float(accident_time)
44 | counter = counter + 1
45 | Tp_Fp += float(len(np.where(pred_eval[i]>=Th)[0])>0)
46 |
47 | if Tp_Fp == 0:
48 | continue
49 | else:
50 | Precision[cnt] = Tp/Tp_Fp
51 | if np.sum(accident_status)==0:
52 | continue
53 | else:
54 | Recall[cnt] = Tp/np.sum(accident_status)
55 | if counter == 0:
56 | continue
57 | else:
58 | Time[cnt] = (1-time/counter)
59 | cnt += 1
60 |
61 | new_index = np.argsort(Recall)
62 | Precision = Precision[new_index]
63 | Recall = Recall[new_index]
64 |
65 | p_r_plot = {'precision':Precision,
66 | 'recall':Recall}
67 |
68 | Time = Time[new_index]
69 | # Unique the recall
70 | _, rep_index = np.unique(Recall, return_index=1)
71 | rep_index = rep_index[1:]
72 | new_Time = np.zeros(len(rep_index))
73 | new_Precision = np.zeros(len(rep_index))
74 | for i in range(len(rep_index)-1):
75 | new_Time[i] = np.max(Time[rep_index[i]:rep_index[i+1]])
76 | new_Precision[i] = np.max(Precision[rep_index[i]:rep_index[i+1]])
77 | # sort by descending order
78 | if rep_index.size != 0:
79 | new_Time[-1] = Time[rep_index[-1]]
80 | new_Precision[-1] = Precision[rep_index[-1]]
81 | new_Recall = Recall[rep_index]
82 | else:
83 | new_Recall = Recall
84 | # compute AP (area under P-R curve)
85 | AP = 0.0
86 | if new_Recall[0] != 0:
87 | AP += new_Precision[0]*(new_Recall[0]-0)
88 | for i in range(1,len(new_Precision)):
89 | AP += (new_Precision[i-1]+new_Precision[i])*(new_Recall[i]-new_Recall[i-1])/2
90 |
91 | # transform the relative mTTA to seconds
92 | mTTA = np.mean(new_Time) * total_seconds
93 | #print("Average Precision= %.4f, mean Time to accident= %.4f"%(AP, mTTA))
94 | sort_time = new_Time[np.argsort(new_Recall)]
95 | sort_recall = np.sort(new_Recall)
96 | TTA_R80 = sort_time[np.argmin(np.abs(sort_recall-0.8))] * total_seconds
97 | #print("Recall@80%, Time to accident= " +"{:.4}".format(TTA_R80))
98 |
99 | return AP, mTTA, TTA_R80, p_r_plot
100 |
101 |
102 |
--------------------------------------------------------------------------------
/msp/lightning_model.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | import pytorch_lightning as pl
4 | import torch.optim as optim
5 | import torch
6 | from mmcv.utils import Config
7 | from torch.utils.data import DataLoader
8 | from .models import build_predictor
9 | from .datasets.utils.metrics import evaluate_metric
10 | from .datasets import build_dataset
11 | from matplotlib import pyplot as plt
12 |
13 | import os
14 | from .utils.train_utils import summarize_metric
15 |
16 | from torch.utils.tensorboard import SummaryWriter
17 |
18 | class LightningModel(pl.LightningModule):
19 |
20 | def __init__(self, cfg: Config):
21 | super(LightningModel, self).__init__()
22 |
23 | self.cfg = cfg
24 | self.model_cfg = cfg.model
25 | self.data_cfg = cfg.data
26 | self.optim_cfg = cfg.optimizer_cfg
27 | self.lr_cfg = cfg.lr_cfg
28 |
29 | self.model = build_predictor(self.model_cfg)
30 | self.hparams = dict(lr=self.optim_cfg.lr,
31 | batch_size=self.data_cfg.batch_size * cfg.batch_accumulate_size)
32 |
33 | self.whether_save = True
34 | if self.whether_save:
35 | self.writer = SummaryWriter(comment="Training")
36 | self.writer.add_hparams({'lr':self.optim_cfg.lr, 'bsize':self.data_cfg.batch_size * cfg.batch_accumulate_size}, {})
37 | self.train_step = 1
38 | self.val_step = 1
39 | self.tests_step = 1
40 |
41 | def forward(self, data, **kwargs):
42 | pred, masker_pred = self.model(data, **kwargs)
43 | return pred, masker_pred
44 |
45 | def training_step(self, batch, batch_idx):
46 | data = batch
47 | pred, masker_pred = self(data, module="Train")
48 | acdt_loss, masker_loss = self.model.loss(pred, masker_pred, data)
49 | loss = acdt_loss + masker_loss
50 |
51 | with torch.no_grad():
52 | masker_pred_np = masker_pred.cpu().numpy()
53 |
54 | shelter_precision = (masker_pred_np[masker_pred_np[:,-1] == 1., 1] >= 0.5).sum() / (masker_pred_np[:,-1] == 1.).sum()
55 | go_away_precision = (masker_pred_np[masker_pred_np[:,-1] == 0., 0] >= 0.5).sum() / (masker_pred_np[:,-1] == 0.).sum()
56 | precision = (np.sum((np.where(masker_pred_np[:, :2] >= 0.50)[1] == masker_pred_np[:, -1])!=0)) / masker_pred_np.shape[0]
57 |
58 | if self.whether_save:
59 | self.writer.add_scalar('Training/acdt_loss', acdt_loss, self.train_step)
60 | self.writer.add_scalar('Training/masker_loss', masker_loss, self.train_step)
61 | self.writer.add_scalar('Training/shelter_precision', shelter_precision, self.train_step)
62 | self.writer.add_scalar('Training/go_away_precision', go_away_precision, self.train_step)
63 | self.writer.add_scalar('Training/precision', precision, self.train_step)
64 |
65 |
66 | logs = {"loss": loss}
67 | self.train_step += 1
68 |
69 | return {"loss": loss,
70 | "log": logs}
71 |
72 | def validation_step(self, batch, batch_idx):
73 | data = batch
74 | pred_np, pred_gate_np = self.model.predict(data, module="Test")
75 | ap, mtta, tta_r80, p_r_plot = evaluate_metric(pred_np, data)
76 |
77 | self.val_step += 1
78 |
79 | ap_list = []
80 | mtta_list = []
81 | tta_r80_list = []
82 |
83 | ap_list.append(ap)
84 | mtta_list.append(mtta)
85 | tta_r80_list.append(tta_r80)
86 |
87 | logs = {"ap":ap_list, "mtta":mtta_list, "tta_r80":tta_r80_list}
88 | return logs
89 |
90 | def validation_epoch_end(self, output):
91 | average_ap, average_mtta, average_tta_r80 = summarize_metric(output)
92 | if self.whether_save:
93 | self.writer.add_scalar('Metrics/AP', average_ap, self.tests_step)
94 | self.writer.add_scalar('Metrics/mtta', average_mtta, self.tests_step)
95 | self.writer.add_scalar('Metrics/TTA_R80', average_tta_r80, self.tests_step)
96 | self.tests_step += 1
97 | return {"output": output}
98 |
99 | def test_step(self, batch, batch_idx):
100 | data = batch
101 | pred_np = self.model.predict(data, module="Test")
102 | for i in range(data['location'].shape[0]):
103 | root_fill = './results/' + str(149) + "/"
104 | if not os.path.exists(root_fill):
105 | os.makedirs(root_fill)
106 | file_name = str(data["video_id"][i]) + ".csv"
107 | np.savetxt(root_fill + file_name, pred_np[i][:,1], delimiter=',')
108 |
109 | return None
110 |
111 | def prepare_data(self):
112 | self.train_dataset = build_dataset(self.data_cfg.train)
113 | self.val_dataset = build_dataset(self.data_cfg.val)
114 |
115 | def train_dataloader(self):
116 | return DataLoader(self.train_dataset,
117 | batch_size=self.data_cfg.batch_size,
118 | shuffle=False,
119 | num_workers=self.data_cfg.num_workers)
120 |
121 | def val_dataloader(self):
122 | return DataLoader(self.val_dataset,
123 | batch_size=self.data_cfg.batch_size,
124 | shuffle=False,
125 | num_workers=self.data_cfg.num_workers)
126 |
127 | def test_dataloader(self):
128 | return DataLoader(self.val_dataset,
129 | batch_size=1,
130 | shuffle=False,
131 | num_workers=self.data_cfg.num_workers)
132 |
133 | def configure_optimizers(self):
134 |
135 | optim_cfg = self.optim_cfg.copy()
136 | optim_class = getattr(optim, optim_cfg.pop("type"))
137 | optimizer = optim_class(self.parameters(), **optim_cfg)
138 |
139 | lr_cfg = self.lr_cfg.copy()
140 | lr_sheduler_class = getattr(optim.lr_scheduler, lr_cfg.pop("type"))
141 | scheduler = {
142 | "scheduler": lr_sheduler_class(optimizer, **lr_cfg),
143 | "monitor": 'avg_val_loss',
144 | "interval": "epoch",
145 | "frequency": 1
146 | }
147 |
148 | return [optimizer], [scheduler]
149 |
150 | def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure = None,
151 | on_tpu: bool = False, using_native_amp=False, using_lbfgs: bool = False):
152 |
153 | warm_up_type, warm_up_step = self.cfg.warm_up_cfg.type, self.cfg.warm_up_cfg.step_size
154 | if warm_up_type == 'Exponential':
155 | lr_scale = self.model_cfg.spatial_graph.hidden_feature ** -0.5
156 | lr_scale *= min((self.trainer.global_step + 1) ** (-0.5),
157 | (self.trainer.global_step + 1) * warm_up_step ** (-1.5))
158 | elif warm_up_type == "Linear":
159 | lr_scale = min(1., float(self.trainer.global_step + 1) / warm_up_step)
160 | else:
161 | raise NotImplementedError
162 |
163 | for pg in optimizer.param_groups:
164 | # import pdb; pdb.set_trace()
165 | if self.whether_save:
166 | self.writer.add_scalar('Training/lr', pg['lr'], self.train_step)
167 | pg['lr'] = lr_scale * self.hparams.lr
168 |
169 | optimizer.step()
170 | optimizer.zero_grad()
171 |
172 |
173 |
174 |
175 |
176 |
177 |
--------------------------------------------------------------------------------
/msp/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import (SGMODEL, TGMODEL, ATTNMODEL, ACCIDENTMODEL, PREDICTORS, LOSSES, GATEMODEL, NEARFUTUREMODEL,
2 | build_spatial, build_temporal, build_attn, build_accident_block, build_predictor, build_loss, build_gate, build_nearfuture)
3 |
4 | from .spatialgraphs import *
5 | from .temporalgraphs import *
6 | from .temporalattns import *
7 | from .nearfuture import *
8 | from .accidentblocks import *
9 | from .gates import *
10 | from .predictors import *
11 | from .losses import *
12 |
13 |
14 | __all__ = ["SGMODEL", "TGMODEL", "PREDICTORS", "ATTNMODEL", "ACCIDENTMODEL", "LOSSES", "GATEMODEL", "NEARFUTUREMODEL",
15 | "build_spatial", "build_temporal", "build_predictor", "build_attn", "build_accident_block", "build_loss", "build_gate", "build_nearfuture"]
16 |
17 |
18 |
--------------------------------------------------------------------------------
/msp/models/accidentblocks/__init__.py:
--------------------------------------------------------------------------------
1 | from .accident_lstm import AccidentLSTM
2 |
3 | __all__ = ['AccidentLSTM']
--------------------------------------------------------------------------------
/msp/models/accidentblocks/accident_lstm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from ..builder import ACCIDENTMODEL
5 |
6 |
7 | @ACCIDENTMODEL.register_module()
8 | class AccidentLSTM(nn.Module):
9 | def __init__(self,
10 | temporal_feature: int,
11 | hidden_feature: int,
12 | num_layers: int):
13 | super(AccidentLSTM, self).__init__()
14 |
15 | self.acc_lstm = torch.nn.LSTM(temporal_feature, hidden_feature, num_layers=num_layers)
16 | self.hidden_feature = hidden_feature
17 | self.pred = nn.Sequential(
18 | nn.Linear(hidden_feature, 64),
19 | nn.ReLU(),
20 | nn.Linear(64, 2),
21 | nn.Softmax(dim=-1))
22 |
23 | def forward(self,
24 | inputs,
25 | h0,
26 | c0):
27 |
28 | output, (hn, cn) = self.acc_lstm(inputs, (h0, c0))
29 |
30 | return self.pred(output)
31 |
32 | def _reset_parameters(self):
33 | for p in self.parameters():
34 | if p.dim() > 1:
35 | nn.init.xavier_uniform_(p)
36 |
37 |
--------------------------------------------------------------------------------
/msp/models/builder.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Optional, Any, Dict, List
2 | import torch.nn as nn
3 | from mmcv.utils import Registry, Config, build_from_cfg
4 |
5 |
6 | SGMODEL = Registry("spatialgraphs")
7 | TGMODEL = Registry("temporalgraphs")
8 | ATTNMODEL = Registry("temporalattns")
9 | NEARFUTUREMODEL = Registry("nearfuture")
10 | ACCIDENTMODEL = Registry("accidentblocks")
11 | GATEMODEL = Registry("gates")
12 | LOSSES = Registry("losses")
13 | PREDICTORS = Registry("predictors")
14 |
15 |
16 | def build(cfg: Union[Dict, List[Dict]],
17 | registry: Registry,
18 | default_args: Optional[Dict] = None) -> Any:
19 | """Build a module.
20 |
21 | Args:
22 | cfg: The config of modules, is is either a dict or a list of configs.
23 | registry: A registry the module belongs to.
24 | default_args: Default arguments to build the module. Defaults to None.
25 |
26 | Returns:
27 | nn.Module: A built nn module.
28 | """
29 | if isinstance(cfg, list):
30 | modules = [build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg]
31 | return nn.Sequential(*modules)
32 | else:
33 | return build_from_cfg(cfg, registry, default_args)
34 |
35 |
36 | def build_spatial(cfg: Union[Dict, List[Dict]]) -> Any:
37 | """Build spatial graph"""
38 | return build(cfg, SGMODEL)
39 |
40 |
41 | def build_temporal(cfg: Union[Dict, List[Dict]]) -> Any:
42 | """Build temporal graph"""
43 | return build(cfg, TGMODEL)
44 |
45 |
46 | def build_attn(cfg: Union[Dict, List[Dict]]) -> Any:
47 | """Build temporal attn"""
48 | return build(cfg, ATTNMODEL)
49 |
50 | def build_nearfuture(cfg: Union[Dict, List[Dict]]) -> Any:
51 | """Build temporal attn"""
52 | return build(cfg, NEARFUTUREMODEL)
53 |
54 |
55 | def build_accident_block(cfg: Union[Dict, List[Dict]]) -> Any:
56 | """Build accident block"""
57 | return build(cfg, ACCIDENTMODEL)
58 |
59 |
60 | def build_gate(cfg: Union[Dict, List[Dict]]) -> Any:
61 | """Build gate"""
62 | return build(cfg, GATEMODEL)
63 |
64 |
65 | def build_loss(cfg: Union[Dict, List[Dict]]) -> Any:
66 | """Build loss"""
67 | return build(cfg, LOSSES)
68 |
69 |
70 | def build_predictor(cfg: Union[Dict, List[Dict]]) -> Any:
71 | """Build model"""
72 | return build(cfg, PREDICTORS)
73 |
74 |
75 |
76 |
77 |
--------------------------------------------------------------------------------
/msp/models/gates/__init__.py:
--------------------------------------------------------------------------------
1 | from .s_t_masker import MaskGate
2 |
3 | __all__ = ['MaskGate']
--------------------------------------------------------------------------------
/msp/models/gates/s_t_masker.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from ..builder import GATEMODEL
4 |
5 |
6 | @GATEMODEL.register_module()
7 | class MaskGate(torch.nn.Module):
8 | def __init__(self, past_num):
9 |
10 | super(MaskGate, self).__init__()
11 | self.gate_cnn = nn.Sequential(
12 | nn.Conv2d(512, 256, (1,1), stride=1),
13 | nn.ReLU(),
14 | nn.Conv2d(256, 128, (1,1), stride=1),
15 | nn.ReLU()
16 | )
17 | self.gate_fc = nn.Sequential(
18 | nn.Linear(128, 64),
19 | nn.ReLU(),
20 | nn.Linear(64, 2),
21 | nn.Softmax(dim=-1)
22 | )
23 | self.past_num = past_num
24 |
25 | def forward(self,
26 | param_dict,
27 | module):
28 |
29 | if module == "Train":
30 | pred = self.masker_train(param_dict["select_feature"])
31 | return pred
32 |
33 | elif module == "Test":
34 | m_t, m_d, fill = self.masker_test(
35 | param_dict["all_mask"],
36 | param_dict["present_mask"],
37 | param_dict["feature_bank"],
38 | param_dict["time"],
39 | param_dict["location_all"],
40 | param_dict["masker_dict"])
41 | return m_t, m_d, fill
42 |
43 | def masker_train(self,
44 | select_feature):
45 |
46 | select_feature_un = select_feature.unsqueeze(-1).unsqueeze(-1)
47 | temp = self.gate_cnn(select_feature_un)
48 | temp_sq = temp.squeeze(-1).squeeze(-1)
49 | pred_status = self.gate_fc(temp_sq)
50 |
51 | return pred_status
52 |
53 |
54 | def masker_test(self,
55 | all_mask,
56 | present_mask,
57 | temporal_feature_bank,
58 | t,
59 | location_all,
60 | masker_dict):
61 |
62 | if len(all_mask) == 0:
63 |
64 | return present_mask, masker_dict, (torch.empty(([0,])).type_as(location_all),
65 | torch.empty(([0,])).type_as(location_all))
66 | else:
67 | # mask before
68 | before_mask = all_mask[-1] # shape: batch_size x objects
69 | # compare from True to False
70 | index = torch.where((before_mask==True) & (present_mask==False))
71 |
72 | if min(index[0].shape) != 0:
73 | un_index, indexed = self.get_index(t, index, masker_dict, present_mask)
74 | present_mask[indexed] = True
75 | temporal_feature_bank_tensor = torch.stack(temporal_feature_bank)[-1,:,:,:]
76 |
77 | tfbt_un = temporal_feature_bank_tensor[un_index].unsqueeze(-1).unsqueeze(-1)
78 | temp = self.gate_cnn(tfbt_un)
79 | temp_sq = temp.squeeze(-1).squeeze(-1)
80 | gate_status = self.gate_fc(temp_sq)
81 |
82 | near_feature_infor = location_all[:,t+1:t+self.past_num+1,:,:]
83 | n_index = near_feature_infor[un_index[0],:,un_index[1], :]
84 | test_label = torch.sum(n_index[:,:,1:5].flatten(start_dim=1), dim=1, keepdim=True).ge(0.01)
85 |
86 | shelter_objects = torch.where(gate_status[:, 1] >= 0.50)
87 |
88 | shelter_index = (un_index[0][shelter_objects], un_index[1][shelter_objects])
89 | present_mask[shelter_index] = True
90 | fill_position_index = (torch.cat((indexed[0], shelter_index[0])),
91 | torch.cat((indexed[1], shelter_index[1])))
92 |
93 | training_test_samples = torch.hstack((gate_status, test_label))
94 |
95 | masker_dict['index'][0] = torch.cat((masker_dict['index'][0], un_index[0]))
96 | masker_dict['index'][1] = torch.cat((masker_dict['index'][1], un_index[1]))
97 | masker_dict['pred_label'] = torch.cat((masker_dict['pred_label'], training_test_samples))
98 |
99 | for i in range(un_index[0].shape[0]):
100 | masker_dict['timestamp'].append(t)
101 |
102 | return present_mask, masker_dict, fill_position_index
103 |
104 | else:
105 |
106 | return present_mask, masker_dict, (torch.empty(([0,])).type_as(location_all),
107 | torch.empty(([0,])).type_as(location_all))
108 |
109 | def get_index(self,
110 | t,
111 | present_index,
112 | index_dict,
113 | present_mask):
114 |
115 | un_index = []
116 | indexed = []
117 | for i in range(present_index[0].shape[0]):
118 | batch_index = torch.where(index_dict['index'][0]==present_index[0][i])
119 | if min(batch_index[0].shape) != 0:
120 | object_index = torch.where(index_dict['index'][1][batch_index]==present_index[1][i])
121 | if min(object_index[0].shape) != 0 :
122 | # check time
123 | nearest_timestamps_index = batch_index[0][object_index[0]][-1]
124 | if (t - index_dict['timestamp'][nearest_timestamps_index] <= self.past_num) & \
125 | (present_mask[present_index[0][i].item(), present_index[1][i].item()].item() == False):
126 | indexed.append(i)
127 | else:
128 | un_index.append(i)
129 | else:
130 | un_index.append(i)
131 | else:
132 | un_index.append(i)
133 |
134 | return (present_index[0][un_index], present_index[1][un_index]), \
135 | (present_index[0][indexed], present_index[1][indexed])
136 |
137 | def _reset_parameters(self):
138 | for p in self.parameters():
139 | if p.dim() > 1:
140 | nn.init.xavier_uniform_(p)
--------------------------------------------------------------------------------
/msp/models/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .mse_loss import LogLoss
2 |
3 |
4 | __all__ = ['LogLoss']
5 |
6 |
--------------------------------------------------------------------------------
/msp/models/losses/mse_loss.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import torch
3 | import math
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch import Tensor
7 |
8 | from ..builder import LOSSES
9 |
10 |
11 | @LOSSES.register_module()
12 | class LogLoss(nn.Module):
13 |
14 | def __init__(self):
15 | super(LogLoss, self).__init__()
16 | self.accident = True
17 |
18 | def forward(self, frame, labels):
19 | if self.accident:
20 | loss = (- math.exp(-max(0, 76 - frame)) * torch.log(labels))
21 | else:
22 | loss = - torch.log(1 - labels)
23 |
24 | return loss
25 |
26 |
--------------------------------------------------------------------------------
/msp/models/nearfuture/__init__.py:
--------------------------------------------------------------------------------
1 | from .near_future import NearFuture
2 |
3 |
4 | __all__ = ['NearFuture']
5 |
6 |
--------------------------------------------------------------------------------
/msp/models/nearfuture/near_future.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from ..builder import NEARFUTUREMODEL
5 |
6 |
7 | @NEARFUTUREMODEL.register_module()
8 | class NearFuture(torch.nn.Module):
9 | def __init__(self) -> None:
10 | super().__init__()
11 |
12 | def forward(self,
13 | location_all,
14 | feature_all,
15 | t,
16 | fill_position_index):
17 | if t==0:
18 | return location_all[:, t], feature_all[:, t]
19 | else:
20 | # process for location prediction
21 | if min(fill_position_index[0].shape) != 0:
22 | if t <= 5:
23 | past_movement_bank = location_all[fill_position_index[0], :t, fill_position_index[1], :].data
24 | else:
25 | past_movement_bank = location_all[fill_position_index[0], (t-5):t, fill_position_index[1], :].data
26 |
27 | real_timestamp = past_movement_bank.shape[1]
28 | delta_movement = (past_movement_bank[:,-1] - past_movement_bank[:, 0])/(real_timestamp -1)
29 | pred_location = past_movement_bank[:, -1, 1:5] - delta_movement[:, 1:5]
30 | location_all[fill_position_index[0], t, fill_position_index[1], 1:5].data = pred_location
31 |
32 | # fill the feature
33 | past_feature = feature_all[fill_position_index[0], t-1, fill_position_index[1], :].data
34 | feature_all[fill_position_index[0], t, fill_position_index[1], :].data = past_feature
35 | return location_all[:, t], feature_all[:, t]
36 |
37 | else:
38 | return location_all[:, t], feature_all[:, t]
39 |
40 | def _reset_parameters(self):
41 | for p in self.parameters():
42 | if p.dim() > 1:
43 | nn.init.xavier_uniform_(p)
--------------------------------------------------------------------------------
/msp/models/predictors/__init__.py:
--------------------------------------------------------------------------------
1 | from .all_batch_version import STAttnsGraphSimAllBatch
2 |
3 | __all__ = ['STAttnsGraphSimAllBatch']
--------------------------------------------------------------------------------
/msp/models/predictors/all_batch_version.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from mmcv.utils.config import ConfigDict
4 | from ..utils import *
5 | from ..builder import PREDICTORS, build_spatial, build_temporal, build_attn, \
6 | build_accident_block, build_gate, build_loss, build_nearfuture
7 |
8 |
9 | @PREDICTORS.register_module()
10 | class STAttnsGraphSimAllBatch(nn.Module):
11 |
12 | def __init__(self,
13 | spatial_graph: ConfigDict,
14 | temporal_graph:ConfigDict,
15 | gate: ConfigDict,
16 | temporal_attn: ConfigDict,
17 | near_future: ConfigDict,
18 | accident_block: ConfigDict,
19 | loss: ConfigDict):
20 |
21 | super(STAttnsGraphSimAllBatch, self).__init__()
22 |
23 | self.spatial_graph = build_spatial(spatial_graph)
24 | self.temporal_graph = build_temporal(temporal_graph)
25 | self.mask_gate = build_gate(gate)
26 | self.near_future = build_nearfuture(near_future)
27 | self.temporal_attn = build_attn(temporal_attn)
28 | self.accident_block = build_accident_block(accident_block)
29 | self.loss_func = build_loss(loss)
30 | # self.phi_x = nn.Sequential(
31 | # nn.Linear(2048, 1024),
32 | # nn.LeakyReLU(0.1)
33 | # )
34 |
35 | def forward(self, data, module="Train"):
36 | location = data['location'] # batch x frames x 19 x 6
37 | maskers = data["maskers"] # batch x frames x 19
38 | missing_status = data["missing_status"]
39 | feature = data['feature'][:,:,:,1:]
40 | batch_size = location.shape[0]
41 | h0 = torch.randn(2, batch_size, self.accident_block.hidden_feature).type_as(location)
42 | c0 = torch.randn(2, batch_size, self.accident_block.hidden_feature).type_as(location)
43 | location_past_index = torch.stack([location[:, 0] for x in range(5)], dim=1)
44 | location_past = torch.cat([location_past_index, location], dim=1) # batch_size x (frames+5) x 19 x 6
45 |
46 |
47 | if module == "Train":
48 | graph_index = data["graph_index"]
49 | graph_weight = data["graph_weight"]
50 | train_dict = {
51 | "feature_all":feature,
52 | "graph_index":graph_index,
53 | "graph_weight":graph_weight}
54 | spatial_feature_bank = self.spatial_graph(train_dict, module)
55 |
56 | train_dict_t = {
57 | "spatial_feature": spatial_feature_bank
58 | }
59 | temporal_feature_bank = self.temporal_graph(train_dict_t, module)
60 |
61 | select_feature, label = select_gate_feature(temporal_feature_bank, missing_status)
62 | train_gate_dict = {
63 | "select_feature":select_feature}
64 | pred_status = self.mask_gate(train_gate_dict, module)
65 | pred_label = torch.hstack((pred_status, label.unsqueeze(-1)))
66 | # weighted feature
67 | attn_feature = self.temporal_attn(temporal_feature_bank, module)
68 | # LSTM
69 | pred_accident = self.accident_block(attn_feature, h0, c0).transpose(1,0)
70 |
71 | return pred_accident, pred_label
72 |
73 | elif module == "Test":
74 | spatial_feature_bank = []
75 | temporal_feature_bank = []
76 | temporal_feature_single_bank = []
77 | masks = []
78 | masker_dict = {
79 | 'index':[torch.empty((0,)).type_as(location), torch.empty((0,)).type_as(location)],
80 | 'pred_label':torch.empty((0, 3)).type_as(location),
81 | 'timestamp': [] # add timestamps
82 | }
83 |
84 | for t in range(location.size(1)):# pass throught time
85 |
86 | location_t = location[:, t]
87 | past_location_t = location_past[:, t:(t+5)]
88 | # missing detect and fill the position
89 | mask_t = maskers[:, t]
90 | masker_param_dict={
91 | "all_mask": masks,
92 | "present_mask": mask_t,
93 | "feature_bank": temporal_feature_single_bank,
94 | "time":t,
95 | "location_all":location,
96 | "masker_dict":masker_dict}
97 |
98 | mask_t_filled, masker_dict, fill_position_index = self.mask_gate(masker_param_dict, module)
99 |
100 | masks.append(mask_t_filled)
101 |
102 | location_t_filled, feature_t_filled = self.near_future(location,
103 | feature,
104 | t,
105 | fill_position_index)
106 |
107 | location[:, t].data = location_t_filled
108 | feature[:, t].data = feature_t_filled
109 |
110 | # GCN encoder
111 | test_param_dict = {
112 | "location_t":location_t_filled,
113 | "feature_t":feature_t_filled,
114 | "mask_t":mask_t_filled,
115 | "past_location_t":past_location_t
116 | }
117 |
118 | enc = self.spatial_graph(module="Test", param_dict=test_param_dict)
119 |
120 | # save the spatial_bank
121 | spatial_feature_bank.append(enc)
122 | # temporal graph
123 | test_dict_t = {
124 | "spatial_feature":spatial_feature_bank,
125 | "mask":mask_t_filled}
126 | temporal_feature_t = self.temporal_graph(test_dict_t, module)
127 |
128 | temporal_feature_single_bank.append(temporal_feature_t)
129 | attn_feature_t = self.temporal_attn(temporal_feature_t, module)
130 |
131 | temporal_feature_bank.append(attn_feature_t)
132 |
133 | temporal_feature_bank = torch.stack(temporal_feature_bank, dim=0)
134 |
135 | # LSTM
136 |
137 | pred = self.accident_block(temporal_feature_bank, h0, c0).transpose(1,0)
138 |
139 | return pred, masker_dict['pred_label']
140 |
141 | def loss(self, pred, masker_pred, data):
142 |
143 | accident = data["accident"]
144 | pred_loss = caculate_softmax_e_loss(pred, accident)
145 | mask_gate_loss = caculate_masker_gate_loss(masker_pred)
146 |
147 | return pred_loss, mask_gate_loss
148 |
149 | def predict(self, data, module):
150 | with torch.no_grad():
151 | pred, pred_gate = self(data, module)
152 | pred_np = pred.cpu().numpy()
153 | pred_gate_np = pred_gate.cpu().numpy()
154 |
155 | return pred_np
156 |
157 |
--------------------------------------------------------------------------------
/msp/models/spatialgraphs/__init__.py:
--------------------------------------------------------------------------------
1 | from .spatial_graph_batch import SpatialGraphBatch
2 |
3 |
4 | __all__ = ['SpatialGraphBatch']
--------------------------------------------------------------------------------
/msp/models/spatialgraphs/spatial_graph_batch.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | from torch_geometric.data import Data
3 | from torch_geometric.loader import DataLoader
4 | from torch_geometric.nn import GCNConv
5 | from torch_geometric.data import Batch, batch
6 | import torch
7 | import torch.nn as nn
8 | import numpy as np
9 | from collections import deque
10 | import itertools
11 | import math
12 | import networkx
13 |
14 |
15 | from ..builder import SGMODEL
16 |
17 |
18 | @SGMODEL.register_module()
19 | class SpatialGraphBatch(nn.Module):
20 |
21 | def __init__(self,
22 | node_feature: int,
23 | hidden_feature: int,
24 | out_feature: int,
25 | ):
26 |
27 | super(SpatialGraphBatch, self).__init__()
28 |
29 | self.g_conv1 = GCNConv(node_feature, hidden_feature)
30 | self.g_conv2 = GCNConv(hidden_feature, out_feature)
31 |
32 | def forward(self,
33 | param_dict,
34 | module):
35 |
36 | if module == "Train":
37 | enc = self.gate_training(param_dict["feature_all"],
38 | param_dict["graph_index"],
39 | param_dict["graph_weight"])
40 |
41 | elif module == "Test":
42 | enc = self.module_testing(param_dict["location_t"],
43 | param_dict["feature_t"],
44 | param_dict["mask_t"],
45 | param_dict["past_location_t"])
46 | return enc
47 |
48 | def gate_training(self,
49 | feature_all,
50 | graph_index,
51 | graph_weight):
52 |
53 |
54 | feature_flatten = torch.flatten(feature_all, start_dim=0, end_dim=1)
55 | graph_weight_flatten = torch.flatten(graph_weight, start_dim=0, end_dim=1)
56 | single_graph = graph_index[0]
57 | train_loader = self.generate_gcn_batch(feature_flatten, single_graph, graph_weight_flatten)
58 | for batch in train_loader:
59 | all_spatial_feature = F.sigmoid(self.g_conv1(batch.x, batch.edge_index, batch.edge_attr))
60 | env_feature = F.sigmoid(self.g_conv2(all_spatial_feature, batch.edge_index, batch.edge_attr))
61 | env_feature_re = env_feature.reshape(feature_all.shape[0], feature_all.shape[1], 19, 512)
62 |
63 | return env_feature_re
64 |
65 |
66 | def generate_gcn_batch(self,
67 | x,
68 | edge_index,
69 | edge_weight):
70 | data_list = []
71 |
72 | for i in range(x.shape[0]):
73 |
74 | temp_data = Data(x=x[i], edge_index=edge_index, edge_attr=edge_weight[i])
75 | data_list.append(temp_data)
76 |
77 | loader = DataLoader(data_list, batch_size=x.shape[0], shuffle=False)
78 |
79 | return loader
80 |
81 | def module_testing(self,
82 | location_t,
83 | feature_t,
84 | mask,
85 | past_location_t):
86 |
87 | num_boxes = location_t.shape[1]
88 | graph_edges = self.generate_graph_from_list(range(num_boxes))
89 | graph_weight = self.generate_weight(location_t, graph_edges, mask, past_location_t)
90 |
91 | graph_edges = torch.tensor(graph_edges).type_as(location_t).t().contiguous().type(torch.long)
92 | graph_edges_b = torch.stack([graph_edges for x in range(location_t.shape[0])], dim=0)
93 | graph_weight = graph_weight.type_as(location_t)
94 |
95 | out_batch = []
96 | for i in range(graph_edges_b.size(0)):
97 | hidden_enc_feature = F.sigmoid(self.g_conv1(feature_t[i], graph_edges_b[i], graph_weight[i]))
98 | enc_feature = F.sigmoid(self.g_conv2(hidden_enc_feature, graph_edges_b[i], graph_weight[i]))
99 | out_batch.append(enc_feature)
100 |
101 | out_batch = torch.stack(out_batch)
102 |
103 | return out_batch
104 |
105 |
106 | def generate_weight(self, location_t, graph_edges, mask, past_location_t):
107 |
108 | weights = torch.zeros((location_t.shape[0], len(graph_edges), ), dtype=torch.float32).type_as(location_t)
109 |
110 | for i, edge in enumerate(graph_edges):
111 |
112 | c1 = [0.5 * (location_t[:, edge[0], 1] + location_t[:, edge[0], 3]),
113 | 0.5 * (location_t[:, edge[0], 2] + location_t[:, edge[0], 4])]
114 | c2 = [0.5 * (location_t[:, edge[1], 1] + location_t[:, edge[1], 3]),
115 | 0.5 * (location_t[:, edge[1], 2] + location_t[:, edge[1], 4])]
116 | d = (c1[0] - c2[0])**2 + (c1[1] - c2[1])**2
117 | # md
118 | md = self.get_motion_distance(edge[0], edge[1], past_location_t)
119 | weights[:, i] = torch.exp(-(0.7*d + 0.3*md))
120 | # mask empty position
121 | batch_index = ((mask[:,edge[0]]==True) & (mask[:,edge[1]]==True))
122 | weights[batch_index==False, i] = 0
123 |
124 | weights_nor = (weights - torch.min(weights, dim=1, keepdim=True)[0]) / (torch.max(weights, dim=1, keepdim=True)[0] - torch.min(weights, dim=1, keepdim=True)[0])
125 |
126 | return weights_nor
127 |
128 | def get_motion_distance(self, id_0, id_1, past_location):
129 |
130 | center_point = past_location[:,:,:,3:-1] - past_location[:,:,:,1:3]
131 | center_point_n_x = (center_point[:,:,:,0] * 1280).unsqueeze(dim=-1)
132 | center_point_n_y = (center_point[:,:,:,1] * 720).unsqueeze(dim=-1)
133 | center_point_n = torch.cat((center_point_n_x, center_point_n_y), dim=-1)
134 |
135 | delta_0 = (center_point_n[:, -1, id_0, 0] - center_point_n[:, 0, id_0, 0],
136 | center_point_n[:, -1, id_0, 1] - center_point_n[:, 0, id_0, 1])
137 | delta_1 = (center_point_n[:, -1, id_1, 0] - center_point_n[:, 0, id_1, 0],
138 | center_point_n[:, -1, id_1, 1] - center_point_n[:, 0, id_1, 1])
139 |
140 | length_0 = (delta_0[0]**2 + delta_0[1]**2)**0.5
141 | length_1 = (delta_1[0]**2 + delta_1[1]**2)**0.5
142 |
143 | with torch.no_grad():
144 | length_0[length_0==0.] = 1.0
145 | length_1[length_1==0.] = 1.0
146 |
147 | nor_delta_x_0 = delta_0[0] / length_0
148 | nor_delta_y_0 = delta_0[1] / length_0
149 |
150 | nor_delta_x_1 = delta_1[0] / length_1
151 | nor_delta_y_1 = delta_1[1] / length_1
152 |
153 | d1 = ((center_point_n[:, 0, id_0, 0] + nor_delta_x_0 - center_point_n[:, 0, id_1, 0] - nor_delta_x_1)**2 + \
154 | (center_point_n[:, 0, id_0, 1] + nor_delta_y_0 - center_point_n[:, 0, id_1, 1] - nor_delta_y_1)**2)**0.5
155 |
156 | d2 = ((center_point_n[:, 0, id_0, 0] - center_point_n[:, 0, id_1, 0])**2 + \
157 | (center_point_n[:, 0, id_0, 1] - center_point_n[:, 0, id_1, 1])**2)**0.5
158 |
159 | # motion distance
160 | md = d1 - d2
161 |
162 | return md
163 |
164 |
165 | def generate_graph_from_list(self, L, create_using=None):
166 | G = networkx.empty_graph(len(L), create_using)
167 | if len(L) > 1:
168 | if G.is_directed():
169 | edges = itertools.permutations(L,2)
170 | else:
171 | edges = itertools.combinations(L,2)
172 |
173 | G.add_edges_from(edges)
174 |
175 | graph_edges = list(G.edges())
176 |
177 | return graph_edges
178 |
--------------------------------------------------------------------------------
/msp/models/temporalattns/__init__.py:
--------------------------------------------------------------------------------
1 | from .s_t_attn import STAttn
2 |
3 |
4 | __all__ = ['STAttn']
--------------------------------------------------------------------------------
/msp/models/temporalattns/s_t_attn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | from ..builder import ATTNMODEL
7 |
8 |
9 | @ATTNMODEL.register_module()
10 | class STAttn(nn.Module):
11 |
12 | def __init__(self):
13 |
14 | super(STAttn, self).__init__()
15 |
16 | self.ue = nn.Linear(512, 64)
17 | self.be = nn.Parameter(torch.zeros(64))
18 | self.w = nn.Linear(64, 1)
19 |
20 | self.fc1 = nn.Linear(512, 256)
21 |
22 | self.softmax = nn.Softmax(dim=-2)
23 |
24 | def forward(self,
25 | inputs,
26 | module):
27 |
28 | if module=="Train":
29 | fc_attr = self.attn_train(inputs)
30 | elif module=="Test":
31 | fc_attr = self.attn_test(inputs)
32 |
33 | return fc_attr
34 |
35 | def attn_train(self, inputs):
36 |
37 | inputs_flatten = torch.flatten(inputs, start_dim=0, end_dim=1)
38 | # weight the feature
39 | e_j = self.w(F.leaky_relu((self.ue(inputs_flatten) + self.be), negative_slope=0.2))
40 | a_j = self.softmax(e_j)
41 | attr = torch.mul(a_j, inputs_flatten).sum(dim=1)
42 | fc_attr = self.fc1(attr)
43 | fc_attr_re = fc_attr.reshape(inputs.shape[0], inputs.shape[1], 256)
44 |
45 | fc_attr_re_lstm = fc_attr_re.permute(1,0,2)
46 |
47 | return fc_attr_re_lstm
48 |
49 | def attn_test(self, inputs):
50 |
51 | # mean the spatial_feature
52 | e_j = self.w(F.leaky_relu((self.ue(inputs) + self.be), negative_slope=0.2))
53 | a_j = self.softmax(e_j)
54 | attr = torch.mul(a_j, inputs).sum(dim=1)
55 | # counting down the dementions
56 | fc_attr = self.fc1(attr)
57 |
58 | return fc_attr
59 |
60 | def _reset_parameters(self):
61 | for p in self.parameters():
62 | if p.dim() > 1:
63 | nn.init.xavier_uniform_(p)
64 |
--------------------------------------------------------------------------------
/msp/models/temporalgraphs/__init__.py:
--------------------------------------------------------------------------------
1 | from .temporal_graph_batch import TemporalGraphBatch
2 |
3 | __all__ = ['TemporalGraphBatch']
--------------------------------------------------------------------------------
/msp/models/temporalgraphs/temporal_graph_batch.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.nn import GCNConv
2 | import torch
3 | import torch.nn as nn
4 | import math
5 |
6 |
7 | from ..builder import TGMODEL
8 |
9 |
10 | @TGMODEL.register_module()
11 | class TemporalGraphBatch(nn.Module):
12 | def __init__(self,
13 | past_num: int,
14 | input_feature: int,
15 | out_feature: int):
16 |
17 | super(TemporalGraphBatch, self).__init__()
18 |
19 | self.past_num = past_num
20 | self.out_feature = out_feature
21 | self.in_feature = input_feature
22 |
23 | edge_index, edge_weight = self.generate_graph(self.past_num)
24 |
25 | self.register_buffer("t_edge_index", edge_index)
26 | self.register_buffer("t_edge_weight", edge_weight)
27 |
28 | self.conv1 = GCNConv(input_feature, out_feature)
29 |
30 | def forward(self,
31 | param_dict,
32 | module):
33 | if module == "Train":
34 | enc = self.ts_train(param_dict["spatial_feature"])
35 | elif module == "Test":
36 | enc = self.ts_test(param_dict["spatial_feature"],
37 | param_dict["mask"])
38 |
39 | return enc
40 |
41 | def ts_train(self,
42 | spatial_feature_bank
43 | ):
44 | batch_size = spatial_feature_bank.shape[0]
45 | filled = spatial_feature_bank[:,0,:,:]
46 | pred_filled = torch.stack([filled for i in range(self.past_num-1)], dim=1)
47 | s_bank_fill = torch.cat((spatial_feature_bank, pred_filled), dim=1)
48 | s_bank_per = s_bank_fill.permute(0,2,1,3)
49 | s_bank_re = torch.flatten(s_bank_per, start_dim=0, end_dim=1)
50 | througth_time = torch.empty(size=[batch_size * 19, 0, self.past_num, self.in_feature]).type_as(spatial_feature_bank)
51 |
52 | for time_step in range(self.past_num-1,100+self.past_num-1):
53 | singe_step = s_bank_re[:,time_step-self.past_num+1:time_step+1,:].unsqueeze(dim=1)
54 | througth_time = torch.cat((througth_time, singe_step), dim=1)
55 |
56 | t_t_re = torch.flatten(througth_time, start_dim=0, end_dim=1)
57 | enc = self.conv1(t_t_re, self.t_edge_index, self.t_edge_weight)
58 | shape_1 = enc.reshape(batch_size*19, 100, self.past_num, self.out_feature)[:,:,-1,:]
59 | shape_2 = shape_1.reshape(batch_size, 19, 100, self.out_feature)
60 | final = shape_2.permute(0,2,1,3)
61 |
62 | return final
63 |
64 | def ts_test(self,
65 | spatial_feature_bank,
66 | mask):
67 |
68 | spatial_feature_bank_tensor = torch.stack(spatial_feature_bank, dim=0).permute(1,0,2,3)
69 | temporal_feature_t = torch.zeros((mask.shape[0], 19, 512,), dtype=torch.float32).type_as(spatial_feature_bank_tensor)
70 |
71 | if spatial_feature_bank_tensor.shape[1] < self.past_num:
72 | missing_num = self.past_num - spatial_feature_bank_tensor.shape[1]
73 | unpack = torch.stack([spatial_feature_bank_tensor[:,0,:,:] for x in range(missing_num)], dim=1)
74 | spatial_feature_bank_pack = torch.cat([unpack, spatial_feature_bank_tensor], dim=1)
75 | else:
76 | spatial_feature_bank_pack = spatial_feature_bank_tensor[:,-5:]
77 |
78 | sf_pack = spatial_feature_bank_pack.permute(0,2,1,3)
79 |
80 | sf_pack_f = torch.flatten(sf_pack, start_dim=0, end_dim=1)
81 | t_f_t_all = self.conv1(sf_pack_f, self.t_edge_index, self.t_edge_weight)
82 | t_f_t_all_re = t_f_t_all.reshape(mask.shape[0],19,5,512).permute(0,2,1,3)[:,-1,:,:]
83 | temporal_feature_t[torch.where(mask==True)] = t_f_t_all_re[torch.where(mask==True)]
84 |
85 | return temporal_feature_t
86 |
87 | def generate_graph(self,
88 | past_num: int):
89 |
90 | edge_index = []
91 | edge_weight = []
92 |
93 | for i in range(past_num - 1):
94 | edge_index.append([i, past_num - 1])
95 | edge_weight.append(math.exp(-(past_num - 1 -i)))
96 | return torch.tensor(edge_index, dtype=torch.long).t().contiguous(), torch.tensor(edge_weight, dtype=torch.float32)
--------------------------------------------------------------------------------
/msp/models/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import get_non_zero, generate_label_matrix, caculate_loss, caculate_softmax_e_loss, caculate_masker_gate_loss, select_gate_feature
2 |
3 | __all__ = ['get_non_zero', 'generate_label_matrix', 'caculate_loss', 'caculate_softmax_e_loss', 'caculate_masker_gate_loss','select_gate_feature']
--------------------------------------------------------------------------------
/msp/models/utils/utils.py:
--------------------------------------------------------------------------------
1 | from networkx.algorithms.shortest_paths.weighted import negative_edge_cycle
2 | from networkx.utils import decorators
3 | from networkx.utils.misc import flatten
4 | import numpy as np
5 | import torch
6 | import torch.nn
7 | import math
8 | import torch.nn.functional as F
9 |
10 | def get_non_zero(id_with_zero):
11 | mask = (id_with_zero != 0)
12 | id_without_zero = id_with_zero[mask]
13 |
14 | return id_without_zero
15 |
16 | def generate_label_matrix(accident):
17 | label_matrix = torch.zeros((accident.shape[0], 40), device=torch.cuda.current_device())
18 | pos = torch.tensor([-math.exp(-max(0, (36-t) / 25)) for t in range(40)], device=torch.cuda.current_device())
19 | neg = torch.tensor([-1. for t in range(40)], device=torch.cuda.current_device())
20 | label_matrix[accident==True] = pos
21 | label_matrix[accident==False] = neg
22 |
23 | return label_matrix
24 |
25 | def caculate_loss(pred, labels):
26 |
27 | weight_loss = torch.mul(torch.log(pred), labels).sum(dim=1).mean()
28 |
29 | return weight_loss
30 |
31 | def caculate_softmax_e_loss(pred, accident):
32 |
33 | crterion = torch.nn.CrossEntropyLoss(reduction='none')
34 | #positive sample
35 | frames = pred.shape[1]
36 | pos = pred[accident==True]
37 | pos_target = torch.tensor([1. for x in range(frames)], dtype=torch.long, device=torch.cuda.current_device())
38 | pos_penalty = torch.tensor([math.exp(-max(0, (76-t)/ 25)) for t in range(frames)], device=torch.cuda.current_device(), dtype=torch.long)
39 | all_positive_loss = []
40 | for batch in range(pos.shape[0]):
41 | loss = -crterion(pos[batch], pos_target)
42 | positive_loss = -torch.mul(loss, pos_penalty).sum()
43 | all_positive_loss.append(positive_loss)
44 | if len(all_positive_loss)!=0:
45 | all_positive_loss_t = torch.stack(all_positive_loss).mean()
46 | else:
47 | all_positive_loss_t = torch.tensor(0., dtype=torch.float32).type_as(pred)
48 |
49 | #neg
50 | neg = pred[accident==False]
51 | neg_target = torch.tensor([0. for x in range(frames)], dtype=torch.long, device=torch.cuda.current_device())
52 | all_neg_loss = []
53 |
54 | for batch in range(neg.shape[0]):
55 | loss_neg = crterion(neg[batch], neg_target).sum()
56 | all_neg_loss.append(loss_neg)
57 | if len(all_neg_loss) != 0:
58 | all_neg_loss_t = torch.stack(all_neg_loss).mean()
59 | else:
60 | all_neg_loss_t = torch.tensor(0., dtype=torch.float32).type_as(pred)
61 |
62 | pred_loss = all_positive_loss_t + 0.2 * all_neg_loss_t
63 | return pred_loss
64 |
65 | def caculate_masker_gate_loss(masker_pred):
66 |
67 | pred_results = masker_pred[:, :2]
68 | label = masker_pred[:, -1]
69 |
70 | crterion = torch.nn.CrossEntropyLoss(reduction='none')
71 | go_away_pred = pred_results
72 | go_away_label = label
73 | go_away_loss = crterion(go_away_pred, 1 - go_away_label.long()).sum()
74 |
75 | shelter_pred = pred_results
76 | shelter_label = label
77 | shelter_loss = crterion(shelter_pred, shelter_label.long()).sum()
78 |
79 | loss_masker = 0.02 * go_away_loss + 0.02 * shelter_loss
80 |
81 | return loss_masker
82 |
83 |
84 | def select_gate_feature(feature_bank, missing_status):
85 |
86 | go_away_index = torch.where(missing_status==1.0)
87 | go_away_feature = feature_bank[go_away_index]
88 | go_away_label = torch.zeros((go_away_index[0].shape[0], )).type_as(feature_bank)
89 | # 处理遮挡状态的物体,给1的标签
90 | shelter_index = torch.where(missing_status==2.0)
91 | shelter_feature = feature_bank[shelter_index]
92 | shelter_label = torch.ones((shelter_index[0].shape[0], )).type_as(feature_bank)
93 | # 合并特征和标签
94 | missing_feature = torch.cat((go_away_feature, shelter_feature), dim=0)
95 | labels = torch.cat((go_away_label, shelter_label), dim=0)
96 |
97 | return missing_feature, labels
--------------------------------------------------------------------------------
/msp/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .train_utils import setup_seed, partial_state_dict, summarize_metric
2 |
3 |
4 | __all__ = ['setup_seed', 'partial_state_dict', 'summarize_metric']
5 |
--------------------------------------------------------------------------------
/msp/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import torch
4 | import numpy as np
5 |
6 |
7 | def setup_seed(seed: int):
8 | torch.manual_seed(seed)
9 | torch.cuda.manual_seed(seed)
10 | torch.cuda.manual_seed_all(seed)
11 | np.random.seed(seed)
12 | random.seed(seed)
13 | torch.backends.cudnn.deterministic = True
14 |
15 |
16 | def partial_state_dict(model: torch.nn.Module, ckpt_path: str):
17 | pre_ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['state_dict']
18 | model_ckpt = model.state_dict()
19 | pre_ckpt = {k: v for k, v in pre_ckpt.items() if k in model_ckpt}
20 | model_ckpt.update(pre_ckpt)
21 |
22 | return model_ckpt
23 |
24 | def summarize_metric(output):
25 | # the output is list
26 | average_ap = []
27 | average_mtta = []
28 | average_tta_r80 = []
29 |
30 | for item in range(len(output)):
31 | average_ap.append(output[item]['ap'])
32 | average_mtta.append(output[item]['mtta'])
33 | average_tta_r80.append(output[item]['tta_r80'])
34 |
35 | return np.array(average_ap).mean(), np.array(average_mtta).mean(), np.array(average_tta_r80).mean()
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.14.1
2 | addict==2.4.0
3 | cachetools==4.2.4
4 | certifi==2020.12.5
5 | charset-normalizer==2.0.7
6 | cycler==0.10.0
7 | decorator==4.4.2
8 | future==0.18.2
9 | google-auth==1.35.0
10 | google-auth-oauthlib==0.4.6
11 | googledrivedownloader==0.4
12 | grpcio==1.41.0
13 | idna==3.3
14 | importlib-metadata==4.8.1
15 | isodate==0.6.0
16 | Jinja2==3.0.2
17 | joblib==1.1.0
18 | kiwisolver==1.3.1
19 | Markdown==3.3.4
20 | MarkupSafe==2.0.1
21 | matplotlib==3.3.4
22 | mmcv==1.3.0
23 | networkx==2.5.1
24 | numpy
25 | oauthlib==3.1.1
26 | olefile==0.46
27 | opencv-python==4.5.3.56
28 | packaging==21.0
29 | pandas==1.1.5
30 | Pillow
31 | protobuf==3.18.1
32 | pyasn1==0.4.8
33 | pyasn1-modules==0.2.8
34 | pyparsing==2.4.7
35 | python-dateutil==2.8.2
36 | pytorch-lightning==0.9.0
37 | pytz==2021.3
38 | PyYAML==6.0
39 | rdflib==5.0.0
40 | requests==2.26.0
41 | requests-oauthlib==1.3.0
42 | rsa==4.7.2
43 | scikit-learn==0.24.2
44 | scipy==1.5.4
45 | six==1.16.0
46 | tensorboard==2.2.0
47 | tensorboard-plugin-wit==1.8.0
48 | threadpoolctl==3.0.0
49 | torch==1.7.0
50 | torch-cluster==1.5.9
51 | torch-geometric==2.0.1
52 | torch-scatter==2.0.7
53 | torch-sparse==0.6.9
54 | torchaudio==0.7.0a0+ac17b64
55 | torchvision==0.8.0
56 | tqdm==4.62.3
57 | typing==3.7.4.3
58 | typing-extensions
59 | urllib3==1.26.7
60 | Werkzeug==2.0.2
61 | yacs==0.1.8
62 | yapf==0.31.0
63 | zipp==3.6.0
64 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 |
5 | from mmcv import Config
6 | import pytorch_lightning as pl
7 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateLogger
8 | from pytorch_lightning.loggers import TensorBoardLogger
9 | from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler
10 |
11 | from msp import LightningModel
12 | from msp.utils import setup_seed, partial_state_dict
13 |
14 |
15 | def parse_args():
16 | parser = argparse.ArgumentParser(description="Train or test a detector.")
17 | parser.add_argument("config", help="Train config file path.")
18 | parser.add_argument("--train", action="store_true")
19 | parser.add_argument("--test", action="store_true")
20 |
21 | args = parser.parse_args()
22 | return args
23 |
24 |
25 | def main():
26 | args = parse_args()
27 | cfg = Config.fromfile(args.config)
28 | setup_seed(cfg.random_seed)
29 |
30 | model = LightningModel(cfg)
31 |
32 | checkpoint_callback = ModelCheckpoint(
33 | filepath=f"{cfg.checkpoint_path}/{cfg.name}/{cfg.version}/"
34 | f"{cfg.name}_{cfg.version}_{{epoch}}_{{avg_val_loss:.3f}}_{{ade:.3f}}_{{fde:.3f}}_{{fiou:.3f}}",
35 | save_last=None,
36 | save_top_k=-1,
37 | verbose=True,
38 | monitor='fiou',
39 | mode='max',
40 | prefix=''
41 | )
42 |
43 | lr_logger_callback = LearningRateLogger(logging_interval='step')
44 | profiler = SimpleProfiler() if cfg.simple_profiler else AdvancedProfiler()
45 | logger = TensorBoardLogger(save_dir=cfg.log_path, name=cfg.name, version=cfg.version)
46 |
47 | trainer = pl.Trainer(
48 | gpus=cfg.num_gpus,
49 | #distributed_backend='dp',
50 | max_epochs=cfg.max_epochs,
51 | logger=logger,
52 | profiler=profiler,
53 | callbacks=[lr_logger_callback],
54 | #gradient_clip_val=cfg.gradient_clip_val,\
55 | checkpoint_callback=checkpoint_callback,
56 | check_val_every_n_epoch=10,
57 | resume_from_checkpoint=cfg.resume_from_checkpoint,
58 | accumulate_grad_batches=cfg.batch_accumulate_size) # 由于每个batch内只有一个样本,所以采用这样的处理方法
59 |
60 | if (not (args.train or args.test)) or args.train:
61 |
62 | shutil.copy(args.config, os.path.join(cfg.log_path, cfg.name, cfg.version, args.config.split('/')[-1]))
63 |
64 | if cfg.load_from_checkpoint is not None:
65 | model_ckpt = partial_state_dict(model, cfg.load_from_checkpoint)
66 | model.load_state_dict(model_ckpt)
67 |
68 | trainer.fit(model)
69 |
70 | if args.test:
71 | if cfg.test_checkpoint is not None:
72 | model_ckpt = partial_state_dict(model, cfg.test_checkpoint)
73 | model.load_state_dict(model_ckpt)
74 |
75 | trainer.test(model)
76 |
77 |
78 | if __name__ == '__main__':
79 | main()
80 |
--------------------------------------------------------------------------------
/scripts/preprecess_data_MASKER_MD.py:
--------------------------------------------------------------------------------
1 | from networkx.algorithms import mis
2 | from networkx.algorithms.distance_measures import center
3 | import numpy as np
4 | import os
5 | import random
6 | import pandas as pd
7 | import networkx
8 | import itertools
9 | random.seed(1234)
10 |
11 | data_root = '/media/group1/data/tianhang/sorted_samples_pn_2048/'
12 | train_root = '/media/group1/data/tianhang/MASKER_MD/'
13 | test_root = '/media/group1/data/tianhang/MASKER_MD/'
14 |
15 | global miss_wait
16 | miss_wait = 10
17 |
18 | def proprecess_data(data_root, train_root, test_root):
19 |
20 | if not os.path.exists(train_root):
21 | os.makedirs(train_root)
22 | if not os.path.exists(test_root):
23 | os.makedirs(test_root)
24 |
25 | all_file_list = os.listdir(data_root)
26 | random.shuffle(all_file_list)
27 |
28 | train_number = int(len(all_file_list) * 0.7)
29 | train_file_list = all_file_list[:train_number]
30 | val_file_list = all_file_list[train_number:]
31 | generate_corresponding_txt(train_file_list, train_root, module="train")
32 | generate_corresponding_txt(val_file_list, test_root, module="valid")
33 |
34 | print("Start to process training set!")
35 | graph_edges = generate_graph_from_list(range(19))
36 |
37 | for number, i in enumerate(train_file_list):
38 |
39 | filled_before = np.load(os.path.join(data_root, i), allow_pickle=True).item()
40 | frames = filled_before["location"].shape[0]
41 |
42 | maskers = []
43 | mask_dict = np.zeros((frames, 19), dtype=np.float32)
44 |
45 | filled_before_location = filled_before['location'] # frames x 19 x 6
46 | filled_before_feature = filled_before['feature'] # frames x 19 x 2048
47 | for t in range(frames):
48 | #print("time", t)
49 | #print("location_present",filled_before_location[t,:,1:5])
50 | mask_t = filled_before_location[t,:,1:5].sum(-1) != 0.
51 | #print("mask_t", mask_t)
52 | maskers, mask_dict, filled_after_location, filled_after_feature = fill_during_training(mask_t,
53 | maskers,
54 | mask_dict,
55 | t,
56 | location_all=filled_before_location,
57 | feature_all=filled_before_feature)
58 |
59 | maskers_np = np.array(maskers)
60 | #print("maskers_np", maskers_np)
61 | graph_weight = generate_weight(filled_after_location, graph_edges, maskers_np)
62 | filled_before["graph_index"] = graph_edges
63 | filled_before["graph_weight"] = graph_weight
64 | filled_before["maskers"] = maskers_np
65 | filled_before["missing_status"] = mask_dict
66 | filled_before["location"] = filled_after_location
67 | filled_before["feature"] = filled_after_feature
68 |
69 | np.save(os.path.join(train_root, i), filled_before)
70 | print("The Training process has been {:.2f} % done!".format((number + 1) / len(train_file_list) * 100))
71 |
72 | print("#" * 30)
73 | print("Start to process Testing set!")
74 | for val_number, j in enumerate(val_file_list):
75 |
76 | val_infor = np.load(os.path.join(data_root, j), allow_pickle=True).item()
77 | frames = val_infor["location"].shape[0]
78 | val_maskers = []
79 | val_mask_dict = np.zeros((frames, 19), dtype=np.float32)
80 | val_location = val_infor["location"] # frames x 19 x 6
81 | for t in range(frames):
82 |
83 | val_mask_t = val_location[t,:,1:5].sum(-1) != 0.
84 | #val_maskers.append(val_mask_t)
85 | val_maskers, val_mask_dict = spect_during_testing(val_mask_t, val_maskers, val_mask_dict, t, val_location)
86 |
87 | val_maskers_np = np.array(val_maskers)
88 | val_infor["maskers"] = val_maskers_np
89 | val_infor["missing_status"] = val_mask_dict
90 | np.save(os.path.join(test_root, j), val_infor)
91 | print("The Testing process has been {:.2f} % done!".format((val_number + 1) / len(val_file_list) * 100))
92 |
93 | def generate_weight(location, graph_edges, mask):
94 | # location: frames x 19 x 16
95 | # graph_weight frames x num_edge
96 |
97 | weights = np.zeros((location.shape[0], len(graph_edges), ), dtype=np.float32)
98 | location_n = location.copy()
99 | location_n[:, :, [1,3]] /= 1280
100 | location_n[:, :, [2,4]] /= 720
101 | for i, edge in enumerate(graph_edges):
102 |
103 | c1 = [0.5 * (location_n[:, edge[0], 1] + location_n[:, edge[0], 3]),
104 | 0.5 * (location_n[:, edge[0], 2] + location_n[:, edge[0], 4])]
105 | c2 = [0.5 * (location_n[:, edge[1], 1] + location_n[:, edge[1], 3]),
106 | 0.5 * (location_n[:, edge[1], 2] + location_n[:, edge[1], 4])]
107 |
108 | d = (c1[0] - c2[0])**2 + (c1[1] - c2[1])**2
109 |
110 | md = get_motion_distance(edge[0], edge[1], location)
111 |
112 | weights[:, i] = np.exp(-(0.7 * d + 0.3 * md))
113 | frame_index = ((mask[:, edge[0]]==True) & (mask[:, edge[1]] == True))
114 | weights[frame_index==False, i] = 0
115 |
116 | # normalize
117 | weights_nor = (weights - np.min(weights, axis=1, keepdims=True)) / (np.max(weights, axis=1, keepdims=True) - np.min(weights, axis=1, keepdims=True))
118 | # account for NaN
119 | weights_nor = np.nan_to_num(weights_nor)
120 |
121 | return weights_nor
122 |
123 | def get_motion_distance(id_0, id_1, location_all):
124 | md = []
125 | center_point = location_all[:,:,3:-1] - location_all[:,:,1:3]
126 | for t in range(100):
127 | if t <= 4:
128 | d = get_normolize_distance(id_0, id_1, 0, t, center_point)
129 | else:
130 | d = get_normolize_distance(id_0, id_1, t-4, t, center_point)
131 | md.append(d)
132 |
133 | md = np.array(md)
134 | return md
135 |
136 | def get_normolize_distance(id_0, id_1, t1, t2, center_point):
137 |
138 | v_id_0 = normalize(((center_point[t1, id_0, 0], center_point[t1, id_0, 1]),(center_point[t2, id_0, 0], center_point[t2, id_0, 1])))
139 | v_id_1 = normalize(((center_point[t1, id_1, 0], center_point[t1, id_1, 1]),(center_point[t2, id_1, 0], center_point[t2, id_1, 1])))
140 |
141 | # Eud_distance
142 | d_1 = ((v_id_0[1][0] - v_id_1[1][0])**2 + (v_id_0[1][1] - v_id_1[1][1])**2)**0.5 #箭头距离
143 | d_2 = ((v_id_0[0][0] - v_id_1[0][0])**2 + (v_id_0[0][1] - v_id_1[0][1])**2)**0.5 #箭尾距离
144 |
145 | md = d_1 - d_2
146 |
147 |
148 | return md
149 |
150 | def normalize(v):
151 | delta_x = v[1][0] - v[0][0]
152 | delta_y = v[1][1] - v[0][1]
153 | length = (delta_x**2 + delta_y**2)**0.5
154 | if length != 0:
155 | nor_delta_x = delta_x/length
156 | nor_delta_y = delta_y/length
157 | return ((v[0][0], v[0][1]),(v[0][0]+nor_delta_x, v[0][1]+nor_delta_y))
158 | else:
159 | return v
160 |
161 | def generate_graph_from_list(L, create_using=None):
162 | G = networkx.empty_graph(len(L), create_using)
163 | if len(L) > 1:
164 | if G.is_directed():
165 | edges = itertools.permutations(L,2)
166 | else:
167 | edges = itertools.combinations(L,2)
168 | G.add_edges_from(edges)
169 | graph_edges = list(G.edges())
170 |
171 | return graph_edges
172 |
173 |
174 | def spect_during_testing(mask_present,
175 | all_mask,
176 | mask_dict,
177 | t,
178 | location_all):
179 |
180 | if len(all_mask) == 0:
181 | all_mask.append(mask_present)
182 |
183 | return all_mask, mask_dict
184 |
185 | else:
186 | before_mask = all_mask[-1]
187 | # compare from True to False
188 | index = np.where((before_mask == True) & (mask_present ==False))
189 | if min(index[0].shape) != 0:
190 | for missing_object in index[0]:
191 | #print("missing_object", missing_object)
192 | near_future_infor = location_all[(t+1):(t+miss_wait), missing_object, 1:5]
193 | #print("near_future_infor", near_future_infor)
194 | near_future_status = np.where((near_future_infor.sum(-1) != 0) == True)
195 | if min(near_future_status[0].shape) != 0:
196 | mask_dict[t-1, missing_object] = 2.
197 | else:
198 | mask_dict[t-1, missing_object] = 1.
199 |
200 | all_mask.append(mask_present)
201 | return all_mask, mask_dict
202 |
203 | else:
204 | all_mask.append(mask_present)
205 | return all_mask, mask_dict
206 |
207 | def fill_during_training(mask_present,
208 | all_mask,
209 | mask_dict,
210 | t,
211 | location_all,
212 | feature_all
213 | ):
214 | # params for waiting time
215 |
216 | if len(all_mask) == 0:
217 | all_mask.append(mask_present)
218 |
219 | return all_mask, mask_dict, location_all, feature_all
220 | else:
221 | # mask_before
222 | before_mask= all_mask[-1]
223 | # compare from True to False
224 | index = np.where((before_mask==True) & (mask_present==False))
225 | if min(index[0].shape) != 0:
226 | #print("missing index", index)
227 | for missing_object in index[0]:
228 | #print("Missing objects", missing_object)
229 | near_future_infor = location_all[(t+1):(t+miss_wait), missing_object,1:5]
230 | near_future_status = np.where((near_future_infor.sum(-1) != 0) == True)
231 | if min(near_future_status[0].shape) != 0:
232 | mask_dict[t-1, missing_object] = 2.
233 | nearest_future = near_future_status[0][0]
234 | location_all[t:(t+nearest_future+1),missing_object, 1:5] = fill_functions(
235 | location_all[(t-5 if t>5 else 0):t, missing_object,1:5],
236 | location_all[t+nearest_future+1, missing_object, 1:5],
237 | nearest_future+1)
238 | feature_all[t:(t+nearest_future+1), missing_object, 1:] = feature_all[t-1,missing_object, 1:]
239 | else:
240 | mask_dict[t-1, missing_object] = 1.
241 |
242 | mask_present = location_all[t,:,1:5].sum(-1) != 0.
243 |
244 | all_mask.append(mask_present)
245 |
246 | return all_mask, mask_dict, location_all, feature_all
247 |
248 | else:
249 |
250 | all_mask.append(mask_present)
251 | return all_mask, mask_dict, location_all, feature_all
252 |
253 | def fill_functions(past_movement, near_future, fill_number):
254 | fill_matrix = np.full((fill_number, 4), np.nan)
255 | interplot = np.vstack((past_movement, fill_matrix, near_future))
256 | df_interplot = pd.DataFrame(interplot)
257 | s = df_interplot.interpolate()
258 | interplot = np.around(np.array(s))
259 | fill_matrix = interplot[(-(fill_number+1)):-1]
260 |
261 | return fill_matrix
262 |
263 | def generate_corresponding_txt(file_list, root, module="train"):
264 | """
265 | all valid module is train or valid
266 | """
267 | str = '\n'
268 | f = open(os.path.join(root, module + "_video_list.txt"), "w")
269 | f.write(str.join(file_list))
270 | f.close()
271 |
272 | proprecess_data(data_root, train_root, test_root)
--------------------------------------------------------------------------------
/scripts/preprecess_data_MD-SG.py:
--------------------------------------------------------------------------------
1 | from networkx.algorithms import mis
2 | import numpy as np
3 | import os
4 | import random
5 | import pandas as pd
6 | import networkx
7 | import itertools
8 | random.seed(1234)
9 |
10 |
11 | data_root = '/media/group1/data/tianhang/sorted_samples_pn_2048/'
12 | train_root = '/media/group1/data/tianhang/MD/'
13 | test_root = '/media/group1/data/tianhang/MD/'
14 |
15 |
16 | def proprecess_data(data_root, train_root, test_root):
17 | if not os.path.exists(train_root):
18 | os.makedirs(train_root)
19 | if not os.path.exists(test_root):
20 | os.makedirs(test_root)
21 |
22 | all_file_list = os.listdir(data_root)
23 | random.shuffle(all_file_list)
24 | train_number = int(len(all_file_list) * 0.7)
25 | train_file_list = all_file_list[:train_number]
26 | val_file_list = all_file_list[train_number:]
27 | generate_corresponding_txt(train_file_list, train_root, module="train")
28 | generate_corresponding_txt(val_file_list, test_root, module="valid")
29 | print("Start to process training set!")
30 | graph_edges = generate_graph_from_list(range(19))
31 |
32 | for number, i in enumerate(train_file_list):
33 | filled_before = np.load(os.path.join(data_root, i), allow_pickle=True).item()
34 | frames = filled_before["location"].shape[0]
35 | maskers = []
36 |
37 | filled_before_location = filled_before['location']
38 | for t in range(frames):
39 | mask_t = filled_before_location[t,:,1:5].sum(-1) != 0.
40 | maskers.append(mask_t)
41 | maskers_np = np.array(maskers)
42 | graph_weight = generate_weight(filled_before_location, graph_edges, maskers_np)
43 | filled_before["graph_index"] = graph_edges
44 | filled_before["graph_weight"] = graph_weight
45 | filled_before["maskers"] = maskers_np
46 | np.save(os.path.join(train_root, i), filled_before)
47 | print("The Training process has been {:.2f} % done!".format((number + 1) / len(train_file_list) * 100))
48 |
49 | print("#" * 30)
50 | print("Start to process Testing set!")
51 | for val_number, j in enumerate(val_file_list):
52 |
53 | val_infor = np.load(os.path.join(data_root, j), allow_pickle=True).item()
54 | frames = val_infor["location"].shape[0]
55 | val_maskers = []
56 | val_location = val_infor["location"]
57 | for t in range(frames):
58 | val_mask_t = val_location[t,:,1:5].sum(-1) != 0.
59 | val_maskers.append(val_mask_t)
60 |
61 | val_maskers_np = np.array(val_maskers)
62 | val_infor["maskers"] = val_maskers_np
63 | np.save(os.path.join(test_root, j), val_infor)
64 | print("The Testing process has been {:.2f} % done!".format((val_number + 1) / len(val_file_list) * 100))
65 |
66 | def generate_weight(location, graph_edges, mask):
67 | # location: frames x 19 x 16
68 | # graph_weight frames x num_edge
69 |
70 | weights = np.zeros((location.shape[0], len(graph_edges), ), dtype=np.float32)
71 | # normolize
72 | location_n = location.copy()
73 | location_n[:, :, [1,3]] /= 1280
74 | location_n[:, :, [2,4]] /= 720
75 |
76 | for i, edge in enumerate(graph_edges):
77 |
78 | c1 = [0.5 * (location_n[:, edge[0], 1] + location_n[:, edge[0], 3]),
79 | 0.5 * (location_n[:, edge[0], 2] + location_n[:, edge[0], 4])]
80 | c2 = [0.5 * (location_n[:, edge[1], 1] + location_n[:, edge[1], 3]),
81 | 0.5 * (location_n[:, edge[1], 2] + location_n[:, edge[1], 4])]
82 |
83 | d = (c1[0] - c2[0])**2 + (c1[1] - c2[1])**2
84 |
85 | md = get_motion_distance(edge[0], edge[1], location)
86 |
87 | weights[:, i] = np.exp(-(0.7*d + 0.3*md))
88 | frame_index = ((mask[:, edge[0]]==True) & (mask[:, edge[1]] == True))
89 | weights[frame_index==False, i] = 0
90 |
91 | # normalize
92 | weights_nor = (weights - np.min(weights, axis=1, keepdims=True)) / (np.max(weights, axis=1, keepdims=True) - np.min(weights, axis=1, keepdims=True))
93 | # account for NaN
94 | weights_nor = np.nan_to_num(weights_nor)
95 |
96 | return weights_nor
97 |
98 | def get_motion_distance(id_0, id_1, location_all):
99 | md = []
100 | center_point = location_all[:,:,3:-1] - location_all[:,:,1:3]
101 | for t in range(100):
102 | if t <= 4:
103 | d = get_normolize_distance(id_0, id_1, 0, t, center_point)
104 | else:
105 | d = get_normolize_distance(id_0, id_1, t-4, t, center_point)
106 | md.append(d)
107 |
108 | md = np.array(md)
109 | return md
110 |
111 | def get_normolize_distance(id_0, id_1, t1, t2, center_point):
112 |
113 | v_id_0 = normalize(((center_point[t1, id_0, 0], center_point[t1, id_0, 1]),(center_point[t2, id_0, 0], center_point[t2, id_0, 1])))
114 | v_id_1 = normalize(((center_point[t1, id_1, 0], center_point[t1, id_1, 1]),(center_point[t2, id_1, 0], center_point[t2, id_1, 1])))
115 |
116 | # Eud_distance
117 | d_1 = ((v_id_0[1][0] - v_id_1[1][0])**2 + (v_id_0[1][1] - v_id_1[1][1])**2)**0.5 #箭头距离
118 | d_2 = ((v_id_0[0][0] - v_id_1[0][0])**2 + (v_id_0[0][1] - v_id_1[0][1])**2)**0.5 #箭尾距离
119 |
120 | md = d_1 - d_2
121 |
122 | return md
123 |
124 | def normalize(v):
125 | delta_x = v[1][0] - v[0][0]
126 | delta_y = v[1][1] - v[0][1]
127 | length = (delta_x**2 + delta_y**2)**0.5
128 | if length != 0:
129 | nor_delta_x = delta_x/length
130 | nor_delta_y = delta_y/length
131 | return ((v[0][0], v[0][1]),(v[0][0]+nor_delta_x, v[0][1]+nor_delta_y))
132 | else:
133 | return v
134 |
135 | def generate_graph_from_list(L, create_using=None):
136 | G = networkx.empty_graph(len(L), create_using)
137 | if len(L) > 1:
138 | if G.is_directed():
139 | edges = itertools.permutations(L,2)
140 | else:
141 | edges = itertools.combinations(L,2)
142 | G.add_edges_from(edges)
143 | graph_edges = list(G.edges())
144 |
145 | return graph_edges
146 |
147 | def generate_corresponding_txt(file_list, root, module="train"):
148 | """
149 | all valid module is train or valid
150 | """
151 | str = '\n'
152 | f = open(os.path.join(root, module + "_video_list.txt"), "w")
153 | f.write(str.join(file_list))
154 | f.close()
155 |
156 | proprecess_data(data_root, train_root, test_root)
--------------------------------------------------------------------------------
/scripts/preprocess_data_MASKER.py:
--------------------------------------------------------------------------------
1 | from networkx.algorithms import mis
2 | from networkx.algorithms.distance_measures import center
3 | import numpy as np
4 | import os
5 | import random
6 | import pandas as pd
7 | import networkx
8 | import itertools
9 | random.seed(1234)
10 |
11 |
12 | data_root = '/media/group1/data/tianhang/sorted_samples_pn_2048/'
13 | train_root = '/media/group1/data/tianhang/MASKER/'
14 | test_root = '/media/group1/data/tianhang/MASKER/'
15 |
16 | global miss_wait
17 | miss_wait = 10
18 |
19 | def proprecess_data(data_root, train_root, test_root):
20 | if not os.path.exists(train_root):
21 | os.makedirs(train_root)
22 | if not os.path.exists(test_root):
23 | os.makedirs(test_root)
24 |
25 | all_file_list = os.listdir(data_root)
26 | random.shuffle(all_file_list)
27 | train_number = int(len(all_file_list) * 0.7)
28 | train_file_list = all_file_list[:train_number]
29 | val_file_list = all_file_list[train_number:]
30 | generate_corresponding_txt(train_file_list, train_root, module="train")
31 | generate_corresponding_txt(val_file_list, test_root, module="valid")
32 | print("Start to process training set!")
33 | graph_edges = generate_graph_from_list(range(19))
34 |
35 | for number, i in enumerate(train_file_list):
36 |
37 | filled_before = np.load(os.path.join(data_root, i), allow_pickle=True).item()
38 | frames = filled_before["location"].shape[0]
39 | maskers = []
40 | mask_dict = np.zeros((frames, 19), dtype=np.float32)
41 |
42 | filled_before_location = filled_before['location']
43 | filled_before_feature = filled_before['feature']
44 | for t in range(frames):
45 | mask_t = filled_before_location[t,:,1:5].sum(-1) != 0.
46 |
47 | maskers, mask_dict, filled_after_location, filled_after_feature = fill_during_training(mask_t,
48 | maskers,
49 | mask_dict,
50 | t,
51 | location_all=filled_before_location,
52 | feature_all=filled_before_feature)
53 | maskers_np = np.array(maskers)
54 |
55 | graph_weight = generate_weight(filled_after_location, graph_edges, maskers_np)
56 | filled_before["graph_index"] = graph_edges
57 | filled_before["graph_weight"] = graph_weight
58 | filled_before["maskers"] = maskers_np
59 | filled_before["missing_status"] = mask_dict
60 | filled_before["location"] = filled_after_location
61 | filled_before["feature"] = filled_after_feature
62 |
63 | np.save(os.path.join(train_root, i), filled_before)
64 | print("The Training process has been {:.2f} % done!".format((number + 1) / len(train_file_list) * 100))
65 |
66 | print("#" * 30)
67 | print("Start to process Testing set!")
68 | for val_number, j in enumerate(val_file_list):
69 |
70 | val_infor = np.load(os.path.join(data_root, j), allow_pickle=True).item()
71 | frames = val_infor["location"].shape[0]
72 | val_maskers = []
73 | val_mask_dict = np.zeros((frames, 19), dtype=np.float32)
74 | val_location = val_infor["location"] # frames x 19 x 6
75 | for t in range(frames):
76 |
77 | val_mask_t = val_location[t,:,1:5].sum(-1) != 0.
78 |
79 | val_maskers, val_mask_dict = spect_during_testing(val_mask_t, val_maskers, val_mask_dict, t, val_location)
80 |
81 | val_maskers_np = np.array(val_maskers)
82 | val_infor["maskers"] = val_maskers_np
83 | val_infor["missing_status"] = val_mask_dict
84 | np.save(os.path.join(test_root, j), val_infor)
85 | print("The Testing process has been {:.2f} % done!".format((val_number + 1) / len(val_file_list) * 100))
86 |
87 | def generate_weight(location, graph_edges, mask):
88 |
89 |
90 | weights = np.zeros((location.shape[0], len(graph_edges), ), dtype=np.float32)
91 | location_n = location.copy()
92 | location_n[:, :, [1,3]] /= 1280
93 | location_n[:, :, [2,4]] /= 720
94 | for i, edge in enumerate(graph_edges):
95 |
96 | c1 = [0.5 * (location_n[:, edge[0], 1] + location_n[:, edge[0], 3]),
97 | 0.5 * (location_n[:, edge[0], 2] + location_n[:, edge[0], 4])]
98 | c2 = [0.5 * (location_n[:, edge[1], 1] + location_n[:, edge[1], 3]),
99 | 0.5 * (location_n[:, edge[1], 2] + location_n[:, edge[1], 4])]
100 |
101 | d = (c1[0] - c2[0])**2 + (c1[1] - c2[1])**2
102 |
103 | weights[:, i] = np.exp(-d)
104 | frame_index = ((mask[:, edge[0]]==True) & (mask[:, edge[1]] == True))
105 | weights[frame_index==False, i] = 0
106 |
107 | # normalize
108 | weights_nor = (weights - np.min(weights, axis=1, keepdims=True)) / (np.max(weights, axis=1, keepdims=True) - np.min(weights, axis=1, keepdims=True))
109 | # account for NaN
110 | weights_nor = np.nan_to_num(weights_nor)
111 |
112 | return weights_nor
113 |
114 | def get_motion_distance(id_0, id_1, location_all):
115 | md = []
116 | center_point = location_all[:,:,3:-1] - location_all[:,:,1:3]
117 | for t in range(100):
118 | if t <= 4:
119 | d = get_normolize_distance(id_0, id_1, 0, t, center_point)
120 | else:
121 | d = get_normolize_distance(id_0, id_1, t-4, t, center_point)
122 | md.append(d)
123 |
124 | md = np.array(md)
125 | return md
126 |
127 | def get_normolize_distance(id_0, id_1, t1, t2, center_point):
128 | v_id_0 = normalize(((center_point[t1, id_0, 0], center_point[t1, id_0, 1]),(center_point[t2, id_0, 0], center_point[t2, id_0, 1])))
129 | v_id_1 = normalize(((center_point[t1, id_1, 0], center_point[t1, id_1, 1]),(center_point[t2, id_1, 0], center_point[t2, id_1, 1])))
130 |
131 | # Eud_distance
132 | d_1 = ((v_id_0[1][0] - v_id_1[1][0])**2 + (v_id_0[1][1] - v_id_1[1][1])**2)**0.5 #箭头距离
133 | d_2 = ((v_id_0[0][0] - v_id_1[0][0])**2 + (v_id_0[0][1] - v_id_1[0][1])**2)**0.5 #箭尾距离
134 |
135 | md = d_1 - d_2
136 |
137 | return md
138 |
139 | def normalize(v):
140 | delta_x = v[1][0] - v[0][0]
141 | delta_y = v[1][1] - v[0][1]
142 | length = (delta_x**2 + delta_y**2)**0.5
143 | if length != 0:
144 | nor_delta_x = delta_x/length
145 | nor_delta_y = delta_y/length
146 | return ((v[0][0], v[0][1]),(v[0][0]+nor_delta_x, v[0][1]+nor_delta_y))
147 | else:
148 | return v
149 |
150 |
151 | def generate_graph_from_list(L, create_using=None):
152 | G = networkx.empty_graph(len(L), create_using)
153 | if len(L) > 1:
154 | if G.is_directed():
155 | edges = itertools.permutations(L,2)
156 | else:
157 | edges = itertools.combinations(L,2)
158 | G.add_edges_from(edges)
159 | graph_edges = list(G.edges())
160 |
161 | return graph_edges
162 |
163 |
164 | def spect_during_testing(mask_present,
165 | all_mask,
166 | mask_dict,
167 | t,
168 | location_all):
169 |
170 | if len(all_mask) == 0:
171 | all_mask.append(mask_present)
172 |
173 | return all_mask, mask_dict
174 |
175 | else:
176 | before_mask = all_mask[-1]
177 | # compare from True to False
178 | index = np.where((before_mask == True) & (mask_present ==False))
179 | if min(index[0].shape) != 0:
180 | for missing_object in index[0]:
181 | #print("missing_object", missing_object)
182 | near_future_infor = location_all[(t+1):(t+miss_wait), missing_object, 1:5]
183 | #print("near_future_infor", near_future_infor)
184 | near_future_status = np.where((near_future_infor.sum(-1) != 0) == True)
185 | if min(near_future_status[0].shape) != 0:
186 | mask_dict[t-1, missing_object] = 2.
187 | else:
188 | mask_dict[t-1, missing_object] = 1.
189 |
190 | all_mask.append(mask_present)
191 | return all_mask, mask_dict
192 |
193 | else:
194 | all_mask.append(mask_present)
195 | return all_mask, mask_dict
196 |
197 | def fill_during_training(mask_present,
198 | all_mask,
199 | mask_dict,
200 | t,
201 | location_all,
202 | feature_all
203 | ):
204 | # params for waiting time
205 |
206 | if len(all_mask) == 0:
207 | all_mask.append(mask_present)
208 |
209 | return all_mask, mask_dict, location_all, feature_all
210 | else:
211 | before_mask= all_mask[-1]
212 | index = np.where((before_mask==True) & (mask_present==False))
213 | if min(index[0].shape) != 0:
214 | #print("missing index", index)
215 | for missing_object in index[0]:
216 | #print("Missing objects", missing_object)
217 | near_future_infor = location_all[(t+1):(t+miss_wait), missing_object,1:5]
218 | near_future_status = np.where((near_future_infor.sum(-1) != 0) == True)
219 | if min(near_future_status[0].shape) != 0:
220 |
221 | mask_dict[t-1, missing_object] = 2.
222 | nearest_future = near_future_status[0][0]
223 | location_all[t:(t+nearest_future+1),missing_object, 1:5] = fill_functions(
224 | location_all[(t-5 if t>5 else 0):t, missing_object,1:5],
225 | location_all[t+nearest_future+1, missing_object, 1:5],
226 | nearest_future+1)
227 | feature_all[t:(t+nearest_future+1), missing_object, 1:] = feature_all[t-1,missing_object, 1:]
228 |
229 | else:
230 | mask_dict[t-1, missing_object] = 1.
231 |
232 | mask_present = location_all[t,:,1:5].sum(-1) != 0.
233 |
234 | all_mask.append(mask_present)
235 |
236 | return all_mask, mask_dict, location_all, feature_all
237 |
238 | else:
239 |
240 | all_mask.append(mask_present)
241 | return all_mask, mask_dict, location_all, feature_all
242 |
243 | def fill_functions(past_movement, near_future, fill_number):
244 | fill_matrix = np.full((fill_number, 4), np.nan)
245 | interplot = np.vstack((past_movement, fill_matrix, near_future))
246 | df_interplot = pd.DataFrame(interplot)
247 | s = df_interplot.interpolate()
248 | interplot = np.around(np.array(s))
249 | fill_matrix = interplot[(-(fill_number+1)):-1]
250 |
251 | return fill_matrix
252 |
253 | def generate_corresponding_txt(file_list, root, module="train"):
254 | """
255 | all valid module is train or valid
256 | """
257 | str = '\n'
258 | f = open(os.path.join(root, module + "_video_list.txt"), "w")
259 | f.write(str.join(file_list))
260 | f.close()
261 |
262 | proprecess_data(data_root, train_root, test_root)
--------------------------------------------------------------------------------