├── main ├── readme.txt ├── 1-train_sce.sh ├── 1-train_bid.sh ├── 1-train_ae.sh ├── 1-train_cvae.sh ├── 2-test_ft.sh ├── 1-finetune.sh ├── clean_failed_exp.py ├── schedule_task.py ├── metrics.py ├── module_cvae.py ├── misc.py ├── argmanager.py ├── 1-train_sce.py ├── 1-train_ae.py ├── 1-finetune.py ├── 1-train_cvae.py ├── 1-train_bid.py ├── 2-test_ft.py ├── model_builder.py ├── dsets.py └── dsets_bid.py ├── .gitignore ├── readme.md └── tools ├── make_new_gt.py ├── make_link.py └── create_scene_set.py /main/readme.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | */__pycache__/ 2 | *.pyc 3 | data.* 4 | save.* 5 | *.npy -------------------------------------------------------------------------------- /main/1-train_sce.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES="3" 2 | 3 | python 1-train_sce.py \ 4 | --data_name "ST|Ave|Cor|NWPU" --video_dir "**" --track_dir "**" --frame_dir "**" --scene_classes N \ 5 | --print_model 'yes' \ 6 | --snippet_inp 1 --snippet_tgt 1 --snippet_itv 2 \ 7 | --lr 0.01 --batch_size 32 --iterations 32 --fr 1 \ 8 | --epochs 10 --schedule 7 \ 9 | --workers 4 --save_freq 5 --print_freq 20 \ 10 | --note "Train BackgroundEncoder" $@ 11 | -------------------------------------------------------------------------------- /main/1-train_bid.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES="0" 2 | 3 | python 1-train_bid.py \ 4 | --data_name "ST" --video_dir "../data.ST_mem/videos/Train" --track_dir "../data.ST_mem/tracking/Train" --scene_classes 13 \ 5 | --print_model 'no' \ 6 | --pre_model "../save.ckpts/main/1-finetune_ST_0607-165004/checkpoint_1.pth.tar" \ 7 | --snippet_inp 8 --snippet_tgt 7 --snippet_itv 12.5 \ 8 | --lr 0.01 --batch_size 32 --iterations 32 --fr 1 \ 9 | --epochs 50 --schedule 40 \ 10 | --workers 16 --save_freq 2 --print_freq 50 \ 11 | --note "" $@ 12 | -------------------------------------------------------------------------------- /main/1-train_ae.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES="3" 2 | 3 | python 1-train_ae.py \ 4 | --data_name "ST" --video_dir "../data.ST_mem/videos/Train" --track_dir "../data.ST_mem/tracking/Train" --scene_classes 13 \ 5 | --print_model 'yes' \ 6 | --bgd_encoder "../save.ckpts/main/1-train_sce_ST_0607-152755/checkpoint_1.pth.tar" \ 7 | --snippet_inp 8 --snippet_tgt 1 --snippet_itv 2 \ 8 | --lr 0.01 --batch_size 32 --iterations 160 --lam_vae 0.1 --fr 0 \ 9 | --epochs 16 --schedule 12 \ 10 | --workers 16 --save_freq 4 --print_freq 50 \ 11 | --note "Train SceneFrameAE" $@ 12 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # A New Comprehensive Benchmark for Semi-Supervised Video Anomaly Detection and Anticipation 2 | 3 | # Note 4 | This is the repository for the codes of `A New Comprehensive Benchmark for Semi-Supervised Video Anomaly Detection and Anticipation` `(CVPR 2023)`. 5 | 6 | The full codes and `README` will be released in a few days. 7 | 8 | # Training 9 | 10 | `1-train_sce.sh` -> `1-train_ae.sh` -> `1-train_cvae.sh` -> `1-finetune.sh` 11 | 12 | The files of extracted trackings: (BaiduYun) https://pan.baidu.com/s/1aPVPLlVq3FZEYzhy20qNiQ code: uq3i; (Google Drive) https://drive.google.com/file/d/1d-UaJT4Vfr3ke4AFkDA_pi7gMXo05ymS 13 | 14 | # Test 15 | `2-test_ft.sh` 16 | -------------------------------------------------------------------------------- /main/1-train_cvae.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES="3" 2 | 3 | python 1-train_cvae.py \ 4 | --data_name "ST" --video_dir "../data.ST_mem/videos/Train" --track_dir "../data.ST_mem/tracking/Train" --scene_classes 13 \ 5 | --print_model 'yes' \ 6 | --bgd_encoder "../save.ckpts/main/1-train_sce_ST_0607-152755/checkpoint_1.pth.tar" \ 7 | --frame_ae "../save.ckpts/main/1-train_ae_ST_0607-153012/checkpoint_1.pth.tar" \ 8 | --snippet_inp 8 --snippet_tgt 1 --snippet_itv 2 \ 9 | --lr 0.01 --batch_size 32 --iterations 160 --lam_vae 0.1 --fr 0 \ 10 | --epochs 16 --schedule 12 \ 11 | --workers 16 --save_freq 4 --print_freq 50 \ 12 | --note "Train the CVAE in SceneFrameAE" $@ 13 | -------------------------------------------------------------------------------- /main/2-test_ft.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES="1,3" 2 | 3 | python 2-test_ft.py \ 4 | --data_name "ST" --video_dir "../data.ST_mem/videos/Test" --track_dir "../data.ST_mem/tracking/Test" --gtnpz_path "../data.ST_mem/gt.npz" --frame_dir "../data.ST_mem/frames/Test" --scene_classes 13 \ 5 | --print_model 'yes' \ 6 | --bgd_encoder "../save.ckpts/main/1-train_sce_ST_0607-152755/checkpoint_1.pth.tar" \ 7 | --resume "../save.ckpts/main/1-finetune_ST_0607-165004/checkpoint_1.pth.tar" \ 8 | --snippet_inp 8 --snippet_tgt 1 --snippet_itv 2 \ 9 | --error_type "patch" --patch_size 256 128 64 32 16 --patch_stride 8 --use_channel_l2 --lam_l1 1.0 --crop_fuse_type "max" \ 10 | --score_post_process "filt" \ 11 | --workers 2 --to_gpu --threads 48 --note "" $@ 12 | # for ST: --score_post_process "filt" "norm" -------------------------------------------------------------------------------- /tools/make_new_gt.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join, exists 3 | import numpy as np 4 | 5 | ''' 6 | convert anomaly to normality 7 | ''' 8 | 9 | correct_vid_name_list = ''' 10 | 01_0014 11 | 01_0029 12 | 01_0052 13 | 01_0138 14 | 01_0139 15 | 01_0163 16 | '''.strip().split('\n') 17 | 18 | 19 | src_gt_npz: dict = np.load("**/data.ST/gt.npz") # NOTE: path to the original GT file 20 | dst_gt_dict = {} 21 | 22 | all_test_vid_list = sorted(os.listdir("./frames/Test")) 23 | 24 | for vid_name in all_test_vid_list: 25 | src_gt_ary: np.ndarray = src_gt_npz[vid_name] 26 | if vid_name in correct_vid_name_list: 27 | dst_gt_ary = np.zeros_like(src_gt_ary) 28 | else: 29 | dst_gt_ary = src_gt_ary.copy() 30 | dst_gt_dict[vid_name] = dst_gt_ary 31 | 32 | np.savez("scene_gt.npz", **dst_gt_dict) 33 | -------------------------------------------------------------------------------- /main/1-finetune.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES="3" 2 | 3 | python 1-finetune.py \ 4 | --data_name "ST" --video_dir "../data.ST_mem/videos/Train" --track_dir "../data.ST_mem/tracking/Train" --scene_classes 13 \ 5 | --print_model 'yes' \ 6 | --bgd_encoder "../save.ckpts/main/1-train_sce_ST_0607-152755/checkpoint_1.pth.tar" \ 7 | --frame_ae "../save.ckpts/main/1-train_ae_ST_0607-153012/checkpoint_1.pth.tar" \ 8 | --snippet_inp 8 --snippet_tgt 1 --snippet_itv 2 \ 9 | --lr 0.01 --batch_size 32 --iterations 160 --lam_vae 0.1 --fr 1 \ 10 | --epochs 16 --schedule 8 \ 11 | --workers 16 --save_freq 4 --print_freq 50 \ 12 | --note "Finetune the network" $@ 13 | # --data_name "ST|Ave|Cor|NWPU" --video_dir "**" --track_dir "**" --frame_dir "**" --scene_classes N \ 14 | # --bgd_encoder "path to the trained BackgroundEncoder checkpoint" \ 15 | # --frame_ae "path to the trained SceneFrameAE checkpoint" \ 16 | -------------------------------------------------------------------------------- /main/clean_failed_exp.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import listdir 3 | from os.path import join, isdir, isfile 4 | import shutil 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser("") 8 | parser.add_argument("time_stamp_list", nargs='+') 9 | args = parser.parse_args() 10 | 11 | save_dir_root = "../" 12 | save_dir_name_list = ("save.ckpts", "save.logs", "save.tbxs") 13 | 14 | for time_stamp in args.time_stamp_list: 15 | for save_dir_name in save_dir_name_list: 16 | for proj_dir_name in listdir(p1 := join(save_dir_root, save_dir_name)): 17 | for exp_name in listdir(p2 := join(p1, proj_dir_name)): 18 | exp_path = join(p2, exp_name) 19 | if time_stamp in exp_name: 20 | if isdir(exp_path): 21 | shutil.rmtree(exp_path) 22 | print(f"remove the dir: \"{exp_path}\"") 23 | elif isfile(exp_path): 24 | os.remove(exp_path) 25 | print(f"remove the file: \"{exp_path}\"") 26 | else: 27 | raise NotImplementedError(f"{exp_path}") 28 | -------------------------------------------------------------------------------- /main/schedule_task.py: -------------------------------------------------------------------------------- 1 | #!/home/lvqiny/miniconda3/envs/tc18cu11/bin/python 2 | import os 3 | import time 4 | import argparse 5 | from typing import Tuple 6 | 7 | parser = argparse.ArgumentParser("Schedule task.") 8 | parser.add_argument('cmd', type=str) 9 | parser.add_argument('pid', type=int, nargs="+") 10 | parser.add_argument('--cond', '-c', type=str, default='all', choices=('all', 'any')) 11 | parser.add_argument('--wtime', '-t', type=int, default=60) 12 | 13 | args = parser.parse_args() 14 | 15 | 16 | def pid_exists(pid: int): 17 | flag_pid_exists = None 18 | try: 19 | os.kill(pid, 0) 20 | flag_pid_exists = True 21 | except ProcessLookupError: 22 | flag_pid_exists = False 23 | except: 24 | flag_pid_exists = True 25 | 26 | return flag_pid_exists 27 | 28 | 29 | cmd: str = args.cmd 30 | dst_pid_list: Tuple[int] = args.pid 31 | cond_str: str = args.cond 32 | wtime: int = args.wtime 33 | 34 | cond_type = None 35 | if cond_str == 'all': 36 | cond_type = all 37 | elif cond_str == 'any': 38 | cond_type = any 39 | else: 40 | raise NameError(cond_str) 41 | 42 | while True: 43 | print(time.ctime()) 44 | if cond_type([not pid_exists(dst_pid) for dst_pid in dst_pid_list]): 45 | print(cmd) 46 | os.system(cmd) 47 | break 48 | 49 | time.sleep(wtime) 50 | -------------------------------------------------------------------------------- /tools/make_link.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import system, mkdir, link 3 | from os.path import join, exists, realpath, abspath, dirname, basename, islink 4 | 5 | name_list = ('save.ckpts', 'save.logs', 'save.results', 'save.tbxs', 'save.visual') 6 | hdd_root = "/home/**/data" # NOTE: save root 7 | proj_root = realpath('./') 8 | proj_name = basename(proj_root) 9 | 10 | 11 | if input(f"'{hdd_root}'\nIs this the hdd root to save the save.* dirs? (y/n)\n") != 'y': 12 | print("Exit!") 13 | exit(1) 14 | if input(f"'{proj_root}'\nIs this the project root? (y/n)\n") != 'y': 15 | print("Exit!") 16 | exit(1) 17 | 18 | for dir_name in name_list: 19 | if exists(p1:=join(hdd_root, dir_name)): 20 | if not exists(p2:=join(p1, proj_name)): 21 | mkdir(p2) 22 | print(f"Dir '{p2}' is created.") 23 | dp = join(proj_root, dir_name) 24 | if not exists(dp): 25 | system(f"ln -s {p2} {dp}") 26 | print(f"Link '{p2}' to '{dp}'.") 27 | else: 28 | print(f"The link {dp} has been existed!") 29 | else: 30 | print(f"'{p2}' has been existed!") 31 | dp = join(proj_root, dir_name) 32 | if not exists(dp): 33 | system(f"ln -s {p2} {dp}") 34 | print(f"Link '{p2}' to '{dp}'.") 35 | else: 36 | print(f"The link {dp} has been existed!") 37 | else: 38 | print(f"Please manually make dir: '{p1}'!") 39 | -------------------------------------------------------------------------------- /tools/create_scene_set.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import listdir 3 | import os.path as osp 4 | from os.path import join, exists 5 | 6 | ''' 7 | Make the ShanghaiTech-sd dataset. 8 | Then use `make_new_gt.py` to create the correct scene-dependent labels. 9 | ''' 10 | 11 | test_str_vid_list = ''' 12 | 01_0014 13 | 01_0029 14 | 01_0052 15 | 01_0138 16 | 01_0139 17 | 01_0163 18 | 06_0147 19 | 06_0150 20 | 06_0155 21 | 10_0037 22 | 10_0074 23 | 12_0142 24 | 12_0148 25 | 12_0151 26 | 12_0154 27 | 12_0173 28 | 12_0174 29 | 12_0175 30 | ''' 31 | train_str_vid_list = ''' 32 | 01_0016 33 | 01_0051 34 | 01_0063 35 | 01_0073 36 | 01_0076 37 | 01_0129 38 | 01_0131 39 | 01_0134 40 | 01_0177 41 | 06_001 42 | 06_002 43 | 06_003 44 | 06_004 45 | 06_005 46 | 06_007 47 | 06_008 48 | 06_009 49 | 06_014 50 | 10_001 51 | 10_002 52 | 10_006 53 | 10_007 54 | 10_008 55 | 10_009 56 | 10_010 57 | 10_011 58 | 12_002 59 | 12_003 60 | 12_004 61 | 12_005 62 | 12_006 63 | 12_007 64 | 12_008 65 | 12_009 66 | 12_015 67 | ''' 68 | 69 | src_root = "**" # NOTE: dataset root path 70 | split = 'Test' # NOTE: 'Train' and 'Test' 71 | 72 | scene_vid_dict = {'Train': train_str_vid_list, 'Test': test_str_vid_list} 73 | vid_name_list = scene_vid_dict[split].strip().split('\n') 74 | data_dir_list = {'frames': '', 'tracking': '.pkl', 'videos': '.avi'} 75 | dst_root = os.getcwd() 76 | 77 | for data_dir in data_dir_list: 78 | src_dir = join(src_root, data_dir) 79 | dst_split_dir = join(dst_root, data_dir, split) 80 | if not exists(dst_split_dir): 81 | os.mkdir(dst_split_dir) 82 | assert exists(src_dir) and exists(dst_split_dir) 83 | 84 | for vid_name in vid_name_list: 85 | dst_file_name = f"{vid_name}{data_dir_list[data_dir]}" 86 | dst_file_list = [join(src_dir, src_split, dst_file_name) for src_split in ('Train', 'Test')] 87 | 88 | src_file_or_dir = "" 89 | for cur_dst_file in dst_file_list: 90 | if exists(cur_dst_file): 91 | src_file_or_dir = cur_dst_file 92 | assert exists(src_file_or_dir) 93 | 94 | dst_file_or_dir = join(dst_split_dir, dst_file_name) 95 | 96 | if exists(dst_file_or_dir): 97 | continue 98 | 99 | print(src_file_or_dir, dst_file_or_dir) 100 | 101 | if osp.isfile(src_file_or_dir): 102 | os.system(f"ln -s {src_file_or_dir} {dst_file_or_dir}") 103 | # os.system(f"cp {src_file_or_dir} {dst_file_or_dir}") 104 | elif osp.isdir(src_file_or_dir): 105 | os.system(f"ln -s {src_file_or_dir} {dst_file_or_dir}") 106 | # os.system(f"cp -r {src_file_or_dir} {dst_file_or_dir}") 107 | else: 108 | raise NotImplementedError(src_file_or_dir) 109 | -------------------------------------------------------------------------------- /main/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.lib.npyio 3 | from scipy.ndimage import gaussian_filter1d 4 | from sklearn.metrics import roc_auc_score 5 | import functools 6 | from typing import Union, Tuple, Dict 7 | 8 | gaussian_filter1d = functools.partial(gaussian_filter1d, axis=0, mode='constant') 9 | 10 | 11 | def cal_micro_auc(score_dict: Dict[str, np.ndarray], gt_npz: numpy.lib.npyio.NpzFile, slen: int, sitv: int, return_score_gt: bool = False, **kwargs) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray, np.ndarray]]: 12 | ''' 13 | Calculate micro-AUC. 14 | ''' 15 | vid_name_list = sorted(list(score_dict.keys())) 16 | 17 | def _concat_score(_score_dict: dict): 18 | _cat_gts = [] 19 | _cat_scores = [] 20 | for _vid_name in vid_name_list: 21 | _vid_gt = gt_npz[_vid_name] 22 | _vid_score: np.ndarray = _score_dict[_vid_name] 23 | _vid_score = _vid_score.squeeze() 24 | assert _vid_gt.ndim == _vid_score.ndim == 1, f"{_vid_gt.shape}, {_vid_score.shape}" 25 | 26 | for _p in kwargs['score_post_process']: 27 | if _p == "filt": 28 | _vid_score = gaussian_filter1d(_vid_score, sigma=sitv * slen / 2) 29 | elif _p == "norm": 30 | _vid_score = normalize_score(_vid_score, 'minmax') 31 | else: 32 | raise NotImplementedError(f"{_p}") 33 | 34 | assert len(_vid_gt) == len(_vid_score), f"{_vid_gt.shape}, {_vid_score.shape}, {slen}, {sitv}" 35 | _cat_gts.append(_vid_gt) 36 | _cat_scores.append(_vid_score) 37 | _cat_gts = np.concatenate(_cat_gts) 38 | _cat_scores = np.concatenate(_cat_scores) 39 | return _cat_gts, _cat_scores 40 | 41 | cat_gts, cat_scores = _concat_score(score_dict) 42 | 43 | micro_auc = roc_auc_score(cat_gts, cat_scores) 44 | 45 | if return_score_gt: 46 | return micro_auc, cat_scores, cat_gts 47 | return micro_auc 48 | 49 | 50 | def normalize_score(input_score: np.ndarray, ntype: str): 51 | if ntype == None: 52 | return input_score 53 | 54 | assert input_score.ndim in (1, 2), f"{input_score.shape}" 55 | ntype = ntype.lower() 56 | 57 | score: np.ndarray = input_score.copy() 58 | if score.ndim == 1: 59 | score = np.expand_dims(score, 1) 60 | 61 | if ntype == 'minmax': 62 | # MinMax 63 | denominator = score.max(0, keepdims=True) - score.min(0, keepdims=True) 64 | # assert np.all(denominator != 0) 65 | if np.all(denominator == 0): 66 | print("WARNING: np.all(denominator == 0) in `normalize_score`") 67 | score = score 68 | else: 69 | score = (score - score.min(0, keepdims=True)) / denominator 70 | elif ntype == 'meanstd': 71 | # MeanStd 72 | denominator = score.std(0, keepdims=True) 73 | assert np.all(denominator != 0) 74 | score = (score - score.mean(0, keepdims=True)) / denominator 75 | elif ntype == 'l2norm': 76 | # L2Norm 77 | denominator = np.linalg.norm(score, ord=2, axis=0, keepdims=True) 78 | assert np.all(denominator != 0) 79 | score = score / denominator 80 | else: 81 | raise NotImplementedError(ntype) 82 | 83 | if input_score.ndim == 1: 84 | score = score.squeeze() 85 | assert score.shape == input_score.shape 86 | return score 87 | -------------------------------------------------------------------------------- /main/module_cvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | 5 | 6 | def idx2onehot(idx: torch.Tensor, n: int): 7 | assert torch.max(idx).item() < n, f"{idx}, {torch.max(idx).item()}, {n}" 8 | 9 | if idx.dim() == 1: 10 | idx = idx.unsqueeze(1) 11 | onehot = torch.zeros(idx.size(0), n).to(idx.device) 12 | onehot.scatter_(1, idx, 1) 13 | 14 | return onehot 15 | 16 | 17 | class CVAE(nn.Module): 18 | 19 | def __init__(self, inp_size, inter_size, latent_size, 20 | conditional=False, num_labels=0): 21 | super().__init__() 22 | 23 | if conditional: 24 | assert num_labels > 0 25 | 26 | self.latent_size = latent_size 27 | 28 | self.encoder = Encoder(inp_size, inter_size, latent_size, conditional, num_labels) 29 | self.decoder = Decoder(inp_size, inter_size, latent_size, conditional, num_labels) 30 | 31 | def forward(self, x, c=None): 32 | means, log_var = self.encoder(x, c) 33 | means: Tensor 34 | log_var: Tensor 35 | z = self.reparameterize(means, log_var) 36 | recon_x = self.decoder(z, c) 37 | 38 | return recon_x, means, log_var, z 39 | 40 | def reparameterize(self, mu, log_var): 41 | std = torch.exp(0.5 * log_var) 42 | eps = torch.randn_like(std) 43 | 44 | return mu + eps * std 45 | 46 | def inference(self, z, c=None): 47 | recon_x = self.decoder(z, c) 48 | 49 | return recon_x 50 | 51 | 52 | class Encoder(nn.Module): 53 | def __init__(self, in_size, inter_size, latent_size, conditional, num_labels): 54 | super().__init__() 55 | 56 | self.conditional = conditional 57 | if self.conditional: 58 | in_size += num_labels 59 | self.num_labels = num_labels 60 | 61 | self.MLP = nn.Sequential( 62 | nn.LayerNorm(in_size), 63 | nn.Linear(in_size, inter_size), 64 | nn.GELU(), 65 | nn.LayerNorm(inter_size), 66 | nn.Linear(inter_size, 32), 67 | nn.GELU(), 68 | nn.LayerNorm(32), 69 | ) 70 | 71 | self.linear_means = nn.Linear(32, latent_size) 72 | self.linear_log_var = nn.Linear(32, latent_size) 73 | 74 | def forward(self, x, c=None): 75 | 76 | if self.conditional: 77 | c = idx2onehot(c, self.num_labels) 78 | x = torch.cat((x, c), dim=-1) 79 | 80 | x = self.MLP(x) 81 | 82 | means = self.linear_means(x) 83 | log_vars = self.linear_log_var(x) 84 | 85 | return means, log_vars 86 | 87 | 88 | class Decoder(nn.Module): 89 | def __init__(self, out_size, inter_size, latent_size, conditional, num_labels): 90 | super().__init__() 91 | 92 | self.num_labels = num_labels 93 | 94 | self.conditional = conditional 95 | if self.conditional: 96 | input_size = latent_size + num_labels 97 | else: 98 | input_size = latent_size 99 | 100 | self.MLP = nn.Sequential( 101 | nn.Linear(input_size, inter_size), 102 | nn.GELU(), 103 | nn.LayerNorm(inter_size), 104 | nn.Linear(inter_size, out_size), 105 | ) 106 | 107 | def forward(self, z, c): 108 | 109 | if self.conditional: 110 | c = idx2onehot(c, n=self.num_labels) 111 | z = torch.cat((z, c), dim=-1) 112 | 113 | x = self.MLP(z) 114 | 115 | return x 116 | -------------------------------------------------------------------------------- /main/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join, exists, dirname, basename, abspath, realpath 3 | import argparse 4 | import time 5 | import logging 6 | import shutil 7 | import torch 8 | import builtins 9 | from tensorboardX import SummaryWriter 10 | import numpy as np 11 | from collections import OrderedDict 12 | from typing import Dict, List, Tuple, OrderedDict as ODt 13 | 14 | 15 | def get_proj_root() -> str: 16 | proj_root = dirname(dirname(realpath(__file__))) 17 | return proj_root 18 | 19 | 20 | def get_time_stamp() -> str: 21 | _t = time.localtime() 22 | time_stamp = f"{str(_t.tm_mon).zfill(2)}{str(_t.tm_mday).zfill(2)}" + \ 23 | f"-{str(_t.tm_hour).zfill(2)}{str(_t.tm_min).zfill(2)}{str(_t.tm_sec).zfill(2)}" 24 | return time_stamp 25 | 26 | 27 | def get_dir_name(file=__file__) -> str: 28 | return basename(dirname(realpath(file))) 29 | 30 | 31 | def get_logger(time_stamp, file_name: str = '', log_root='save.logs', data_name='') -> logging.Logger: 32 | logger = logging.getLogger() 33 | logger.setLevel(logging.INFO) 34 | 35 | formatter = logging.Formatter('[%(asctime)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 36 | 37 | if file_name: 38 | _save_subdir = basename(dirname(abspath(__file__))) 39 | if data_name: 40 | data_name = '_' + data_name 41 | log_file = join(get_proj_root(), log_root, _save_subdir, f"{basename(file_name).split('.')[0]}{data_name}_{time_stamp}.log") 42 | 43 | if not exists(dirname(log_file)): 44 | os.makedirs(dirname(log_file)) 45 | 46 | file_handler = logging.FileHandler(log_file, mode='a') 47 | file_handler.setLevel(logging.INFO) 48 | file_handler.setFormatter(formatter) 49 | logger.addHandler(file_handler) 50 | 51 | console_handler = logging.StreamHandler() 52 | console_handler.setLevel(logging.INFO) 53 | console_handler.setFormatter(formatter) 54 | logger.addHandler(console_handler) 55 | 56 | return logger 57 | 58 | 59 | def get_ckpt_dir(time_stamp, file_name, ckpt_root='save.ckpts', data_name='') -> str: 60 | _save_subdir = basename(dirname(realpath(__file__))) 61 | if data_name: 62 | data_name = '_' + data_name 63 | ckpt_dir = join(get_proj_root(), ckpt_root, _save_subdir, f"{basename(file_name).split('.')[0]}{data_name}_{time_stamp}") 64 | if not exists(ckpt_dir): 65 | os.makedirs(ckpt_dir) 66 | return ckpt_dir 67 | 68 | 69 | def get_result_dir(result_root='save.results') -> str: 70 | result_dir = join(get_proj_root(), result_root, basename(dirname(abspath(__file__)))) 71 | if not exists(result_dir): 72 | os.makedirs(result_dir) 73 | return result_dir 74 | 75 | 76 | def format_args(args: argparse.Namespace, sorted_key: bool = True) -> str: 77 | _cont = '\n' + '-' * 30 + "args" + '-' * 30 + '\n' 78 | args: dict = args.__dict__ 79 | 80 | m_l = max([len(k) for k in args.keys()]) 81 | 82 | key_list = list(args.keys()) 83 | if sorted_key: 84 | key_list.sort() 85 | 86 | for _k in key_list: 87 | _v = args[_k] 88 | _cont += f"{_k:>{m_l}s} = {_v}\n" 89 | _cont += '-' * 60 + '\n' 90 | return _cont 91 | 92 | 93 | def save_checkpoint(state, is_best, filedir, epoch, writer=builtins.print) -> None: 94 | if not exists(filedir): 95 | os.makedirs(filedir) 96 | 97 | filename = join(filedir, f'checkpoint_{epoch}.pth.tar') 98 | torch.save(state, filename) 99 | writer(f"Saved checkpoint to: {filename}") 100 | if is_best: 101 | shutil.copyfile(filename, join(filedir, 'checkpoint_best.pth.tar')) 102 | 103 | 104 | class ScalarWriter(): 105 | ''' 106 | ''' 107 | def __init__(self, name_list: Tuple[str], tensorboard_writer: SummaryWriter = None, init_global_step: int = 0, init_epoch: int = 0): 108 | self.tbxs_writer = tensorboard_writer 109 | self.global_step = init_global_step 110 | self.epoch = init_epoch 111 | self.epoch_value_list_dict: ODt[str, List[float]] = OrderedDict([(n, []) for n in name_list]) 112 | self.epoch_sample_counter_list = [] 113 | 114 | def add_step_value(self, step_value_dict: ODt[str, float], num_samples=1): 115 | ''' 116 | ''' 117 | for k, v in step_value_dict.items(): 118 | assert k in self.epoch_value_list_dict, f"{k}" 119 | self.epoch_value_list_dict[k].append(v) 120 | if isinstance(self.tbxs_writer, SummaryWriter): 121 | self.tbxs_writer.add_scalar(f"step/{k}", v, self.global_step) 122 | self.epoch_sample_counter_list.append(num_samples) 123 | self.global_step += 1 124 | 125 | def update_epoch_average_value(self) -> ODt[str, float]: 126 | epoch_average_value_dict = OrderedDict() 127 | sc_ary: np.ndarray = np.array(self.epoch_sample_counter_list) 128 | for k, v in self.epoch_value_list_dict.items(): 129 | v_ary: np.ndarray = np.array(v) 130 | assert v_ary.shape == sc_ary.shape 131 | v_mean = np.dot(v_ary, sc_ary) / np.sum(sc_ary) 132 | if isinstance(self.tbxs_writer, SummaryWriter): 133 | self.tbxs_writer.add_scalar(f"epoch/{k}", v_mean, self.epoch) 134 | epoch_average_value_dict[k] = v_mean 135 | 136 | self.epoch += 1 137 | 138 | for k in self.epoch_value_list_dict.keys(): 139 | self.epoch_value_list_dict[k].clear() 140 | self.epoch_sample_counter_list.clear() 141 | 142 | return epoch_average_value_dict 143 | -------------------------------------------------------------------------------- /main/argmanager.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from misc import get_time_stamp 3 | 4 | 5 | def str2bool(v: str): 6 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 7 | return True 8 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 9 | return False 10 | else: 11 | raise argparse.ArgumentTypeError('Expect boolean value in string.') 12 | 13 | 14 | def _base_parser(training: bool): 15 | parser = argparse.ArgumentParser(description=f'') 16 | 17 | # Path settings 18 | parser.add_argument('--data_name', type=str, default='', 19 | help='Name of the dataset.') 20 | parser.add_argument('--video_dir', type=str, required=True, 21 | help='Path to the videos.') 22 | parser.add_argument('--track_dir', type=str, required=True, 23 | help='Path to the tracks.') 24 | parser.add_argument('--frame_dir', type=str, default='', 25 | help='Path to the frames.') 26 | 27 | parser.add_argument('--time_stamp', type=str, default=get_time_stamp(), 28 | help='') 29 | 30 | parser.add_argument('--resume', default="", type=str, 31 | help='') 32 | parser.add_argument('--bgd_encoder', default="", type=str, 33 | help='') 34 | parser.add_argument('--frame_ae', default="", type=str, 35 | help='') 36 | 37 | parser.add_argument('--scene_classes', type=int, 38 | help='') 39 | 40 | # Resource usage settings 41 | parser.add_argument('--workers', default=6 if training else 1, type=int, 42 | help='Number of data loading workers') 43 | 44 | # Model settings 45 | parser.add_argument('--lam_cvae', type=float, default=0., 46 | help="") 47 | 48 | # Dataset settings 49 | parser.add_argument('--snippet_inp', type=int, default=8, 50 | help="") 51 | parser.add_argument('--snippet_itv', type=float, default=2, 52 | help="") 53 | parser.add_argument('--snippet_tgt', type=int, default=1, 54 | help="") 55 | 56 | # Loss weight settings 57 | parser.add_argument('--lam_l1', default=1.0, type=float, 58 | help="") 59 | parser.add_argument('--lam_vae', default=1.0, type=float, 60 | help="") 61 | 62 | # Other settings 63 | parser.add_argument('--log_root', default="save.logs", type=str, 64 | help='') 65 | parser.add_argument('--note', default="", type=str, 66 | help='A note for this experiment') 67 | parser.add_argument('--print_model', type=str2bool, default='yes', 68 | help='') 69 | parser.add_argument('--debug_mode', action="store_true", 70 | help='') 71 | 72 | return parser 73 | 74 | 75 | def train_parser(): 76 | parser = _base_parser(training=True) 77 | 78 | # Optimizer settings 79 | parser.add_argument('--lr', '--learning_rate', default=0.01, type=float, 80 | help='Learning rate.') 81 | parser.add_argument('--fr', '--funetune_rate', default=1, type=float, 82 | help='') 83 | parser.add_argument('--batch_size', default=32, type=int, 84 | help='Batch size.') 85 | parser.add_argument('--epochs', default=80, type=int, 86 | help='Number of total epochs to run.') 87 | parser.add_argument('--schedule', default=[60], nargs='*', type=int, 88 | help='Learning rate schedule.') 89 | 90 | # Dataset settings 91 | parser.add_argument('--iterations', default=32, type=int, 92 | help='A way to simulate more epochs.') 93 | 94 | # Saving and logging settings 95 | parser.add_argument('--ckpt_root', default="save.ckpts", type=str, 96 | help='') 97 | parser.add_argument('--tbxs_root', default="save.tbxs", type=str, 98 | help='') 99 | parser.add_argument('--save_freq', default=1, type=int, 100 | help='Save frequency.') 101 | parser.add_argument('--print_freq', default=10, type=int, 102 | help='Print frequency.') 103 | 104 | parser.add_argument('--pre_model', default="", type=str, 105 | help='the pre-trained forward model') 106 | 107 | return parser 108 | 109 | 110 | def test_parser(): 111 | parser = _base_parser(training=False) 112 | 113 | # Path settings 114 | parser.add_argument('--gtnpz_path', type=str, required=True, 115 | help='Path to groundtruth npz file.') 116 | parser.add_argument('--score_dict_path', type=str, default="", 117 | help='Only calculate AUCs for this score_dict. --video_dir and --resume will be ignored.') 118 | 119 | # Dataset settings 120 | parser.add_argument('--to_gpu', action='store_true', 121 | help="put data to gpu") 122 | parser.add_argument("--ignore_first_frame_score", action="store_true") 123 | 124 | # Resource usage settings 125 | parser.add_argument('--threads', default=24, type=int, 126 | help='Number of threads used by pytorch') 127 | 128 | # Error settings 129 | parser.add_argument('--error_type', type=str, default='frame', choices=('frame', 'patch'), 130 | help='') 131 | parser.add_argument('--patch_size', type=int, nargs='+', 132 | help='') 133 | parser.add_argument('--patch_stride', type=int, default=8, 134 | help='') 135 | parser.add_argument('--use_channel_l2', action="store_true", 136 | help='') 137 | parser.add_argument('--crop_fuse_type', type=str, default='mean', choices=('mean', 'max'), 138 | help='Use mean or max to obtaion snippet_score') 139 | 140 | parser.add_argument('--score_post_process', type=str, nargs='*', default=['filt'], choices=('filt', 'norm'), 141 | help='') 142 | 143 | # Saving settings 144 | parser.add_argument('--tmp_score_dir', default="", type=str, 145 | help='') 146 | parser.add_argument('--result_root', default="save.results", type=str, 147 | help='') 148 | parser.add_argument('--visual_root', default="save.visual", type=str, 149 | help='') 150 | return parser 151 | 152 | 153 | if __name__ == '__main__': 154 | pass 155 | -------------------------------------------------------------------------------- /main/1-train_sce.py: -------------------------------------------------------------------------------- 1 | from tensorboardX import SummaryWriter 2 | from model_builder import BackgroundEncoder 3 | from argmanager import train_parser 4 | from misc import get_logger, format_args, get_ckpt_dir, save_checkpoint, ScalarWriter 5 | from dsets import TrainSetTrackingObject 6 | from torch.backends import cudnn 7 | import torch.utils.data 8 | import torch.nn as nn 9 | import os 10 | from functools import partial 11 | from collections import OrderedDict 12 | 13 | import torch 14 | import random 15 | import numpy as np 16 | 17 | rand_seed = 2022 18 | random.seed(rand_seed) 19 | torch.manual_seed(rand_seed) 20 | torch.cuda.manual_seed(rand_seed) 21 | torch.cuda.manual_seed_all(rand_seed) 22 | np.random.seed(rand_seed) 23 | 24 | 25 | if __name__ == '__main__': 26 | args = train_parser().parse_args() 27 | logger = get_logger(args.time_stamp, '' if args.debug_mode else __file__, args.log_root, args.data_name) 28 | logger.info(format_args(args)) 29 | if args.debug_mode: 30 | logger.info(f"ATTENTION: You are in DEBUG mode. Nothing will be saved!") 31 | 32 | cudnn.benchmark = not args.debug_mode 33 | 34 | n_gpus = torch.cuda.device_count() 35 | 36 | train_dataset = TrainSetTrackingObject(args.video_dir, args.track_dir, args.snippet_inp + args.snippet_tgt, args.snippet_itv, args.iterations, frame_dir=args.frame_dir) 37 | 38 | smpl_weight = train_dataset.vid_samp_weight * train_dataset.iterations 39 | assert len(smpl_weight) == len(train_dataset) 40 | weighted_sampler = torch.utils.data.WeightedRandomSampler(smpl_weight, len(train_dataset), replacement=True) 41 | 42 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=weighted_sampler, 43 | num_workers=args.workers, pin_memory=True, drop_last=False, prefetch_factor=1) 44 | 45 | model = BackgroundEncoder(args.scene_classes) 46 | 47 | if args.print_model: 48 | logger.info(f"{model}") 49 | 50 | model.train() 51 | model = model.cuda() 52 | 53 | special_layer = 'NONE' 54 | if args.fr == 0: 55 | for n, p in model.named_parameters(): 56 | if special_layer in n: 57 | p.requires_grad_(False) 58 | 59 | logger.info("[Learning Rate] Set requires_grad ...") 60 | param_list = [] 61 | for n, p in model.named_parameters(): 62 | if p.requires_grad: 63 | param_list.append(p) 64 | else: 65 | logger.info(f"{n}: requires_grad={p.requires_grad}") 66 | 67 | elif args.fr == 1: 68 | logger.info("[Learning Rate] All layers have the same `lr` ") 69 | param_list = [] 70 | for n, p in model.named_parameters(): 71 | param_list.append(p) 72 | 73 | elif args.fr > 0: 74 | param_list = [{'params': [], 'lr':args.lr}, 75 | {'params': [], 'lr':args.lr * args.fr}] 76 | pname_list = {_p['lr']: [] for _p in param_list} 77 | 78 | logger.info("[Learning Rate] Set finetuning_rate ...") 79 | for n, p in model.named_parameters(): 80 | if p.requires_grad: 81 | _group_idx = 1 if special_layer in n else 0 82 | param_list[_group_idx]['params'].append(p) 83 | pname_list[param_list[_group_idx]['lr']].append(n) 84 | else: 85 | logger.info(f"{n}: requires_grad={p.requires_grad}") 86 | 87 | for _lr, _pn in pname_list.items(): 88 | logger.info(f'[Optimizer] lr={_lr}:') 89 | logger.info(' | '.join(_pn)) 90 | 91 | else: 92 | raise ValueError(f"`args.fr({args.fr})` shoule be >=0") 93 | 94 | criterion_ce = nn.CrossEntropyLoss().cuda() 95 | 96 | optimizer = torch.optim.Adam(param_list, args.lr) 97 | lr_sch = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.schedule, gamma=0.1) 98 | 99 | _tbxs_witer = None 100 | if not args.debug_mode: 101 | ckpt_save_func = partial(save_checkpoint, is_best=False, filedir=get_ckpt_dir(args.time_stamp, __file__, args.ckpt_root, args.data_name), writer=logger.info) 102 | _tbxs_witer = SummaryWriter(get_ckpt_dir(args.time_stamp, __file__, args.tbxs_root, args.data_name)) 103 | 104 | start_epoch = 1 105 | if args.resume: 106 | if os.path.exists(args.resume) and os.path.isfile(args.resume): 107 | ckpt = torch.load(args.resume) 108 | model.load_state_dict(ckpt['model']) 109 | optimizer.load_state_dict(ckpt['optimizer']) 110 | lr_sch.load_state_dict(ckpt['lr_sch']) 111 | start_epoch = ckpt['epoch'] + 1 112 | else: 113 | raise FileNotFoundError(f"{args.resume}") 114 | else: 115 | if not args.debug_mode: 116 | ckpt_save_func({'epoch': 0, 117 | 'model': model.state_dict(), 118 | 'optimizer': optimizer.state_dict(), 119 | 'lr_sch': lr_sch.state_dict()}, 120 | epoch=0) 121 | 122 | scalar_writer = ScalarWriter(('lr', 'loss_ce'), 123 | _tbxs_witer, (start_epoch - 1) * len(train_loader), start_epoch) 124 | 125 | for epoch in range(start_epoch, args.epochs + 1): 126 | for step, (batch_snp, batch_bgd, batch_lbl) in enumerate(train_loader): 127 | batch_bgd: torch.Tensor 128 | batch_lbl: torch.Tensor 129 | batch_bgd = batch_bgd.cuda(non_blocking=True) # [b, c, t, h, w] 130 | batch_lbl = batch_lbl.cuda(non_blocking=True) # [b, c, t, h, w] 131 | 132 | batch_out = model(batch_bgd) 133 | loss_ce: torch.Tensor = criterion_ce(batch_out, batch_lbl) 134 | 135 | optimizer.zero_grad() 136 | loss_ce.backward() 137 | optimizer.step() 138 | 139 | _scalar_dict = OrderedDict(lr=lr_sch.get_last_lr()[0], loss_ce=loss_ce.item()) 140 | scalar_writer.add_step_value(_scalar_dict, len(batch_bgd)) 141 | 142 | if step % args.print_freq == 0: 143 | logger.info(f"{'(DEBUG) ' if args.debug_mode else ''}Epoch[{epoch}/{args.epochs}] step {step:>4d}/{len(train_loader)}: " + " ".join([f"{k}={v:.4f}" if not k in ['lr'] else f"{k}={v}" for k, v in _scalar_dict.items()])) 144 | 145 | _epoch_average_dict = scalar_writer.update_epoch_average_value() 146 | logger.info(f"Epoch[{epoch}/{args.epochs}] [ Average ]: " + " ".join([f"{k}={v:.4f}" if not k in ['lr'] else f"{k}={v}" for k, v in _epoch_average_dict.items()])) 147 | 148 | _last_lr = lr_sch.get_last_lr()[0] 149 | lr_sch.step() 150 | if lr_sch.get_last_lr()[0] < _last_lr: 151 | logger.info(f"[Learning Rate] Decay `lr` from {_last_lr} to {lr_sch.get_last_lr()[0]}") 152 | 153 | if not args.debug_mode: 154 | if epoch % args.save_freq == 0 or epoch == args.epochs: 155 | ckpt_save_func({'epoch': epoch, 156 | 'model': model.state_dict(), 157 | 'optimizer': optimizer.state_dict(), 158 | 'lr_sch': lr_sch.state_dict()}, 159 | epoch=epoch) 160 | if not args.debug_mode: 161 | if isinstance(_tbxs_witer, SummaryWriter): 162 | _tbxs_witer.close() 163 | -------------------------------------------------------------------------------- /main/1-train_ae.py: -------------------------------------------------------------------------------- 1 | from tensorboardX import SummaryWriter 2 | from model_builder import SceneFrameAE, BackgroundEncoder 3 | from argmanager import train_parser 4 | from misc import get_logger, format_args, get_ckpt_dir, save_checkpoint, ScalarWriter 5 | from dsets import TrainSetTrackingObject 6 | from torch.backends import cudnn 7 | import torch.utils.data 8 | import torch.nn as nn 9 | import os 10 | from functools import partial 11 | from collections import OrderedDict 12 | 13 | import torch 14 | import random 15 | import numpy as np 16 | 17 | rand_seed = 2022 18 | random.seed(rand_seed) 19 | torch.manual_seed(rand_seed) 20 | torch.cuda.manual_seed(rand_seed) 21 | torch.cuda.manual_seed_all(rand_seed) 22 | np.random.seed(rand_seed) 23 | 24 | 25 | if __name__ == '__main__': 26 | args = train_parser().parse_args() 27 | logger = get_logger(args.time_stamp, '' if args.debug_mode else __file__, args.log_root, args.data_name) 28 | logger.info(format_args(args)) 29 | if args.debug_mode: 30 | logger.info(f"ATTENTION: You are in DEBUG mode. Nothing will be saved!") 31 | 32 | cudnn.benchmark = not args.debug_mode 33 | 34 | n_gpus = torch.cuda.device_count() 35 | 36 | train_dataset = TrainSetTrackingObject(args.video_dir, args.track_dir, args.snippet_inp + args.snippet_tgt, args.snippet_itv, args.iterations, frame_dir=args.frame_dir) 37 | 38 | smpl_weight = train_dataset.vid_samp_weight * train_dataset.iterations 39 | assert len(smpl_weight) == len(train_dataset) 40 | weighted_sampler = torch.utils.data.WeightedRandomSampler(smpl_weight, len(train_dataset), replacement=True) 41 | 42 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=weighted_sampler, 43 | num_workers=args.workers, pin_memory=True, drop_last=False, prefetch_factor=1) 44 | 45 | bgd_encoder = BackgroundEncoder(args.scene_classes) 46 | bgd_encoder.load_state_dict(torch.load(args.bgd_encoder, map_location='cpu')['model']) 47 | 48 | model = SceneFrameAE(inp_frm=args.snippet_inp, tgt_frm=args.snippet_tgt, bgd_encoder=bgd_encoder, lam_cvae=0.) 49 | 50 | model.train() 51 | model = model.cuda() 52 | 53 | if args.print_model: 54 | logger.info(f"{model}") 55 | 56 | special_layer = ('bgd_encoder',) 57 | if args.fr == 0: 58 | for n, p in model.named_parameters(): 59 | for _s in special_layer: 60 | if _s in n: 61 | p.requires_grad_(False) 62 | 63 | logger.info("[Learning Rate] Set requires_grad ...") 64 | param_list = [] 65 | for n, p in model.named_parameters(): 66 | if p.requires_grad: 67 | param_list.append(p) 68 | logger.info(f"{n}: requires_grad={p.requires_grad}") 69 | else: 70 | logger.info(f"{n}: requires_grad={p.requires_grad}") 71 | 72 | for n, m in model.named_modules(): 73 | for _s in special_layer: 74 | if _s in n: 75 | m.eval() 76 | logger.info(f"{n} is set to {'training()' if m.training else 'evel()'} mode.") 77 | else: 78 | logger.info(f"{n} is set to {'training()' if m.training else 'evel()'} mode.") 79 | 80 | elif args.fr == 1: 81 | logger.info("[Learning Rate] All layers have the same `lr` ") 82 | param_list = [] 83 | for n, p in model.named_parameters(): 84 | param_list.append(p) 85 | 86 | elif args.fr > 0: 87 | raise NotImplementedError("special_layer is not processed") 88 | param_list = [{'params': [], 'lr':args.lr}, 89 | {'params': [], 'lr':args.lr * args.fr}] 90 | pname_list = {_p['lr']: [] for _p in param_list} 91 | 92 | logger.info("[Learning Rate] Set finetuning_rate ...") 93 | for n, p in model.named_parameters(): 94 | if p.requires_grad: 95 | _group_idx = 1 if special_layer in n else 0 96 | param_list[_group_idx]['params'].append(p) 97 | pname_list[param_list[_group_idx]['lr']].append(n) 98 | else: 99 | logger.info(f"{n}: requires_grad={p.requires_grad}") 100 | 101 | for _lr, _pn in pname_list.items(): 102 | logger.info(f'[Optimizer] lr={_lr}:') 103 | logger.info(' | '.join(_pn)) 104 | 105 | else: 106 | raise ValueError(f"`args.fr({args.fr})` shoule be >=0") 107 | 108 | criterion_mse = nn.MSELoss().cuda() 109 | criterion_l1 = nn.L1Loss().cuda() 110 | 111 | optimizer = torch.optim.Adam(param_list, args.lr) 112 | lr_sch = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.schedule, gamma=0.1) 113 | 114 | _tbxs_witer = None 115 | if not args.debug_mode: 116 | ckpt_save_func = partial(save_checkpoint, is_best=False, filedir=get_ckpt_dir(args.time_stamp, __file__, args.ckpt_root, args.data_name), writer=logger.info) 117 | _tbxs_witer = SummaryWriter(get_ckpt_dir(args.time_stamp, __file__, args.tbxs_root, args.data_name)) 118 | 119 | start_epoch = 1 120 | if args.resume: 121 | if os.path.exists(args.resume) and os.path.isfile(args.resume): 122 | ckpt = torch.load(args.resume) 123 | model.load_state_dict(ckpt['model']) 124 | optimizer.load_state_dict(ckpt['optimizer']) 125 | lr_sch.load_state_dict(ckpt['lr_sch']) 126 | start_epoch = ckpt['epoch'] + 1 127 | else: 128 | raise FileNotFoundError(f"{args.resume}") 129 | else: 130 | if not args.debug_mode: 131 | ckpt_save_func({'epoch': 0, 132 | 'model': model.state_dict(), 133 | 'optimizer': optimizer.state_dict(), 134 | 'lr_sch': lr_sch.state_dict()}, 135 | epoch=0) 136 | 137 | scalar_writer = ScalarWriter(('lr', 'img_mse', 'img_l1', 'loss_tot'), _tbxs_witer, (start_epoch - 1) * len(train_loader), start_epoch) 138 | 139 | for epoch in range(start_epoch, args.epochs + 1): 140 | for step, (batch_snp, batch_bgd, batch_lbl) in enumerate(train_loader): 141 | batch_snp: torch.Tensor 142 | batch_bgd: torch.Tensor 143 | batch_lbl: torch.Tensor 144 | batch_snp = batch_snp.cuda(non_blocking=True) # [b, c, t, h, w] 145 | batch_bgd = batch_bgd.cuda(non_blocking=True) # [b, c, t, h, w] 146 | batch_lbl = batch_lbl.cuda(non_blocking=True) # [b, c, t, h, w] 147 | 148 | inp_snp = batch_snp[:, :, :args.snippet_inp] 149 | tgt_snp = batch_snp[:, :, -args.snippet_tgt:] 150 | inp_bgd = batch_bgd 151 | 152 | out_snp = model(inp_snp, inp_bgd) 153 | 154 | img_mse: torch.Tensor = criterion_mse(out_snp, tgt_snp) 155 | img_l1: torch.Tensor = criterion_l1(out_snp, tgt_snp) 156 | loss_tot: torch.Tensor = img_mse + args.lam_l1 * img_l1 157 | 158 | optimizer.zero_grad() 159 | loss_tot.backward() 160 | optimizer.step() 161 | 162 | _scalar_dict = OrderedDict(lr=lr_sch.get_last_lr()[0], img_mse=img_mse.item(), img_l1=img_l1.item(), loss_tot=loss_tot.item()) 163 | scalar_writer.add_step_value(_scalar_dict, len(tgt_snp)) 164 | 165 | if step % args.print_freq == 0: 166 | logger.info(f"{'(DEBUG) ' if args.debug_mode else ''}Epoch[{epoch}/{args.epochs}] step {step:>4d}/{len(train_loader)}: " + 167 | " ".join([f"{k}={v:.4f}" if not k in ['lr'] else f"{k}={v}" for k, v in _scalar_dict.items()])) 168 | 169 | _epoch_average_dict = scalar_writer.update_epoch_average_value() 170 | logger.info(f"Epoch[{epoch}/{args.epochs}] [ Average ]: " + 171 | " ".join([f"{k}={v:.4f}" if not k in ['lr'] else f"{k}={v}" for k, v in _epoch_average_dict.items()])) 172 | 173 | _last_lr = lr_sch.get_last_lr()[0] 174 | lr_sch.step() 175 | if lr_sch.get_last_lr()[0] < _last_lr: 176 | logger.info(f"[Learning Rate] Decay `lr` from {_last_lr} to {lr_sch.get_last_lr()[0]}") 177 | 178 | if not args.debug_mode: 179 | if epoch % args.save_freq == 0 or epoch == args.epochs: 180 | ckpt_save_func({'epoch': epoch, 181 | 'model': model.state_dict(), 182 | 'optimizer': optimizer.state_dict(), 183 | 'lr_sch': lr_sch.state_dict()}, 184 | epoch=epoch) 185 | if not args.debug_mode: 186 | if isinstance(_tbxs_witer, SummaryWriter): 187 | _tbxs_witer.close() 188 | -------------------------------------------------------------------------------- /main/1-finetune.py: -------------------------------------------------------------------------------- 1 | from tensorboardX import SummaryWriter 2 | from model_builder import SceneFrameAE, BackgroundEncoder 3 | from argmanager import train_parser 4 | from misc import get_logger, format_args, get_ckpt_dir, save_checkpoint, ScalarWriter 5 | from dsets import TrainSetTrackingObject 6 | from torch.backends import cudnn 7 | import torch.utils.data 8 | from torch.nn import functional as tf 9 | import torch.nn as nn 10 | import os 11 | from functools import partial 12 | from collections import OrderedDict 13 | 14 | import torch 15 | import random 16 | import numpy as np 17 | 18 | rand_seed = 2022 19 | random.seed(rand_seed) 20 | torch.manual_seed(rand_seed) 21 | torch.cuda.manual_seed(rand_seed) 22 | torch.cuda.manual_seed_all(rand_seed) 23 | np.random.seed(rand_seed) 24 | 25 | 26 | if __name__ == '__main__': 27 | args = train_parser().parse_args() 28 | logger = get_logger(args.time_stamp, '' if args.debug_mode else __file__, args.log_root, args.data_name) 29 | logger.info(format_args(args)) 30 | if args.debug_mode: 31 | logger.info(f"ATTENTION: You are in DEBUG mode. Nothing will be saved!") 32 | 33 | cudnn.benchmark = not args.debug_mode 34 | 35 | n_gpus = torch.cuda.device_count() 36 | 37 | train_dataset = TrainSetTrackingObject(args.video_dir, args.track_dir, args.snippet_inp + args.snippet_tgt, args.snippet_itv, args.iterations, frame_dir=args.frame_dir) 38 | 39 | smpl_weight = train_dataset.vid_samp_weight * train_dataset.iterations 40 | assert len(smpl_weight) == len(train_dataset) 41 | weighted_sampler = torch.utils.data.WeightedRandomSampler(smpl_weight, len(train_dataset), replacement=True) 42 | 43 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=weighted_sampler, 44 | num_workers=args.workers, pin_memory=True, drop_last=False, prefetch_factor=1) 45 | 46 | bgd_encoder = BackgroundEncoder(args.scene_classes) 47 | bgd_encoder.load_state_dict(torch.load(args.bgd_encoder, map_location='cpu')['model']) 48 | frame_ae: OrderedDict = torch.load(args.frame_ae, map_location='cpu')['model'] 49 | 50 | model = SceneFrameAE(inp_frm=args.snippet_inp, tgt_frm=args.snippet_tgt, bgd_encoder=bgd_encoder, lam_cvae=1.) 51 | 52 | _load_info = model.load_state_dict(frame_ae, strict=False) 53 | print(_load_info) 54 | 55 | model.train() 56 | model = model.cuda() 57 | 58 | if args.print_model: 59 | logger.info(f"{model}") 60 | 61 | special_layer = ('NONE',) 62 | if args.fr == 0: 63 | for n, p in model.named_parameters(): 64 | for _s in special_layer: 65 | if _s in n: 66 | p.requires_grad_(False) 67 | 68 | logger.info("[Learning Rate] Set requires_grad ...") 69 | param_list = [] 70 | for n, p in model.named_parameters(): 71 | if p.requires_grad: 72 | param_list.append(p) 73 | logger.info(f"{n}: requires_grad={p.requires_grad}") 74 | else: 75 | logger.info(f"{n}: requires_grad={p.requires_grad}") 76 | 77 | for n, m in model.named_modules(): 78 | for _s in special_layer: 79 | if _s in n: 80 | m.eval() 81 | logger.info(f"{n} is set to {'training()' if m.training else 'evel()'} mode.") 82 | else: 83 | logger.info(f"{n} is set to {'training()' if m.training else 'evel()'} mode.") 84 | 85 | elif args.fr == 1: 86 | logger.info("[Learning Rate] All layers have the same `lr` ") 87 | param_list = [] 88 | for n, p in model.named_parameters(): 89 | param_list.append(p) 90 | 91 | elif args.fr > 0: 92 | raise NotImplementedError("special_layer is not processed") 93 | param_list = [{'params': [], 'lr':args.lr}, 94 | {'params': [], 'lr':args.lr * args.fr}] 95 | pname_list = {_p['lr']: [] for _p in param_list} 96 | 97 | logger.info("[Learning Rate] Set finetuning_rate ...") 98 | for n, p in model.named_parameters(): 99 | if p.requires_grad: 100 | _group_idx = 1 if special_layer in n else 0 101 | param_list[_group_idx]['params'].append(p) 102 | pname_list[param_list[_group_idx]['lr']].append(n) 103 | else: 104 | logger.info(f"{n}: requires_grad={p.requires_grad}") 105 | 106 | for _lr, _pn in pname_list.items(): 107 | logger.info(f'[Optimizer] lr={_lr}:') 108 | logger.info(' | '.join(_pn)) 109 | 110 | else: 111 | raise ValueError(f"`args.fr({args.fr})` shoule be >=0") 112 | 113 | criterion_mse = nn.MSELoss().cuda() 114 | criterion_l1 = nn.L1Loss().cuda() 115 | 116 | def loss_fn_cvae(recon_x: torch.Tensor, x: torch.Tensor, mean: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor: 117 | MSE_L1 = criterion_mse(recon_x, x.detach()) + criterion_l1(recon_x, x.detach()) 118 | KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp()) 119 | return (MSE_L1 + KLD) / x.size(0) 120 | 121 | optimizer = torch.optim.Adam(param_list, args.lr) 122 | lr_sch = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.schedule, gamma=0.1) 123 | 124 | _tbxs_witer = None 125 | if not args.debug_mode: 126 | ckpt_save_func = partial(save_checkpoint, is_best=False, filedir=get_ckpt_dir(args.time_stamp, __file__, args.ckpt_root, args.data_name), writer=logger.info) 127 | _tbxs_witer = SummaryWriter(get_ckpt_dir(args.time_stamp, __file__, args.tbxs_root, args.data_name)) 128 | 129 | start_epoch = 1 130 | if args.resume: 131 | if os.path.exists(args.resume) and os.path.isfile(args.resume): 132 | ckpt = torch.load(args.resume) 133 | model.load_state_dict(ckpt['model']) 134 | optimizer.load_state_dict(ckpt['optimizer']) 135 | lr_sch.load_state_dict(ckpt['lr_sch']) 136 | start_epoch = ckpt['epoch'] + 1 137 | else: 138 | raise FileNotFoundError(f"{args.resume}") 139 | else: 140 | if not args.debug_mode: 141 | ckpt_save_func({'epoch': 0, 142 | 'model': model.state_dict(), 143 | 'optimizer': optimizer.state_dict(), 144 | 'lr_sch': lr_sch.state_dict()}, 145 | epoch=0) 146 | 147 | scalar_writer = ScalarWriter(('lr', 'img_mse', 'img_l1', 'loss_v32', 'loss_v64', 'loss_tot'), 148 | _tbxs_witer, (start_epoch - 1) * len(train_loader), start_epoch) 149 | 150 | for epoch in range(start_epoch, args.epochs + 1): 151 | for step, (batch_snp, batch_bgd, batch_lbl) in enumerate(train_loader): 152 | batch_snp: torch.Tensor 153 | batch_bgd: torch.Tensor 154 | batch_lbl: torch.Tensor 155 | batch_snp = batch_snp.cuda(non_blocking=True) # [b, c, t, h, w] 156 | batch_bgd = batch_bgd.cuda(non_blocking=True) 157 | batch_lbl = batch_lbl.cuda(non_blocking=True) 158 | 159 | inp_snp = batch_snp[:, :, :args.snippet_inp] 160 | tgt_snp = batch_snp[:, :, -args.snippet_tgt:] 161 | inp_bgd = batch_bgd 162 | 163 | out_snp, vae32, vae64 = model(inp_snp, inp_bgd) 164 | 165 | loss_v32 = loss_fn_cvae(vae32['rec_x32'], vae32['feat_32'], vae32['mean_32'], vae32['log_var_32']) 166 | loss_v64 = loss_fn_cvae(vae64['rec_x64'], vae64['feat_64'], vae64['mean_64'], vae64['log_var_64']) 167 | 168 | img_mse: torch.Tensor = criterion_mse(out_snp, tgt_snp) 169 | img_l1: torch.Tensor = criterion_l1(out_snp, tgt_snp) 170 | loss_tot: torch.Tensor = img_mse + args.lam_l1 * img_l1 + args.lam_vae * loss_v32 + args.lam_vae * loss_v64 171 | 172 | optimizer.zero_grad() 173 | loss_tot.backward() 174 | optimizer.step() 175 | 176 | _scalar_dict = OrderedDict(lr=lr_sch.get_last_lr()[0], img_mse=img_mse.item(), img_l1=img_l1.item(), loss_v32=loss_v32.item(), loss_v64=loss_v64.item(), loss_tot=loss_tot.item()) 177 | scalar_writer.add_step_value(_scalar_dict, len(tgt_snp)) 178 | 179 | if step % args.print_freq == 0: 180 | logger.info(f"{'(DEBUG) ' if args.debug_mode else ''}Epoch[{epoch}/{args.epochs}] step {step:>4d}/{len(train_loader)}: " + 181 | " ".join([f"{k}={v:.4f}" if not k in ['lr'] else f"{k}={v}" for k, v in _scalar_dict.items()])) 182 | 183 | _epoch_average_dict = scalar_writer.update_epoch_average_value() 184 | logger.info(f"Epoch[{epoch}/{args.epochs}] [ Average ]: " + 185 | " ".join([f"{k}={v:.4f}" if not k in ['lr'] else f"{k}={v}" for k, v in _epoch_average_dict.items()])) 186 | 187 | _last_lr = lr_sch.get_last_lr()[0] 188 | lr_sch.step() 189 | if lr_sch.get_last_lr()[0] < _last_lr: 190 | logger.info(f"[Learning Rate] Decay `lr` from {_last_lr} to {lr_sch.get_last_lr()[0]}") 191 | 192 | if not args.debug_mode: 193 | if epoch % args.save_freq == 0 or epoch == args.epochs: 194 | ckpt_save_func({'epoch': epoch, 195 | 'model': model.state_dict(), 196 | 'optimizer': optimizer.state_dict(), 197 | 'lr_sch': lr_sch.state_dict()}, 198 | epoch=epoch) 199 | if not args.debug_mode: 200 | if isinstance(_tbxs_witer, SummaryWriter): 201 | _tbxs_witer.close() 202 | -------------------------------------------------------------------------------- /main/1-train_cvae.py: -------------------------------------------------------------------------------- 1 | from tensorboardX import SummaryWriter 2 | from model_builder import SceneFrameAE, BackgroundEncoder 3 | from argmanager import train_parser 4 | from misc import get_logger, format_args, get_ckpt_dir, save_checkpoint, ScalarWriter 5 | from dsets import TrainSetTrackingObject 6 | from einops import rearrange 7 | from torch.backends import cudnn 8 | import torch.utils.data 9 | from torch.nn import functional as tf 10 | import torch.nn as nn 11 | import os 12 | from functools import partial 13 | from collections import OrderedDict 14 | 15 | import torch 16 | import random 17 | import numpy as np 18 | 19 | rand_seed = 2022 20 | random.seed(rand_seed) 21 | torch.manual_seed(rand_seed) 22 | torch.cuda.manual_seed(rand_seed) 23 | torch.cuda.manual_seed_all(rand_seed) 24 | np.random.seed(rand_seed) 25 | 26 | 27 | if __name__ == '__main__': 28 | args = train_parser().parse_args() 29 | logger = get_logger(args.time_stamp, '' if args.debug_mode else __file__, args.log_root, args.data_name) 30 | logger.info(format_args(args)) 31 | if args.debug_mode: 32 | logger.info(f"ATTENTION: You are in DEBUG mode. Nothing will be saved!") 33 | 34 | cudnn.benchmark = not args.debug_mode 35 | 36 | n_gpus = torch.cuda.device_count() 37 | 38 | train_dataset = TrainSetTrackingObject(args.video_dir, args.track_dir, args.snippet_inp + args.snippet_tgt, args.snippet_itv, args.iterations, frame_dir=args.frame_dir) 39 | 40 | smpl_weight = train_dataset.vid_samp_weight * train_dataset.iterations 41 | assert len(smpl_weight) == len(train_dataset) 42 | weighted_sampler = torch.utils.data.WeightedRandomSampler(smpl_weight, len(train_dataset), replacement=True) 43 | 44 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=weighted_sampler, 45 | num_workers=args.workers, pin_memory=True, drop_last=False, prefetch_factor=1) 46 | 47 | bgd_encoder = BackgroundEncoder(args.scene_classes) 48 | bgd_encoder.load_state_dict(torch.load(args.bgd_encoder, map_location='cpu')['model']) 49 | frame_ae: OrderedDict = torch.load(args.frame_ae, map_location='cpu')['model'] 50 | 51 | model = SceneFrameAE(inp_frm=args.snippet_inp, tgt_frm=args.snippet_tgt, bgd_encoder=bgd_encoder, lam_cvae=1) 52 | 53 | _load_info = model.load_state_dict(frame_ae, strict=False) 54 | print(_load_info) 55 | 56 | model.train() 57 | model = model.cuda() 58 | 59 | if args.print_model: 60 | logger.info(f"{model}") 61 | 62 | special_layer = ('bgd_encoder', 'frame_encoder', 'frame_decoder') 63 | if args.fr == 0: 64 | for n, p in model.named_parameters(): 65 | for _s in special_layer: 66 | if _s in n: 67 | p.requires_grad_(False) 68 | 69 | logger.info("[Learning Rate] Set requires_grad ...") 70 | param_list = [] 71 | for n, p in model.named_parameters(): 72 | if p.requires_grad: 73 | param_list.append(p) 74 | logger.info(f"{n}: requires_grad={p.requires_grad}") 75 | else: 76 | logger.info(f"{n}: requires_grad={p.requires_grad}") 77 | 78 | for n, m in model.named_modules(): 79 | for _s in special_layer: 80 | if _s in n: 81 | m.eval() 82 | logger.info(f"{n} is set to {'training()' if m.training else 'evel()'} mode.") 83 | else: 84 | logger.info(f"{n} is set to {'training()' if m.training else 'evel()'} mode.") 85 | 86 | elif args.fr == 1: 87 | logger.info("[Learning Rate] All layers have the same `lr` ") 88 | param_list = [] 89 | for n, p in model.named_parameters(): 90 | param_list.append(p) 91 | 92 | elif args.fr > 0: 93 | raise NotImplementedError("special_layer is not processed") 94 | param_list = [{'params': [], 'lr':args.lr}, 95 | {'params': [], 'lr':args.lr * args.fr}] 96 | pname_list = {_p['lr']: [] for _p in param_list} 97 | 98 | logger.info("[Learning Rate] Set finetuning_rate ...") 99 | for n, p in model.named_parameters(): 100 | if p.requires_grad: 101 | _group_idx = 1 if special_layer in n else 0 102 | param_list[_group_idx]['params'].append(p) 103 | pname_list[param_list[_group_idx]['lr']].append(n) 104 | else: 105 | logger.info(f"{n}: requires_grad={p.requires_grad}") 106 | 107 | for _lr, _pn in pname_list.items(): 108 | logger.info(f'[Optimizer] lr={_lr}:') 109 | logger.info(' | '.join(_pn)) 110 | 111 | else: 112 | raise ValueError(f"`args.fr({args.fr})` shoule be >=0") 113 | 114 | criterion_mse = nn.MSELoss().cuda() 115 | criterion_l1 = nn.L1Loss().cuda() 116 | 117 | def loss_fn_cvae(recon_x: torch.Tensor, x: torch.Tensor, mean: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor: 118 | MSE_L1 = criterion_mse(recon_x, x.detach()) + criterion_l1(recon_x, x.detach()) 119 | KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp()) 120 | return (MSE_L1 + KLD) / x.size(0) 121 | 122 | optimizer = torch.optim.Adam(param_list, args.lr) 123 | lr_sch = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.schedule, gamma=0.1) 124 | 125 | _tbxs_witer = None 126 | if not args.debug_mode: 127 | ckpt_save_func = partial(save_checkpoint, is_best=False, filedir=get_ckpt_dir(args.time_stamp, __file__, args.ckpt_root, args.data_name), writer=logger.info) 128 | _tbxs_witer = SummaryWriter(get_ckpt_dir(args.time_stamp, __file__, args.tbxs_root, args.data_name)) 129 | 130 | start_epoch = 1 131 | if args.resume: 132 | if os.path.exists(args.resume) and os.path.isfile(args.resume): 133 | ckpt = torch.load(args.resume) 134 | model.load_state_dict(ckpt['model']) 135 | optimizer.load_state_dict(ckpt['optimizer']) 136 | lr_sch.load_state_dict(ckpt['lr_sch']) 137 | start_epoch = ckpt['epoch'] + 1 138 | else: 139 | raise FileNotFoundError(f"{args.resume}") 140 | else: 141 | if not args.debug_mode: 142 | ckpt_save_func({'epoch': 0, 143 | 'model': model.state_dict(), 144 | 'optimizer': optimizer.state_dict(), 145 | 'lr_sch': lr_sch.state_dict()}, 146 | epoch=0) 147 | 148 | scalar_writer = ScalarWriter(('lr', 'img_mse', 'img_l1', 'loss_v32', 'loss_v64', 'loss_tot'), 149 | _tbxs_witer, (start_epoch - 1) * len(train_loader), start_epoch) 150 | 151 | for epoch in range(start_epoch, args.epochs + 1): 152 | for step, (batch_snp, batch_bgd, batch_lbl) in enumerate(train_loader): 153 | batch_snp: torch.Tensor 154 | batch_bgd: torch.Tensor 155 | batch_lbl: torch.Tensor 156 | batch_snp = batch_snp.cuda(non_blocking=True) # [b, c, t, h, w] 157 | batch_bgd = batch_bgd.cuda(non_blocking=True) # [b, c, t, h, w] 158 | batch_lbl = batch_lbl.cuda(non_blocking=True) # [b, c, t, h, w] 159 | 160 | inp_snp = batch_snp[:, :, :args.snippet_inp] 161 | tgt_snp = batch_snp[:, :, -args.snippet_tgt:] 162 | inp_bgd = batch_bgd 163 | 164 | out_snp, vae32, vae64 = model(inp_snp, inp_bgd) 165 | 166 | loss_v32 = loss_fn_cvae(vae32['rec_x32'], vae32['feat_32'], vae32['mean_32'], vae32['log_var_32']) 167 | loss_v64 = loss_fn_cvae(vae64['rec_x64'], vae64['feat_64'], vae64['mean_64'], vae64['log_var_64']) 168 | 169 | img_mse: torch.Tensor = criterion_mse(out_snp, tgt_snp) 170 | img_l1: torch.Tensor = criterion_l1(out_snp, tgt_snp) 171 | loss_tot: torch.Tensor = img_mse + args.lam_l1 * img_l1 + args.lam_vae * loss_v32 + args.lam_vae * loss_v64 172 | 173 | optimizer.zero_grad() 174 | loss_tot.backward() 175 | optimizer.step() 176 | 177 | _scalar_dict = OrderedDict(lr=lr_sch.get_last_lr()[0], img_mse=img_mse.item(), img_l1=img_l1.item(), loss_v32=loss_v32.item(), loss_v64=loss_v64.item(), loss_tot=loss_tot.item()) 178 | scalar_writer.add_step_value(_scalar_dict, len(tgt_snp)) 179 | 180 | if step % args.print_freq == 0: 181 | logger.info(f"{'(DEBUG) ' if args.debug_mode else ''}Epoch[{epoch}/{args.epochs}] step {step:>4d}/{len(train_loader)}: " + 182 | " ".join([f"{k}={v:.4f}" if not k in ['lr'] else f"{k}={v}" for k, v in _scalar_dict.items()])) 183 | 184 | _epoch_average_dict = scalar_writer.update_epoch_average_value() 185 | logger.info(f"Epoch[{epoch}/{args.epochs}] [ Average ]: " + 186 | " ".join([f"{k}={v:.4f}" if not k in ['lr'] else f"{k}={v}" for k, v in _epoch_average_dict.items()])) 187 | 188 | _last_lr = lr_sch.get_last_lr()[0] 189 | lr_sch.step() 190 | if lr_sch.get_last_lr()[0] < _last_lr: 191 | logger.info(f"[Learning Rate] Decay `lr` from {_last_lr} to {lr_sch.get_last_lr()[0]}") 192 | 193 | if not args.debug_mode: 194 | if epoch % args.save_freq == 0 or epoch == args.epochs: 195 | ckpt_save_func({'epoch': epoch, 196 | 'model': model.state_dict(), 197 | 'optimizer': optimizer.state_dict(), 198 | 'lr_sch': lr_sch.state_dict()}, 199 | epoch=epoch) 200 | if not args.debug_mode: 201 | if isinstance(_tbxs_witer, SummaryWriter): 202 | _tbxs_witer.close() 203 | -------------------------------------------------------------------------------- /main/1-train_bid.py: -------------------------------------------------------------------------------- 1 | from tensorboardX import SummaryWriter 2 | from model_builder import BidirectionalFrameAE 3 | from argmanager import train_parser 4 | from misc import get_logger, format_args, get_ckpt_dir, save_checkpoint, ScalarWriter 5 | from dsets import TrainSetTrackingObject 6 | from torch.backends import cudnn 7 | import torch.utils.data 8 | import torch.nn as nn 9 | import os 10 | from functools import partial 11 | from collections import OrderedDict 12 | 13 | import torch 14 | import random 15 | import numpy as np 16 | 17 | rand_seed = 2022 18 | random.seed(rand_seed) 19 | torch.manual_seed(rand_seed) 20 | torch.cuda.manual_seed(rand_seed) 21 | torch.cuda.manual_seed_all(rand_seed) 22 | np.random.seed(rand_seed) 23 | 24 | 25 | if __name__ == '__main__': 26 | args = train_parser().parse_args() 27 | logger = get_logger(args.time_stamp, '' if args.debug_mode else __file__, args.log_root, args.data_name) 28 | logger.info(format_args(args)) 29 | if args.debug_mode: 30 | logger.info(f"ATTENTION: You are in DEBUG mode. Nothing will be saved!") 31 | 32 | cudnn.benchmark = not args.debug_mode 33 | 34 | n_gpus = torch.cuda.device_count() 35 | 36 | snippet_cur = 1 37 | train_dataset = TrainSetTrackingObject(args.video_dir, args.track_dir, args.snippet_inp + args.snippet_tgt, args.snippet_itv, args.iterations, frame_dir=args.frame_dir) 38 | 39 | smpl_weight = train_dataset.vid_samp_weight * train_dataset.iterations 40 | assert len(smpl_weight) == len(train_dataset) 41 | weighted_sampler = torch.utils.data.WeightedRandomSampler(smpl_weight, len(train_dataset), replacement=True) 42 | 43 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=weighted_sampler, 44 | num_workers=args.workers, pin_memory=True, drop_last=False, prefetch_factor=1) 45 | 46 | model = BidirectionalFrameAE(inp_frm=args.snippet_inp, tgt_frm=args.snippet_tgt, scene_classes=args.scene_classes) 47 | pre_state = torch.load(args.pre_model, map_location='cpu') 48 | 49 | for n, p in model.named_parameters(): 50 | if n.startswith('f_frameAE'): 51 | pre_w: torch.Tensor = pre_state['model'][n.replace('f_frameAE.', '')].data 52 | if p.data.shape != pre_w.shape: 53 | assert 'rec_conv.3' in n 54 | p.data[:3, ...] = pre_w[:, ...] 55 | else: 56 | p.data = pre_w 57 | else: 58 | pre_w: torch.Tensor = pre_state['model'][n.replace('b_frameAE.', '')].data 59 | assert p.data.shape == pre_w.shape, f"{n}, {p.data.shape}, {pre_w.shape}" 60 | p.data = pre_w 61 | 62 | # for n, p in model.named_parameters(): 63 | # if 'bn' in n: 64 | # print(n, p) 65 | 66 | model.train() 67 | model = model.cuda() 68 | 69 | if args.print_model: 70 | logger.info(f"{model}") 71 | 72 | # special_layer = ('bgd_encoder', 'f_frameAE', 'b_frameAE') 73 | # special_layer = ('bgd_encoder', 'f_frameAE') 74 | special_layer = ('bgd_encoder',) 75 | # exclude_layer = 'rec_conv.3' 76 | exclude_layer = 'rec_conv' 77 | if args.fr == 0: 78 | for n, p in model.named_parameters(): 79 | for _s in special_layer: 80 | if _s in n and not exclude_layer in n: 81 | p.requires_grad_(False) 82 | 83 | logger.info("[Learning Rate] Set requires_grad ...") 84 | param_list = [] 85 | for n, p in model.named_parameters(): 86 | if p.requires_grad: 87 | param_list.append(p) 88 | logger.info(f"{n}: requires_grad={p.requires_grad}") 89 | # else: 90 | # logger.info(f"{n}: requires_grad={p.requires_grad}") 91 | 92 | model.train() 93 | for n, m in model.named_modules(): 94 | for _s in special_layer: 95 | if _s in n and not exclude_layer in n: 96 | m.eval() 97 | for n, m in model.named_modules(): 98 | if m.training: 99 | logger.info(f"{n} is set to {'training()' if m.training else 'evel()'} mode.") 100 | 101 | elif args.fr == 1: 102 | logger.info("[Learning Rate] All layers have the same `lr` ") 103 | param_list = [] 104 | for n, p in model.named_parameters(): 105 | param_list.append(p) 106 | 107 | elif args.fr > 0: 108 | raise NotImplementedError("special_layer is not processed") 109 | param_list = [{'params': [], 'lr':args.lr}, 110 | {'params': [], 'lr':args.lr * args.fr}] 111 | pname_list = {_p['lr']: [] for _p in param_list} 112 | 113 | logger.info("[Learning Rate] Set finetuning_rate ...") 114 | for n, p in model.named_parameters(): 115 | if p.requires_grad: 116 | _group_idx = 1 if special_layer in n else 0 117 | param_list[_group_idx]['params'].append(p) 118 | pname_list[param_list[_group_idx]['lr']].append(n) 119 | else: 120 | logger.info(f"{n}: requires_grad={p.requires_grad}") 121 | 122 | for _lr, _pn in pname_list.items(): 123 | logger.info(f'[Optimizer] lr={_lr}:') 124 | logger.info(' | '.join(_pn)) 125 | 126 | else: 127 | raise ValueError(f"`args.fr({args.fr})` shoule be >=0") 128 | 129 | criterion_mse = nn.MSELoss().cuda() 130 | criterion_l1 = nn.L1Loss().cuda() 131 | 132 | optimizer_f = torch.optim.Adam(filter(lambda p: p.requires_grad, model.f_frameAE.parameters()), args.lr) 133 | optimizer_b = torch.optim.Adam(filter(lambda p: p.requires_grad, model.b_frameAE.parameters()), args.lr) 134 | lr_sch_f = torch.optim.lr_scheduler.MultiStepLR(optimizer_f, args.schedule, gamma=0.1) 135 | lr_sch_b = torch.optim.lr_scheduler.MultiStepLR(optimizer_b, args.schedule, gamma=0.1) 136 | 137 | _tbxs_witer = None 138 | if not args.debug_mode: 139 | ckpt_save_func = partial(save_checkpoint, is_best=False, filedir=get_ckpt_dir(args.time_stamp, __file__, args.ckpt_root, args.data_name), writer=logger.info) 140 | _tbxs_witer = SummaryWriter(get_ckpt_dir(args.time_stamp, __file__, args.tbxs_root, args.data_name)) 141 | 142 | start_epoch = 1 143 | if args.resume: 144 | if os.path.exists(args.resume) and os.path.isfile(args.resume): 145 | ckpt = torch.load(args.resume) 146 | model.load_state_dict(ckpt['model']) 147 | optimizer_f.load_state_dict(ckpt['optimizer_f']) 148 | optimizer_b.load_state_dict(ckpt['optimizer_b']) 149 | lr_sch_f.load_state_dict(ckpt['lr_sch_f']) 150 | lr_sch_b.load_state_dict(ckpt['lr_sch_b']) 151 | start_epoch = ckpt['epoch'] + 1 152 | else: 153 | raise FileNotFoundError(f"{args.resume}") 154 | else: 155 | if not args.debug_mode: 156 | ckpt_save_func({'epoch': 0, 157 | 'model': model.state_dict(), 158 | 'optimizer_f': optimizer_f.state_dict(), 159 | 'optimizer_b': optimizer_b.state_dict(), 160 | 'lr_sch_f': lr_sch_f.state_dict(), 161 | 'lr_sch_b': lr_sch_b.state_dict()}, 162 | epoch=0) 163 | 164 | def loss_fn_cvae(recon_x: torch.Tensor, x: torch.Tensor, mean: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor: 165 | MSE_L1 = criterion_mse(recon_x, x.detach()) + criterion_l1(recon_x, x.detach()) 166 | KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp()) 167 | return (MSE_L1 + KLD) / x.size(0) 168 | 169 | scalar_writer = ScalarWriter(('lr', 'f_loss', 'b_loss'), 170 | _tbxs_witer, (start_epoch - 1) * len(train_loader), start_epoch) 171 | 172 | for epoch in range(start_epoch, args.epochs + 1): 173 | for step, (batch_snp, batch_bgd, batch_lbl) in enumerate(train_loader): 174 | batch_snp: torch.Tensor = batch_snp.cuda(non_blocking=True) # [b, c, t, h, w] 175 | batch_bgd: torch.Tensor = batch_bgd.cuda(non_blocking=True) 176 | batch_lbl: torch.Tensor = batch_lbl.cuda(non_blocking=True) 177 | 178 | f_inp_snp = batch_snp[:, :, :args.snippet_inp] 179 | f_tgt_snp = batch_snp[:, :, - args.snippet_tgt:] 180 | f_out_snp, f_vae32, f_vae64 = model.f_frameAE(f_inp_snp, batch_bgd) # [b, c, tgt_frm, h, w] 181 | # f_out_snp = model.f_frameAE(f_inp_snp, batch_bgd) # [b, c, tgt_frm, h, w] 182 | 183 | f_loss: torch.Tensor = criterion_mse(f_out_snp, f_tgt_snp) + args.lam_l1 * criterion_l1(f_out_snp, f_tgt_snp) 184 | 185 | optimizer_f.zero_grad() 186 | f_loss.backward() 187 | optimizer_f.step() 188 | 189 | # Only randomly select one backward snippet for training 190 | _i_snp = random.randrange(0, args.snippet_tgt) 191 | 192 | b_all_snp_true = torch.flip(batch_snp[:, :, _i_snp: _i_snp + args.snippet_inp + snippet_cur], [2]) 193 | b_inp_snp_true = b_all_snp_true[:, :, :args.snippet_inp] 194 | b_tgt_snp_true = b_all_snp_true[:, :, args.snippet_inp: args.snippet_inp + snippet_cur] 195 | 196 | b_all_snp_fcst = torch.flip(torch.cat([f_inp_snp, f_out_snp.detach()], 2)[:, :, _i_snp: _i_snp + args.snippet_inp + snippet_cur], [2]) 197 | b_inp_snp_fcst = b_all_snp_fcst[:, :, :args.snippet_inp] 198 | b_tgt_snp_fcst = b_all_snp_fcst[:, :, args.snippet_inp: args.snippet_inp + snippet_cur] 199 | 200 | b_inp_snp = torch.cat([b_inp_snp_true, b_inp_snp_fcst], 0) 201 | b_tgt_snp = torch.cat([b_tgt_snp_true, b_tgt_snp_fcst], 0) 202 | b_batch_bgd = torch.cat([batch_bgd, batch_bgd], 0) 203 | b_out_snp, b_vae32, b_vae64 = model.b_frameAE(b_inp_snp, b_batch_bgd) 204 | # b_out_snp = model.b_frameAE(b_inp_snp, batch_bgd) 205 | 206 | b_loss: torch.Tensor = criterion_mse(b_out_snp, b_tgt_snp) + args.lam_l1 * criterion_l1(b_out_snp, b_tgt_snp) 207 | 208 | optimizer_b.zero_grad() 209 | b_loss.backward() 210 | optimizer_b.step() 211 | 212 | _scalar_dict = OrderedDict(lr=lr_sch_f.get_last_lr()[0], f_loss=f_loss.item(), b_loss=b_loss.item()) 213 | scalar_writer.add_step_value(_scalar_dict, len(batch_snp)) 214 | 215 | if step % args.print_freq == 0: 216 | logger.info(f"{'(DEBUG) ' if args.debug_mode else ''}Epoch[{epoch}/{args.epochs}] step {step:>4d}/{len(train_loader)}: " + 217 | " ".join([f"{k}={v:.4f}" if not k in ['lr'] else f"{k}={v}" for k, v in _scalar_dict.items()])) 218 | 219 | _epoch_average_dict = scalar_writer.update_epoch_average_value() 220 | logger.info(f"Epoch[{epoch}/{args.epochs}] [ Average ]: " + 221 | " ".join([f"{k}={v:.4f}" if not k in ['lr'] else f"{k}={v}" for k, v in _epoch_average_dict.items()])) 222 | 223 | _last_lr = lr_sch_f.get_last_lr()[0] 224 | lr_sch_f.step() 225 | lr_sch_b.step() 226 | if lr_sch_f.get_last_lr()[0] < _last_lr: 227 | logger.info(f"[Learning Rate] Decay `lr` from {_last_lr} to {lr_sch_f.get_last_lr()[0]}") 228 | 229 | if not args.debug_mode: 230 | if epoch % args.save_freq == 0 or epoch == args.epochs: 231 | ckpt_save_func({'epoch': epoch, 232 | 'model': model.state_dict(), 233 | 'optimizer_f': optimizer_f.state_dict(), 234 | 'optimizer_b': optimizer_b.state_dict(), 235 | 'lr_sch_f': lr_sch_f.state_dict(), 236 | 'lr_sch_b': lr_sch_b.state_dict()}, 237 | epoch=epoch) 238 | if not args.debug_mode: 239 | if isinstance(_tbxs_witer, SummaryWriter): 240 | _tbxs_witer.close() 241 | -------------------------------------------------------------------------------- /main/2-test_ft.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join, exists, isfile 3 | import time 4 | from time import time as ttime 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import functools 8 | import tqdm 9 | import random 10 | from collections import OrderedDict 11 | from typing import Tuple 12 | 13 | import torch 14 | import torch.cuda 15 | import torch.multiprocessing as mp 16 | from torch import nn 17 | from torch.backends import cudnn 18 | 19 | from model_builder import SceneFrameAE, BackgroundEncoder 20 | from dsets import TestSetTrackingObject 21 | from misc import get_logger, get_result_dir, format_args 22 | from metrics import cal_micro_auc 23 | from argmanager import test_parser 24 | 25 | 26 | def load_model(args): 27 | bgd_encoder = BackgroundEncoder(args.scene_classes) 28 | bgd_encoder.load_state_dict(torch.load(args.bgd_encoder, map_location='cpu')['model']) 29 | 30 | model = SceneFrameAE(inp_frm=args.snippet_inp, tgt_frm=args.snippet_tgt, bgd_encoder=bgd_encoder, lam_cvae=1.) 31 | if args.resume: 32 | if isfile(args.resume): 33 | print("Loading checkpoint '{}'".format(args.resume)) 34 | 35 | checkpoint = torch.load(args.resume, map_location='cpu') 36 | 37 | new_state_dict = OrderedDict() 38 | _prefix = 'module.' 39 | for k, v in checkpoint['model'].items(): 40 | if k.startswith(_prefix): 41 | new_state_dict[k[len(_prefix):]] = v 42 | else: 43 | new_state_dict[k] = v 44 | 45 | model.load_state_dict(new_state_dict) 46 | 47 | else: 48 | raise FileNotFoundError("No checkpoint found at '{}'".format(args.resume)) 49 | else: 50 | raise NotImplementedError("A checkpoint should be loaded.") 51 | 52 | model.eval() 53 | return model 54 | 55 | 56 | def patch_error(img_true: torch.Tensor, img_test: torch.Tensor, patch_func_list: nn.ModuleList, lam_l1: float, use_channel_l2: bool): 57 | # [C,H,W] 58 | assert img_true.ndim == img_test.ndim == 3, f"{img_true.shape}, {img_test.shape}" 59 | assert img_true.shape == img_test.shape 60 | assert img_true.shape[0] == img_test.shape[0] == 3 61 | 62 | if use_channel_l2: 63 | diff_mse = torch.square(img_true - img_test).sum(0, True).sqrt().div(img_true.shape[0]) 64 | else: 65 | diff_mse = torch.square(img_true - img_test).mean(0, True) 66 | diff_l1 = torch.abs(img_true - img_test).mean(0, True) 67 | 68 | patch_score_list = [] 69 | for _patch_func in patch_func_list: 70 | _patch_err_mse: torch.Tensor = _patch_func(diff_mse) 71 | _patch_err_l1: torch.Tensor = _patch_func(diff_l1) 72 | _patch_score = _patch_err_mse.amax() + lam_l1 * _patch_err_l1.amax() 73 | patch_score_list.append(_patch_score) 74 | return patch_score_list 75 | 76 | 77 | def frame_error(img_true: torch.Tensor, img_test: torch.Tensor, lam_l1: float, use_channel_l2: bool): 78 | assert img_true.ndim == img_test.ndim == 3, f"{img_true.shape}, {img_test.shape}" 79 | assert img_true.shape == img_test.shape 80 | assert img_true.shape[0] == img_test.shape[0] == 3 81 | if not (img_true.shape[1] > 0 and img_true.shape[2] > 0): 82 | print(f"\nError in `frame_error()` function: {img_true.shape}. It will be given a high value (2.0)") 83 | return torch.as_tensor(2.0).cpu() 84 | 85 | if use_channel_l2: 86 | diff_mse = torch.square(img_true - img_test).sum(0, True).sqrt().div(img_true.shape[0]) 87 | else: 88 | diff_mse = torch.square(img_true - img_test).mean(0, True) 89 | diff_l1 = torch.abs(img_true - img_test).mean(0, True) 90 | 91 | frame_score = diff_mse.mean() + lam_l1 * diff_l1.mean() 92 | return frame_score.cpu() 93 | 94 | 95 | def cal_anomaly_score(i_proc: int, proc_cnt: int, score_queue: mp.Queue, args): 96 | ''' 97 | Calculate anomaly scores 98 | ''' 99 | gpu_id = i_proc % torch.cuda.device_count() 100 | 101 | test_dataset = TestSetTrackingObject(args.video_dir, args.track_dir, args.snippet_inp + args.snippet_tgt, args.snippet_itv, 102 | device=f'cuda:{gpu_id}' if args.to_gpu else 'cpu', frame_dir=args.frame_dir) 103 | num_video = len(test_dataset) 104 | 105 | model = load_model(args) 106 | model.cuda(gpu_id) 107 | 108 | if args.print_model and i_proc == 0: 109 | print(model) 110 | 111 | fuse_func = torch.mean if args.crop_fuse_type == 'mean' else torch.amax 112 | 113 | if args.error_type == 'patch': 114 | _avg_pool_list = nn.ModuleList() 115 | for _patch_size in args.patch_size: 116 | assert 0 < _patch_size <= 256, f"{_patch_size}" 117 | _avg_pool_list.append(nn.AvgPool2d(_patch_size, args.patch_stride)) 118 | snp_error_func = functools.partial(patch_error, patch_func_list=_avg_pool_list, lam_l1=args.lam_l1, use_channel_l2=args.use_channel_l2) 119 | elif args.error_type == 'frame': 120 | snp_error_func = functools.partial(frame_error, lam_l1=args.lam_l1, use_channel_l2=args.use_channel_l2) 121 | else: 122 | raise NameError(f"ERROR args.error_type: {args.error_type}") 123 | 124 | if not args.debug_mode: 125 | if not exists(args.tmp_score_dir): 126 | time.sleep(i_proc) 127 | if not exists(args.tmp_score_dir): 128 | os.makedirs(args.tmp_score_dir) 129 | 130 | if exists(args.tmp_score_dir): 131 | if len(os.listdir(args.tmp_score_dir)) == num_video: 132 | print("ATTENTION: The temp_dir is full. Check it and ensure the old dir has been emptied.") 133 | 134 | for vid_idx in range(i_proc, num_video, proc_cnt): 135 | vid_name = list(test_dataset.all_trk_dict.keys())[vid_idx] 136 | tmp_score_path = join(args.tmp_score_dir, f"{vid_name}.npy") 137 | score_dict = {} 138 | 139 | if exists(tmp_score_path): 140 | vid_scores = np.load(tmp_score_path) 141 | else: 142 | vid_stream = test_dataset[vid_idx] 143 | assert vid_stream.vid_name == vid_name 144 | 145 | if args.error_type == 'patch': 146 | vid_scores: np.ndarray = np.zeros([len(vid_stream), len(args.patch_size)]) 147 | elif args.error_type == 'frame': 148 | vid_scores: np.ndarray = np.zeros([len(vid_stream), 1]) 149 | else: 150 | raise NameError(f"ERROR args.error_type: {args.error_type}") 151 | 152 | tbars = functools.partial(tqdm.tqdm, desc=f"{vid_stream.vid_name}({vid_idx+1:>{len(str(num_video))}}/{num_video})", total=len(vid_stream), 153 | ncols=120, disable=False, unit='frame', position=i_proc, colour=random.choice(['green', 'blue', 'red', 'yellow', 'magenta', 'cyan', 'white'])) 154 | 155 | for _snippet_idx in tbars(range(len(vid_stream))): 156 | batch_snp, background = vid_stream[_snippet_idx] 157 | batch_snp: torch.Tensor 158 | background: torch.Tensor 159 | if not batch_snp is None: 160 | if batch_snp.device.type == 'cpu': 161 | batch_snp = batch_snp.cuda(gpu_id) # [b, c, t, h, w] 162 | if background.device.type == 'cpu': 163 | background = batch_snp.cuda(gpu_id) # [b, c, h, w] 164 | 165 | inp_snp = batch_snp[:, :, :args.snippet_inp] 166 | tgt_snp = batch_snp[:, :, -args.snippet_tgt:] 167 | 168 | with torch.no_grad(): 169 | inp_bgd = background.repeat(batch_snp.shape[0], 1, 1, 1) 170 | out_snp, vae32, vae64 = model(inp_snp, inp_bgd) 171 | out_snp: torch.Tensor 172 | 173 | out_snp.squeeze_(2) 174 | tgt_snp.squeeze_(2) 175 | 176 | _obj_score = torch.as_tensor([snp_error_func(out_snp[i_obj], tgt_snp[i_obj]) for i_obj in range(len(batch_snp))]) 177 | vid_scores[_snippet_idx] = fuse_func(_obj_score, dim=0).cpu().numpy() 178 | else: 179 | pass 180 | 181 | if not args.debug_mode: 182 | np.save(tmp_score_path, vid_scores) 183 | 184 | score_dict[vid_name] = vid_scores 185 | 186 | assert not score_queue.full() 187 | score_queue.put(score_dict) 188 | 189 | 190 | if __name__ == '__main__': 191 | args = test_parser().parse_args() 192 | 193 | if not args.tmp_score_dir: 194 | args.tmp_score_dir = f"./{args.data_name}" 195 | res_stamp = f"{args.data_name}_{args.time_stamp}" 196 | 197 | logger = get_logger(args.time_stamp, '' if args.debug_mode else __file__, args.log_root, args.data_name) 198 | logger.info(format_args(args)) 199 | if args.debug_mode: 200 | logger.info(f"ATTENTION: You are in DEBUG mode. Nothing will be saved!") 201 | 202 | cudnn.benchmark = not args.debug_mode 203 | torch.set_num_threads(args.threads) 204 | 205 | t0 = ttime() 206 | 207 | gt_npz: dict = np.load(args.gtnpz_path) 208 | 209 | if args.score_dict_path: 210 | logger.info(f"Using the score_dict from '{args.score_dict_path}'") 211 | assert exists(args.score_dict_path), f"{args.score_dict_path}" 212 | score_dict = np.load(args.score_dict_path) 213 | else: 214 | epoch = torch.load(args.resume, map_location='cpu')['epoch'] 215 | args.tmp_score_dir += f"_{epoch}" 216 | 217 | logger.info(f"Testing epoch [{epoch}] ...") 218 | len_dataset = len(TestSetTrackingObject(args.video_dir, args.track_dir, args.snippet_inp + args.snippet_tgt, args.snippet_itv)) 219 | score_queue = mp.Manager().Queue(maxsize=len_dataset) 220 | 221 | mp.spawn(cal_anomaly_score, args=(args.workers, score_queue, args), nprocs=args.workers) 222 | 223 | assert score_queue.full() 224 | score_dict = {} 225 | while not score_queue.empty(): 226 | score_dict.update(score_queue.get()) 227 | assert len(score_dict) == len_dataset 228 | 229 | # Save scores 230 | if not args.debug_mode: 231 | score_dict_path = join(get_result_dir(args.result_root), f"{res_stamp}_score_dict_{epoch}.npz") 232 | np.savez(score_dict_path, **score_dict) 233 | logger.info(f"Saved score_dict to: {score_dict_path}") 234 | [os.remove(join(args.tmp_score_dir, f)) for f in os.listdir(args.tmp_score_dir)] 235 | os.removedirs(args.tmp_score_dir) 236 | logger.info(f"The tmp_score_dir {args.tmp_score_dir} is removed.") 237 | 238 | # Calculate AUC 239 | origin_score_dict = OrderedDict() 240 | smooth_score_dict = OrderedDict() 241 | vid_macro_auc_dict = OrderedDict() 242 | default_ps = 256 243 | error_level_list = args.patch_size if args.error_type == 'patch' else (default_ps,) 244 | for _i_patch, _patch_size in enumerate(error_level_list): 245 | _p_score_dict = {} 246 | for _vid_name, _vid_score in score_dict.items(): 247 | _p_score = _vid_score[:, _i_patch] 248 | 249 | if np.any(np.isnan(_p_score)): 250 | _p_score[np.isnan(_p_score)] = 2. 251 | assert not np.any(np.isnan(_p_score)) 252 | _p_score_dict[_vid_name] = _p_score[1:] if args.ignore_first_frame_score else _p_score[:] 253 | 254 | _snippet_len = args.snippet_inp + args.snippet_tgt 255 | micro_auc = cal_micro_auc(_p_score_dict, gt_npz, _snippet_len, args.snippet_itv, score_post_process=args.score_post_process) 256 | origin_score_dict[_patch_size] = _p_score_dict 257 | logger.info(f"Patch_size {_patch_size:>3}: Micro-AUC = {micro_auc:.2%}") 258 | 259 | if not args.debug_mode: 260 | origin_score_path = join(get_result_dir(args.visual_root), f"{res_stamp}_origin_score.pkl") 261 | with open(origin_score_path, 'wb') as f: 262 | np.savez(f, origin_score_dict) 263 | logger.info(f"Saved origin_score to: {origin_score_path}") 264 | smooth_score_path = join(get_result_dir(args.visual_root), f"{res_stamp}_smooth_score.pkl") 265 | with open(smooth_score_path, 'wb') as f: 266 | np.savez(f, smooth_score_dict) 267 | logger.info(f"Saved smooth_score to: {smooth_score_path}") 268 | 269 | if not args.debug_mode: 270 | score_curve_dir = join(get_result_dir(args.visual_root), f"{res_stamp}_score_curve") 271 | if not exists(score_curve_dir): 272 | os.mkdir(score_curve_dir) 273 | 274 | for vid_name in tqdm.tqdm(sorted(score_dict.keys()), total=len(score_dict), desc="Generating score curves"): 275 | dst_curve_path = join(score_curve_dir, f"{vid_name}.png") 276 | gt_ary = gt_npz[vid_name] 277 | frm_idx = np.arange(len(gt_ary)) 278 | ori_score = origin_score_dict[default_ps][vid_name] 279 | smo_score = smooth_score_dict[default_ps][vid_name] 280 | plt.figure(figsize=(12, 7), dpi=300) 281 | plt.plot(frm_idx, gt_ary, 'r') 282 | plt.plot(frm_idx, ori_score, 'b') 283 | plt.plot(frm_idx, smo_score, 'g') 284 | plt.xticks(frm_idx[:: max(int(round(len(frm_idx) // 25, -1)), 10)]) 285 | plt.title(f"{vid_name} AUC={vid_macro_auc_dict[default_ps][vid_name]:.2%}") 286 | plt.legend(['GT', 'Ori', 'Smo']) 287 | plt.savefig(dst_curve_path) 288 | plt.close() 289 | 290 | t1 = ttime() 291 | logger.info(f"Time={(t1-t0)/60:.1f} min") 292 | -------------------------------------------------------------------------------- /main/model_builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch import nn 4 | from torch.nn import init 5 | from torch.nn import functional as tf 6 | from torchvision.transforms import functional as tvf 7 | from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet, load_state_dict_from_url, model_urls 8 | from typing import Type, Any, Callable, Union, List, Optional, Tuple 9 | from functools import partial 10 | from einops import rearrange 11 | 12 | from torchvision.models.resnet import BasicBlock, Bottleneck 13 | from module_cvae import CVAE 14 | 15 | feat_dim_dict = {'basic': (64, 64, 128, 256, 512), 16 | 'bottleneck': (64, 256, 512, 1024, 2048)} 17 | feat_size_list = (128, 64, 32, 16, 8) 18 | 19 | 20 | class BasicBlockUpsample(BasicBlock): 21 | ''' 22 | Here `downsample` is actually an upsample 23 | ''' 24 | 25 | def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, relu_type=nn.LeakyReLU, conv1: nn.Module = None) -> None: 26 | super().__init__(inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer) 27 | if conv1: 28 | if issubclass(conv1, nn.ConvTranspose2d): 29 | self.conv1 = conv1(inplanes, planes, kernel_size=stride, stride=stride, bias=False) 30 | else: 31 | raise TypeError(type(conv1)) 32 | if downsample: 33 | if issubclass(conv1, nn.ConvTranspose2d): 34 | self.downsample = downsample(inplanes, planes, kernel_size=stride, stride=stride, bias=False) 35 | else: 36 | raise TypeError(type(conv1)) 37 | 38 | self.relu = relu_type() 39 | 40 | 41 | class BottleneckUpsample(Bottleneck): 42 | ''' 43 | Here `downsample` is actually an upsample 44 | ''' 45 | expansion: int = 1 46 | 47 | def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, relu_type=nn.LeakyReLU, conv1: nn.Module = None) -> None: 48 | super().__init__(inplanes, planes, 1, downsample, groups, base_width, dilation, norm_layer) 49 | if conv1: 50 | if issubclass(conv1, nn.ConvTranspose2d): 51 | self.conv1 = conv1(inplanes, planes, kernel_size=stride, stride=stride, bias=False) 52 | else: 53 | raise TypeError(type(conv1)) 54 | if downsample: 55 | if issubclass(conv1, nn.ConvTranspose2d): 56 | self.downsample = downsample(inplanes, planes, kernel_size=stride, stride=stride, bias=False) 57 | else: 58 | raise TypeError(type(conv1)) 59 | 60 | self.relu = relu_type() 61 | 62 | 63 | class FrameDecoder(nn.Module): 64 | def __init__(self, encoder_block_type: str, decoder_block_type: str, num_frm: int = 1): 65 | super().__init__() 66 | 67 | encoder_block_type = encoder_block_type.lower() 68 | if encoder_block_type == 'basic': 69 | _ch_div = 1 70 | elif encoder_block_type == 'bottleneck': 71 | _ch_div = 4 72 | else: 73 | raise NotImplementedError(f"{encoder_block_type}") 74 | feat_dim = list(feat_dim_dict[encoder_block_type]) 75 | feat_dim.reverse() 76 | 77 | decoder_block_type = decoder_block_type.lower() 78 | if decoder_block_type == 'basic': 79 | block_type = BasicBlockUpsample 80 | elif decoder_block_type == 'bottleneck': 81 | block_type = BottleneckUpsample 82 | else: 83 | raise NotImplementedError(f"{decoder_block_type}") 84 | 85 | upconv = nn.ConvTranspose2d 86 | relu_type = nn.LeakyReLU 87 | norm_layer = nn.BatchNorm2d 88 | 89 | up_block = partial(block_type, stride=2, conv1=upconv, downsample=upconv, relu_type=relu_type, norm_layer=norm_layer) 90 | 91 | # [1, 512+512, 32, 32] 92 | self.uconv3 = nn.Sequential(up_block(feat_dim[2], c := feat_dim[2] // 2), BasicBlock(c, c)) # [1, 256, 64, 64] 93 | 94 | # [1, 256+256, 64, 64] 95 | self.uconv4 = nn.Sequential(up_block(feat_dim[3] * 2, c := feat_dim[3] // _ch_div), BasicBlock(c, c)) # [1, 64, 128, 128] 96 | 97 | # [1, 64+64, 128, 128] 98 | self.uconv5 = nn.Sequential(up_block(feat_dim[4], c := feat_dim[4] // 2), BasicBlock(c, c)) # [1, 32, 256, 256] 99 | 100 | # [1, 32, 256, 256] 101 | _rec_ch = feat_dim[4] // 2 102 | 103 | self.num_frm = num_frm 104 | self.rec_conv = nn.Sequential( 105 | nn.Conv2d(_rec_ch, _rec_ch, 3, 1, 1, bias=False), 106 | norm_layer(_rec_ch), 107 | relu_type(), 108 | nn.Conv2d(_rec_ch, 3 * self.num_frm, 3, 1, 1) 109 | ) # [1, 3, 256, 256] 110 | 111 | self.init_params() 112 | 113 | def init_params(self, init_type='kaiming'): 114 | init_gain = 1.0 115 | 116 | def init_func(m): # define the initialization function 117 | classname = m.__class__.__name__ 118 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 119 | if init_type == 'normal': 120 | init.normal_(m.weight.data, 0.0, init_gain) 121 | elif init_type == 'xavier': 122 | init.xavier_normal_(m.weight.data, gain=init_gain) 123 | elif init_type == 'kaiming': 124 | init.kaiming_normal_(m.weight.data) 125 | elif init_type == 'orthogonal': 126 | init.orthogonal_(m.weight.data, gain=init_gain) 127 | else: 128 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 129 | if hasattr(m, 'bias') and m.bias is not None: 130 | init.constant_(m.bias.data, 0.0) 131 | 132 | self.apply(init_func) 133 | 134 | def forward(self, feat_list: List[Tensor]): 135 | # feat_list[i]: [b, c, h, w] 136 | y2, y3 = feat_list 137 | z2 = self.uconv3(y3) # (512+512,32)->(256,64) 138 | z1 = self.uconv4(torch.cat([z2, y2], 1)) # (256+256,64)->(64,128) 139 | z0 = self.uconv5(z1) # (64+64,128)->(32,256) 140 | out = self.rec_conv(z0) # (32,256)->(3,256) 141 | out = rearrange(out, 'b (c t) h w -> b c t h w', c=3, t=self.num_frm) 142 | return out 143 | 144 | 145 | class FrameEncoder(ResNet): 146 | def __init__(self, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], img_dim: int = 3, zero_init_residual: bool = False, groups: int = 1, width_per_group: int = 64, replace_stride_with_dilation: Optional[List[bool]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None): 147 | super().__init__(block, layers, 0, zero_init_residual, groups, width_per_group, replace_stride_with_dilation, norm_layer) 148 | 149 | self.conv1 = nn.Conv2d(img_dim, self.conv1.out_channels, kernel_size=self.conv1.kernel_size, stride=self.conv1.stride, padding=self.conv1.padding, bias=self.conv1.bias) 150 | 151 | del self.layer3 152 | del self.layer4 153 | del self.avgpool 154 | del self.fc 155 | 156 | self.init_params() 157 | 158 | def init_params(self, init_type='kaiming'): 159 | init_gain = 1.0 160 | 161 | def init_func(m): # define the initialization function 162 | classname = m.__class__.__name__ 163 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 164 | if init_type == 'normal': 165 | init.normal_(m.weight.data, 0.0, init_gain) 166 | elif init_type == 'xavier': 167 | init.xavier_normal_(m.weight.data, gain=init_gain) 168 | elif init_type == 'kaiming': 169 | init.kaiming_normal_(m.weight.data) 170 | elif init_type == 'orthogonal': 171 | init.orthogonal_(m.weight.data, gain=init_gain) 172 | else: 173 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 174 | if hasattr(m, 'bias') and m.bias is not None: 175 | init.constant_(m.bias.data, 0.0) 176 | 177 | self.apply(init_func) 178 | 179 | def _forward_impl(self, x: Tensor) -> List[Tensor]: 180 | feat_list = [] 181 | x = self.conv1(x) # (64,128) 182 | x = self.bn1(x) 183 | x = self.relu(x) 184 | x = self.maxpool(x) 185 | 186 | x = self.layer1(x) # (256,64) 187 | feat_list.append(x) 188 | x = self.layer2(x) # (512,32) 189 | feat_list.append(x) 190 | 191 | return feat_list 192 | 193 | 194 | class BackgroundEncoder(nn.Module): 195 | def __init__(self, num_classes: int, img_dim=3): 196 | super().__init__() 197 | 198 | backbone = nn.Sequential( 199 | # 480 200 | nn.Conv2d(3, 32, 7, 4), 201 | nn.BatchNorm2d(32), 202 | nn.LeakyReLU(), 203 | nn.Conv2d(32, 64, 7, 4), 204 | nn.BatchNorm2d(64), 205 | nn.LeakyReLU(), 206 | nn.Conv2d(64, 128, 5, 2), # 13 207 | nn.BatchNorm2d(128), 208 | nn.LeakyReLU(), 209 | ) 210 | 211 | self.num_classes = num_classes 212 | self.backbone = backbone 213 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 214 | self.classifier = nn.Linear(128, num_classes) 215 | 216 | def forward(self, inp_bgd: Tensor, only_feat: bool = False): 217 | # inp_bgd: [b, c, h, w] 218 | inp_bgd = tvf.resize(inp_bgd, (480, 480)) 219 | feat: Tensor = self.backbone(inp_bgd) # (b, 128, 13, 13) 220 | if only_feat: 221 | return feat 222 | feat: Tensor = self.avg_pool(feat) 223 | feat.squeeze_(2).squeeze_(2) # (b, 128) 224 | out = self.classifier(feat) 225 | return out 226 | 227 | 228 | class SceneFrameAE(nn.Module): 229 | def __init__(self, img_dim: int = 3, inp_frm: int = 8, tgt_frm: int = 1, bgd_encoder: BackgroundEncoder = None, encoder_block_type: str = 'basic', decoder_block_type='basic', pretrained_encoder: bool = False, writer=print, lam_cvae: float = 0.): 230 | super().__init__() 231 | 232 | encoder_block_type = encoder_block_type.lower() 233 | if encoder_block_type == 'basic': 234 | encoder_block_class = BasicBlock 235 | model_arch = 'resnet34' 236 | elif encoder_block_type == 'bottleneck': 237 | encoder_block_class = Bottleneck 238 | model_arch = 'resnet50' 239 | else: 240 | raise NameError(encoder_block_type) 241 | 242 | self.frame_encoder = FrameEncoder(encoder_block_class, [3, 4, 6, 3], img_dim=img_dim * inp_frm) 243 | 244 | assert isinstance(bgd_encoder, BackgroundEncoder), f"{type(bgd_encoder)}" 245 | self.bgd_encoder = bgd_encoder 246 | 247 | assert lam_cvae in (0., 1.), "Only support `lam_cvae = 0 or 1` now." 248 | self.use_cvae = bool(int(lam_cvae)) 249 | if self.use_cvae: 250 | self.cvae_32 = CVAE(32**2, 256, 2, True, self.bgd_encoder.num_classes) 251 | self.cvae_64 = CVAE(64**2, 256, 2, True, self.bgd_encoder.num_classes) 252 | 253 | self.frame_decoder = FrameDecoder(encoder_block_type, decoder_block_type, num_frm=tgt_frm) 254 | 255 | if pretrained_encoder: 256 | writer(f"Loading pre-trained weights on ImageNet for encoder {model_arch} ...") 257 | state_dict = load_state_dict_from_url(model_urls[model_arch]) 258 | _ret_load_info = self.encoder.load_state_dict(state_dict, strict=False) 259 | writer(f"missing_keys: {_ret_load_info[0]}") 260 | writer(f"unexpected_keys: {_ret_load_info[1]}") 261 | else: 262 | writer(f"Do NOT load the pre-trained weights on ImageNet for encoder {model_arch}") 263 | 264 | def forward(self, inp_snp: Tensor, inp_bgd, sce_cls=None): 265 | # inp_snp: [b, c, t, h, w], inp_bgd: [b, c, h, w] 266 | assert inp_snp.ndim == 5, f"{inp_snp.shape}" 267 | b, c, t, h, w = inp_snp.shape 268 | inp_snp = rearrange(inp_snp, 'b c t h w -> b (c t) h w') 269 | 270 | feat_list: List[Tensor] = self.frame_encoder(inp_snp) # (c',h'w')=(64,128X),(256|64,64**2),(512|128,32**2),#(1024,16),(2048,8); #(2048,2),(8192,) 271 | # feat_list[i]: [b, c', h', w'] 272 | 273 | if self.use_cvae: 274 | if sce_cls: 275 | sce_cls = torch.as_tensor([sce_cls] * b, dtype=torch.int64, device=inp_snp.device) 276 | else: 277 | sce_cls: Tensor = self.bgd_encoder(inp_bgd, False) # [b, d] 278 | sce_cls = torch.argmax(sce_cls, 1, True) 279 | 280 | c = feat_list[1].shape[1] 281 | feat_32: Tensor = rearrange(feat_list[1], 'b c h w -> (b c) (h w)') # [b, c, h*w] = [b, 128, 32*32] 282 | rec_x32, mean_32, log_var_32, z_32 = self.cvae_32(feat_32.detach(), torch.repeat_interleave(sce_cls, c, 0)) 283 | feat_list[1] = rearrange(rec_x32, '(b c) (h w) -> b c h w', b=b, c=128, h=32, w=32) 284 | 285 | c = feat_list[0].shape[1] 286 | feat_64: Tensor = rearrange(feat_list[0], 'b c h w -> (b c) (h w)') # [b, c, h*w] = [b, 64, 64*64] 287 | rec_x64, mean_64, log_var_64, z_64 = self.cvae_64(feat_64.detach(), torch.repeat_interleave(sce_cls, c, 0)) 288 | feat_list[0] = rearrange(rec_x64, '(b c) (h w) -> b c h w', b=b, c=64, h=64, w=64) 289 | 290 | rec_img = self.frame_decoder(feat_list) # [b, c, t, h, w] 291 | 292 | if self.use_cvae: 293 | return rec_img, {'rec_x32': rec_x32, 'feat_32': feat_32.detach(), 'mean_32': mean_32, 'log_var_32': log_var_32}, {'rec_x64': rec_x64, 'feat_64': feat_64.detach(), 'mean_64': mean_64, 'log_var_64': log_var_64} 294 | else: 295 | return rec_img 296 | 297 | 298 | class BidirectionalFrameAE(nn.Module): 299 | def __init__(self, img_dim: int = 3, inp_frm: int = 8, tgt_frm: int = 1, scene_classes=2, lam_cvae=1., encoder_block_type: str = 'basic', decoder_block_type='basic'): 300 | super().__init__() 301 | f_bgd_encoder = BackgroundEncoder(scene_classes) 302 | b_bgd_encoder = BackgroundEncoder(scene_classes) 303 | self.f_frameAE = SceneFrameAE(img_dim, inp_frm, tgt_frm, f_bgd_encoder, encoder_block_type, decoder_block_type, lam_cvae=lam_cvae) 304 | self.b_frameAE = SceneFrameAE(img_dim, inp_frm, 1, b_bgd_encoder, encoder_block_type, decoder_block_type, lam_cvae=lam_cvae) 305 | 306 | def forward(self, inp_frm: Tensor): 307 | # inp_frm: [b, c, t, h, w] 308 | # pred_forward_frm = self.f_frameAE(inp_frm) # [b, c, tgt_frm+pre_frm, h, w] 309 | # pre_backward_frm = self.b_frameAE(torch.cat([pred_forward_frm, torch.flip(inp_frm[:, :, 1:], [-1])], -1)) 310 | # return pred_forward_frm, pre_backward_frm 311 | raise NotImplementedError() 312 | -------------------------------------------------------------------------------- /main/dsets.py: -------------------------------------------------------------------------------- 1 | from os import listdir, getpid 2 | from os.path import join, exists, basename 3 | 4 | import torch 5 | from torch.utils.data.dataset import Dataset as tc_Dataset 6 | from torchvision import transforms as T 7 | from torchvision.transforms import functional as F 8 | 9 | from typing import Tuple, Dict, List, Union 10 | from collections import OrderedDict 11 | import pickle 12 | import random 13 | import numpy as np 14 | import mmcv 15 | import warnings 16 | 17 | mmcv.use_backend('turbojpeg') 18 | 19 | 20 | class DatasetInfoDict(): 21 | def __init__(self): 22 | # (h, w); None represents more than one resolutions 23 | self.frame_size = {'ST': (480, 856), 'Ave': (360, 640), 'Cor': (1080, 1920), 'NWPU': (1080, 1920)} 24 | 25 | @property 26 | def data_names(self): 27 | return ('ST', 'Ave', 'Cor', 'NWPU') 28 | 29 | def __getitem__(self, data_name: str): 30 | data_attr = {} 31 | for attr_name, attr_dict in self.__dict__.items(): 32 | if not attr_name.startswith('__'): 33 | if data_name in attr_dict: 34 | data_attr[attr_name] = attr_dict[data_name] 35 | else: 36 | raise KeyError(f"Cannot find the attribute '{attr_name}' for '{data_name}'") 37 | return data_attr 38 | 39 | 40 | class BoxTransform(): 41 | def __init__(self, frame_height, frame_width): 42 | self.frame_height = frame_height 43 | self.frame_width = frame_width 44 | 45 | def __call__(self, batch_bbox: torch.Tensor): 46 | # batch_bbox: [C=4,T] 47 | assert batch_bbox.ndim == 2 and batch_bbox.shape[0] == 4, f"{batch_bbox.shape}" 48 | dst_batch_box = torch.zeros_like(batch_bbox) 49 | dst_batch_box[0] = (batch_bbox[0] + batch_bbox[2]) / 2 / self.frame_width 50 | dst_batch_box[1] = (batch_bbox[1] + batch_bbox[3]) / 2 / self.frame_height 51 | dst_batch_box[2] = (batch_bbox[2] - batch_bbox[0]) / self.frame_width 52 | dst_batch_box[3] = (batch_bbox[3] - batch_bbox[1]) / self.frame_height 53 | return dst_batch_box 54 | 55 | 56 | class TrainSetTrackingObject(tc_Dataset): 57 | def __init__(self, video_dir: str, track_dir: str, snippet_len: int, snippet_itv: float, iterations: int = 1, vid_suffix='avi', cache_video: bool = False, device='cpu', **kwargs): 58 | super().__init__() 59 | self._video_dir = video_dir 60 | self._snippet_len = snippet_len 61 | self._snippet_itv = snippet_itv 62 | self._iterations = iterations 63 | self._vid_suffix = vid_suffix 64 | self.device = device 65 | self._kwargs = kwargs 66 | 67 | data_info_dict = DatasetInfoDict() 68 | self.data_info = data_info_dict[kwargs.get('data_name', list(filter(lambda dname: dname in video_dir, data_info_dict.data_names))[0])] 69 | 70 | self._all_trk_dict = self._get_track_dict(track_dir) 71 | self._vid_samp_weight = self._get_video_sampling_weight(self._all_trk_dict) 72 | self._frm_samp_weight = self._get_frame_sampling_weight(self._all_trk_dict) 73 | self._scene_label_dict = self._get_scene_label_dict() 74 | 75 | self._tsfm_img = self._get_tsfm_img() 76 | self._tsfm_box = self._get_tsfm_box() 77 | self._tsfm_bgd = self._get_tsfm_bgd() 78 | 79 | self._rng = random.Random(kwargs.get('seed', 2022 + getpid())) 80 | 81 | self._frame_dir = kwargs.get('frame_dir', "") 82 | 83 | self.__cache_max_vid = kwargs.get('cache_max_vid', len(self._all_trk_dict)) 84 | self.__cache_max_frm = kwargs.get('cache_max_frm', 4096) 85 | 86 | if cache_video: 87 | warnings.warn(f"work>1时用vid_cache可能会造成内存翻倍,并产生OpenCV错误!") 88 | self._vid_cache = mmcv.Cache(self.__cache_max_vid) 89 | else: 90 | self._vid_cache = None 91 | 92 | self.check_init() 93 | 94 | self._frname_tmpl = self._get_frm_file_tmpl() if self._frame_dir else None 95 | 96 | @staticmethod 97 | def _get_track_dict(track_dir: str) -> OrderedDict: 98 | track_dict = OrderedDict() 99 | for pkl_name in sorted(listdir(track_dir)): 100 | with open(join(track_dir, pkl_name), 'rb') as f: 101 | _trk_dict: OrderedDict = pickle.load(f) 102 | track_dict[pkl_name.split('.')[0]] = _trk_dict 103 | return track_dict 104 | 105 | @staticmethod 106 | def _get_video_sampling_weight(all_track_dict: OrderedDict) -> List[int]: 107 | vid_samp_weight = [len(vid_trk_dict) for vid_trk_dict in all_track_dict.values()] 108 | return vid_samp_weight 109 | 110 | @staticmethod 111 | def _get_frame_sampling_weight(all_track_dict: OrderedDict) -> Dict[str, Dict[int, int]]: 112 | frm_samp_weight = OrderedDict() 113 | for vid_name, vid_trk_dict in all_track_dict.items(): 114 | vid_trk_dict: OrderedDict 115 | assert not vid_name in frm_samp_weight 116 | frm_samp_weight[vid_name] = OrderedDict() 117 | for frm_idx, trk_ary in vid_trk_dict.items(): 118 | # 之后可以用(key, value)组成的list对来确定每个frm_idx的weight 119 | assert not frm_idx in frm_samp_weight 120 | frm_samp_weight[vid_name][frm_idx] = len(trk_ary) 121 | return frm_samp_weight 122 | 123 | @property 124 | def snippet_len(self): 125 | return self._snippet_len 126 | 127 | @property 128 | def snippet_itv(self): 129 | return self._snippet_itv 130 | 131 | @property 132 | def vid_suffix(self): 133 | return self._vid_suffix 134 | 135 | @property 136 | def iterations(self): 137 | return self._iterations 138 | 139 | @property 140 | def vid_samp_weight(self): 141 | return self._vid_samp_weight 142 | 143 | @property 144 | def frm_samp_weight(self): 145 | return self._frm_samp_weight 146 | 147 | @property 148 | def all_trk_dict(self): 149 | return self._all_trk_dict 150 | 151 | def check_init(self): 152 | vid_list = sorted(listdir(self._video_dir)) 153 | assert len(vid_list) > 0 154 | assert len(vid_list) == len(self.all_trk_dict) == len(self.vid_samp_weight) == len(self.frm_samp_weight) 155 | 156 | if self._frame_dir: 157 | assert exists(self._frame_dir) 158 | 159 | for file_name in vid_list: 160 | vid_name = file_name.split('.')[0] 161 | assert vid_name in self.all_trk_dict, f"{vid_name}" 162 | 163 | if self._frame_dir: 164 | assert exists(p := join(self._frame_dir, vid_name)), f"{p}" 165 | 166 | with mmcv.VideoReader(join(self._video_dir, file_name)) as v: 167 | n = v.frame_cnt 168 | if not n == len(lp := listdir(p)): 169 | print(f"WARNING: video \"{p}\": vid_len = {n}, num_frm = {len(lp)}. The difference of frame numbers is ignored. Note it might cause bugs.") 170 | 171 | assert np.all(np.array(alp := list(map(len, lp))) == alp[0]), f"{p}" 172 | 173 | @staticmethod 174 | def _get_tsfm_img(): 175 | tsfm = T.Compose([ 176 | T.Resize((256, 256)), 177 | T.Normalize([0.45, 0.45, 0.45], [0.225, 0.225, 0.225], inplace=True), 178 | ]) 179 | return tsfm 180 | 181 | def _get_tsfm_bgd(self): 182 | tsfm = T.Compose([ 183 | T.Normalize([0.45, 0.45, 0.45], [0.225, 0.225, 0.225], inplace=True), 184 | ]) 185 | return tsfm 186 | 187 | def _get_tsfm_box(self): 188 | tsfm = BoxTransform(self.data_info['frame_size'][0], self.data_info['frame_size'][1]) 189 | return tsfm 190 | 191 | def _get_frm_file_tmpl(self) -> str: 192 | one_frame_name = listdir(join(self._frame_dir, listdir(self._frame_dir)[0]))[0] 193 | frm_name, suffix = one_frame_name.split('.') 194 | fname_tmpl = "{:0" + str(len(frm_name)) + "d}" + ".{}".format(suffix) 195 | # fname_tmpl = "{:{}d}.{}".format(len(frm_name), suffix) 196 | return fname_tmpl 197 | 198 | def _get_scene_label_dict(self): 199 | scene_label_dict = OrderedDict() 200 | for vid_name in sorted(self.all_trk_dict.keys()): 201 | vid_name: str 202 | if not '_' in vid_name: 203 | scene_label_dict['single'] = len(scene_label_dict) 204 | break 205 | scene_name = vid_name.split('_')[0] 206 | if not scene_name in scene_label_dict: 207 | scene_label_dict[scene_name] = len(scene_label_dict) 208 | return scene_label_dict 209 | 210 | def choose_one_video(self) -> str: 211 | raise NotImplementedError() 212 | cho_vid_name = self._rng.choices(list(self.all_trk_dict.keys()), weights=self.vid_samp_weight, k=1)[0] 213 | return cho_vid_name 214 | 215 | def choose_one_frame(self, vid_name: str, uniformly: bool = False) -> int: 216 | if uniformly: 217 | cho_frm_idx = self._rng.choice(list(self.frm_samp_weight[vid_name].keys())) 218 | else: 219 | cho_frm_idx = self._rng.choices(list(self.frm_samp_weight[vid_name].keys()), weights=self.frm_samp_weight[vid_name].values(), k=1)[0] 220 | return cho_frm_idx 221 | 222 | def choose_one_track(self, vid_name: str, frm_idx: int) -> List[Tuple[int, np.ndarray]]: 223 | trk_data: np.ndarray = self.all_trk_dict[vid_name][frm_idx] 224 | trk_ary = trk_data[self._rng.randrange(0, len(trk_data))] # trk_id(0), bbox(1~4), prob(5), cls(6) 225 | trk_id = trk_ary[0] 226 | 227 | snippet_trk_ary_list = [(frm_idx, trk_ary)] 228 | 229 | for _fi in [frm_idx - i * self.snippet_itv for i in range(1, self.snippet_len)]: 230 | _fi = round(_fi) 231 | if _fi in self.all_trk_dict[vid_name]: 232 | _other_trk_data = self.all_trk_dict[vid_name][_fi] 233 | _ret_idx = np.where(trk_id == _other_trk_data[:, 0]) 234 | assert len(_ret_idx) == 1, f"{len(_ret_idx)}" 235 | if len(_ret_idx[0]) == 0: 236 | return [] 237 | else: 238 | _ary_idx = _ret_idx[0][0] 239 | _other_trk_ary = self.all_trk_dict[vid_name][_fi][_ary_idx] 240 | snippet_trk_ary_list.append((_fi, _other_trk_ary)) 241 | else: 242 | return [] 243 | 244 | snippet_trk_ary_list.reverse() 245 | 246 | return snippet_trk_ary_list 247 | 248 | def sample_one_snippet_track(self, vid_name: str) -> List[Tuple[int, np.ndarray]]: 249 | # vid_name = self.choose_one_video() 250 | end_frm_idx = self.choose_one_frame(vid_name) 251 | snippet_trk_ary_list = self.choose_one_track(vid_name, end_frm_idx) 252 | n_attempt = 30 253 | while snippet_trk_ary_list == []: 254 | n_attempt -= 1 255 | if n_attempt == 0: 256 | raise TimeoutError(f"{vid_name}") 257 | # raise TimeoutError(f"{vid_name}") 258 | end_frm_idx = self.choose_one_frame(vid_name, n_attempt <= 5) 259 | snippet_trk_ary_list = self.choose_one_track(vid_name, end_frm_idx) 260 | return snippet_trk_ary_list 261 | 262 | def _read_from_image(self, video_path: str, frame_idx: int) -> np.ndarray: 263 | frm_path = join(self._frame_dir, basename(video_path).split('.')[0], self._frname_tmpl.format(frame_idx)) 264 | assert exists(frm_path), f"{frm_path}" 265 | # print(abspath(frm_path), realpath(frm_path)) 266 | return mmcv.imread(frm_path, backend='turbojpeg') 267 | 268 | def load_snippet(self, vid_name: str, snippet_trk_ary_list: List[Tuple[int, np.ndarray]]): 269 | vid_path = join(self._video_dir, f"{vid_name}.{self.vid_suffix}") 270 | if not exists(vid_path): 271 | raise FileNotFoundError(vid_path) 272 | 273 | vid_cap = None 274 | if self._frame_dir: 275 | pass 276 | else: 277 | if self._vid_cache: 278 | if _v := self._vid_cache.get(vid_name): 279 | vid_cap = _v 280 | else: 281 | vid_cap = mmcv.VideoReader(vid_path, cache_capacity=self.__cache_max_frm) 282 | self._vid_cache.put(vid_name, vid_cap) 283 | else: 284 | vid_cap = mmcv.VideoReader(vid_path, cache_capacity=self.__cache_max_frm) 285 | 286 | # 根据某一帧,确定一个方形固定视野 287 | obj_bbox_list = np.asarray([[max(int(p), 0) for p in _trk_ary[1:5]] for _fi, _trk_ary in snippet_trk_ary_list], np.int64) 288 | # 这里先根据中间帧设置 289 | anchor_box = obj_bbox_list[int(len(obj_bbox_list) / 2), :] 290 | 291 | view_center_xy = [(anchor_box[0] + anchor_box[2]) // 2, (anchor_box[1] + anchor_box[3]) // 2] 292 | 293 | full_view_range = 256 294 | # full_view_range = max(anchor_box[2] - anchor_box[0], anchor_box[3] - anchor_box[1]) 295 | half_view_range = full_view_range // 2 296 | 297 | _fidx = snippet_trk_ary_list[0][0] 298 | fh, fw, _ = self._read_from_image(vid_path, _fidx).shape if self._frame_dir else vid_cap.get_frame(_fidx).shape 299 | 300 | view_center_xy[0] = min(max(half_view_range, view_center_xy[0]), fw - half_view_range) 301 | view_center_xy[1] = min(max(half_view_range, view_center_xy[1]), fh - half_view_range) 302 | 303 | full_view_box = [view_center_xy[0] - half_view_range, view_center_xy[1] - half_view_range, 304 | view_center_xy[0] + half_view_range, view_center_xy[1] + half_view_range] 305 | 306 | # 然后统一读取这一个区域 307 | snippet = [] 308 | for _fi, _trk_ary in snippet_trk_ary_list: 309 | frm: np.ndarray = self._read_from_image(vid_path, _fi) if self._frame_dir else vid_cap.get_frame(_fi) 310 | obj: np.ndarray = frm[full_view_box[1]:full_view_box[3], full_view_box[0]:full_view_box[2], :] 311 | 312 | obj: torch.Tensor = torch.as_tensor(obj, dtype=torch.float32, device=self.device) 313 | obj = self._tsfm_img(obj.permute(2, 0, 1).div(255.)) 314 | 315 | snippet.append(obj) 316 | 317 | last_frm = frm.copy() 318 | last_fidx: int = _fi 319 | return torch.stack(snippet, 1), last_frm, last_fidx 320 | 321 | def load_background(self, frame: np.ndarray, frm_trk_list: np.ndarray): 322 | fh, fw, _ = frame.shape 323 | for _trk_ary in frm_trk_list: 324 | obj_bbox = np.asarray(_trk_ary[1:5], np.int64) 325 | obj_bbox[0] = max(obj_bbox[0], 0) 326 | obj_bbox[1] = max(obj_bbox[1], 0) 327 | obj_bbox[2] = min(obj_bbox[2], fw) 328 | obj_bbox[3] = min(obj_bbox[3], fh) 329 | 330 | frame[obj_bbox[1]:obj_bbox[3], obj_bbox[0]:obj_bbox[2], :] = 0 331 | 332 | frame: torch.Tensor = torch.as_tensor(frame, dtype=torch.float32, device=self.device) 333 | frame = self._tsfm_bgd(frame.permute(2, 0, 1).div(255.)) 334 | return frame 335 | 336 | def load_bgd_label(self, vid_name: str): 337 | vid_name: str 338 | if not '_' in vid_name: 339 | scene_name = 'single' 340 | else: 341 | scene_name = vid_name.split('_')[0] 342 | scene_label = self._scene_label_dict[scene_name] 343 | return scene_label 344 | 345 | def __len__(self) -> int: 346 | return len(self.all_trk_dict) * self.iterations 347 | 348 | def __getitem__(self, vid_idx: int): 349 | vid_name: str = list(self.all_trk_dict.keys())[vid_idx % len(self.all_trk_dict)] 350 | while True: 351 | try: 352 | snippet_trk_ary_list = self.sample_one_snippet_track(vid_name) 353 | except TimeoutError as e: 354 | vid_name: str = list(self.all_trk_dict.keys())[random.randrange(0, len(self.all_trk_dict))] 355 | else: 356 | break 357 | 358 | snippet, last_frm, last_fidx = self.load_snippet(vid_name, snippet_trk_ary_list) # [C, T, H, W] 359 | background = self.load_background(last_frm, self.all_trk_dict[vid_name][last_fidx]) 360 | scene_label = self.load_bgd_label(vid_name) 361 | 362 | return snippet, background, scene_label 363 | 364 | 365 | class SnippetVideoReader(mmcv.VideoReader): 366 | def __init__(self, video_path: str, video_track: OrderedDict, snippet_len: int, snippet_itv: int, device='cpu', tsfm_img=None, tsfm_box=None, tsfm_bgd=None, frame_dir: str = '', frname_tmpl: str = ''): 367 | super().__init__(video_path, (snippet_len + 1) * snippet_itv) 368 | self._vid_name = basename(video_path).split('.')[0] 369 | self._vid_trk = video_track 370 | self._slen = snippet_len 371 | self._sitv = snippet_itv 372 | self.device = device 373 | 374 | self._tsfm_img = tsfm_img 375 | self._tsfm_box = tsfm_box 376 | self._tsfm_bgd = tsfm_bgd 377 | 378 | self._frame_dir = frame_dir 379 | self._frname_tmpl = frname_tmpl 380 | 381 | if self.frame_cnt < 3050: 382 | self.all_frames = self.read_all_frames() 383 | if len(self.all_frames) < 2500: 384 | try: 385 | self.all_frames = self.all_frames.to(device) 386 | except RuntimeError as e: 387 | print(f"{self.vid_name}: {len(self.all_frames)} frames to 'cpu'") 388 | torch.cuda.empty_cache() 389 | else: 390 | self.all_frames = None 391 | 392 | @property 393 | def vid_name(self): 394 | return self._vid_name 395 | 396 | @property 397 | def vid_trk(self): 398 | return self._vid_trk 399 | 400 | @property 401 | def snippet_len(self): 402 | return self._slen 403 | 404 | @property 405 | def snippet_itv(self): 406 | return self._sitv 407 | 408 | def get_all_tracks(self, frm_idx: int) -> List[Tuple[int, np.ndarray]]: 409 | if not frm_idx in self.vid_trk: 410 | return [] 411 | 412 | all_snippet_trks = [] 413 | 414 | trk_data: np.ndarray = self.vid_trk[frm_idx] 415 | for ary_idx in range(0, len(trk_data)): 416 | trk_ary = trk_data[ary_idx] # trk_id(0), bbox(1~4), prob(5), cls(6) 417 | trk_id = trk_ary[0] 418 | 419 | snippet_trk_ary_list = [(frm_idx, trk_ary)] 420 | for _fi in [frm_idx - i * self.snippet_itv for i in range(1, self.snippet_len)]: 421 | _fi = round(_fi) 422 | if _fi in self.vid_trk: 423 | _other_trk_data = self.vid_trk[_fi] 424 | _ret_idx = np.where(trk_id == _other_trk_data[:, 0]) 425 | assert len(_ret_idx) == 1, f"{len(_ret_idx)}" 426 | if len(_ret_idx[0]) == 0: 427 | break 428 | else: 429 | _ary_idx = _ret_idx[0][0] 430 | _other_trk_ary = self.vid_trk[_fi][_ary_idx] 431 | snippet_trk_ary_list.append((_fi, _other_trk_ary)) 432 | else: 433 | break 434 | 435 | if len(snippet_trk_ary_list) < self.snippet_len: 436 | continue 437 | 438 | snippet_trk_ary_list.reverse() 439 | all_snippet_trks.append(snippet_trk_ary_list) 440 | 441 | return all_snippet_trks 442 | 443 | def _read_from_image(self, frame_idx: int) -> np.ndarray: 444 | if (frm := self._cache.get(frame_idx)) is not None: 445 | return frm 446 | else: 447 | frm_path = join(self._frame_dir, self.vid_name, self._frname_tmpl.format(frame_idx)) 448 | assert exists(frm_path), f"{frm_path}" 449 | frm = mmcv.imread(frm_path, backend='turbojpeg') 450 | self._cache.put(frame_idx, frm) 451 | return frm 452 | 453 | def read_all_frames(self): 454 | all_frames = [] 455 | for frm_idx in range(self.frame_cnt): 456 | all_frames.append(torch.from_numpy(self.get_frame(frm_idx))) 457 | all_frames = torch.stack(all_frames) 458 | return all_frames # [T, H, W, C] 459 | 460 | def load_snippet(self, snippet_trk_ary_list: List[Tuple[int, np.ndarray]]): 461 | obj_bbox_list = np.asarray([[max(int(p), 0) for p in _trk_ary[1:5]] for _fi, _trk_ary in snippet_trk_ary_list], np.int64) 462 | anchor_box = obj_bbox_list[int(len(obj_bbox_list) / 2), :] 463 | view_center_xy = [(anchor_box[0] + anchor_box[2]) // 2, (anchor_box[1] + anchor_box[3]) // 2] 464 | 465 | full_view_range = 256 466 | half_view_range = full_view_range // 2 467 | 468 | _fidx = snippet_trk_ary_list[0][0] 469 | fh, fw, _ = self._read_from_image(_fidx).shape if self._frame_dir else self.get_frame(_fidx).shape 470 | view_center_xy[0] = min(max(half_view_range, view_center_xy[0]), fw - half_view_range) 471 | view_center_xy[1] = min(max(half_view_range, view_center_xy[1]), fh - half_view_range) 472 | 473 | full_view_box = [view_center_xy[0] - half_view_range, view_center_xy[1] - half_view_range, 474 | view_center_xy[0] + half_view_range, view_center_xy[1] + half_view_range] 475 | 476 | snippet = [] 477 | for _fi, _trk_ary in snippet_trk_ary_list: 478 | if self.all_frames is None: 479 | frm: np.ndarray = self._read_from_image(_fi) if self._frame_dir else self.get_frame(_fi) 480 | else: 481 | frm: torch.Tensor = self.all_frames[_fi] 482 | 483 | obj = frm[full_view_box[1]:full_view_box[3], full_view_box[0]:full_view_box[2], :] 484 | 485 | obj: torch.Tensor = torch.as_tensor(obj, dtype=torch.float32, device=self.device) 486 | obj = self._tsfm_img(obj.permute(2, 0, 1).div(255.)) 487 | 488 | snippet.append(obj) 489 | 490 | last_frm = frm.copy() if isinstance(frm, np.ndarray) else frm.clone() 491 | last_fidx: int = _fi 492 | return torch.stack(snippet, 1), last_frm, last_fidx # [C,T,H,W] 493 | 494 | def load_background(self, frame: np.ndarray, frm_trk_list: np.ndarray): 495 | fh, fw, _ = frame.shape 496 | for _trk_ary in frm_trk_list: 497 | obj_bbox = np.asarray(_trk_ary[1:5], np.int64) 498 | obj_bbox[0] = max(obj_bbox[0], 0) 499 | obj_bbox[1] = max(obj_bbox[1], 0) 500 | obj_bbox[2] = min(obj_bbox[2], fw) 501 | obj_bbox[3] = min(obj_bbox[3], fh) 502 | 503 | frame[obj_bbox[1]:obj_bbox[3], obj_bbox[0]:obj_bbox[2], :] = 0 504 | 505 | frame: torch.Tensor = torch.as_tensor(frame, dtype=torch.float32, device=self.device) 506 | frame = self._tsfm_bgd(frame.permute(2, 0, 1).div(255.)) 507 | return frame 508 | 509 | def __getitem__(self, end_frm_idx: int) -> torch.Tensor: 510 | ''' 511 | 把这个片段中所有物体的片段作为一个batch返回 512 | ''' 513 | assert 0 <= end_frm_idx < self.frame_cnt 514 | if end_frm_idx < (self._slen - 1) * self._sitv: 515 | return None, None 516 | else: 517 | all_trk_list = self.get_all_tracks(end_frm_idx) 518 | if all_trk_list: 519 | batch_snippet = [] 520 | background = None 521 | for _trk_list in all_trk_list: 522 | obj_snippet, last_frm, last_fidx = self.load_snippet(_trk_list) 523 | batch_snippet.append(obj_snippet) 524 | if background is None: 525 | background = self.load_background(last_frm, self.vid_trk[last_fidx]) 526 | batch_snippet = torch.stack(batch_snippet, 0) # [B,C,T=1,H,W] 527 | return batch_snippet, background.unsqueeze(0) 528 | else: 529 | return None, None 530 | 531 | def __iter__(self): 532 | raise NotImplementedError() 533 | 534 | def __next__(self): 535 | raise NotImplementedError() 536 | 537 | 538 | class TestSetTrackingObject(TrainSetTrackingObject): 539 | def __init__(self, video_dir: str, track_dir: str, snippet_len: int, snippet_itv: int, vid_suffix='avi', device='cpu', **kwargs): 540 | super().__init__(video_dir, track_dir, snippet_len, snippet_itv, 1, vid_suffix, False, device, **kwargs) 541 | 542 | def __getitem__(self, vid_idx: int): 543 | vid_name = list(self.all_trk_dict.keys())[vid_idx] 544 | vid_stream = SnippetVideoReader(join(self._video_dir, f"{vid_name}.{self.vid_suffix}"), self.all_trk_dict[vid_name], 545 | self.snippet_len, self.snippet_itv, self.device, self._tsfm_img, self._tsfm_box, self._tsfm_bgd, 546 | self._frame_dir, self._frname_tmpl) 547 | return vid_stream 548 | -------------------------------------------------------------------------------- /main/dsets_bid.py: -------------------------------------------------------------------------------- 1 | from os import listdir, getpid 2 | from os.path import join, exists, basename, dirname 3 | 4 | import torch 5 | from torch.utils.data.dataset import Dataset as tc_Dataset 6 | from torchvision import transforms as T 7 | from torchvision.transforms import functional as F 8 | 9 | from typing import Tuple, Dict, List, Union 10 | from collections import OrderedDict 11 | import pickle 12 | import random 13 | import numpy as np 14 | import mmcv 15 | import warnings 16 | 17 | mmcv.use_backend('turbojpeg') 18 | 19 | 20 | class DatasetInfoDict(): 21 | def __init__(self): 22 | # (h, w); None represents more than one resolutions 23 | self.frame_size = {'ST': (480, 856), 'Ave': (360, 640), 'Cor': (1080, 1920), 'NWPU': (1080, 1920)} 24 | 25 | @property 26 | def data_names(self): 27 | return ('ST', 'Ave', 'Cor', 'NWPU') 28 | 29 | def __getitem__(self, data_name: str): 30 | data_attr = {} 31 | for attr_name, attr_dict in self.__dict__.items(): 32 | if not attr_name.startswith('__'): 33 | if data_name in attr_dict: 34 | data_attr[attr_name] = attr_dict[data_name] 35 | else: 36 | raise KeyError(f"Cannot find the attribute '{attr_name}' for '{data_name}'") 37 | return data_attr 38 | 39 | 40 | class BoxTransform(): 41 | def __init__(self, frame_height, frame_width): 42 | self.frame_height = frame_height 43 | self.frame_width = frame_width 44 | 45 | def __call__(self, batch_bbox: torch.Tensor): 46 | # batch_bbox: [C=4,T] 47 | assert batch_bbox.ndim == 2 and batch_bbox.shape[0] == 4, f"{batch_bbox.shape}" 48 | dst_batch_box = torch.zeros_like(batch_bbox) 49 | dst_batch_box[0] = (batch_bbox[0] + batch_bbox[2]) / 2 / self.frame_width 50 | dst_batch_box[1] = (batch_bbox[1] + batch_bbox[3]) / 2 / self.frame_height 51 | dst_batch_box[2] = (batch_bbox[2] - batch_bbox[0]) / self.frame_width 52 | dst_batch_box[3] = (batch_bbox[3] - batch_bbox[1]) / self.frame_height 53 | return dst_batch_box 54 | 55 | 56 | class TrainSetTrackingObject(tc_Dataset): 57 | def __init__(self, video_dir: str, track_dir: str, snippet_len: int, snippet_itv: int, iterations: int = 1, vid_suffix='avi', cache_video: bool = False, device='cpu', **kwargs): 58 | super().__init__() 59 | self._video_dir = video_dir 60 | self._snippet_len = snippet_len 61 | self._snippet_itv = snippet_itv 62 | self._iterations = iterations 63 | self._vid_suffix = vid_suffix 64 | self.device = device 65 | self._kwargs = kwargs 66 | 67 | data_info_dict = DatasetInfoDict() 68 | self.data_info = data_info_dict[kwargs.get('data_name', list(filter(lambda dname: dname in video_dir, data_info_dict.data_names))[0])] 69 | 70 | self._all_trk_dict = self._get_track_dict(track_dir) 71 | self._vid_samp_weight = self._get_video_sampling_weight(self._all_trk_dict) 72 | self._frm_samp_weight = self._get_frame_sampling_weight(self._all_trk_dict) 73 | 74 | self._tsfm_img = self._get_tsfm_img() 75 | self._tsfm_box = self._get_tsfm_box() 76 | 77 | self._rng = random.Random(kwargs.get('seed', 2022 + getpid())) 78 | 79 | self._frame_dir = kwargs.get('frame_dir', "") 80 | 81 | self.__cache_max_vid = kwargs.get('cache_max_vid', len(self._all_trk_dict)) 82 | self.__cache_max_frm = kwargs.get('cache_max_frm', 4096) 83 | 84 | if cache_video: 85 | warnings.warn(f"work>1时用vid_cache可能会造成内存翻倍,并产生OpenCV错误!") 86 | self._vid_cache = mmcv.Cache(self.__cache_max_vid) 87 | else: 88 | self._vid_cache = None 89 | 90 | self.check_init() 91 | 92 | self._frname_tmpl = self._get_frm_file_tmpl() if self._frame_dir else None 93 | 94 | @staticmethod 95 | def _get_track_dict(track_dir: str) -> OrderedDict: 96 | track_dict = OrderedDict() 97 | for pkl_name in sorted(listdir(track_dir)): 98 | with open(join(track_dir, pkl_name), 'rb') as f: 99 | _trk_dict: OrderedDict = pickle.load(f) 100 | track_dict[pkl_name.split('.')[0]] = _trk_dict 101 | return track_dict 102 | 103 | @staticmethod 104 | def _get_video_sampling_weight(all_track_dict: OrderedDict) -> List[int]: 105 | vid_samp_weight = [len(vid_trk_dict) for vid_trk_dict in all_track_dict.values()] 106 | return vid_samp_weight 107 | 108 | @staticmethod 109 | def _get_frame_sampling_weight(all_track_dict: OrderedDict) -> Dict[str, Dict[int, int]]: 110 | frm_samp_weight = OrderedDict() 111 | for vid_name, vid_trk_dict in all_track_dict.items(): 112 | vid_trk_dict: OrderedDict 113 | assert not vid_name in frm_samp_weight 114 | frm_samp_weight[vid_name] = OrderedDict() 115 | for frm_idx, trk_ary in vid_trk_dict.items(): 116 | # 之后可以用(key, value)组成的list对来确定每个frm_idx的weight 117 | assert not frm_idx in frm_samp_weight 118 | frm_samp_weight[vid_name][frm_idx] = len(trk_ary) 119 | return frm_samp_weight 120 | 121 | @property 122 | def snippet_len(self): 123 | return self._snippet_len 124 | 125 | @property 126 | def snippet_itv(self): 127 | return self._snippet_itv 128 | 129 | @property 130 | def vid_suffix(self): 131 | return self._vid_suffix 132 | 133 | @property 134 | def iterations(self): 135 | return self._iterations 136 | 137 | @property 138 | def vid_samp_weight(self): 139 | return self._vid_samp_weight 140 | 141 | @property 142 | def frm_samp_weight(self): 143 | return self._frm_samp_weight 144 | 145 | @property 146 | def all_trk_dict(self): 147 | return self._all_trk_dict 148 | 149 | def check_init(self): 150 | vid_list = sorted(listdir(self._video_dir)) 151 | assert len(vid_list) > 0 152 | assert len(vid_list) == len(self.all_trk_dict) == len(self.vid_samp_weight) == len(self.frm_samp_weight) 153 | 154 | if self._frame_dir: 155 | assert exists(self._frame_dir) 156 | 157 | for file_name in vid_list: 158 | vid_name = file_name.split('.')[0] 159 | assert vid_name in self.all_trk_dict, f"{vid_name}" 160 | 161 | if self._frame_dir: 162 | assert exists(p := join(self._frame_dir, vid_name)), f"{p}" 163 | 164 | with mmcv.VideoReader(join(self._video_dir, file_name)) as v: 165 | n = v.frame_cnt 166 | if not n == len(lp := listdir(p)): 167 | print(f"WARNING: video \"{p}\": vid_len = {n}, num_frm = {len(lp)}. The difference of frame numbers is ignored. Note it might cause bugs.") 168 | 169 | assert np.all(np.array(alp := list(map(len, lp))) == alp[0]), f"{p}" 170 | 171 | @staticmethod 172 | def _get_tsfm_img(): 173 | tsfm = T.Compose([ 174 | T.Resize((256, 256)), 175 | T.Normalize([0.45, 0.45, 0.45], [0.225, 0.225, 0.225], inplace=True), 176 | ]) 177 | return tsfm 178 | 179 | def _get_tsfm_box(self): 180 | tsfm = BoxTransform(self.data_info['frame_size'][0], self.data_info['frame_size'][1]) 181 | return tsfm 182 | 183 | def _get_frm_file_tmpl(self) -> str: 184 | one_frame_name = listdir(join(self._frame_dir, listdir(self._frame_dir)[0]))[0] 185 | frm_name, suffix = one_frame_name.split('.') 186 | fname_tmpl = "{:0" + str(len(frm_name)) + "d}" + ".{}".format(suffix) 187 | # fname_tmpl = "{:{}d}.{}".format(len(frm_name), suffix) 188 | return fname_tmpl 189 | 190 | def choose_one_video(self) -> str: 191 | raise NotImplementedError() 192 | cho_vid_name = self._rng.choices(list(self.all_trk_dict.keys()), weights=self.vid_samp_weight, k=1)[0] 193 | return cho_vid_name 194 | 195 | def choose_one_frame(self, vid_name: str, uniformly: bool = False) -> int: 196 | frm_idx_list = list(self.frm_samp_weight[vid_name].keys()) 197 | if frm_idx_list == []: 198 | # print(vid_name, self.frm_samp_weight[vid_name]) 199 | return None 200 | if uniformly: 201 | cho_frm_idx = self._rng.choice(frm_idx_list) 202 | else: 203 | cho_frm_idx = self._rng.choices(frm_idx_list, weights=self.frm_samp_weight[vid_name].values(), k=1)[0] 204 | return cho_frm_idx 205 | 206 | def OLD_choose_one_track(self, vid_name: str, frm_idx: int) -> List[Tuple[int, np.ndarray]]: 207 | trk_data: np.ndarray = self.all_trk_dict[vid_name][frm_idx] 208 | trk_ary = trk_data[self._rng.randrange(0, len(trk_data))] # trk_id(0), bbox(1~4), prob(5), cls(6) 209 | trk_id = trk_ary[0] 210 | 211 | snippet_trk_ary_list = [(frm_idx, trk_ary)] 212 | 213 | for _fi in [frm_idx - i * self.snippet_itv for i in range(1, self.snippet_len)]: 214 | _fi = round(_fi) 215 | if _fi in self.all_trk_dict[vid_name]: 216 | _other_trk_data = self.all_trk_dict[vid_name][_fi] 217 | _ret_idx = np.where(trk_id == _other_trk_data[:, 0]) 218 | assert len(_ret_idx) == 1, f"{len(_ret_idx)}" 219 | if len(_ret_idx[0]) == 0: 220 | return [] 221 | else: 222 | _ary_idx = _ret_idx[0][0] 223 | _other_trk_ary = self.all_trk_dict[vid_name][_fi][_ary_idx] 224 | snippet_trk_ary_list.append((_fi, _other_trk_ary)) 225 | else: 226 | return [] 227 | 228 | snippet_trk_ary_list.reverse() 229 | 230 | return snippet_trk_ary_list 231 | 232 | def choose_one_track(self, vid_name: str, frm_idx: int) -> List[Tuple[int, np.ndarray]]: 233 | trk_data: np.ndarray = self.all_trk_dict[vid_name][frm_idx] 234 | trk_ary = trk_data[self._rng.randrange(0, len(trk_data))] # trk_id(0), bbox(1~4), prob(5), cls(6) 235 | trk_id = trk_ary[0] 236 | 237 | snippet_trk_ary_list = [(frm_idx, trk_ary)] 238 | 239 | for _fi in [frm_idx - i * self.snippet_itv for i in range(1, self.snippet_len)]: 240 | _fi = round(_fi) 241 | if _fi in self.all_trk_dict[vid_name]: 242 | _other_trk_data = self.all_trk_dict[vid_name][_fi] 243 | _ret_idx = np.where(trk_id == _other_trk_data[:, 0]) 244 | assert len(_ret_idx) == 1, f"{len(_ret_idx)}" 245 | if len(_ret_idx[0]) == 0: 246 | return [] 247 | else: 248 | _ary_idx = _ret_idx[0][0] 249 | _other_trk_ary = self.all_trk_dict[vid_name][_fi][_ary_idx] 250 | snippet_trk_ary_list.append((_fi, _other_trk_ary)) 251 | else: 252 | return [] 253 | 254 | snippet_trk_ary_list.reverse() 255 | 256 | return snippet_trk_ary_list 257 | 258 | def sample_one_snippet_track(self, vid_name: str) -> List[Tuple[int, np.ndarray]]: 259 | # vid_name = self.choose_one_video() 260 | # end_frm_idx = self.choose_one_frame(vid_name) 261 | if (end_frm_idx := self.choose_one_frame(vid_name)) is None: 262 | # print(f"{vid_name}: EMPTY!") 263 | raise TimeoutError(f"{vid_name}") 264 | snippet_trk_ary_list = self.choose_one_track(vid_name, end_frm_idx) 265 | n_attempt = 30 266 | while snippet_trk_ary_list == []: 267 | n_attempt -= 1 268 | if n_attempt == 0: 269 | raise TimeoutError(f"{vid_name}") 270 | # raise TimeoutError(f"{vid_name}") 271 | end_frm_idx = self.choose_one_frame(vid_name, n_attempt <= 5) 272 | snippet_trk_ary_list = self.choose_one_track(vid_name, end_frm_idx) 273 | return snippet_trk_ary_list 274 | 275 | def _read_from_image(self, video_path: str, frame_idx: int) -> np.ndarray: 276 | frm_path = join(self._frame_dir, basename(video_path).split('.')[0], self._frname_tmpl.format(frame_idx)) 277 | assert exists(frm_path), f"{frm_path}" 278 | return mmcv.imread(frm_path, backend='turbojpeg') 279 | 280 | def load_snippet(self, vid_name: str, snippet_trk_ary_list: List[Tuple[int, np.ndarray]]) -> torch.Tensor: 281 | vid_path = join(self._video_dir, f"{vid_name}.{self.vid_suffix}") 282 | if not exists(vid_path): 283 | raise FileNotFoundError(vid_path) 284 | 285 | vid_cap = None 286 | if self._frame_dir: 287 | pass 288 | else: 289 | if self._vid_cache: 290 | if _v := self._vid_cache.get(vid_name): 291 | vid_cap = _v 292 | else: 293 | vid_cap = mmcv.VideoReader(vid_path, cache_capacity=self.__cache_max_frm) 294 | self._vid_cache.put(vid_name, vid_cap) 295 | else: 296 | vid_cap = mmcv.VideoReader(vid_path, cache_capacity=self.__cache_max_frm) 297 | 298 | # 根据某一帧,确定一个方形固定视野 299 | obj_bbox_list = np.asarray([[max(int(p), 0) for p in _trk_ary[1:5]] for _fi, _trk_ary in snippet_trk_ary_list], np.int64) 300 | # 这里先根据中间帧设置 301 | anchor_box = obj_bbox_list[int(len(obj_bbox_list) / 2), :] 302 | 303 | view_center_xy = [(anchor_box[0] + anchor_box[2]) // 2, (anchor_box[1] + anchor_box[3]) // 2] 304 | 305 | # full_view_range = 256 306 | full_view_range = 512 307 | half_view_range = full_view_range // 2 308 | 309 | _fidx = snippet_trk_ary_list[0][0] 310 | fh, fw, _ = self._read_from_image(vid_path, _fidx).shape if self._frame_dir else vid_cap.get_frame(_fidx).shape 311 | 312 | view_center_xy[0] = min(max(half_view_range, view_center_xy[0]), fw - half_view_range) 313 | view_center_xy[1] = min(max(half_view_range, view_center_xy[1]), fh - half_view_range) 314 | 315 | full_view_box = [view_center_xy[0] - half_view_range, view_center_xy[1] - half_view_range, 316 | view_center_xy[0] + half_view_range, view_center_xy[1] + half_view_range] 317 | 318 | # 然后统一读取这一个区域 319 | snippet = [] 320 | for _fi, _trk_ary in snippet_trk_ary_list: 321 | frm: np.ndarray = self._read_from_image(vid_path, _fi) if self._frame_dir else vid_cap.get_frame(_fi) 322 | obj: np.ndarray = frm[full_view_box[1]:full_view_box[3], full_view_box[0]:full_view_box[2], :] 323 | 324 | # mmcv.imwrite(obj, f"../save.visual/main3a_{vid_name}/src_{_fi}_{_trk_ary[0]}.jpg", auto_mkdir=True) 325 | 326 | obj: torch.Tensor = torch.as_tensor(obj, dtype=torch.float32, device=self.device) 327 | obj = self._tsfm_img(obj.permute(2, 0, 1).div(255.)) 328 | 329 | # mmcv.imwrite(np.transpose(np.array((obj.cpu().numpy() * 0.225 + 0.45) * 255, dtype=np.uint8), (1, 2, 0)), f"../save.visual/{vid_name}/restore_{_fi}_{_trk_ary[0]}.jpg") 330 | 331 | snippet.append(obj) 332 | 333 | return torch.stack(snippet, 1) 334 | 335 | def __len__(self) -> int: 336 | return len(self.all_trk_dict) * self.iterations 337 | 338 | def __getitem__(self, vid_idx: int): 339 | vid_name = list(self.all_trk_dict.keys())[vid_idx % len(self.all_trk_dict)] 340 | nn = 0 341 | while True: 342 | nn += 1 343 | try: 344 | snippet_trk_ary_list = self.sample_one_snippet_track(vid_name) 345 | except TimeoutError as e: 346 | # print(f"{e} Try another video.") 347 | vid_name = list(self.all_trk_dict.keys())[random.randrange(0, len(self.all_trk_dict))] 348 | else: 349 | break 350 | # if nn > 1: 351 | # print(f"attempt: {vid_name}: {nn}") 352 | # bbox = torch.stack([torch.from_numpy(_trk_ary[1:5]) for _fi, _trk_ary in snippet_trk_ary_list], 1).float() # [C=4, T] 353 | # bbox = self._tsfm_box(bbox) 354 | 355 | snippet = self.load_snippet(vid_name, snippet_trk_ary_list) # [C, T, H, W] 356 | # print("GET_ITEM", bbox.shape, snippet.shape) 357 | # assert bbox.shape[1] == snippet.shape[1] 358 | 359 | # return (snippet, bbox) 360 | return snippet 361 | # return 1 362 | 363 | 364 | class SnippetVideoReader(mmcv.VideoReader): 365 | def __init__(self, video_path: str, video_track: OrderedDict, snippet_len: int, snippet_itv: int, device='cpu', tsfm_img=None, tsfm_box=None, frame_dir: str = '', frname_tmpl: str = ''): 366 | super().__init__(video_path, (snippet_len + 1) * snippet_itv) 367 | self._vid_name = basename(video_path).split('.')[0] 368 | self._vid_trk = video_track 369 | self._slen = snippet_len 370 | self._sitv = snippet_itv 371 | self.device = device 372 | 373 | self._tsfm_img = tsfm_img 374 | self._tsfm_box = tsfm_box 375 | 376 | self._frame_dir = frame_dir 377 | self._frname_tmpl = frname_tmpl 378 | 379 | if self.frame_cnt < 3050: 380 | self.all_frames = self.read_all_frames() 381 | if len(self.all_frames) < 2500: 382 | try: 383 | self.all_frames = self.all_frames.to(device) 384 | except RuntimeError as e: 385 | print(f"{self.vid_name}: {len(self.all_frames)} frames to 'cpu'") 386 | torch.cuda.empty_cache() 387 | else: 388 | self.all_frames = None 389 | 390 | @property 391 | def vid_name(self): 392 | return self._vid_name 393 | 394 | @property 395 | def vid_trk(self): 396 | return self._vid_trk 397 | 398 | @property 399 | def snippet_len(self): 400 | return self._slen 401 | 402 | @property 403 | def snippet_itv(self): 404 | return self._sitv 405 | 406 | def get_all_tracks(self, frm_idx: int) -> List[Tuple[int, np.ndarray]]: 407 | if not frm_idx in self.vid_trk: 408 | return [] 409 | 410 | all_snippet_trks = [] 411 | 412 | trk_data: np.ndarray = self.vid_trk[frm_idx] 413 | for ary_idx in range(0, len(trk_data)): 414 | trk_ary = trk_data[ary_idx] # trk_id(0), bbox(1~4), prob(5), cls(6) 415 | trk_id = trk_ary[0] 416 | 417 | snippet_trk_ary_list = [(frm_idx, trk_ary)] 418 | for _fi in [frm_idx - i * self.snippet_itv for i in range(1, self.snippet_len)]: 419 | _fi = round(_fi) 420 | if _fi in self.vid_trk: 421 | _other_trk_data = self.vid_trk[_fi] 422 | _ret_idx = np.where(trk_id == _other_trk_data[:, 0]) 423 | assert len(_ret_idx) == 1, f"{len(_ret_idx)}" 424 | if len(_ret_idx[0]) == 0: 425 | break 426 | else: 427 | _ary_idx = _ret_idx[0][0] 428 | _other_trk_ary = self.vid_trk[_fi][_ary_idx] 429 | snippet_trk_ary_list.append((_fi, _other_trk_ary)) 430 | else: 431 | break 432 | 433 | if len(snippet_trk_ary_list) < self.snippet_len: 434 | continue 435 | 436 | snippet_trk_ary_list.reverse() 437 | all_snippet_trks.append(snippet_trk_ary_list) 438 | 439 | return all_snippet_trks 440 | 441 | def _read_from_image(self, frame_idx: int) -> np.ndarray: 442 | if (frm := self._cache.get(frame_idx)) is not None: 443 | # print("USE CACHE") 444 | return frm 445 | else: 446 | # print("READ") 447 | frm_path = join(self._frame_dir, self.vid_name, self._frname_tmpl.format(frame_idx)) 448 | assert exists(frm_path), f"{frm_path}" 449 | frm = mmcv.imread(frm_path, backend='turbojpeg') 450 | self._cache.put(frame_idx, frm) 451 | return frm 452 | 453 | def read_all_frames(self): 454 | all_frames = [] 455 | for frm_idx in range(self.frame_cnt): 456 | all_frames.append(torch.from_numpy(self.get_frame(frm_idx))) 457 | all_frames = torch.stack(all_frames) 458 | return all_frames # [T, H, W, C] 459 | 460 | def load_snippet(self, snippet_trk_ary_list: List[Tuple[int, np.ndarray]]) -> torch.Tensor: 461 | # 根据某一帧,确定一个方形固定视野 462 | obj_bbox_list = np.asarray([[max(int(p), 0) for p in _trk_ary[1:5]] for _fi, _trk_ary in snippet_trk_ary_list], np.int64) 463 | # 这里先根据中间帧设置 464 | anchor_box = obj_bbox_list[int(len(obj_bbox_list) / 2), :] 465 | view_center_xy = [(anchor_box[0] + anchor_box[2]) // 2, (anchor_box[1] + anchor_box[3]) // 2] 466 | 467 | # full_view_range = 256 468 | full_view_range = 512 469 | half_view_range = full_view_range // 2 470 | 471 | _fidx = snippet_trk_ary_list[0][0] 472 | fh, fw, _ = self._read_from_image(_fidx).shape if self._frame_dir else self.get_frame(_fidx).shape 473 | view_center_xy[0] = min(max(half_view_range, view_center_xy[0]), fw - half_view_range) 474 | view_center_xy[1] = min(max(half_view_range, view_center_xy[1]), fh - half_view_range) 475 | 476 | full_view_box = [view_center_xy[0] - half_view_range, view_center_xy[1] - half_view_range, 477 | view_center_xy[0] + half_view_range, view_center_xy[1] + half_view_range] 478 | 479 | # 然后统一读取这一个区域 480 | snippet = [] 481 | for _fi, _trk_ary in snippet_trk_ary_list: 482 | if self.all_frames is None: 483 | frm: np.ndarray = self._read_from_image(_fi) if self._frame_dir else self.get_frame(_fi) 484 | else: 485 | frm: torch.Tensor = self.all_frames[_fi] 486 | 487 | obj = frm[full_view_box[1]:full_view_box[3], full_view_box[0]:full_view_box[2], :] 488 | 489 | # mmcv.imwrite(obj, f"../save.visual/{self.vid_name}/src_{_trk_ary[0]}_{_fi}.jpg", auto_mkdir=True) 490 | 491 | obj: torch.Tensor = torch.as_tensor(obj, dtype=torch.float32, device=self.device) 492 | obj = self._tsfm_img(obj.permute(2, 0, 1).div(255.)) 493 | 494 | # mmcv.imwrite(np.transpose(np.array((obj.cpu().numpy() * 0.225 + 0.45) * 255, dtype=np.uint8), (1, 2, 0)), f"../save.visual/{self.vid_name}/restore_{_trk_ary[0]}_{_fi}.jpg") 495 | 496 | snippet.append(obj) 497 | 498 | return torch.stack(snippet, 1) # [C,T,H,W] 499 | 500 | def __getitem__(self, end_frm_idx: int) -> torch.Tensor: 501 | ''' 502 | 把这个片段中所有物体的片段作为一个batch返回 503 | ''' 504 | assert 0 <= end_frm_idx < self.frame_cnt 505 | if end_frm_idx < (self._slen - 1) * self._sitv: 506 | return None 507 | else: 508 | all_trk_list = self.get_all_tracks(end_frm_idx) 509 | if all_trk_list: 510 | batch_snippet = [] 511 | # batch_box = [] 512 | for _trk_list in all_trk_list: 513 | batch_snippet.append(self.load_snippet(_trk_list)) 514 | # batch_box.append(self._tsfm_box(torch.stack([torch.from_numpy(_trk_ary[1:5]).float() for _fi, _trk_ary in _trk_list], 1))) # [C=4,T] 515 | batch_snippet = torch.stack(batch_snippet, 0) # [B,C,T=1,H,W] 516 | # batch_box = torch.stack(batch_box, 0) # [B,C=4,T] 517 | # print("TEST-GET_ITEM", batch_snippet.shape, batch_box.shape) 518 | # return batch_snippet, batch_box 519 | return batch_snippet 520 | else: 521 | return None 522 | 523 | def __iter__(self): 524 | raise NotImplementedError() 525 | 526 | def __next__(self): 527 | raise NotImplementedError() 528 | 529 | 530 | class TestSetTrackingObject(TrainSetTrackingObject): 531 | def __init__(self, video_dir: str, track_dir: str, snippet_len: int, snippet_itv: int, vid_suffix='avi', device='cpu', **kwargs): 532 | super().__init__(video_dir, track_dir, snippet_len, snippet_itv, 1, vid_suffix, False, device, **kwargs) 533 | 534 | def __getitem__(self, vid_idx: int): 535 | vid_name = list(self.all_trk_dict.keys())[vid_idx] 536 | vid_stream = SnippetVideoReader(join(self._video_dir, f"{vid_name}.{self.vid_suffix}"), self.all_trk_dict[vid_name], 537 | self.snippet_len, self.snippet_itv, self.device, self._tsfm_img, self._tsfm_box, self._frame_dir, self._frname_tmpl) 538 | return vid_stream 539 | 540 | 541 | if __name__ == "__main__": 542 | seed = 2 543 | torch.manual_seed(seed) 544 | np.random.seed(seed) 545 | random.seed(seed) 546 | from torch.utils.data import DataLoader, WeightedRandomSampler 547 | 548 | trainset = TrainSetTrackingObject("../data.Ave_hd/videos/Train", "../data.Ave_hd/tracking/Train", 9, 2, iterations=60, cache_video=False, device='cpu', frame_dir="../data.Ave_hd/frames/Train") 549 | 550 | smpl_weight = trainset.vid_samp_weight * trainset.iterations 551 | assert len(smpl_weight) == len(trainset) 552 | weighted_sampler = WeightedRandomSampler(smpl_weight, len(trainset), replacement=True) 553 | dloader = DataLoader(trainset, sampler=weighted_sampler, batch_size=1, pin_memory=True, num_workers=16) 554 | 555 | for i, (inp_img, tgt_img, inp_box, tgt_box) in enumerate(dloader): 556 | inp_img: torch.Tensor 557 | tgt_img: torch.Tensor 558 | inp_box: torch.Tensor 559 | tgt_box: torch.Tensor 560 | inp_img = inp_img.cuda(non_blocking=True) 561 | tgt_img = tgt_img.cuda(non_blocking=True) 562 | print(inp_img.shape, tgt_img.shape, inp_img.min(), inp_img.max(), inp_img.mean(), inp_img.std()) 563 | print(inp_box.shape, tgt_box.shape, inp_box.min(), inp_box.max(), inp_box.mean(), inp_box.std()) 564 | break 565 | 566 | testset = TestSetTrackingObject("../data.Ave_mem/videos/Test", "../data.Ave_mem/tracking/Test", 9, 2, device='cpu', frame_dir="../data.Ave_mem/frames/Test") 567 | # testset = TestSetTrackingObject("../data.Ave_mem/videos/Test", "../data.Ave_mem/tracking/Test", 9, 2, device='cpu') 568 | import time 569 | from time import time as ttime 570 | for vid_stream in testset: 571 | print(vid_stream.vid_name) 572 | vid_stream.device = 'cuda:0' 573 | t0 = ttime() 574 | for frm_idx in range(len(vid_stream)): 575 | batch_snippet, batch_box = vid_stream[frm_idx] 576 | if not batch_snippet is None: 577 | print(f"{frm_idx}, {batch_snippet.shape}, {batch_snippet.dtype}, {batch_snippet.device}, {batch_snippet.min()}, {batch_snippet.max()}, {batch_snippet.mean()}") 578 | print(f"{frm_idx}, {batch_box.shape}, {batch_box.dtype}, {batch_box.device}, {batch_box.min()}, {batch_box.max()}, {batch_box.mean()}") 579 | else: 580 | print(f"{frm_idx}, None") 581 | print(ttime() - t0) 582 | break 583 | --------------------------------------------------------------------------------