├── utils ├── __init__.py ├── logger.py ├── optimizer.py ├── config.py ├── cluster.py └── tools.py ├── datasets ├── __init__.py ├── blending.py ├── rand_augment.py └── build.py ├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── model_zoo.py ├── simple_tokenizer.py ├── clip.py └── model.py ├── labels ├── ucf_2_labels.csv ├── traffic_2_labels.csv ├── TAD_test.txt ├── TAD_train.txt └── UCF_test.txt ├── requirements.txt ├── tools ├── dist_test_recognizer.sh ├── dist_train_recognizer.sh └── dist_umil_recognizer.sh ├── configs ├── ucf │ └── 32_5.yaml └── traffic │ └── 32_5.yaml ├── models ├── mit.py ├── prompt.py ├── cct.py └── xclip.py ├── README.md ├── main.py └── main_umil.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /labels/ucf_2_labels.csv: -------------------------------------------------------------------------------- 1 | id,name 2 | 0,normal 3 | 1,abnormal 4 | -------------------------------------------------------------------------------- /labels/traffic_2_labels.csv: -------------------------------------------------------------------------------- 1 | id,name 2 | 0,normal 3 | 1,abnormal 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mmcv-full 2 | decord 3 | ftfy 4 | einops 5 | termcolor 6 | timm 7 | regex 8 | -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ktr-hubrt/UMIL/HEAD/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/model_zoo.py: -------------------------------------------------------------------------------- 1 | import os 2 | def get_model_path(ckpt): 3 | if os.path.isfile(ckpt): 4 | return ckpt 5 | else: 6 | print('not found pretrained model in {}'.format(ckpt)) 7 | raise FileNotFoundError 8 | -------------------------------------------------------------------------------- /tools/dist_test_recognizer.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | PYTHON=${PYTHON:-"python"} 4 | $PYTHON -m torch.distributed.launch --nproc_per_node=$1 main_umil.py -cfg configs/ucf/32_5.yaml --output output/test --only_test --pretrained output/test/best.pth 5 | -------------------------------------------------------------------------------- /tools/dist_train_recognizer.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | PYTHON=${PYTHON:-"python"} 4 | 5 | $PYTHON -m torch.distributed.launch --nproc_per_node=$1 main.py -cfg configs/ucf/32_5.yaml --batch-size 2 --accumulation-steps 8 --output output/mil --pretrained k400_32_8.pth -------------------------------------------------------------------------------- /tools/dist_umil_recognizer.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | PYTHON=${PYTHON:-"python"} 4 | $PYTHON -m torch.distributed.launch --nproc_per_node=$1 main_umil.py -cfg configs/ucf/32_5.yaml --batch-size 1 --batch-size-umil 16 --accumulation-steps 8 --output output/umil --pretrained k400_32_8.pth -------------------------------------------------------------------------------- /configs/ucf/32_5.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | ROOT: 'root of UCF dataset' 3 | TRAIN_FILE: 'labels/UCF_train.txt' 4 | VAL_FILE: 'labels/UCF_test.txt' 5 | DATASET: ucf 6 | NUM_CLIPS: 16 7 | NUM_FRAMES: 5 8 | FRAME_INTERVAL: 6 9 | NUM_CLASSES: 2 10 | LABEL_LIST: 'labels/ucf_2_labels.csv' 11 | FILENAME_TMPL: 'img_{:08}.jpg' 12 | MODEL: 13 | ARCH: ViT-B/32 14 | TRAIN: 15 | BATCH_SIZE: 2 16 | ACCUMULATION_STEPS: 8 17 | -------------------------------------------------------------------------------- /configs/traffic/32_5.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | ROOT: 'root of traffic dataset' 3 | TRAIN_FILE: 'labels/TAD_train.txt' 4 | VAL_FILE: 'labels/TAD_test.txt' 5 | DATASET: traffic 6 | NUM_CLIPS: 16 7 | NUM_FRAMES: 5 8 | FRAME_INTERVAL: 6 9 | NUM_CLASSES: 2 10 | LABEL_LIST: 'labels/traffic_2_labels.csv' 11 | FILENAME_TMPL: 'img_{:08}.jpg' 12 | MODEL: 13 | ARCH: ViT-B/32 14 | TRAIN: 15 | BATCH_SIZE: 2 16 | ACCUMULATION_STEPS: 8 -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | import functools 5 | from termcolor import colored 6 | 7 | 8 | @functools.lru_cache() 9 | def create_logger(output_dir, dist_rank=0, name=''): 10 | # create logger 11 | logger = logging.getLogger(name) 12 | logger.setLevel(logging.DEBUG) 13 | logger.propagate = False 14 | 15 | # create formatter 16 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 17 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 18 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' 19 | 20 | # create console handlers for master process 21 | if dist_rank == 0: 22 | console_handler = logging.StreamHandler(sys.stdout) 23 | console_handler.setLevel(logging.DEBUG) 24 | console_handler.setFormatter( 25 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 26 | logger.addHandler(console_handler) 27 | 28 | # create file handlers 29 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a') 30 | file_handler.setLevel(logging.DEBUG) 31 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 32 | logger.addHandler(file_handler) 33 | 34 | return logger 35 | -------------------------------------------------------------------------------- /models/mit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from collections import OrderedDict 4 | from timm.models.layers import trunc_normal_ 5 | import sys 6 | sys.path.append("../") 7 | from clip.model import QuickGELU 8 | 9 | 10 | class ResidualAttentionBlock(nn.Module): 11 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 12 | super().__init__() 13 | 14 | self.attn = nn.MultiheadAttention(d_model, n_head) 15 | self.ln_1 = nn.LayerNorm(d_model) 16 | self.mlp = nn.Sequential(OrderedDict([ 17 | ("c_fc", nn.Linear(d_model, d_model * 4)), 18 | ("gelu", QuickGELU()), 19 | ("c_proj", nn.Linear(d_model * 4, d_model)) 20 | ])) 21 | self.ln_2 = nn.LayerNorm(d_model) 22 | self.attn_mask = attn_mask 23 | 24 | def attention(self, x: torch.Tensor): 25 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 26 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 27 | 28 | def forward(self, x: torch.Tensor): 29 | x = x + self.attention(self.ln_1(x)) 30 | x = x + self.mlp(self.ln_2(x)) 31 | return x 32 | 33 | 34 | class MultiframeIntegrationTransformer(nn.Module): 35 | def __init__(self, T, embed_dim=512, layers=1,): 36 | super().__init__() 37 | self.T = T 38 | transformer_heads = embed_dim // 64 39 | self.positional_embedding = nn.Parameter(torch.empty(1, T, embed_dim)) 40 | trunc_normal_(self.positional_embedding, std=0.02) 41 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(d_model=embed_dim, n_head=transformer_heads) for _ in range(layers)]) 42 | 43 | self.apply(self._init_weights) 44 | 45 | def _init_weights(self, m): 46 | if isinstance(m, (nn.Linear,)): 47 | trunc_normal_(m.weight, std=0.02) 48 | if m.bias is not None: 49 | nn.init.zeros_(m.bias) 50 | elif isinstance(m, nn.LayerNorm): 51 | nn.init.zeros_(m.bias) 52 | nn.init.ones_(m.weight) 53 | 54 | def forward(self, x): 55 | ori_x = x 56 | x = x + self.positional_embedding 57 | x = x.permute(1, 0, 2) 58 | x = self.resblocks(x) 59 | x = x.permute(1, 0, 2) 60 | x = x.type(ori_x.dtype) + ori_x 61 | 62 | return x.mean(dim=1, keepdim=False) 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is an official implementation of CVPR23 paper 'Unbiased Multiple Instance Learning for Weakly Supervised Video Anomaly Detection' (https://arxiv.org/abs/2303.12369v1). 2 | 3 | 4 | # Environment Setup 5 | To set up the environment, you can easily run the following command: 6 | ``` 7 | conda create -n UMIL python=3.7 8 | conda activate UMIL 9 | pip install -r requirements.txt 10 | ``` 11 | 12 | Install Apex as follows 13 | ``` 14 | git clone https://github.com/NVIDIA/apex 15 | cd apex 16 | pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 17 | ``` 18 | 19 | # Data Preparation 20 | 21 | Download the videos and labels for UCF-crime or TAD dataset and extract frames from videos. 22 | The dataset directory should be origanized as follows: 23 | ``` 24 | UCF/ 25 | ├─ frames/ 26 | ├─ Abuse/ 27 | ├─ Abuse001_x264.mp4/ 28 | ├─ img_00000000.jpg 29 | ├─ Arrest/ 30 | ... 31 | 32 | TAD/ 33 | ├─ frames/ 34 | ├─ abnormal/ 35 | ├─ 01_Accident_001.mp4/ 36 | ... 37 | ├─ normal/ 38 | ... 39 | ``` 40 | > [**TAD extracted frames**](https://smu-my.sharepoint.com/:f:/r/personal/huilyu_smu_edu_sg/Documents/UMIL/TAD?csf=1&web=1&e=HxzRqC) 41 | 42 | # Pre-trained model weights 43 | Please find the model weights in the following: 44 | > [**k400 pre-trained weights**](https://smu-my.sharepoint.com/:u:/g/personal/huilyu_smu_edu_sg/ESDZwxBmIAdLqJBuDwhU-YIB1kn7MNEQ0CEGAkkUSwfPkA?e=7dpMd5) 45 | 46 | # Train 47 | The config files lie in `configs`. For example, to train X-CLIP-B/32 with 5 frames on UCF on 2 GPUs, you can run 48 | ``` 49 | CUDA_VISIBLE_DEVICES=0,1 bash tools/dist_train_recognizer.sh 2 50 | ``` 51 | 52 | **Note:** 53 | - The test during training is a fast test strategy, it does not represent the real AUC. 54 | - Please specify the data path in config file(`configs/*.yaml`). Also, you can set them by attaching an argument `--opts DATA.ROOT /PATH/TO/videos DATA.TRAIN_FILE /PATH/TO/train.txt DATA.VAL_FILE /PATH/TO/val.txt`. Note that if you use the tar file(`videos.tar`), just set the `DATA.ROOT` to `/PATH/TO/videos.tar`. For standard folder, set that to `/PATH/TO/videos` naturally. 55 | - The pretrained model will be automatically downloaded. Of course, you can specify it by using `--pretrained /PATH/TO/PRETRAINED`. 56 | 57 | # Test 58 | For example, to test the X-CLIP-B/32 with 5 frames on UCF, you can run 59 | ``` 60 | CUDA_VISIBLE_DEVICES=1 bash tools/dist_test_recognizer.sh 1 61 | ``` 62 | 63 | If you find this work helpful, please cite: 64 | ``` 65 | @inproceedings{Lv2023unbiased, 66 | title={Unbiased Multiple Instance Learning for Weakly Supervised Video Anomaly Detection}, 67 | author={Hui Lv and Zhongqi Yue and Qianru Sun and Bin Luo and Zhen Cui and Hanwang Zhang}, 68 | booktitle={CVPR}, 69 | year={2023} 70 | } 71 | ``` 72 | -------------------------------------------------------------------------------- /models/prompt.py: -------------------------------------------------------------------------------- 1 | from timm.models.layers import trunc_normal_ 2 | import torch 3 | from torch import nn 4 | import sys 5 | sys.path.append("../") 6 | from clip.model import QuickGELU 7 | 8 | 9 | class MulitHeadAttention(nn.Module): 10 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 11 | super().__init__() 12 | self.num_heads = num_heads 13 | head_dim = dim // num_heads 14 | 15 | self.scale = qk_scale or head_dim ** -0.5 16 | 17 | self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) 18 | self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) 19 | self.v_proj = nn.Linear(dim, dim, bias=qkv_bias) 20 | 21 | 22 | self.attn_drop = nn.Dropout(attn_drop) 23 | self.proj = nn.Linear(dim, dim) 24 | self.proj_drop = nn.Dropout(proj_drop) 25 | 26 | def forward(self, q, k, v): 27 | B, N, C = q.shape 28 | B, M, C = k.shape 29 | q = self.q_proj(q).reshape(B, N, self.num_heads, C // self.num_heads).permute(0,2,1,3) 30 | k = self.k_proj(k).reshape(B, M, self.num_heads, C // self.num_heads).permute(0,2,1,3) 31 | v = self.v_proj(v).reshape(B, M, self.num_heads, C // self.num_heads).permute(0,2,1,3) 32 | 33 | attn = (q @ k.transpose(-2, -1)) * self.scale 34 | attn = attn.softmax(dim=-1) 35 | attn = self.attn_drop(attn) 36 | 37 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 38 | x = self.proj(x) 39 | x = self.proj_drop(x) 40 | return x 41 | 42 | 43 | class PromptGeneratorLayer(nn.Module): 44 | def __init__( 45 | self, 46 | d_model, 47 | nhead, 48 | dropout=0., 49 | ): 50 | super().__init__() 51 | self.cross_attn = MulitHeadAttention(d_model, nhead, proj_drop=dropout) 52 | 53 | self.norm1 = nn.LayerNorm(d_model) 54 | self.norm3 = nn.LayerNorm(d_model) 55 | 56 | self.dropout = nn.Dropout(dropout) 57 | 58 | self.mlp = nn.Sequential( 59 | nn.Linear(d_model, d_model * 4), 60 | QuickGELU(), 61 | nn.Dropout(dropout), 62 | nn.Linear(d_model * 4, d_model) 63 | ) 64 | 65 | def forward(self, x, visual): 66 | q = k = v = self.norm1(x) 67 | x = x + self.cross_attn(q, visual, visual) 68 | x = x + self.dropout(self.mlp(self.norm3(x))) 69 | return x 70 | 71 | 72 | class VideoSpecificPrompt(nn.Module): 73 | def __init__(self, layers=2, embed_dim=512, alpha=0.1,): 74 | super().__init__() 75 | self.norm = nn.LayerNorm(embed_dim) 76 | self.decoder = nn.ModuleList([PromptGeneratorLayer(embed_dim, embed_dim//64) for _ in range(layers)]) 77 | self.alpha = nn.Parameter(torch.ones(embed_dim) * alpha) 78 | self.apply(self._init_weights) 79 | 80 | 81 | def _init_weights(self, m): 82 | if isinstance(m, nn.Linear): 83 | trunc_normal_(m.weight, std=.02) 84 | if isinstance(m, nn.Linear) and m.bias is not None: 85 | nn.init.constant_(m.bias, 0) 86 | elif isinstance(m, nn.LayerNorm): 87 | nn.init.constant_(m.bias, 0) 88 | nn.init.constant_(m.weight, 1.0) 89 | 90 | 91 | def forward(self, text, visual): 92 | B, N, C = visual.shape 93 | visual = self.norm(visual) 94 | for layer in self.decoder: 95 | text = layer(text, visual) 96 | 97 | return self.alpha * text -------------------------------------------------------------------------------- /utils/optimizer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch.optim as optim 3 | from timm.scheduler.cosine_lr import CosineLRScheduler 4 | import torch.distributed as dist 5 | 6 | def is_main_process(): 7 | return dist.get_rank() == 0 8 | 9 | def check_keywords_in_name(name, keywords=()): 10 | isin = False 11 | for keyword in keywords: 12 | if keyword in name: 13 | isin = True 14 | return isin 15 | 16 | def set_weight_decay(model, skip_list=(), skip_keywords=(), weight_decay=0.001, lr=2e-6, have=(), not_have=()): 17 | has_decay = [] 18 | no_decay = [] 19 | for name, param in model.named_parameters(): 20 | if not param.requires_grad: 21 | continue # frozen weights 22 | if len(have) > 0 and not check_keywords_in_name(name, have): 23 | continue 24 | if len(not_have) > 0 and check_keywords_in_name(name, not_have): 25 | continue 26 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 27 | check_keywords_in_name(name, skip_keywords): 28 | no_decay.append(param) 29 | else: 30 | has_decay.append(param) 31 | 32 | return [{'params': has_decay, 'weight_decay': weight_decay, 'lr': lr}, 33 | {'params': no_decay, 'weight_decay': 0., 'lr': lr}] 34 | 35 | 36 | def fix_text(model): 37 | for name, param in model.named_parameters(): 38 | if "visual." in name or "mit" in name or "prompts" in name or "_head" in name: 39 | continue 40 | else: 41 | param.requires_grad=False 42 | 43 | def build_optimizer(config, model): 44 | model = model.module if hasattr(model, 'module') else model 45 | 46 | # fix text 47 | if config.MODEL.FIX_TEXT: 48 | fix_text(model) 49 | 50 | # set decay and lr 51 | skip = {} 52 | skip_keywords = {} 53 | if hasattr(model, 'no_weight_decay'): 54 | skip = model.no_weight_decay() 55 | if hasattr(model, 'no_weight_decay_keywords'): 56 | skip_keywords = model.no_weight_decay_keywords() 57 | clip_parameters = set_weight_decay(model, skip, skip_keywords, 58 | weight_decay=config.TRAIN.WEIGHT_DECAY, lr=config.TRAIN.LR, 59 | have=(), not_have=("prompts", "mit", "message_", "anomaly_head", "cluster_head",) 60 | ) 61 | msg_parameters = set_weight_decay(model, skip, skip_keywords, 62 | weight_decay=config.TRAIN.WEIGHT_DECAY, lr=config.TRAIN.LR, 63 | have=("message_",), not_have=() 64 | ) 65 | mit_parameters = set_weight_decay(model, skip, skip_keywords, 66 | weight_decay=config.TRAIN.WEIGHT_DECAY, lr=config.TRAIN.LR, 67 | have=("mit",), not_have=() 68 | ) 69 | prompts_parameters = set_weight_decay(model, skip, skip_keywords, 70 | weight_decay=config.TRAIN.WEIGHT_DECAY, lr=config.TRAIN.LR, 71 | have=("prompts",), not_have=() 72 | ) 73 | anomaly_parameters = set_weight_decay(model, skip, skip_keywords, 74 | weight_decay=config.TRAIN.WEIGHT_DECAY, lr=config.TRAIN.LR*0.1, 75 | have=("anomaly_head",), not_have=() 76 | ) 77 | cluster_parameters = set_weight_decay(model, skip, skip_keywords, 78 | weight_decay=config.TRAIN.WEIGHT_DECAY, lr=config.TRAIN.LR, 79 | have=("cluster_head",), not_have=() 80 | ) 81 | 82 | optimizer = optim.AdamW(clip_parameters + mit_parameters + prompts_parameters + msg_parameters + anomaly_parameters, 83 | betas=(0.9, 0.98), eps=1e-8,) 84 | optimizer_umil = optim.AdamW(cluster_parameters + anomaly_parameters, 85 | betas=(0.9, 0.98), eps=1e-8,) 86 | return optimizer, optimizer_umil 87 | 88 | 89 | def build_scheduler(config, optimizer, n_iter_per_epoch): 90 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch) 91 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch) 92 | 93 | lr_scheduler = CosineLRScheduler( 94 | optimizer, 95 | t_initial=num_steps, 96 | lr_min=config.TRAIN.LR / 100, 97 | warmup_lr_init=0, 98 | warmup_t=warmup_steps, 99 | cycle_limit=1, 100 | t_in_epochs=False, 101 | ) 102 | 103 | return lr_scheduler -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from yacs.config import CfgNode as CN 4 | 5 | _C = CN() 6 | 7 | # Base config files 8 | _C.BASE = [''] 9 | 10 | # ----------------------------------------------------------------------------- 11 | # Data settings 12 | # ----------------------------------------------------------------------------- 13 | _C.DATA = CN() 14 | _C.DATA.ROOT = '' 15 | _C.DATA.TRAIN_FILE = '' 16 | _C.DATA.VAL_FILE = '' 17 | _C.DATA.DATASET = 'kinetics400' 18 | _C.DATA.INPUT_SIZE = 224 19 | _C.DATA.NUM_CLIPS = 16 20 | _C.DATA.NUM_FRAMES = 5 21 | _C.DATA.FRAME_INTERVAL = 6 22 | _C.DATA.NUM_CLASSES = 400 23 | _C.DATA.LABEL_LIST = 'labels/kinetics_400_labels.csv' 24 | _C.DATA.FILENAME_TMPL = 'img_{:08}.jpg' 25 | # ----------------------------------------------------------------------------- 26 | # Model settings 27 | # ----------------------------------------------------------------------------- 28 | _C.MODEL = CN() 29 | _C.MODEL.ARCH = 'ViT-B/32' 30 | _C.MODEL.DROP_PATH_RATE = 0. 31 | _C.MODEL.PRETRAINED = None 32 | _C.MODEL.RESUME = None 33 | _C.MODEL.FIX_TEXT = True 34 | 35 | # ----------------------------------------------------------------------------- 36 | # Training settings 37 | # ----------------------------------------------------------------------------- 38 | _C.TRAIN = CN() 39 | _C.TRAIN.EPOCHS = 40 40 | _C.TRAIN.WARMUP_EPOCHS = 5 41 | _C.TRAIN.WEIGHT_DECAY = 0.001 42 | _C.TRAIN.LR = 8.e-6 43 | _C.TRAIN.BATCH_SIZE = 8 44 | _C.TRAIN.BATCH_SIZE_UMIL = 4 45 | _C.TRAIN.ACCUMULATION_STEPS = 1 46 | _C.TRAIN.LR_SCHEDULER = 'cosine' 47 | _C.TRAIN.OPTIMIZER = 'adamw' 48 | _C.TRAIN.OPT_LEVEL = 'O1' 49 | _C.TRAIN.AUTO_RESUME = False 50 | _C.TRAIN.USE_CHECKPOINT = False 51 | 52 | # ----------------------------------------------------------------------------- 53 | # Augmentation settings 54 | # ----------------------------------------------------------------------------- 55 | _C.AUG = CN() 56 | _C.AUG.LABEL_SMOOTH = 0.0 57 | _C.AUG.COLOR_JITTER = 0.8 58 | _C.AUG.GRAY_SCALE = 0.2 59 | _C.AUG.MIXUP = 0.0 60 | _C.AUG.CUTMIX = 1.0 61 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 62 | 63 | # ----------------------------------------------------------------------------- 64 | # Testing settings 65 | # ----------------------------------------------------------------------------- 66 | _C.TEST = CN() 67 | _C.TEST.NUM_CLIP = 1 68 | _C.TEST.NUM_CROP = 1 69 | _C.TEST.ONLY_TEST = False 70 | 71 | # ----------------------------------------------------------------------------- 72 | # Misc 73 | # ----------------------------------------------------------------------------- 74 | _C.OUTPUT = '' 75 | _C.SAVE_FREQ = 1 76 | _C.PRINT_FREQ = 20 77 | _C.SEED = 1024 78 | 79 | 80 | 81 | def _update_config_from_file(config, cfg_file): 82 | config.defrost() 83 | with open(cfg_file, 'r') as f: 84 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 85 | 86 | for cfg in yaml_cfg.setdefault('BASE', ['']): 87 | if cfg: 88 | _update_config_from_file( 89 | config, os.path.join(os.path.dirname(cfg_file), cfg) 90 | ) 91 | print('=> merge config from {}'.format(cfg_file)) 92 | config.merge_from_file(cfg_file) 93 | config.freeze() 94 | 95 | 96 | def update_config(config, args): 97 | _update_config_from_file(config, args.config) 98 | 99 | config.defrost() 100 | if args.opts: 101 | config.merge_from_list(args.opts) 102 | # merge from specific arguments 103 | if args.batch_size: 104 | config.TRAIN.BATCH_SIZE = args.batch_size 105 | if args.pretrained: 106 | config.MODEL.PRETRAINED = args.pretrained 107 | if args.resume: 108 | config.MODEL.RESUME = args.resume 109 | if args.accumulation_steps: 110 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps 111 | if args.output: 112 | config.OUTPUT = args.output 113 | if args.only_test: 114 | config.TEST.ONLY_TEST = True 115 | # set local rank for distributed training 116 | config.LOCAL_RANK = args.local_rank 117 | config.freeze() 118 | 119 | 120 | def get_config(args): 121 | """Get a yacs CfgNode object with default values.""" 122 | # Return a clone so that the defaults will not be altered 123 | # This is for the "local variable" use pattern 124 | config = _C.clone() 125 | update_config(config, args) 126 | 127 | return config 128 | -------------------------------------------------------------------------------- /utils/cluster.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | class Normalize(nn.Module): 8 | def __init__(self, power=2): 9 | super(Normalize, self).__init__() 10 | self.power = power 11 | 12 | def forward(self, x): 13 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 14 | out = x.div(norm) 15 | return out 16 | 17 | def reduce_dimension(features, mode='umap', dim=50): 18 | if mode == 'pca': 19 | from sklearn.decomposition import PCA 20 | pca = PCA(n_components=dim) 21 | transformed_features = pca.fit_transform(features) 22 | fit_score = pca.explained_variance_ratio_.sum() 23 | elif mode == 'umap': 24 | import umap 25 | fit = umap.UMAP(n_components=dim) 26 | transformed_features = fit.fit_transform(features) 27 | fit_score = 0.0 28 | return transformed_features, fit_score 29 | 30 | 31 | def PairEnum(x,mask=None): 32 | # Enumerate all pairs of feature in x 33 | assert x.ndimension() == 2, 'Input dimension must be 2' 34 | x1 = x.repeat(x.size(0), 1) 35 | x2 = x.repeat(1, x.size(0)).view(-1, x.size(1)) 36 | if mask is not None: 37 | xmask = mask.view(-1,1).repeat(1,x.size(1)) 38 | #dim 0: #sample, dim 1:#feature 39 | x1 = x1[xmask].view(-1,x.size(1)) 40 | x2 = x2[xmask].view(-1,x.size(1)) 41 | return x1,x2 42 | 43 | class BCE(nn.Module): 44 | eps = 1e-7 # Avoid calculating log(0). Use the small value of float16. 45 | def forward(self, prob1, prob2, simi): 46 | # simi: 1->similar; -1->dissimilar; 0->unknown(ignore) 47 | assert len(prob1)==len(prob2)==len(simi), 'Wrong input size:{0},{1},{2}'.format(str(len(prob1)),str(len(prob2)),str(len(simi))) 48 | P = prob1.mul_(prob2) 49 | P = P.sum(1) 50 | P.mul_(simi).add_(simi.eq(-1).type_as(P)) 51 | neglogP = -P.add_(BCE.eps).log_() 52 | return neglogP.mean() 53 | 54 | class ClusterLoss(): 55 | def __init__(self, num_classes, bce_type, cosine_threshold, topk): 56 | self.num_classes = num_classes 57 | self.bce_type = bce_type 58 | self.costhre = cosine_threshold 59 | self.topk = topk 60 | self.bce = BCE() 61 | 62 | def compute_losses(self, inputs, include_label=False, unlabel_only=True): 63 | assert (include_label == False) or (unlabel_only == False) 64 | bce_loss = 0.0 65 | feat, feat_q, output2 = \ 66 | inputs["x1"], inputs["x1_norm"], inputs["preds1_u"] 67 | feat_bar, feat_k, output2_bar = \ 68 | inputs["x2"], inputs["x2_norm"], inputs["preds2_u"] 69 | label = inputs["labels"] 70 | 71 | if unlabel_only: 72 | mask_lb = inputs["mask"] 73 | else: 74 | mask_lb = torch.zeros_like(inputs["mask"]).bool() 75 | 76 | prob2, prob2_bar = F.softmax(output2, dim=1), F.softmax(output2_bar, dim=1) 77 | 78 | rank_feat = (feat[~mask_lb]).detach() 79 | if self.bce_type == 'cos': 80 | # default: cosine similarity with threshold 81 | feat_row, feat_col = PairEnum(F.normalize(rank_feat, dim=1)) 82 | tmp_distance_ori = torch.bmm( 83 | feat_row.view(feat_row.size(0), 1, -1), 84 | feat_col.view(feat_row.size(0), -1, 1) 85 | ) 86 | tmp_distance_ori = tmp_distance_ori.squeeze() 87 | target_ulb = torch.zeros_like(tmp_distance_ori).float() - 1 88 | target_ulb[tmp_distance_ori > self.costhre] = 1 89 | elif self.bce_type == 'RK': 90 | # top-k rank statics 91 | rank_idx = torch.argsort(rank_feat, dim=1, descending=True) 92 | rank_idx1, rank_idx2 = PairEnum(rank_idx) 93 | rank_idx1, rank_idx2 = rank_idx1[:, :self.topk], rank_idx2[:, :self.topk] 94 | rank_idx1, _ = torch.sort(rank_idx1, dim=1) 95 | rank_idx2, _ = torch.sort(rank_idx2, dim=1) 96 | rank_diff = rank_idx1 - rank_idx2 97 | rank_diff = torch.sum(torch.abs(rank_diff), dim=1) 98 | target_ulb = torch.ones_like(rank_diff).float().cuda() 99 | target_ulb[rank_diff > 0] = -1 100 | 101 | if include_label: 102 | # use source domain label for similar/dissimilar 103 | labels = labels_s.contiguous().view(-1, 1) 104 | mask_l = torch.eq(labels, labels.T).float().to(device) 105 | mask_l = (mask_l - 0.5) * 2.0 106 | target_ulb_t = target_ulb.view(feat.size(0), -1) 107 | target_ulb_t[:num_s, :num_s] = mask_l 108 | target_ulb = target_ulb_t.flatten() 109 | 110 | prob1_ulb, _ = PairEnum(prob2[~mask_lb]) 111 | _, prob2_ulb = PairEnum(prob2_bar[~mask_lb]) 112 | 113 | bce_loss = self.bce(prob1_ulb, prob2_ulb, target_ulb) 114 | return bce_loss, target_ulb 115 | -------------------------------------------------------------------------------- /labels/TAD_test.txt: -------------------------------------------------------------------------------- 1 | frames/abnormal/01_Accident_006.mp4 266 1 75 110 2 | frames/abnormal/01_Accident_008.mp4 270 1 32 254 3 | frames/abnormal/01_Accident_011.mp4 235 1 16 66 4 | frames/abnormal/01_Accident_014.mp4 243 1 86 244 5 | frames/abnormal/01_Accident_020.mp4 342 1 47 343 6 | frames/abnormal/01_Accident_023.mp4 331 1 77 332 7 | frames/abnormal/01_Accident_025.mp4 479 1 47 158 8 | frames/abnormal/01_Accident_033.mp4 214 1 56 215 9 | frames/abnormal/01_Accident_036.mp4 396 1 130 322 10 | frames/abnormal/01_Accident_038.mp4 210 1 24 128 11 | frames/abnormal/01_Accident_039.mp4 232 1 22 233 12 | frames/abnormal/01_Accident_041.mp4 265 1 22 96 13 | frames/abnormal/01_Accident_044.mp4 394 1 6 57 14 | frames/abnormal/01_Accident_047.mp4 339 1 92 124 15 | frames/abnormal/01_Accident_050.mp4 242 1 36 243 16 | frames/abnormal/01_Accident_054.mp4 242 1 115 167 17 | frames/abnormal/01_Accident_062.mp4 386 1 48 387 18 | frames/abnormal/01_Accident_063.mp4 313 1 100 156 19 | frames/abnormal/01_Accident_064.mp4 364 1 245 272 20 | frames/abnormal/01_Accident_083.mp4 357 1 20 322 21 | frames/abnormal/01_Accident_086.mp4 187 1 113 188 22 | frames/abnormal/01_Accident_088.mp4 441 1 55 164 23 | frames/abnormal/01_Accident_091.mp4 100 1 35 101 24 | frames/abnormal/01_Accident_095.mp4 240 1 94 241 25 | frames/abnormal/02_IllegalTurn_004.mp4 278 2 140 202 26 | frames/abnormal/02_IllegalTurn_007.mp4 375 2 145 169 27 | frames/abnormal/02_IllegalTurn_011.mp4 805 2 73 264 28 | frames/abnormal/03_IllegalOccupation_003.mp4 246 3 3 247 29 | frames/abnormal/03_IllegalOccupation_004.mp4 255 3 0 256 30 | frames/abnormal/03_IllegalOccupation_005.mp4 337 3 0 338 31 | frames/abnormal/03_IllegalOccupation_019.mp4 233 3 0 234 32 | frames/abnormal/03_IllegalOccupation_020.mp4 546 3 0 513 33 | frames/abnormal/04_Retrograde_002.mp4 331 4 0 120 34 | frames/abnormal/04_Retrograde_004.mp4 254 4 86 117 35 | frames/abnormal/05_else_010.mp4 361 5 66 322 36 | frames/abnormal/05_else_020.mp4 296 5 95 285 37 | frames/abnormal/05_else_023.mp4 173 5 131 174 38 | frames/abnormal/05_else_026.mp4 147 5 0 140 39 | frames/abnormal/05_else_031.mp4 129 5 22 130 40 | frames/abnormal/05_else_037.mp4 175 5 14 176 41 | frames/abnormal/05_else_038.mp4 262 5 0 88 42 | frames/abnormal/05_else_045.mp4 301 5 45 160 43 | frames/abnormal/05_else_047.mp4 537 5 196 322 44 | frames/abnormal/05_else_049.mp4 951 5 0 952 45 | frames/abnormal/06_PedestrianOnRoad_002.mp4 462 6 278 421 46 | frames/abnormal/06_PedestrianOnRoad_004.mp4 213 6 150 171 47 | frames/abnormal/06_PedestrianOnRoad_006.mp4 753 6 0 754 48 | frames/abnormal/06_PedestrianOnRoad_009.mp4 666 6 0 534 49 | frames/abnormal/06_PedestrianOnRoad_014.mp4 606 6 117 293 50 | frames/abnormal/06_PedestrianOnRoad_016.mp4 631 6 382 430 51 | frames/abnormal/06_PedestrianOnRoad_017.mp4 275 6 113 252 52 | frames/abnormal/06_PedestrianOnRoad_019.mp4 337 6 203 264 53 | frames/abnormal/06_PedestrianOnRoad_021.mp4 375 6 0 376 54 | frames/abnormal/06_PedestrianOnRoad_022.mp4 409 6 248 410 55 | frames/abnormal/06_PedestrianOnRoad_025.mp4 918 6 661 729 56 | frames/abnormal/06_PedestrianOnRoad_027.mp4 376 6 0 81 57 | frames/abnormal/07_RoadSpills_003.mp4 397 7 0 398 58 | frames/abnormal/07_RoadSpills_008.mp4 392 7 175 276 59 | frames/abnormal/07_RoadSpills_013.mp4 155 7 51 70 60 | frames/abnormal/07_RoadSpills_014.mp4 292 7 45 72 61 | frames/normal/Normal_008.mp4 77 0 0 77 62 | frames/normal/Normal_011.mp4 70 0 0 70 63 | frames/normal/Normal_014.mp4 101 0 0 101 64 | frames/normal/Normal_017.mp4 104 0 0 104 65 | frames/normal/Normal_018.mp4 330 0 0 330 66 | frames/normal/Normal_019.mp4 257 0 0 257 67 | frames/normal/Normal_026.mp4 239 0 0 239 68 | frames/normal/Normal_030.mp4 477 0 0 477 69 | frames/normal/Normal_041.mp4 200 0 0 200 70 | frames/normal/Normal_048.mp4 149 0 0 149 71 | frames/normal/Normal_050.mp4 127 0 0 127 72 | frames/normal/Normal_066.mp4 118 0 0 118 73 | frames/normal/Normal_069.mp4 116 0 0 116 74 | frames/normal/Normal_073.mp4 119 0 0 119 75 | frames/normal/Normal_075.mp4 59 0 0 59 76 | frames/normal/Normal_092.mp4 163 0 0 163 77 | frames/normal/Normal_095.mp4 186 0 0 186 78 | frames/normal/Normal_120.mp4 1528 0 0 1528 79 | frames/normal/Normal_122.mp4 2967 0 0 2967 80 | frames/normal/Normal_126.mp4 595 0 0 595 81 | frames/normal/Normal_136.mp4 5394 0 0 5394 82 | frames/normal/Normal_139.mp4 7792 0 0 7792 83 | frames/normal/Normal_143.mp4 8991 0 0 8991 84 | frames/normal/Normal_155.mp4 2036 0 0 2036 85 | frames/normal/Normal_160.mp4 1168 0 0 1168 86 | frames/normal/Normal_164.mp4 2631 0 0 2631 87 | frames/normal/Normal_170.mp4 3596 0 0 3596 88 | frames/normal/Normal_178.mp4 239 0 0 239 89 | frames/normal/Normal_183.mp4 764 0 0 764 90 | frames/normal/Normal_190.mp4 2697 0 0 2697 91 | frames/normal/Normal_192.mp4 4391 0 0 4391 92 | frames/normal/Normal_194.mp4 665 0 0 665 93 | frames/normal/Normal_199.mp4 1949 0 0 1949 94 | frames/normal/Normal_203.mp4 400 0 0 400 95 | frames/normal/Normal_215.mp4 4039 0 0 4039 96 | frames/normal/Normal_222.mp4 2038 0 0 2038 97 | frames/normal/Normal_225.mp4 1981 0 0 1981 98 | frames/normal/Normal_230.mp4 3116 0 0 3116 99 | frames/normal/Normal_238.mp4 580 0 0 580 100 | frames/normal/Normal_241.mp4 300 0 0 300 101 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /models/cct.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from timm.models.layers import trunc_normal_ 3 | import torch 4 | from torch import nn 5 | from torch.utils.checkpoint import checkpoint_sequential 6 | import sys 7 | sys.path.append("../") 8 | from clip.model import LayerNorm, QuickGELU, DropPath 9 | 10 | 11 | class CrossFramelAttentionBlock(nn.Module): 12 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, droppath = 0., T=0, ): 13 | super().__init__() 14 | self.T = T 15 | 16 | self.message_fc = nn.Linear(d_model, d_model) 17 | self.message_ln = LayerNorm(d_model) 18 | self.message_attn = nn.MultiheadAttention(d_model, n_head,) 19 | 20 | self.attn = nn.MultiheadAttention(d_model, n_head,) 21 | self.ln_1 = LayerNorm(d_model) 22 | 23 | self.drop_path = DropPath(droppath) if droppath > 0. else nn.Identity() 24 | self.mlp = nn.Sequential(OrderedDict([ 25 | ("c_fc", nn.Linear(d_model, d_model * 4)), 26 | ("gelu", QuickGELU()), 27 | ("c_proj", nn.Linear(d_model * 4, d_model)) 28 | ])) 29 | self.ln_2 = LayerNorm(d_model) 30 | self.attn_mask = attn_mask 31 | 32 | def attention(self, x: torch.Tensor): 33 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 34 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 35 | 36 | 37 | def forward(self, x): 38 | l, bt, d = x.size() 39 | b = bt // self.T 40 | x = x.view(l, b, self.T, d) 41 | 42 | msg_token = self.message_fc(x[0,:,:,:]) 43 | msg_token = msg_token.view(b, self.T, 1, d) 44 | 45 | msg_token = msg_token.permute(1,2,0,3).view(self.T, b, d) 46 | msg_token = msg_token + self.drop_path(self.message_attn(self.message_ln(msg_token),self.message_ln(msg_token),self.message_ln(msg_token),need_weights=False)[0]) 47 | msg_token = msg_token.view(self.T, 1, b, d).permute(1,2,0,3) 48 | 49 | x = torch.cat([x, msg_token], dim=0) 50 | 51 | x = x.view(l+1, -1, d) 52 | x = x + self.drop_path(self.attention(self.ln_1(x))) 53 | x = x[:l,:,:] 54 | x = x + self.drop_path(self.mlp(self.ln_2(x))) 55 | return x 56 | 57 | 58 | class Transformer(nn.Module): 59 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, droppath=None, use_checkpoint=False, T=8): 60 | super().__init__() 61 | self.use_checkpoint = use_checkpoint 62 | if droppath is None: 63 | droppath = [0.0 for i in range(layers)] 64 | self.width = width 65 | self.layers = layers 66 | 67 | self.resblocks = nn.Sequential(*[CrossFramelAttentionBlock(width, heads, attn_mask, droppath[i], T) for i in range(layers)]) 68 | 69 | def forward(self, x: torch.Tensor): 70 | if not self.use_checkpoint: 71 | return self.resblocks(x) 72 | else: 73 | return checkpoint_sequential(self.resblocks, 3, x) 74 | 75 | 76 | class CrossFrameCommunicationTransformer(nn.Module): 77 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, 78 | droppath = None, T = 8, use_checkpoint = False,): 79 | super().__init__() 80 | self.input_resolution = input_resolution 81 | self.output_dim = output_dim 82 | 83 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 84 | 85 | scale = width ** -0.5 86 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 87 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 88 | self.ln_pre = LayerNorm(width) 89 | 90 | ## Attention Blocks 91 | self.transformer = Transformer(width, layers, heads, droppath=droppath, use_checkpoint=use_checkpoint, T=T,) 92 | self.ln_post = LayerNorm(width) 93 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 94 | 95 | 96 | def init_weights(self): 97 | self.apply(self._init_weights) 98 | 99 | def _init_weights(self, m): 100 | if isinstance(m, nn.Linear): 101 | trunc_normal_(m.weight, std=.02) 102 | if isinstance(m, nn.Linear) and m.bias is not None: 103 | nn.init.constant_(m.bias, 0) 104 | elif isinstance(m, nn.LayerNorm): 105 | nn.init.constant_(m.bias, 0) 106 | nn.init.constant_(m.weight, 1.0) 107 | 108 | def forward(self, x: torch.Tensor): 109 | # import pdb; 110 | # pdb.set_trace() 111 | x = self.conv1(x) # shape = [*, width, grid, grid] 112 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 113 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 114 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 115 | x = x + self.positional_embedding.to(x.dtype) 116 | 117 | x = self.ln_pre(x) 118 | 119 | x = x.permute(1, 0, 2) 120 | x = self.transformer(x) 121 | x = x.permute(1, 0, 2) 122 | 123 | cls_x = self.ln_post(x[:, 0, :]) 124 | 125 | if self.proj is not None: 126 | cls_x = cls_x @ self.proj 127 | 128 | return cls_x, x[:,1:,:] 129 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import torch.distributed as dist 3 | import torch 4 | import clip 5 | import os 6 | import numpy as np 7 | from sklearn.metrics import roc_auc_score, roc_curve 8 | import scipy.signal as signal 9 | from matplotlib import pyplot as plt 10 | 11 | def match(scores1, scores2): 12 | #score shape: T,2 13 | score1 = scores1[:, 1] 14 | score2 = scores2[:, 1] 15 | iou = np.stack((score1,score2),axis=1).min(1).sum() 16 | iou_ = np.stack((score1,1-score2),axis=1).min(1).sum() 17 | 18 | if iou > iou_: 19 | return score2 20 | else: 21 | return 1-score2 22 | 23 | def evaluate_result(vid2abnormality, anno_file, root=''): 24 | LABEL_PATH = anno_file 25 | gt = [] 26 | ans = [] 27 | GT = [] 28 | ANS = [] 29 | video_path_list = [] 30 | videos = {} 31 | for video in open(LABEL_PATH): 32 | vid = video.strip().split(' ')[0].split('/')[-1] 33 | video_len = int(video.strip().split(' ')[1]) 34 | sub_video_gt = np.zeros((video_len,), dtype=np.int8) 35 | anomaly_tuple = video.split(' ')[3:] 36 | for ind in range(len(anomaly_tuple) // 2): 37 | start = int(anomaly_tuple[2 * ind]) 38 | end = int(anomaly_tuple[2 * ind + 1]) 39 | if start > 0: 40 | sub_video_gt[start:end] = 1 41 | videos[vid] = sub_video_gt 42 | 43 | for vid in videos: 44 | if vid not in vid2abnormality.keys(): 45 | print("The video %s is excluded on the result!" % vid) 46 | continue 47 | 48 | cur_ab = np.array(vid2abnormality[vid]) 49 | if cur_ab.shape[0]==1: 50 | cur_ab = cur_ab[0, :,] 51 | else: 52 | cur_ab = cur_ab[:, 0,] 53 | cur_gt = np.array(videos[vid]) 54 | ratio = float(len(cur_gt)) / float(len(cur_ab)) 55 | cur_ans = np.zeros_like(cur_gt, dtype='float32') 56 | for i in range(len(cur_ab)): 57 | b = int(i * ratio + 0.5) 58 | e = int((i + 1) * ratio + 0.5) 59 | cur_ans[b: e] = cur_ab[i] 60 | 61 | cur_ans = postpress(cur_ans, seg_size=32) 62 | 63 | if cur_gt.max() >=1: 64 | gt.extend(cur_gt.tolist()) 65 | ans.extend(cur_ans.tolist()) 66 | 67 | GT.extend(cur_gt.tolist()) 68 | ANS.extend(cur_ans.tolist()) 69 | 70 | ret = roc_auc_score(gt, ans) 71 | Ret = roc_auc_score(GT, ANS) 72 | fpr, tpr, threshold = roc_curve(GT, ANS) 73 | 74 | if root != '': 75 | output_file = path + "AUC.npz" 76 | np.savez(output_file, fpr=fpr, tpr=tpr, thre=threshold) 77 | 78 | return Ret, ret 79 | 80 | def postpress(curve, seg_size=32): 81 | leng = curve.shape[0] 82 | window_size = leng//seg_size 83 | new_curve = np.zeros_like(curve) 84 | for i in range(seg_size): 85 | new_curve[window_size*i:window_size*(i+1)] = np.mean(curve[window_size*i:window_size*(i+1)]) 86 | if leng>window_size*seg_size: 87 | new_curve[seg_size*window_size:] = np.mean(curve[seg_size*window_size:]) 88 | return new_curve 89 | 90 | def reduce_tensor(tensor, n=None): 91 | if n is None: 92 | n = dist.get_world_size() 93 | rt = tensor.clone() 94 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 95 | rt = rt / n 96 | return rt 97 | 98 | 99 | class AverageMeter: 100 | """Computes and stores the average and current value""" 101 | def __init__(self): 102 | self.reset() 103 | 104 | def reset(self): 105 | self.val = 0 106 | self.avg = 0 107 | self.sum = 0 108 | self.count = 0 109 | 110 | def update(self, val, n=1): 111 | self.val = val 112 | self.sum += val * n 113 | self.count += n 114 | self.avg = self.sum / self.count 115 | 116 | def sync(self): 117 | rank = dist.get_rank() 118 | world_size = dist.get_world_size() 119 | val = torch.tensor(self.val).cuda() 120 | sum_v = torch.tensor(self.sum).cuda() 121 | count = torch.tensor(self.count).cuda() 122 | self.val = reduce_tensor(val, world_size).item() 123 | self.sum = reduce_tensor(sum_v, 1).item() 124 | self.count = reduce_tensor(count, 1).item() 125 | self.avg = self.sum / self.count 126 | 127 | 128 | def epoch_saving(config, epoch, model, max_accuracy, optimizer, lr_scheduler, optimizer_u, lr_scheduler_u, logger, working_dir, is_best): 129 | save_state = {'model': model.state_dict(), 130 | 'optimizer': optimizer.state_dict(), 131 | 'lr_scheduler': lr_scheduler.state_dict(), 132 | 'max_accuracy': max_accuracy, 133 | 'epoch': epoch, 134 | 'config': config} 135 | if (epoch + 1) % 10 == 0: 136 | save_path = os.path.join(working_dir, f'ckpt_epoch_{epoch}.pth') 137 | logger.info(f"{save_path} saving......") 138 | torch.save(save_state, save_path) 139 | logger.info(f"{save_path} saved !!!") 140 | if is_best: 141 | best_path = os.path.join(working_dir, f'best.pth') 142 | torch.save(save_state, best_path) 143 | logger.info(f"{best_path} saved !!!") 144 | 145 | 146 | def load_checkpoint(config, model, optimizer, lr_scheduler, logger): 147 | if os.path.isfile(config.MODEL.RESUME): 148 | logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................") 149 | checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') 150 | load_state_dict = checkpoint['model'] 151 | 152 | msg = model.load_state_dict(load_state_dict, strict=False) 153 | logger.info(f"resume model: {msg}") 154 | 155 | try: 156 | optimizer.load_state_dict(checkpoint['optimizer']) 157 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 158 | 159 | start_epoch = checkpoint['epoch'] + 1 160 | max_accuracy = checkpoint['max_accuracy'] 161 | 162 | logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") 163 | 164 | del checkpoint 165 | torch.cuda.empty_cache() 166 | 167 | return start_epoch, max_accuracy 168 | except: 169 | del checkpoint 170 | torch.cuda.empty_cache() 171 | return 0, 0. 172 | 173 | else: 174 | logger.info(("=> no checkpoint found at '{}'".format(config.MODEL.RESUME))) 175 | return 0, 0 176 | 177 | 178 | def auto_resume_helper(output_dir): 179 | checkpoints = os.listdir(output_dir) 180 | checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] 181 | print(f"All checkpoints founded in {output_dir}: {checkpoints}") 182 | if len(checkpoints) > 0: 183 | latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) 184 | print(f"The latest checkpoint founded: {latest_checkpoint}") 185 | resume_file = latest_checkpoint 186 | else: 187 | resume_file = None 188 | return resume_file 189 | 190 | 191 | def generate_text(data): 192 | text_aug = f"{{}}" 193 | classes = torch.cat([clip.tokenize(text_aug.format(c), context_length=77) for i, c in data.classes]) 194 | 195 | return classes 196 | -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | # from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | import sys 15 | sys.path.append("../") 16 | from models.xclip import build_model 17 | 18 | 19 | __all__ = ["available_models", "load", "tokenize", "_download", "_MODELS"] 20 | _tokenizer = _Tokenizer() 21 | 22 | _MODELS = { 23 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 24 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 25 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 26 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 27 | } 28 | 29 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 30 | os.makedirs(root, exist_ok=True) 31 | filename = os.path.basename(url) 32 | 33 | expected_sha256 = url.split("/")[-2] 34 | download_target = os.path.join(root, filename) 35 | 36 | if os.path.exists(download_target) and not os.path.isfile(download_target): 37 | raise RuntimeError(f"{download_target} exists and is not a regular file") 38 | 39 | if os.path.isfile(download_target): 40 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 41 | return download_target 42 | else: 43 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 44 | 45 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 46 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 47 | while True: 48 | buffer = source.read(8192) 49 | if not buffer: 50 | break 51 | 52 | output.write(buffer) 53 | loop.update(len(buffer)) 54 | 55 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 56 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 57 | 58 | return download_target 59 | 60 | 61 | def _transform(n_px): 62 | return Compose([ 63 | Resize(n_px, interpolation=Image.BICUBIC), 64 | CenterCrop(n_px), 65 | lambda image: image.convert("RGB"), 66 | ToTensor(), 67 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 68 | ]) 69 | 70 | 71 | def available_models() -> List[str]: 72 | """Returns the names of available CLIP models""" 73 | return list(_MODELS.keys()) 74 | 75 | 76 | def load(model_path, name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 77 | jit=True, T=8, droppath=0., use_checkpoint=False, logger=None, use_cache=True, prompts_alpha=1e-1, prompts_layers=2, mit_layers=1, 78 | ): 79 | """Load a CLIP model 80 | 81 | Parameters 82 | ---------- 83 | name : str 84 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 85 | 86 | device : Union[str, torch.device] 87 | The device to put the loaded model 88 | 89 | jit : bool 90 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 91 | 92 | Returns 93 | ------- 94 | model : torch.nn.Module 95 | The CLIP model 96 | 97 | preprocess : Callable[[PIL.Image], torch.Tensor] 98 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 99 | """ 100 | 101 | if model_path is None: 102 | model_path = _download(_MODELS[name]) 103 | try: 104 | # loading JIT archive 105 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 106 | state_dict = None 107 | except RuntimeError: 108 | # loading saved state dict 109 | if jit: 110 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 111 | jit = False 112 | state_dict = torch.load(model_path, map_location="cpu") 113 | 114 | if not jit: 115 | model = build_model(state_dict or model.state_dict(), T=T, droppath=droppath, 116 | use_checkpoint=use_checkpoint, logger=logger, 117 | prompts_alpha=prompts_alpha, 118 | prompts_layers=prompts_layers, 119 | use_cache=use_cache, 120 | mit_layers=mit_layers, 121 | ) 122 | if str(device) == "cpu": 123 | model.float() 124 | return model, model.state_dict() 125 | 126 | # patch the device names 127 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 128 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 129 | 130 | def patch_device(module): 131 | graphs = [module.graph] if hasattr(module, "graph") else [] 132 | if hasattr(module, "forward1"): 133 | graphs.append(module.forward1.graph) 134 | 135 | for graph in graphs: 136 | for node in graph.findAllNodes("prim::Constant"): 137 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 138 | node.copyAttributes(device_node) 139 | 140 | model.apply(patch_device) 141 | 142 | if str(device) == "cpu": 143 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 144 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 145 | float_node = float_input.node() 146 | 147 | def patch_float(module): 148 | graphs = [module.graph] if hasattr(module, "graph") else [] 149 | if hasattr(module, "forward1"): 150 | graphs.append(module.forward1.graph) 151 | 152 | for graph in graphs: 153 | for node in graph.findAllNodes("aten::to"): 154 | inputs = list(node.inputs()) 155 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 156 | if inputs[i].node()["value"] == 5: 157 | inputs[i].node().copyAttributes(float_node) 158 | 159 | model.apply(patch_float) 160 | patch_float(model.encode_image) 161 | patch_float(model.encode_text) 162 | 163 | model.float() 164 | 165 | return model, _transform(model.input_resolution.item()) 166 | 167 | 168 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 169 | """ 170 | Returns the tokenized representation of given input string(s) 171 | 172 | Parameters 173 | ---------- 174 | texts : Union[str, List[str]] 175 | An input string or a list of input strings to tokenize 176 | 177 | context_length : int 178 | The context length to use; all CLIP models use 77 as the context length 179 | 180 | Returns 181 | ------- 182 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 183 | """ 184 | if isinstance(texts, str): 185 | texts = [texts] 186 | 187 | sot_token = _tokenizer.encoder["<|startoftext|>"] 188 | eot_token = _tokenizer.encoder["<|endoftext|>"] 189 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 190 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 191 | 192 | for i, tokens in enumerate(all_tokens): 193 | if len(tokens) > context_length: 194 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 195 | result[i, :len(tokens)] = torch.tensor(tokens) 196 | 197 | return result 198 | -------------------------------------------------------------------------------- /datasets/blending.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.distributions.beta import Beta 6 | import numpy as np 7 | 8 | 9 | def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): 10 | x = x.long().view(-1, 1) 11 | return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) 12 | 13 | 14 | class BaseMiniBatchBlending(metaclass=ABCMeta): 15 | """Base class for Image Aliasing.""" 16 | 17 | def __init__(self, num_classes, smoothing=0.): 18 | self.num_classes = num_classes 19 | self.off_value = smoothing / self.num_classes 20 | self.on_value = 1. - smoothing + self.off_value 21 | 22 | @abstractmethod 23 | def do_blending(self, imgs, label, **kwargs): 24 | pass 25 | 26 | def __call__(self, imgs, label, **kwargs): 27 | """Blending data in a mini-batch. 28 | 29 | Images are float tensors with the shape of (B, N, C, H, W) for 2D 30 | recognizers or (B, N, C, T, H, W) for 3D recognizers. 31 | 32 | Besides, labels are converted from hard labels to soft labels. 33 | Hard labels are integer tensors with the shape of (B, 1) and all of the 34 | elements are in the range [0, num_classes - 1]. 35 | Soft labels (probablity distribution over classes) are float tensors 36 | with the shape of (B, 1, num_classes) and all of the elements are in 37 | the range [0, 1]. 38 | 39 | Args: 40 | imgs (torch.Tensor): Model input images, float tensor with the 41 | shape of (B, N, C, H, W) or (B, N, C, T, H, W). 42 | label (torch.Tensor): Hard labels, integer tensor with the shape 43 | of (B, 1) and all elements are in range [0, num_classes). 44 | kwargs (dict, optional): Other keyword argument to be used to 45 | blending imgs and labels in a mini-batch. 46 | 47 | Returns: 48 | mixed_imgs (torch.Tensor): Blending images, float tensor with the 49 | same shape of the input imgs. 50 | mixed_label (torch.Tensor): Blended soft labels, float tensor with 51 | the shape of (B, 1, num_classes) and all elements are in range 52 | [0, 1]. 53 | """ 54 | one_hot_label = one_hot(label, num_classes=self.num_classes, on_value=self.on_value, off_value=self.off_value, device=label.device) 55 | 56 | mixed_imgs, mixed_label = self.do_blending(imgs, one_hot_label, 57 | **kwargs) 58 | 59 | return mixed_imgs, mixed_label 60 | 61 | 62 | class MixupBlending(BaseMiniBatchBlending): 63 | """Implementing Mixup in a mini-batch. 64 | 65 | This module is proposed in `mixup: Beyond Empirical Risk Minimization 66 | `_. 67 | Code Reference https://github.com/open-mmlab/mmclassification/blob/master/mmcls/models/utils/mixup.py # noqa 68 | 69 | Args: 70 | num_classes (int): The number of classes. 71 | alpha (float): Parameters for Beta distribution. 72 | """ 73 | 74 | def __init__(self, num_classes, alpha=.2, smoothing=0.): 75 | super().__init__(num_classes=num_classes, smoothing=smoothing) 76 | self.beta = Beta(alpha, alpha) 77 | 78 | def do_blending(self, imgs, label, **kwargs): 79 | """Blending images with mixup.""" 80 | assert len(kwargs) == 0, f'unexpected kwargs for mixup {kwargs}' 81 | 82 | lam = self.beta.sample() 83 | batch_size = imgs.size(0) 84 | rand_index = torch.randperm(batch_size) 85 | 86 | mixed_imgs = lam * imgs + (1 - lam) * imgs[rand_index, :] 87 | mixed_label = lam * label + (1 - lam) * label[rand_index, :] 88 | 89 | return mixed_imgs, mixed_label 90 | 91 | 92 | class CutmixBlending(BaseMiniBatchBlending): 93 | """Implementing Cutmix in a mini-batch. 94 | This module is proposed in `CutMix: Regularization Strategy to Train Strong 95 | Classifiers with Localizable Features `_. 96 | Code Reference https://github.com/clovaai/CutMix-PyTorch 97 | Args: 98 | num_classes (int): The number of classes. 99 | alpha (float): Parameters for Beta distribution. 100 | """ 101 | 102 | def __init__(self, num_classes, alpha=.2, smoothing=0.): 103 | super().__init__(num_classes=num_classes, smoothing=smoothing) 104 | self.beta = Beta(alpha, alpha) 105 | 106 | @staticmethod 107 | def rand_bbox(img_size, lam): 108 | """Generate a random boudning box.""" 109 | w = img_size[-1] 110 | h = img_size[-2] 111 | cut_rat = torch.sqrt(1. - lam) 112 | cut_w = torch.tensor(int(w * cut_rat)) 113 | cut_h = torch.tensor(int(h * cut_rat)) 114 | 115 | # uniform 116 | cx = torch.randint(w, (1, ))[0] 117 | cy = torch.randint(h, (1, ))[0] 118 | 119 | bbx1 = torch.clamp(cx - cut_w // 2, 0, w) 120 | bby1 = torch.clamp(cy - cut_h // 2, 0, h) 121 | bbx2 = torch.clamp(cx + cut_w // 2, 0, w) 122 | bby2 = torch.clamp(cy + cut_h // 2, 0, h) 123 | 124 | return bbx1, bby1, bbx2, bby2 125 | 126 | def do_blending(self, imgs, label, **kwargs): 127 | """Blending images with cutmix.""" 128 | assert len(kwargs) == 0, f'unexpected kwargs for cutmix {kwargs}' 129 | 130 | batch_size = imgs.size(0) 131 | rand_index = torch.randperm(batch_size) 132 | lam = self.beta.sample() 133 | 134 | bbx1, bby1, bbx2, bby2 = self.rand_bbox(imgs.size(), lam) 135 | imgs[:, ..., bby1:bby2, bbx1:bbx2] = imgs[rand_index, ..., bby1:bby2, 136 | bbx1:bbx2] 137 | lam = 1 - (1.0 * (bbx2 - bbx1) * (bby2 - bby1) / 138 | (imgs.size()[-1] * imgs.size()[-2])) 139 | 140 | label = lam * label + (1 - lam) * label[rand_index, :] 141 | 142 | return imgs, label 143 | 144 | 145 | class LabelSmoothing(BaseMiniBatchBlending): 146 | def do_blending(self, imgs, label, **kwargs): 147 | return imgs, label 148 | 149 | 150 | class CutmixMixupBlending(BaseMiniBatchBlending): 151 | def __init__(self, num_classes=400, smoothing=0.1, mixup_alpha=.8, cutmix_alpha=1, switch_prob=0.5): 152 | super().__init__(num_classes=num_classes, smoothing=smoothing) 153 | self.mixup_beta = Beta(mixup_alpha, mixup_alpha) 154 | self.cutmix_beta = Beta(cutmix_alpha, cutmix_alpha) 155 | self.switch_prob = switch_prob 156 | 157 | @staticmethod 158 | def rand_bbox(img_size, lam): 159 | """Generate a random boudning box.""" 160 | w = img_size[-1] 161 | h = img_size[-2] 162 | cut_rat = torch.sqrt(1. - lam) 163 | cut_w = torch.tensor(int(w * cut_rat)) 164 | cut_h = torch.tensor(int(h * cut_rat)) 165 | 166 | # uniform 167 | cx = torch.randint(w, (1, ))[0] 168 | cy = torch.randint(h, (1, ))[0] 169 | 170 | bbx1 = torch.clamp(cx - cut_w // 2, 0, w) 171 | bby1 = torch.clamp(cy - cut_h // 2, 0, h) 172 | bbx2 = torch.clamp(cx + cut_w // 2, 0, w) 173 | bby2 = torch.clamp(cy + cut_h // 2, 0, h) 174 | 175 | return bbx1, bby1, bbx2, bby2 176 | 177 | def do_cutmix(self, imgs, label, **kwargs): 178 | """Blending images with cutmix.""" 179 | assert len(kwargs) == 0, f'unexpected kwargs for cutmix {kwargs}' 180 | 181 | batch_size = imgs.size(0) 182 | rand_index = torch.randperm(batch_size) 183 | lam = self.cutmix_beta.sample() 184 | 185 | bbx1, bby1, bbx2, bby2 = self.rand_bbox(imgs.size(), lam) 186 | imgs[:, ..., bby1:bby2, bbx1:bbx2] = imgs[rand_index, ..., bby1:bby2, 187 | bbx1:bbx2] 188 | lam = 1 - (1.0 * (bbx2 - bbx1) * (bby2 - bby1) / 189 | (imgs.size()[-1] * imgs.size()[-2])) 190 | 191 | label = lam * label + (1 - lam) * label[rand_index, :] 192 | return imgs, label 193 | 194 | def do_mixup(self, imgs, label, **kwargs): 195 | """Blending images with mixup.""" 196 | assert len(kwargs) == 0, f'unexpected kwargs for mixup {kwargs}' 197 | 198 | lam = self.mixup_beta.sample() 199 | batch_size = imgs.size(0) 200 | rand_index = torch.randperm(batch_size) 201 | 202 | mixed_imgs = lam * imgs + (1 - lam) * imgs[rand_index, :] 203 | mixed_label = lam * label + (1 - lam) * label[rand_index, :] 204 | 205 | return mixed_imgs, mixed_label 206 | 207 | def do_blending(self, imgs, label, **kwargs): 208 | """Blending images with MViT style. Cutmix for half for mixup for the other half.""" 209 | assert len(kwargs) == 0, f'unexpected kwargs for cutmix_half_mixup {kwargs}' 210 | 211 | if np.random.rand() < self.switch_prob : 212 | return self.do_cutmix(imgs, label) 213 | else: 214 | return self.do_mixup(imgs, label) 215 | -------------------------------------------------------------------------------- /clip/model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from collections import OrderedDict 3 | from typing import Tuple, Union 4 | from timm.models.layers import trunc_normal_ 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | from einops import rearrange 10 | from torch.utils.checkpoint import checkpoint_sequential 11 | import math 12 | import clip 13 | 14 | 15 | def drop_path(x, drop_prob: float = 0., training: bool = False): 16 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 17 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 18 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 19 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 20 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 21 | 'survival rate' as the argument. 22 | """ 23 | if drop_prob == 0. or not training: 24 | return x 25 | keep_prob = 1 - drop_prob 26 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 27 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 28 | random_tensor.floor_() # binarize 29 | output = x.div(keep_prob) * random_tensor 30 | return output 31 | 32 | class DropPath(nn.Module): 33 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 34 | """ 35 | def __init__(self, drop_prob=None): 36 | super(DropPath, self).__init__() 37 | self.drop_prob = drop_prob 38 | 39 | def forward(self, x): 40 | return drop_path(x, self.drop_prob, self.training) 41 | 42 | class LayerNorm(nn.LayerNorm): 43 | """Subclass torch's LayerNorm to handle fp16.""" 44 | 45 | def forward(self, x: torch.Tensor): 46 | # orig_type = x.dtype 47 | # ret = super().forward(x.type(torch.float32)) 48 | # return ret.type(orig_type) 49 | return super().forward(x) 50 | 51 | class QuickGELU(nn.Module): 52 | def forward(self, x: torch.Tensor): 53 | return x * torch.sigmoid(1.702 * x) 54 | 55 | class ResidualAttentionBlock(nn.Module): 56 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, ): 57 | super().__init__() 58 | 59 | self.attn = nn.MultiheadAttention(d_model, n_head,) 60 | self.ln_1 = LayerNorm(d_model) 61 | 62 | self.mlp = nn.Sequential(OrderedDict([ 63 | ("c_fc", nn.Linear(d_model, d_model * 4)), 64 | ("gelu", QuickGELU()), 65 | ("c_proj", nn.Linear(d_model * 4, d_model)) 66 | ])) 67 | self.ln_2 = LayerNorm(d_model) 68 | self.attn_mask = attn_mask 69 | 70 | def attention(self, x: torch.Tensor): 71 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 72 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 73 | 74 | def forward(self, x: torch.Tensor): 75 | x = x + self.attention(self.ln_1(x)) 76 | x = x + self.mlp(self.ln_2(x)) 77 | return x 78 | 79 | class Transformer(nn.Module): 80 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 81 | super().__init__() 82 | self.width = width 83 | self.layers = layers 84 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 85 | 86 | def forward(self, x: torch.Tensor): 87 | return self.resblocks(x) 88 | 89 | class VisionTransformer(nn.Module): 90 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 91 | super().__init__() 92 | self.input_resolution = input_resolution 93 | self.output_dim = output_dim 94 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 95 | 96 | scale = width ** -0.5 97 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 98 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 99 | self.ln_pre = LayerNorm(width) 100 | 101 | self.transformer = Transformer(width, layers, heads) 102 | 103 | self.ln_post = LayerNorm(width) 104 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 105 | 106 | def forward(self, x: torch.Tensor): 107 | x = self.conv1(x) # shape = [*, width, grid, grid] 108 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 109 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 110 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 111 | x = x + self.positional_embedding.to(x.dtype) 112 | x = self.ln_pre(x) 113 | 114 | x = x.permute(1, 0, 2) # NLD -> LND 115 | x = self.transformer(x) 116 | x = x.permute(1, 0, 2) # LND -> NLD 117 | 118 | x = self.ln_post(x[:, 0, :]) 119 | 120 | if self.proj is not None: 121 | x = x @ self.proj 122 | return x 123 | 124 | class CLIP(nn.Module): 125 | def __init__(self, 126 | embed_dim: int, 127 | # vision 128 | image_resolution: int, 129 | vision_layers: Union[Tuple[int, int, int, int], int], 130 | vision_width: int, 131 | vision_patch_size: int, 132 | # text 133 | context_length: int, 134 | vocab_size: int, 135 | transformer_width: int, 136 | transformer_heads: int, 137 | transformer_layers: int 138 | ): 139 | super().__init__() 140 | 141 | self.context_length = context_length 142 | 143 | # vision_heads = vision_width // 64 144 | # self.visual = VisionTransformer( 145 | # input_resolution=image_resolution, 146 | # patch_size=vision_patch_size, 147 | # width=vision_width, 148 | # layers=vision_layers, 149 | # heads=vision_heads, 150 | # output_dim=embed_dim 151 | # ) 152 | 153 | # self.transformer = Transformer( 154 | # width=transformer_width, 155 | # layers=transformer_layers, 156 | # heads=transformer_heads, 157 | # attn_mask=self.build_attention_mask() 158 | # ) 159 | 160 | # self.vocab_size = vocab_size 161 | # self.token_embedding = nn.Embedding(vocab_size, transformer_width) 162 | # self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 163 | # self.ln_final = LayerNorm(transformer_width) 164 | 165 | # self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 166 | # self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 167 | 168 | # self.initialize_parameters() 169 | 170 | def initialize_parameters(self): 171 | # import pdb;pdb.set_trace() 172 | # nn.init.normal_(self.anomaly_head.fc_1.weight, std=0.0255) 173 | # nn.init.normal_(self.cluster_head.fc_1.weight, std=0.0255) 174 | 175 | nn.init.normal_(self.token_embedding.weight, std=0.02) 176 | nn.init.normal_(self.positional_embedding, std=0.01) 177 | 178 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 179 | attn_std = self.transformer.width ** -0.5 180 | fc_std = (2 * self.transformer.width) ** -0.5 181 | for block in self.transformer.resblocks: 182 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 183 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 184 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 185 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 186 | 187 | if self.text_projection is not None: 188 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 189 | 190 | def build_attention_mask(self): 191 | # lazily create causal attention mask, with full attention between the vision tokens 192 | # pytorch uses additive attention mask; fill with -inf 193 | mask = torch.empty(self.context_length, self.context_length) 194 | mask.fill_(float("-inf")) 195 | mask.triu_(1) # zero out the lower diagonal 196 | return mask 197 | 198 | @property 199 | def dtype(self): 200 | return self.visual.conv1.weight.dtype 201 | 202 | def encode_image(self, image): 203 | return self.visual(image.type(self.dtype)) 204 | 205 | def encode_text(self, text): 206 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 207 | 208 | x = x + self.positional_embedding.type(self.dtype) 209 | x = x.permute(1, 0, 2) # NLD -> LND 210 | x = self.transformer(x) 211 | x = x.permute(1, 0, 2) # LND -> NLD 212 | x = self.ln_final(x).type(self.dtype) 213 | 214 | # x.shape = [batch_size, n_ctx, transformer.width] 215 | # take features from the eot embedding (eot_token is the highest number in each sequence) 216 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 217 | 218 | return x 219 | 220 | def forward(self, image, text): 221 | image_features = self.encode_image(image) 222 | text_features = self.encode_text(text) 223 | 224 | # normalized features 225 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 226 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 227 | 228 | # cosine similarity as logits 229 | logit_scale = self.logit_scale.exp() 230 | logits_per_image = logit_scale * image_features @ text_features.t() 231 | logits_per_text = logits_per_image.t() 232 | 233 | # shape = [global_batch_size, global_batch_size] 234 | return logits_per_image, logits_per_text 235 | 236 | -------------------------------------------------------------------------------- /models/xclip.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | import torch 3 | from torch import nn 4 | import numpy as np 5 | from .mit import MultiframeIntegrationTransformer 6 | from .prompt import VideoSpecificPrompt 7 | from .cct import CrossFrameCommunicationTransformer 8 | import sys 9 | import warnings 10 | sys.path.append("../") 11 | from clip.model import CLIP,LayerNorm,Transformer 12 | import clip 13 | from clip.model import QuickGELU 14 | 15 | class XCLIP(CLIP): 16 | def __init__(self, 17 | embed_dim: int, 18 | # vision 19 | image_resolution: int, 20 | vision_layers: Union[Tuple[int, int, int, int], int], 21 | vision_width: int, 22 | vision_patch_size: int, 23 | # text 24 | context_length: int, 25 | vocab_size: int, 26 | transformer_width: int, 27 | transformer_heads: int, 28 | transformer_layers: int, 29 | # video 30 | T=8, 31 | droppath=0., 32 | mit_layers=1, 33 | # prompt 34 | prompts_alpha=1e-4, 35 | prompts_layers=1, 36 | # other 37 | use_cache=True, 38 | use_checkpoint=False, 39 | frozen_backbone=True, 40 | ): 41 | super().__init__( 42 | embed_dim, 43 | image_resolution, vision_layers, vision_width, vision_patch_size, 44 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 45 | ) 46 | 47 | self.prompts_generator = VideoSpecificPrompt(layers=prompts_layers, embed_dim=embed_dim, alpha=prompts_alpha,) 48 | self.use_cache=use_cache 49 | self.mit = MultiframeIntegrationTransformer(T=T, embed_dim=embed_dim, layers=mit_layers,) 50 | 51 | dpr = [x.item() for x in torch.linspace(0, droppath, vision_layers)] if droppath > 0. else None 52 | 53 | vision_heads = vision_width // 64 54 | self.visual = CrossFrameCommunicationTransformer( 55 | input_resolution=image_resolution, 56 | patch_size=vision_patch_size, 57 | width=vision_width, 58 | layers=vision_layers, 59 | heads=vision_heads, 60 | output_dim=embed_dim, 61 | droppath=dpr, 62 | T=T, 63 | use_checkpoint=use_checkpoint, 64 | ) 65 | 66 | self.transformer = Transformer( 67 | width=transformer_width, 68 | layers=transformer_layers, 69 | heads=transformer_heads, 70 | attn_mask=self.build_attention_mask() 71 | ) 72 | self.vocab_size = vocab_size 73 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 74 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 75 | self.ln_final = LayerNorm(transformer_width) 76 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 77 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 78 | 79 | self.cache_text_features = None 80 | self.prompts_visual_ln = LayerNorm(vision_width) 81 | self.prompts_visual_proj = nn.Parameter(torch.randn(vision_width, embed_dim)) 82 | # for UDA 83 | self.head_video = nn.Linear(embed_dim, embed_dim) 84 | self.u_head_video = nn.Linear(embed_dim, embed_dim) 85 | 86 | self.initialize_parameters() 87 | 88 | @torch.jit.ignore 89 | def no_weight_decay_keywords(self): 90 | return {'positional_embedding'} 91 | 92 | def encode_image(self, image): 93 | return self.visual(image) 94 | 95 | def encode_text(self, text): 96 | x = self.token_embedding(text) 97 | eos_indx = text.argmax(dim=-1) 98 | K, N1, C = x.shape 99 | 100 | x = x + self.positional_embedding 101 | x = x.permute(1, 0, 2) # NLD -> LND 102 | x = self.transformer(x) 103 | x = x.permute(1, 0, 2) # LND -> NLD 104 | x = self.ln_final(x) 105 | # x.shape = [batch_size, n_ctx, transformer.width] 106 | # take features from the eot embedding (eot_token is the highest number in each sequence) 107 | x = x[torch.arange(x.shape[0]), eos_indx] @ self.text_projection 108 | x = x.reshape(K, -1) 109 | return x 110 | 111 | def encode_video(self, image): 112 | b,t,c,h,w = image.size() 113 | image = image.reshape(-1,c,h,w) 114 | 115 | cls_features, img_features = self.encode_image(image) 116 | img_features = self.prompts_visual_ln(img_features) 117 | img_features = img_features @ self.prompts_visual_proj 118 | 119 | cls_features = cls_features.view(b, t, -1) 120 | img_features = img_features.view(b,t,-1,cls_features.shape[-1]) 121 | 122 | video_features = self.mit(cls_features) 123 | 124 | return video_features, img_features 125 | 126 | def cache_text(self, text, train_flag): 127 | self.eval() 128 | with torch.no_grad(): 129 | if self.cache_text_features is None: 130 | self.cache_text_features = self.encode_text(text) 131 | if train_flag: 132 | self.train() 133 | return self.cache_text_features 134 | 135 | def uda(self, video_feature, text_feature, train_flag): 136 | v_fea = self.head_video(video_feature) 137 | v_fea_u = self.u_head_video(video_feature) 138 | 139 | v_fea = v_fea / v_fea.norm(dim=-1, keepdim=True) 140 | v_fea_u = v_fea_u / v_fea_u.norm(dim=-1, keepdim=True) 141 | t_fea = text_feature / text_feature.norm(dim=-1, keepdim=True) 142 | 143 | if train_flag: 144 | v_fea_u_nograd = self.u_head_video(video_feature.detach()) 145 | t_fea_nograd = t_fea.detach() 146 | return video_feature, v_fea, v_fea_u, t_fea, v_fea_u_nograd, v_fea_u_nograd, t_fea_nograd 147 | else: 148 | return video_feature, v_fea, v_fea_u, t_fea 149 | 150 | def forward(self, image, text): 151 | b = image.shape[0] 152 | 153 | video_features, img_features = self.encode_video(image) 154 | 155 | img_features = img_features.mean(dim=1, keepdim=False) 156 | 157 | if self.use_cache: 158 | text_features = self.cache_text(text, self.training) 159 | else: 160 | text_features = self.encode_text(text) 161 | 162 | text_features = text_features.unsqueeze(0).expand(b, -1, -1) 163 | text_features = text_features + self.prompts_generator(text_features, img_features) 164 | 165 | logit_scale = self.logit_scale.exp() 166 | 167 | if self.training: 168 | _, v_features, v_features_u, t_features, \ 169 | _, v_features_u_n, t_features_n = self.uda(video_features, text_features, self.training) 170 | 171 | logits = torch.einsum("bd,bkd->bk", v_features, logit_scale * t_features) 172 | logits_u = torch.einsum("bd,bkd->bk", v_features_u, logit_scale * t_features) 173 | logits_u_n = torch.einsum("bd,bkd->bk", v_features_u_n, logit_scale * t_features) 174 | 175 | outputs= { 176 | "y": logits, 177 | "y_cluster_all": logits_u, 178 | "feature_v": video_features, 179 | "y_cluster_all_nograd": logits_u_n 180 | } 181 | return outputs 182 | else: 183 | video_features, v_features, v_features_u, t_features= self.uda(video_features, text_features, self.training) 184 | logits = torch.einsum("bd,bkd->bk", v_features, logit_scale * t_features) 185 | logits_u = torch.einsum("bd,bkd->bk", v_features_u, logit_scale * t_features) 186 | 187 | outputs = { 188 | "y": logits, 189 | "y_cluster_all": logits_u, 190 | "feature_v": video_features, 191 | "feature_t": text_features, 192 | } 193 | return outputs 194 | 195 | 196 | 197 | def build_model(state_dict: dict, T=8, droppath=0., use_checkpoint=False, logger=None, prompts_alpha=1e-1, prompts_layers=2, use_cache=True, mit_layers=4,): 198 | vit = "visual.proj" in state_dict 199 | 200 | if vit: 201 | vision_width = state_dict["visual.conv1.weight"].shape[0] 202 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 203 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 204 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 205 | image_resolution = vision_patch_size * grid_size 206 | else: 207 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 208 | vision_layers = tuple(counts) 209 | 210 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 211 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 212 | vision_patch_size = None 213 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 214 | image_resolution = output_width * 32 215 | 216 | embed_dim = state_dict["text_projection"].shape[1] 217 | context_length = state_dict["positional_embedding"].shape[0] 218 | vocab_size = state_dict["token_embedding.weight"].shape[0] 219 | transformer_width = state_dict["ln_final.weight"].shape[0] 220 | transformer_heads = transformer_width // 64 221 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 222 | 223 | model = XCLIP( 224 | embed_dim, 225 | image_resolution, vision_layers, vision_width, vision_patch_size, 226 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, 227 | T=T, droppath=droppath, mit_layers=mit_layers, 228 | prompts_alpha=prompts_alpha, prompts_layers=prompts_layers, 229 | use_checkpoint=use_checkpoint, use_cache=use_cache, 230 | ) 231 | 232 | for key in ["input_resolution", "context_length", "vocab_size", "mit.positional_embedding"]: 233 | if key in state_dict: 234 | del state_dict[key] 235 | 236 | msg = model.load_state_dict(state_dict,strict=False) 237 | logger.info(f"load pretrained CLIP: {msg}") 238 | 239 | return model.eval() 240 | 241 | 242 | def load(model_path, name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 243 | jit=True, T=8, droppath=0., use_checkpoint=False, logger=None, use_cache=True, prompts_alpha=1e-1, prompts_layers=2, mit_layers=1, 244 | ): 245 | if model_path is None: 246 | model_path = clip._download(clip._MODELS[name]) 247 | 248 | try: 249 | # loading JIT archive 250 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 251 | state_dict = None 252 | except RuntimeError: 253 | # loading saved state dict 254 | if jit: 255 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 256 | jit = False 257 | state_dict = torch.load(model_path, map_location="cpu") 258 | 259 | model = build_model(state_dict['model'] or model.state_dict(), T=T, droppath=droppath, 260 | use_checkpoint=use_checkpoint, logger=logger, 261 | prompts_alpha=prompts_alpha, 262 | prompts_layers=prompts_layers, 263 | use_cache=use_cache, 264 | mit_layers=mit_layers, 265 | ) 266 | if str(device) == "cpu": 267 | model.float() 268 | return model, model.state_dict() 269 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.backends.cudnn as cudnn 5 | import torch.distributed as dist 6 | import argparse 7 | import datetime 8 | import shutil 9 | from pathlib import Path 10 | from utils.config import get_config 11 | from utils.optimizer import build_optimizer, build_scheduler 12 | from utils.tools import AverageMeter, reduce_tensor, epoch_saving, load_checkpoint, generate_text, auto_resume_helper, evaluate_result 13 | from utils.cluster import ClusterLoss, Normalize, BCE, NCLMemory, PairEnum 14 | from datasets.build import build_dataloader 15 | from utils.logger import create_logger 16 | import time 17 | import numpy as np 18 | import random 19 | import mmcv 20 | from apex import amp 21 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 22 | from datasets.blending import CutmixMixupBlending 23 | from utils.config import get_config 24 | from models import xclip 25 | from einops import rearrange 26 | import torch.nn.functional as F 27 | 28 | def parse_option(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--config', '-cfg', required=True, type=str, default='configs/k400/32_8.yaml') 31 | parser.add_argument( 32 | "--opts", 33 | help="Modify config options by adding 'KEY VALUE' pairs. ", 34 | default=None, 35 | nargs='+', 36 | ) 37 | parser.add_argument('--output', type=str, default="exp") 38 | parser.add_argument('--resume', type=str) 39 | parser.add_argument('--pretrained', type=str) 40 | parser.add_argument('--only_test', action='store_true') 41 | parser.add_argument('--batch-size', type=int) 42 | parser.add_argument('--accumulation-steps', type=int) 43 | # model parameters 44 | parser.add_argument("--local_rank", type=int, default=-1, help='local rank for DistributedDataParallel') 45 | parser.add_argument('--w-smooth', default=0.01, type=float, help='weight of smooth loss') 46 | parser.add_argument('--w-sparse', default=0.001, type=float, help='weight of sparse loss') 47 | 48 | args = parser.parse_args() 49 | 50 | config = get_config(args) 51 | 52 | return args, config 53 | 54 | 55 | def main(config): 56 | train_data, val_data, test_data, train_loader, val_loader, test_loader, val_loader_train,_ = build_dataloader(logger, config) 57 | model, _ = xclip.load(config.MODEL.PRETRAINED, config.MODEL.ARCH, 58 | device="cpu", jit=False, 59 | T=config.DATA.NUM_FRAMES, 60 | droppath=config.MODEL.DROP_PATH_RATE, 61 | use_checkpoint=config.TRAIN.USE_CHECKPOINT, 62 | use_cache=config.MODEL.FIX_TEXT, 63 | logger=logger, 64 | ) 65 | model = model.cuda() 66 | 67 | 68 | optimizer, _ = build_optimizer(config, model) 69 | lr_scheduler = build_scheduler(config, optimizer, len(train_loader)) 70 | if config.TRAIN.OPT_LEVEL != 'O0': 71 | model, optimizer = amp.initialize(models=model, optimizers=optimizer, opt_level=config.TRAIN.OPT_LEVEL) 72 | 73 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False, find_unused_parameters=False) 74 | 75 | start_epoch, best_epoch, max_auc = 0, 0, 0.0 76 | 77 | if config.TRAIN.AUTO_RESUME: 78 | resume_file = auto_resume_helper(config.OUTPUT) 79 | if resume_file: 80 | config.defrost() 81 | config.MODEL.RESUME = resume_file 82 | config.freeze() 83 | logger.info(f'auto resuming from {resume_file}') 84 | else: 85 | logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') 86 | 87 | if config.MODEL.RESUME: 88 | start_epoch, max_accuracy = load_checkpoint(config, model.module, optimizer, lr_scheduler, logger) 89 | 90 | text_labels = generate_text(train_data) 91 | 92 | if config.TEST.ONLY_TEST: 93 | if not os.path.isdir(config.MODEL.PRETRAINED): 94 | #evaluate on val set 95 | out_path = config.MODEL.PRETRAINED.replace('pth','pkl') 96 | if os.path.exists(out_path): 97 | scores_dict = mmcv.load(out_path) 98 | else: 99 | scores_dict = validate(test_loader, text_labels, model, config, out_path) 100 | 101 | tmp_dict = {} 102 | for v_name in scores_dict["prd"].keys(): 103 | p_scores = np.array(scores_dict["prd"][v_name]).copy() 104 | if p_scores.shape[0] == 1: 105 | # 1,32,2 106 | tmp_dict[v_name] = [p_scores[0, :, 1]] 107 | else: 108 | # T,1,2 109 | tmp_dict[v_name] = [p_scores[:, 0, 1]] 110 | 111 | auc_all, auc_ano = evaluate_result(tmp_dict, config.DATA.VAL_FILE) 112 | 113 | logger.info(f"AUC@all/ano of version {out_path.split('/')[-2]} on epoch {out_path.split('/')[-1].split('_')[-1][:-4]} : {auc_all:.4f}({auc_ano:.4f})") 114 | return 115 | else: 116 | for epoch in range(config.TRAIN.EPOCHS): 117 | out_path = os.path.join(config.MODEL.PRETRAINED, 'ckpt_epoch_' + str(epoch) + '.pkl') 118 | scores_dict = validate(test_loader, text_labels, model, config, out_path) 119 | tmp_dict = {} 120 | for v_name in scores_dict["cls"].keys(): 121 | tmp_dict[v_name] = [np.array(scores_dict["prd"][v_name])[:,0]] # 1,32,2 + 122 | auc_all, auc_ano = evaluate_result(tmp_dict, config.DATA.VAL_FILE) 123 | is_best = auc_all > max_auc 124 | if is_best: 125 | best_epoch = epoch 126 | max_auc = max(max_auc, auc_all) 127 | logger.info(f"Auc on epoch {epoch}: {auc_all:.4f}({auc_ano:.4f})") 128 | logger.info(f'Max AUC@all {best_epoch}/{epoch} : {max_auc:.4f}') 129 | 130 | for epoch in range(start_epoch, config.TRAIN.EPOCHS): 131 | train_loader.sampler.set_epoch(epoch) 132 | train_one_epoch(epoch, model, optimizer, lr_scheduler, train_loader, text_labels, config) 133 | 134 | if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): 135 | epoch_saving(config, epoch, model.module, max_auc, optimizer, lr_scheduler,_,_, logger, config.OUTPUT, is_best) 136 | 137 | def train_one_epoch(epoch, model, optimizer, lr_scheduler, train_loader, text_labels, config, data_dict=None): 138 | model.train() 139 | 140 | optimizer.zero_grad() 141 | 142 | num_steps = len(train_loader) 143 | batch_time = AverageMeter() 144 | tot_loss_meter = AverageMeter() 145 | mil_loss_meter = AverageMeter() 146 | sm_loss_meter = AverageMeter() 147 | sp_loss_meter = AverageMeter() 148 | 149 | start = time.time() 150 | end = time.time() 151 | 152 | texts = text_labels.cuda(non_blocking=True) 153 | 154 | for idx, batch_data in enumerate(train_loader): 155 | images = batch_data["imgs"].cuda(non_blocking=True)[:,:1] 156 | label_id = batch_data["label"].cuda(non_blocking=True)[:,:1] 157 | label_id = label_id.reshape(-1) 158 | bz = images.shape[0] 159 | a_aug = images.shape[1] 160 | n_clips = images.shape[2] 161 | 162 | images = rearrange(images, 'b a k c t h w -> (b a k) t c h w')# bz*num_aug*num_clips,num_frames,h,w 163 | 164 | if texts.shape[0] == 1: 165 | texts = texts.view(1, -1) 166 | 167 | output = model(images, texts) 168 | # mil loss on max scores among bags, view instance of max scores as labeled data 169 | logits = rearrange(output['y'], '(b a k) c -> (b a) k c', b=bz, a=a_aug,) 170 | 171 | scores = F.softmax(logits, dim=-1) 172 | scores_ano = scores[:,:,1] 173 | scores_nor = scores[:,:,0] 174 | max_prob_ano, max_ind = torch.max(scores_ano, dim=-1) 175 | max_prob_nor, _ = torch.max(scores_nor, dim=-1) 176 | 177 | logits_video = torch.gather(logits, 1, max_ind[:, None, None].repeat((1, 1, 2))).squeeze(1) 178 | max_prob_video, _ = torch.max(torch.gather(scores, 1, max_ind[:, None, None].repeat((1, 1, 2))).squeeze(1), 179 | dim=-1) 180 | labels_binary = label_id > 0 181 | 182 | # MIL loss 183 | loss_mil = F.cross_entropy(logits_video, labels_binary.long(), reduction='none') 184 | loss_mil = loss_mil * max_prob_video 185 | loss_mil = loss_mil.mean() 186 | 187 | scores_all = scores 188 | smoothed_scores = (scores_all[:,1:,1] - scores_all[:,:-1,1]) 189 | smoothed_loss = smoothed_scores.pow(2).sum(dim=-1).mean() 190 | 191 | sparsity_loss = scores_all[:,:,1].sum(dim=-1).mean() 192 | 193 | w_smooth = args.w_smooth 194 | w_sparse = args.w_sparse 195 | 196 | total_loss = loss_mil + smoothed_loss * w_smooth + sparsity_loss * w_sparse 197 | 198 | total_loss = total_loss / config.TRAIN.ACCUMULATION_STEPS 199 | 200 | if config.TRAIN.ACCUMULATION_STEPS == 1: 201 | optimizer.zero_grad() 202 | if config.TRAIN.OPT_LEVEL != 'O0': 203 | with amp.scale_loss(total_loss, optimizer) as scaled_loss: 204 | scaled_loss.backward() 205 | else: 206 | total_loss.backward() 207 | if config.TRAIN.ACCUMULATION_STEPS > 1: 208 | if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: 209 | optimizer.step() 210 | optimizer.zero_grad() 211 | lr_scheduler.step_update(epoch * num_steps + idx) 212 | else: 213 | optimizer.step() 214 | lr_scheduler.step_update(epoch * num_steps + idx) 215 | 216 | torch.cuda.synchronize() 217 | 218 | tot_loss_meter.update(total_loss.item(), len(label_id)) 219 | mil_loss_meter.update(loss_mil.item(), len(label_id)) 220 | sm_loss_meter.update((smoothed_loss * w_smooth).item(), len(label_id)) 221 | sp_loss_meter.update((sparsity_loss * w_sparse).item(), len(label_id)) 222 | batch_time.update(time.time() - end) 223 | end = time.time() 224 | 225 | if idx % config.PRINT_FREQ == 0: 226 | lr = optimizer.param_groups[0]['lr'] 227 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 228 | etas = batch_time.avg * (num_steps - idx) 229 | logger.info( 230 | f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' 231 | f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.9f}\t' 232 | f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' 233 | f'tot {tot_loss_meter.val:.4f} ({tot_loss_meter.avg:.4f})\t' 234 | f'mil {mil_loss_meter.val:.4f} ({mil_loss_meter.avg:.4f})\t' 235 | f'sm {sm_loss_meter.val:.4f} ({sm_loss_meter.avg:.4f})\t' 236 | f'sp {sp_loss_meter.val:.4f} ({sp_loss_meter.avg:.4f})\t' 237 | f'mem {memory_used:.0f}MB') 238 | 239 | epoch_time = time.time() - start 240 | logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") 241 | 242 | 243 | @torch.no_grad() 244 | def validate(data_loader, text_labels, model, config, out_path): 245 | model.eval() 246 | vid_list = [] 247 | 248 | anno_file = config.DATA.VAL_FILE 249 | 250 | with open(anno_file, 'r') as fin: 251 | for line in fin: 252 | line_split = line.strip().split() 253 | filename = line_split[0].split('/')[-1] 254 | vid_list.append(filename) 255 | 256 | with torch.no_grad(): 257 | text_inputs = text_labels.cuda() 258 | logger.info(f"{config.TEST.NUM_CLIP * config.TEST.NUM_CROP} views inference") 259 | scores_dict = dict() 260 | scores_dict['prd'] = dict() 261 | for idx, batch_data in enumerate(data_loader): 262 | _image = batch_data["imgs"] 263 | label_id = batch_data["label"] 264 | label_id = label_id.reshape(-1) 265 | b, n, c, t, h, w = _image.size() 266 | _image = rearrange(_image, 'b n c t h w -> (b n) t c h w') 267 | output = model(_image, text_inputs) 268 | 269 | scores_prd = F.softmax(output['y'], dim=-1) 270 | scores_prd = rearrange(scores_prd, '(b n) c -> b n c', b=b) 271 | scores_np_prd = scores_prd.cpu().data.numpy() 272 | 273 | for ind in range(scores_np_prd.shape[0]): 274 | v_name = vid_list[batch_data["vid"][ind]] 275 | if v_name not in scores_dict['prd']: 276 | scores_dict['prd'][v_name] = [] 277 | scores_dict['prd'][v_name].append(scores_np_prd[ind]) 278 | if idx % 100 == 0 and len(data_loader) >= 100: 279 | logger.info( 280 | f'Test: [{idx}/{len(data_loader)}]\t' 281 | ) 282 | tmp_dict = {} 283 | for v_name in scores_dict["prd"].keys(): 284 | p_scores = np.array(scores_dict["prd"][v_name]).copy() 285 | if p_scores.shape[0] == 1: 286 | # 1,T,2 287 | tmp_dict[v_name] = [p_scores[0, :, 1]] 288 | else: 289 | # T,1,2 290 | tmp_dict[v_name] = [p_scores[:, 0, 1]] 291 | 292 | auc_all_p, auc_ano_p = evaluate_result(tmp_dict, config.DATA.VAL_FILE) 293 | 294 | logger.info( 295 | f'AUC: [{auc_all_p:.3f}/{auc_ano_p:.3f}]\t' 296 | ) 297 | logger.info(f'writing results to {out_path}') 298 | mmcv.dump(scores_dict, out_path) 299 | return scores_dict 300 | 301 | 302 | if __name__ == '__main__': 303 | # prepare config 304 | args, config = parse_option() 305 | 306 | # init_distributed 307 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 308 | rank = int(os.environ["RANK"]) 309 | world_size = int(os.environ['WORLD_SIZE']) 310 | print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") 311 | else: 312 | rank = -1 313 | world_size = -1 314 | torch.cuda.set_device(args.local_rank) 315 | torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) 316 | torch.distributed.barrier(device_ids=[args.local_rank]) 317 | 318 | seed = config.SEED + dist.get_rank() 319 | torch.manual_seed(seed) 320 | np.random.seed(seed) 321 | random.seed(seed) 322 | cudnn.benchmark = True 323 | 324 | # create working_dir 325 | Path(config.OUTPUT).mkdir(parents=True, exist_ok=True) 326 | 327 | # logger 328 | logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.ARCH}") 329 | logger.info(f"working dir: {config.OUTPUT}") 330 | 331 | # save config 332 | if dist.get_rank() == 0: 333 | logger.info(config) 334 | shutil.copy(args.config, config.OUTPUT) 335 | 336 | main(config) 337 | -------------------------------------------------------------------------------- /labels/TAD_train.txt: -------------------------------------------------------------------------------- 1 | frames/abnormal/01_Accident_001.mp4 0 224 1 2 | frames/abnormal/01_Accident_002.mp4 0 220 1 3 | frames/abnormal/01_Accident_003.mp4 0 199 1 4 | frames/abnormal/01_Accident_004.mp4 0 190 1 5 | frames/abnormal/01_Accident_005.mp4 0 318 1 6 | frames/abnormal/01_Accident_007.mp4 0 327 1 7 | frames/abnormal/01_Accident_009.mp4 0 183 1 8 | frames/abnormal/01_Accident_010.mp4 0 210 1 9 | frames/abnormal/01_Accident_013.mp4 0 299 1 10 | frames/abnormal/01_Accident_015.mp4 0 200 1 11 | frames/abnormal/01_Accident_016.mp4 0 298 1 12 | frames/abnormal/01_Accident_017.mp4 0 274 1 13 | frames/abnormal/01_Accident_018.mp4 0 296 1 14 | frames/abnormal/01_Accident_019.mp4 0 231 1 15 | frames/abnormal/01_Accident_021.mp4 0 296 1 16 | frames/abnormal/01_Accident_022.mp4 0 911 1 17 | frames/abnormal/01_Accident_024.mp4 0 475 1 18 | frames/abnormal/01_Accident_026.mp4 0 430 1 19 | frames/abnormal/01_Accident_027.mp4 0 221 1 20 | frames/abnormal/01_Accident_028.mp4 0 290 1 21 | frames/abnormal/01_Accident_029.mp4 0 315 1 22 | frames/abnormal/01_Accident_030.mp4 0 190 1 23 | frames/abnormal/01_Accident_031.mp4 0 186 1 24 | frames/abnormal/01_Accident_032.mp4 0 158 1 25 | frames/abnormal/01_Accident_034.mp4 0 167 1 26 | frames/abnormal/01_Accident_035.mp4 0 116 1 27 | frames/abnormal/01_Accident_037.mp4 0 146 1 28 | frames/abnormal/01_Accident_040.mp4 0 278 1 29 | frames/abnormal/01_Accident_042.mp4 0 586 1 30 | frames/abnormal/01_Accident_043.mp4 0 290 1 31 | frames/abnormal/01_Accident_045.mp4 0 407 1 32 | frames/abnormal/01_Accident_046.mp4 0 510 1 33 | frames/abnormal/01_Accident_048.mp4 0 365 1 34 | frames/abnormal/01_Accident_049.mp4 0 389 1 35 | frames/abnormal/01_Accident_051.mp4 0 535 1 36 | frames/abnormal/01_Accident_052.mp4 0 316 1 37 | frames/abnormal/01_Accident_053.mp4 0 365 1 38 | frames/abnormal/01_Accident_055.mp4 0 680 1 39 | frames/abnormal/01_Accident_056.mp4 0 483 1 40 | frames/abnormal/01_Accident_057.mp4 0 266 1 41 | frames/abnormal/01_Accident_058.mp4 0 290 1 42 | frames/abnormal/01_Accident_059.mp4 0 411 1 43 | frames/abnormal/01_Accident_060.mp4 0 218 1 44 | frames/abnormal/01_Accident_061.mp4 0 580 1 45 | frames/abnormal/01_Accident_065.mp4 0 360 1 46 | frames/abnormal/01_Accident_066.mp4 0 260 1 47 | frames/abnormal/01_Accident_067.mp4 0 601 1 48 | frames/abnormal/01_Accident_068.mp4 0 327 1 49 | frames/abnormal/01_Accident_069.mp4 0 60 1 50 | frames/abnormal/01_Accident_070.mp4 0 144 1 51 | frames/abnormal/01_Accident_071.mp4 0 348 1 52 | frames/abnormal/01_Accident_072.mp4 0 39 1 53 | frames/abnormal/01_Accident_077.mp4 0 216 1 54 | frames/abnormal/01_Accident_078.mp4 0 240 1 55 | frames/abnormal/01_Accident_079.mp4 0 240 1 56 | frames/abnormal/01_Accident_080.mp4 0 156 1 57 | frames/abnormal/01_Accident_081.mp4 0 200 1 58 | frames/abnormal/01_Accident_082.mp4 0 113 1 59 | frames/abnormal/01_Accident_084.mp4 0 206 1 60 | frames/abnormal/01_Accident_085.mp4 0 240 1 61 | frames/abnormal/01_Accident_087.mp4 0 210 1 62 | frames/abnormal/01_Accident_090.mp4 0 294 1 63 | frames/abnormal/01_Accident_092.mp4 0 113 1 64 | frames/abnormal/01_Accident_093.mp4 0 330 1 65 | frames/abnormal/01_Accident_094.mp4 0 300 1 66 | frames/abnormal/01_Accident_097.mp4 0 162 1 67 | frames/abnormal/01_Accident_098.mp4 0 275 1 68 | frames/abnormal/01_Accident_099.mp4 0 176 1 69 | frames/abnormal/01_Accident_100.mp4 0 241 1 70 | frames/abnormal/01_Accident_101.mp4 0 252 1 71 | frames/abnormal/01_Accident_102.mp4 0 185 1 72 | frames/abnormal/01_Accident_104.mp4 0 127 1 73 | frames/abnormal/01_Accident_105.mp4 0 247 1 74 | frames/abnormal/01_Accident_106.mp4 0 209 1 75 | frames/abnormal/01_Accident_109.mp4 0 185 1 76 | frames/abnormal/01_Accident_110.mp4 0 149 1 77 | frames/abnormal/01_Accident_096.mp4 0 177 1 78 | frames/abnormal/01_Accident_103.mp4 0 147 1 79 | frames/abnormal/01_Accident_107.mp4 0 210 1 80 | frames/abnormal/01_Accident_108.mp4 0 172 1 81 | frames/abnormal/02_IllegalTurn_001.mp4 0 187 2 82 | frames/abnormal/02_IllegalTurn_002.mp4 0 275 2 83 | frames/abnormal/02_IllegalTurn_003.mp4 0 350 2 84 | frames/abnormal/02_IllegalTurn_005.mp4 0 507 2 85 | frames/abnormal/02_IllegalTurn_006.mp4 0 384 2 86 | frames/abnormal/02_IllegalTurn_008.mp4 0 349 2 87 | frames/abnormal/02_IllegalTurn_009.mp4 0 394 2 88 | frames/abnormal/02_IllegalTurn_010.mp4 0 849 2 89 | frames/abnormal/02_IllegalTurn_012.mp4 0 195 2 90 | frames/abnormal/02_IllegalTurn_013.mp4 0 407 2 91 | frames/abnormal/02_IllegalTurn_015.mp4 0 225 2 92 | frames/abnormal/03_IllegalOccupation_002.mp4 0 732 3 93 | frames/abnormal/03_IllegalOccupation_006.mp4 0 256 3 94 | frames/abnormal/03_IllegalOccupation_007.mp4 0 206 3 95 | frames/abnormal/03_IllegalOccupation_008.mp4 0 173 3 96 | frames/abnormal/03_IllegalOccupation_009.mp4 0 240 3 97 | frames/abnormal/03_IllegalOccupation_010.mp4 0 254 3 98 | frames/abnormal/03_IllegalOccupation_011.mp4 0 247 3 99 | frames/abnormal/03_IllegalOccupation_012.mp4 0 335 3 100 | frames/abnormal/03_IllegalOccupation_013.mp4 0 285 3 101 | frames/abnormal/03_IllegalOccupation_014.mp4 0 339 3 102 | frames/abnormal/03_IllegalOccupation_015.mp4 0 222 3 103 | frames/abnormal/03_IllegalOccupation_016.mp4 0 234 3 104 | frames/abnormal/03_IllegalOccupation_017.mp4 0 274 3 105 | frames/abnormal/03_IllegalOccupation_018.mp4 0 255 3 106 | frames/abnormal/04_Retrograde_001.mp4 0 600 4 107 | frames/abnormal/04_Retrograde_003.mp4 0 278 4 108 | frames/abnormal/04_Retrograde_005.mp4 0 349 4 109 | frames/abnormal/04_Retrograde_006.mp4 0 734 4 110 | frames/abnormal/04_Retrograde_007.mp4 0 192 4 111 | frames/abnormal/04_Retrograde_008.mp4 0 275 4 112 | frames/abnormal/05_else_001.mp4 0 304 5 113 | frames/abnormal/05_else_002.mp4 0 78 5 114 | frames/abnormal/05_else_003.mp4 0 168 5 115 | frames/abnormal/05_else_004.mp4 0 1051 5 116 | frames/abnormal/05_else_005.mp4 0 336 5 117 | frames/abnormal/05_else_006.mp4 0 210 5 118 | frames/abnormal/05_else_007.mp4 0 240 5 119 | frames/abnormal/05_else_008.mp4 0 178 5 120 | frames/abnormal/05_else_009.mp4 0 153 5 121 | frames/abnormal/05_else_011.mp4 0 270 5 122 | frames/abnormal/05_else_012.mp4 0 176 5 123 | frames/abnormal/05_else_013.mp4 0 332 5 124 | frames/abnormal/05_else_014.mp4 0 231 5 125 | frames/abnormal/05_else_015.mp4 0 430 5 126 | frames/abnormal/05_else_016.mp4 0 121 5 127 | frames/abnormal/05_else_017.mp4 0 84 5 128 | frames/abnormal/05_else_018.mp4 0 202 5 129 | frames/abnormal/05_else_019.mp4 0 269 5 130 | frames/abnormal/05_else_021.mp4 0 249 5 131 | frames/abnormal/05_else_022.mp4 0 201 5 132 | frames/abnormal/05_else_024.mp4 0 145 5 133 | frames/abnormal/05_else_025.mp4 0 150 5 134 | frames/abnormal/05_else_027.mp4 0 538 5 135 | frames/abnormal/05_else_028.mp4 0 165 5 136 | frames/abnormal/05_else_029.mp4 0 227 5 137 | frames/abnormal/05_else_033.mp4 0 393 5 138 | frames/abnormal/05_else_034.mp4 0 125 5 139 | frames/abnormal/05_else_035.mp4 0 769 5 140 | frames/abnormal/05_else_036.mp4 0 804 5 141 | frames/abnormal/05_else_039.mp4 0 460 5 142 | frames/abnormal/05_else_040.mp4 0 530 5 143 | frames/abnormal/05_else_041.mp4 0 236 5 144 | frames/abnormal/05_else_042.mp4 0 244 5 145 | frames/abnormal/05_else_043.mp4 0 544 5 146 | frames/abnormal/05_else_044.mp4 0 426 5 147 | frames/abnormal/05_else_046.mp4 0 215 5 148 | frames/abnormal/05_else_048.mp4 0 892 5 149 | frames/abnormal/05_else_050.mp4 0 358 5 150 | frames/abnormal/05_else_051.mp4 0 39 5 151 | frames/abnormal/05_else_052.mp4 0 411 5 152 | frames/abnormal/05_else_053.mp4 0 95 5 153 | frames/abnormal/06_PedestrianOnRoad_003.mp4 0 243 6 154 | frames/abnormal/06_PedestrianOnRoad_005.mp4 0 293 6 155 | frames/abnormal/06_PedestrianOnRoad_007.mp4 0 787 6 156 | frames/abnormal/06_PedestrianOnRoad_008.mp4 0 945 6 157 | frames/abnormal/06_PedestrianOnRoad_010.mp4 0 243 6 158 | frames/abnormal/06_PedestrianOnRoad_011.mp4 0 374 6 159 | frames/abnormal/06_PedestrianOnRoad_012.mp4 0 736 6 160 | frames/abnormal/06_PedestrianOnRoad_013.mp4 0 400 6 161 | frames/abnormal/06_PedestrianOnRoad_015.mp4 0 626 6 162 | frames/abnormal/06_PedestrianOnRoad_018.mp4 0 401 6 163 | frames/abnormal/06_PedestrianOnRoad_020.mp4 0 405 6 164 | frames/abnormal/06_PedestrianOnRoad_023.mp4 0 232 6 165 | frames/abnormal/06_PedestrianOnRoad_024.mp4 0 258 6 166 | frames/abnormal/06_PedestrianOnRoad_026.mp4 0 990 6 167 | frames/abnormal/07_RoadSpills_001.mp4 0 457 7 168 | frames/abnormal/07_RoadSpills_002.mp4 0 783 7 169 | frames/abnormal/07_RoadSpills_004.mp4 0 338 7 170 | frames/abnormal/07_RoadSpills_005.mp4 0 373 7 171 | frames/abnormal/07_RoadSpills_006.mp4 0 624 7 172 | frames/abnormal/07_RoadSpills_007.mp4 0 488 7 173 | frames/abnormal/07_RoadSpills_009.mp4 0 528 7 174 | frames/abnormal/07_RoadSpills_010.mp4 0 518 7 175 | frames/abnormal/07_RoadSpills_011.mp4 0 278 7 176 | frames/abnormal/07_RoadSpills_012.mp4 0 151 7 177 | frames/abnormal/07_RoadSpills_015.mp4 0 382 7 178 | frames/abnormal/07_RoadSpills_016.mp4 0 385 7 179 | frames/abnormal/07_RoadSpills_017.mp4 0 570 7 180 | frames/abnormal/07_RoadSpills_018.mp4 0 270 7 181 | frames/abnormal/07_RoadSpills_019.mp4 0 268 7 182 | frames/abnormal/07_RoadSpills_020.mp4 0 311 7 183 | frames/abnormal/07_RoadSpills_021.mp4 0 152 7 184 | frames/abnormal/07_RoadSpills_022.mp4 0 341 7 185 | frames/abnormal/07_RoadSpills_023.mp4 0 159 7 186 | frames/abnormal/07_RoadSpills_024.mp4 0 1166 7 187 | frames/abnormal/07_RoadSpills_025.mp4 0 360 7 188 | frames/abnormal/07_RoadSpills_026.mp4 0 178 7 189 | frames/abnormal/07_RoadSpills_027.mp4 0 213 7 190 | frames/abnormal/07_RoadSpills_028.mp4 0 96 7 191 | frames/normal/Normal_001.mp4 0 141 0 192 | frames/normal/Normal_002.mp4 0 64 0 193 | frames/normal/Normal_003.mp4 0 45 0 194 | frames/normal/Normal_004.mp4 0 54 0 195 | frames/normal/Normal_005.mp4 0 63 0 196 | frames/normal/Normal_006.mp4 0 65 0 197 | frames/normal/Normal_007.mp4 0 83 0 198 | frames/normal/Normal_009.mp4 0 48 0 199 | frames/normal/Normal_010.mp4 0 53 0 200 | frames/normal/Normal_012.mp4 0 70 0 201 | frames/normal/Normal_013.mp4 0 74 0 202 | frames/normal/Normal_015.mp4 0 75 0 203 | frames/normal/Normal_016.mp4 0 99 0 204 | frames/normal/Normal_020.mp4 0 749 0 205 | frames/normal/Normal_021.mp4 0 227 0 206 | frames/normal/Normal_022.mp4 0 256 0 207 | frames/normal/Normal_023.mp4 0 261 0 208 | frames/normal/Normal_024.mp4 0 256 0 209 | frames/normal/Normal_025.mp4 0 145 0 210 | frames/normal/Normal_027.mp4 0 266 0 211 | frames/normal/Normal_028.mp4 0 285 0 212 | frames/normal/Normal_029.mp4 0 335 0 213 | frames/normal/Normal_031.mp4 0 562 0 214 | frames/normal/Normal_032.mp4 0 595 0 215 | frames/normal/Normal_033.mp4 0 450 0 216 | frames/normal/Normal_034.mp4 0 890 0 217 | frames/normal/Normal_035.mp4 0 1289 0 218 | frames/normal/Normal_036.mp4 0 977 0 219 | frames/normal/Normal_037.mp4 0 1313 0 220 | frames/normal/Normal_038.mp4 0 507 0 221 | frames/normal/Normal_039.mp4 0 280 0 222 | frames/normal/Normal_040.mp4 0 937 0 223 | frames/normal/Normal_042.mp4 0 238 0 224 | frames/normal/Normal_044.mp4 0 390 0 225 | frames/normal/Normal_045.mp4 0 365 0 226 | frames/normal/Normal_046.mp4 0 1289 0 227 | frames/normal/Normal_049.mp4 0 180 0 228 | frames/normal/Normal_051.mp4 0 585 0 229 | frames/normal/Normal_054.mp4 0 492 0 230 | frames/normal/Normal_055.mp4 0 529 0 231 | frames/normal/Normal_056.mp4 0 120 0 232 | frames/normal/Normal_057.mp4 0 90 0 233 | frames/normal/Normal_058.mp4 0 90 0 234 | frames/normal/Normal_059.mp4 0 75 0 235 | frames/normal/Normal_060.mp4 0 180 0 236 | frames/normal/Normal_061.mp4 0 132 0 237 | frames/normal/Normal_062.mp4 0 21 0 238 | frames/normal/Normal_063.mp4 0 140 0 239 | frames/normal/Normal_064.mp4 0 60 0 240 | frames/normal/Normal_065.mp4 0 150 0 241 | frames/normal/Normal_067.mp4 0 94 0 242 | frames/normal/Normal_068.mp4 0 273 0 243 | frames/normal/Normal_070.mp4 0 132 0 244 | frames/normal/Normal_071.mp4 0 93 0 245 | frames/normal/Normal_072.mp4 0 150 0 246 | frames/normal/Normal_074.mp4 0 150 0 247 | frames/normal/Normal_076.mp4 0 30 0 248 | frames/normal/Normal_077.mp4 0 75 0 249 | frames/normal/Normal_078.mp4 0 78 0 250 | frames/normal/Normal_080.mp4 0 48 0 251 | frames/normal/Normal_082.mp4 0 52 0 252 | frames/normal/Normal_083.mp4 0 120 0 253 | frames/normal/Normal_084.mp4 0 96 0 254 | frames/normal/Normal_086.mp4 0 263 0 255 | frames/normal/Normal_088.mp4 0 69 0 256 | frames/normal/Normal_089.mp4 0 45 0 257 | frames/normal/Normal_090.mp4 0 607 0 258 | frames/normal/Normal_091.mp4 0 73 0 259 | frames/normal/Normal_093.mp4 0 308 0 260 | frames/normal/Normal_094.mp4 0 144 0 261 | frames/normal/Normal_096.mp4 0 135 0 262 | frames/normal/Normal_097.mp4 0 207 0 263 | frames/normal/Normal_098.mp4 0 165 0 264 | frames/normal/Normal_099.mp4 0 235 0 265 | frames/normal/Normal_100.mp4 0 153 0 266 | frames/normal/Normal_101.mp4 0 63 0 267 | frames/normal/Normal_102.mp4 0 84 0 268 | frames/normal/Normal_103.mp4 0 167 0 269 | frames/normal/Normal_104.mp4 0 105 0 270 | frames/normal/Normal_105.mp4 0 51 0 271 | frames/normal/Normal_106.mp4 0 169 0 272 | frames/normal/Normal_107.mp4 0 4796 0 273 | frames/normal/Normal_108.mp4 0 2672 0 274 | frames/normal/Normal_109.mp4 0 1665 0 275 | frames/normal/Normal_110.mp4 0 1799 0 276 | frames/normal/Normal_111.mp4 0 3597 0 277 | frames/normal/Normal_112.mp4 0 5684 0 278 | frames/normal/Normal_113.mp4 0 3532 0 279 | frames/normal/Normal_114.mp4 0 3597 0 280 | frames/normal/Normal_115.mp4 0 676 0 281 | frames/normal/Normal_116.mp4 0 5395 0 282 | frames/normal/Normal_117.mp4 0 735 0 283 | frames/normal/Normal_118.mp4 0 3597 0 284 | frames/normal/Normal_119.mp4 0 591 0 285 | frames/normal/Normal_121.mp4 0 900 0 286 | frames/normal/Normal_123.mp4 0 240 0 287 | frames/normal/Normal_124.mp4 0 3000 0 288 | frames/normal/Normal_125.mp4 0 450 0 289 | frames/normal/Normal_127.mp4 0 516 0 290 | frames/normal/Normal_128.mp4 0 2251 0 291 | frames/normal/Normal_129.mp4 0 420 0 292 | frames/normal/Normal_130.mp4 0 2848 0 293 | frames/normal/Normal_131.mp4 0 3108 0 294 | frames/normal/Normal_132.mp4 0 5395 0 295 | frames/normal/Normal_133.mp4 0 180 0 296 | frames/normal/Normal_134.mp4 0 972 0 297 | frames/normal/Normal_135.mp4 0 15000 0 298 | frames/normal/Normal_137.mp4 0 5839 0 299 | frames/normal/Normal_138.mp4 0 4916 0 300 | frames/normal/Normal_140.mp4 0 219 0 301 | frames/normal/Normal_141.mp4 0 8992 0 302 | frames/normal/Normal_142.mp4 0 4040 0 303 | frames/normal/Normal_144.mp4 0 3257 0 304 | frames/normal/Normal_145.mp4 0 2892 0 305 | frames/normal/Normal_146.mp4 0 1755 0 306 | frames/normal/Normal_147.mp4 0 216 0 307 | frames/normal/Normal_148.mp4 0 1799 0 308 | frames/normal/Normal_149.mp4 0 3597 0 309 | frames/normal/Normal_150.mp4 0 421 0 310 | frames/normal/Normal_151.mp4 0 1222 0 311 | frames/normal/Normal_152.mp4 0 1807 0 312 | frames/normal/Normal_153.mp4 0 1093 0 313 | frames/normal/Normal_154.mp4 0 965 0 314 | frames/normal/Normal_156.mp4 0 1619 0 315 | frames/normal/Normal_157.mp4 0 7193 0 316 | frames/normal/Normal_158.mp4 0 930 0 317 | frames/normal/Normal_159.mp4 0 600 0 318 | frames/normal/Normal_161.mp4 0 214 0 319 | frames/normal/Normal_162.mp4 0 6133 0 320 | frames/normal/Normal_163.mp4 0 7440 0 321 | frames/normal/Normal_165.mp4 0 3597 0 322 | frames/normal/Normal_166.mp4 0 509 0 323 | frames/normal/Normal_167.mp4 0 1500 0 324 | frames/normal/Normal_168.mp4 0 723 0 325 | frames/normal/Normal_169.mp4 0 480 0 326 | frames/normal/Normal_171.mp4 0 750 0 327 | frames/normal/Normal_172.mp4 0 180 0 328 | frames/normal/Normal_173.mp4 0 1799 0 329 | frames/normal/Normal_174.mp4 0 720 0 330 | frames/normal/Normal_175.mp4 0 180 0 331 | frames/normal/Normal_176.mp4 0 830 0 332 | frames/normal/Normal_177.mp4 0 3000 0 333 | frames/normal/Normal_179.mp4 0 7951 0 334 | frames/normal/Normal_180.mp4 0 7113 0 335 | frames/normal/Normal_181.mp4 0 1504 0 336 | frames/normal/Normal_182.mp4 0 837 0 337 | frames/normal/Normal_184.mp4 0 3597 0 338 | frames/normal/Normal_185.mp4 0 9565 0 339 | frames/normal/Normal_186.mp4 0 2258 0 340 | frames/normal/Normal_187.mp4 0 2462 0 341 | frames/normal/Normal_188.mp4 0 5095 0 342 | frames/normal/Normal_189.mp4 0 450 0 343 | frames/normal/Normal_191.mp4 0 1799 0 344 | frames/normal/Normal_193.mp4 0 1448 0 345 | frames/normal/Normal_195.mp4 0 1813 0 346 | frames/normal/Normal_196.mp4 0 1799 0 347 | frames/normal/Normal_197.mp4 0 274 0 348 | frames/normal/Normal_198.mp4 0 8992 0 349 | frames/normal/Normal_200.mp4 0 7463 0 350 | frames/normal/Normal_201.mp4 0 480 0 351 | frames/normal/Normal_202.mp4 0 2938 0 352 | frames/normal/Normal_204.mp4 0 1049 0 353 | frames/normal/Normal_205.mp4 0 7500 0 354 | frames/normal/Normal_206.mp4 0 7500 0 355 | frames/normal/Normal_207.mp4 0 3597 0 356 | frames/normal/Normal_208.mp4 0 420 0 357 | frames/normal/Normal_209.mp4 0 7193 0 358 | frames/normal/Normal_210.mp4 0 180 0 359 | frames/normal/Normal_211.mp4 0 450 0 360 | frames/normal/Normal_212.mp4 0 120 0 361 | frames/normal/Normal_213.mp4 0 210 0 362 | frames/normal/Normal_214.mp4 0 3597 0 363 | frames/normal/Normal_216.mp4 0 1500 0 364 | frames/normal/Normal_217.mp4 0 228 0 365 | frames/normal/Normal_218.mp4 0 1288 0 366 | frames/normal/Normal_219.mp4 0 450 0 367 | frames/normal/Normal_220.mp4 0 1754 0 368 | frames/normal/Normal_221.mp4 0 8992 0 369 | frames/normal/Normal_223.mp4 0 3797 0 370 | frames/normal/Normal_224.mp4 0 2698 0 371 | frames/normal/Normal_226.mp4 0 5395 0 372 | frames/normal/Normal_227.mp4 0 1510 0 373 | frames/normal/Normal_228.mp4 0 3106 0 374 | frames/normal/Normal_229.mp4 0 4224 0 375 | frames/normal/Normal_231.mp4 0 3597 0 376 | frames/normal/Normal_232.mp4 0 1799 0 377 | frames/normal/Normal_233.mp4 0 738 0 378 | frames/normal/Normal_234.mp4 0 4130 0 379 | frames/normal/Normal_235.mp4 0 900 0 380 | frames/normal/Normal_236.mp4 0 8992 0 381 | frames/normal/Normal_237.mp4 0 7560 0 382 | frames/normal/Normal_239.mp4 0 3597 0 383 | frames/normal/Normal_242.mp4 0 1499 0 384 | frames/normal/Normal_243.mp4 0 2698 0 385 | frames/normal/Normal_244.mp4 0 1649 0 386 | frames/normal/Normal_245.mp4 0 893 0 387 | frames/normal/Normal_246.mp4 0 900 0 388 | frames/normal/Normal_247.mp4 0 3597 0 389 | frames/normal/Normal_248.mp4 0 1789 0 390 | frames/normal/Normal_249.mp4 0 180 0 391 | frames/normal/Normal_250.mp4 0 300 0 392 | frames/normal/Normal_251.mp4 0 7500 0 393 | frames/normal/Normal_252.mp4 0 10890 0 394 | frames/normal/Normal_253.mp4 0 3597 0 395 | frames/normal/Normal_254.mp4 0 371 0 396 | frames/normal/Normal_255.mp4 0 3597 0 397 | frames/normal/Normal_256.mp4 0 855 0 398 | frames/normal/Normal_257.mp4 0 3597 0 399 | frames/normal/Normal_258.mp4 0 581 0 400 | frames/normal/Normal_259.mp4 0 4416 0 -------------------------------------------------------------------------------- /datasets/rand_augment.py: -------------------------------------------------------------------------------- 1 | """ 2 | This implementation is based on 3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py 4 | pulished under an Apache License 2.0. 5 | 6 | COMMENT FROM ORIGINAL: 7 | AutoAugment, RandAugment, and AugMix for PyTorch 8 | This code implements the searched ImageNet policies with various tweaks and 9 | improvements and does not include any of the search code. AA and RA 10 | Implementation adapted from: 11 | https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py 12 | AugMix adapted from: 13 | https://github.com/google-research/augmix 14 | Papers: 15 | AutoAugment: Learning Augmentation Policies from Data 16 | https://arxiv.org/abs/1805.09501 17 | Learning Data Augmentation Strategies for Object Detection 18 | https://arxiv.org/abs/1906.11172 19 | RandAugment: Practical automated data augmentation... 20 | https://arxiv.org/abs/1909.13719 21 | AugMix: A Simple Data Processing Method to Improve Robustness and 22 | Uncertainty https://arxiv.org/abs/1912.02781 23 | 24 | Hacked together by / Copyright 2020 Ross Wightman 25 | """ 26 | 27 | import math 28 | import numpy as np 29 | import random 30 | import re 31 | import PIL 32 | from PIL import Image, ImageEnhance, ImageOps 33 | 34 | _PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]]) 35 | 36 | _FILL = (128, 128, 128) 37 | 38 | # This signifies the max integer that the controller RNN could predict for the 39 | # augmentation scheme. 40 | _MAX_LEVEL = 10.0 41 | 42 | _HPARAMS_DEFAULT = { 43 | "translate_const": 250, 44 | "img_mean": _FILL, 45 | } 46 | 47 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) 48 | 49 | 50 | def _interpolation(kwargs): 51 | interpolation = kwargs.pop("resample", Image.BILINEAR) 52 | if isinstance(interpolation, (list, tuple)): 53 | return random.choice(interpolation) 54 | else: 55 | return interpolation 56 | 57 | 58 | def _check_args_tf(kwargs): 59 | if "fillcolor" in kwargs and _PIL_VER < (5, 0): 60 | kwargs.pop("fillcolor") 61 | kwargs["resample"] = _interpolation(kwargs) 62 | 63 | 64 | def shear_x(img, factor, **kwargs): 65 | _check_args_tf(kwargs) 66 | return img.transform( 67 | img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs 68 | ) 69 | 70 | 71 | def shear_y(img, factor, **kwargs): 72 | _check_args_tf(kwargs) 73 | return img.transform( 74 | img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs 75 | ) 76 | 77 | 78 | def translate_x_rel(img, pct, **kwargs): 79 | pixels = pct * img.size[0] 80 | _check_args_tf(kwargs) 81 | return img.transform( 82 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs 83 | ) 84 | 85 | 86 | def translate_y_rel(img, pct, **kwargs): 87 | pixels = pct * img.size[1] 88 | _check_args_tf(kwargs) 89 | return img.transform( 90 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs 91 | ) 92 | 93 | 94 | def translate_x_abs(img, pixels, **kwargs): 95 | _check_args_tf(kwargs) 96 | return img.transform( 97 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs 98 | ) 99 | 100 | 101 | def translate_y_abs(img, pixels, **kwargs): 102 | _check_args_tf(kwargs) 103 | return img.transform( 104 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs 105 | ) 106 | 107 | 108 | def rotate(img, degrees, **kwargs): 109 | _check_args_tf(kwargs) 110 | if _PIL_VER >= (5, 2): 111 | return img.rotate(degrees, **kwargs) 112 | elif _PIL_VER >= (5, 0): 113 | w, h = img.size 114 | post_trans = (0, 0) 115 | rotn_center = (w / 2.0, h / 2.0) 116 | angle = -math.radians(degrees) 117 | matrix = [ 118 | round(math.cos(angle), 15), 119 | round(math.sin(angle), 15), 120 | 0.0, 121 | round(-math.sin(angle), 15), 122 | round(math.cos(angle), 15), 123 | 0.0, 124 | ] 125 | 126 | def transform(x, y, matrix): 127 | (a, b, c, d, e, f) = matrix 128 | return a * x + b * y + c, d * x + e * y + f 129 | 130 | matrix[2], matrix[5] = transform( 131 | -rotn_center[0] - post_trans[0], 132 | -rotn_center[1] - post_trans[1], 133 | matrix, 134 | ) 135 | matrix[2] += rotn_center[0] 136 | matrix[5] += rotn_center[1] 137 | return img.transform(img.size, Image.AFFINE, matrix, **kwargs) 138 | else: 139 | return img.rotate(degrees, resample=kwargs["resample"]) 140 | 141 | 142 | def auto_contrast(img, **__): 143 | return ImageOps.autocontrast(img) 144 | 145 | 146 | def invert(img, **__): 147 | return ImageOps.invert(img) 148 | 149 | 150 | def equalize(img, **__): 151 | return ImageOps.equalize(img) 152 | 153 | 154 | def solarize(img, thresh, **__): 155 | return ImageOps.solarize(img, thresh) 156 | 157 | 158 | def solarize_add(img, add, thresh=128, **__): 159 | lut = [] 160 | for i in range(256): 161 | if i < thresh: 162 | lut.append(min(255, i + add)) 163 | else: 164 | lut.append(i) 165 | if img.mode in ("L", "RGB"): 166 | if img.mode == "RGB" and len(lut) == 256: 167 | lut = lut + lut + lut 168 | return img.point(lut) 169 | else: 170 | return img 171 | 172 | 173 | def posterize(img, bits_to_keep, **__): 174 | if bits_to_keep >= 8: 175 | return img 176 | return ImageOps.posterize(img, bits_to_keep) 177 | 178 | 179 | def contrast(img, factor, **__): 180 | return ImageEnhance.Contrast(img).enhance(factor) 181 | 182 | 183 | def color(img, factor, **__): 184 | return ImageEnhance.Color(img).enhance(factor) 185 | 186 | 187 | def brightness(img, factor, **__): 188 | return ImageEnhance.Brightness(img).enhance(factor) 189 | 190 | 191 | def sharpness(img, factor, **__): 192 | return ImageEnhance.Sharpness(img).enhance(factor) 193 | 194 | 195 | def _randomly_negate(v): 196 | """With 50% prob, negate the value""" 197 | return -v if random.random() > 0.5 else v 198 | 199 | 200 | def _rotate_level_to_arg(level, _hparams): 201 | # range [-30, 30] 202 | level = (level / _MAX_LEVEL) * 30.0 203 | level = _randomly_negate(level) 204 | return (level,) 205 | 206 | 207 | def _enhance_level_to_arg(level, _hparams): 208 | # range [0.1, 1.9] 209 | return ((level / _MAX_LEVEL) * 1.8 + 0.1,) 210 | 211 | 212 | def _enhance_increasing_level_to_arg(level, _hparams): 213 | # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend 214 | # range [0.1, 1.9] 215 | level = (level / _MAX_LEVEL) * 0.9 216 | level = 1.0 + _randomly_negate(level) 217 | return (level,) 218 | 219 | 220 | def _shear_level_to_arg(level, _hparams): 221 | # range [-0.3, 0.3] 222 | level = (level / _MAX_LEVEL) * 0.3 223 | level = _randomly_negate(level) 224 | return (level,) 225 | 226 | 227 | def _translate_abs_level_to_arg(level, hparams): 228 | translate_const = hparams["translate_const"] 229 | level = (level / _MAX_LEVEL) * float(translate_const) 230 | level = _randomly_negate(level) 231 | return (level,) 232 | 233 | 234 | def _translate_rel_level_to_arg(level, hparams): 235 | # default range [-0.45, 0.45] 236 | translate_pct = hparams.get("translate_pct", 0.45) 237 | level = (level / _MAX_LEVEL) * translate_pct 238 | level = _randomly_negate(level) 239 | return (level,) 240 | 241 | 242 | def _posterize_level_to_arg(level, _hparams): 243 | # As per Tensorflow TPU EfficientNet impl 244 | # range [0, 4], 'keep 0 up to 4 MSB of original image' 245 | # intensity/severity of augmentation decreases with level 246 | return (int((level / _MAX_LEVEL) * 4),) 247 | 248 | 249 | def _posterize_increasing_level_to_arg(level, hparams): 250 | # As per Tensorflow models research and UDA impl 251 | # range [4, 0], 'keep 4 down to 0 MSB of original image', 252 | # intensity/severity of augmentation increases with level 253 | return (4 - _posterize_level_to_arg(level, hparams)[0],) 254 | 255 | 256 | def _posterize_original_level_to_arg(level, _hparams): 257 | # As per original AutoAugment paper description 258 | # range [4, 8], 'keep 4 up to 8 MSB of image' 259 | # intensity/severity of augmentation decreases with level 260 | return (int((level / _MAX_LEVEL) * 4) + 4,) 261 | 262 | 263 | def _solarize_level_to_arg(level, _hparams): 264 | # range [0, 256] 265 | # intensity/severity of augmentation decreases with level 266 | return (int((level / _MAX_LEVEL) * 256),) 267 | 268 | 269 | def _solarize_increasing_level_to_arg(level, _hparams): 270 | # range [0, 256] 271 | # intensity/severity of augmentation increases with level 272 | return (256 - _solarize_level_to_arg(level, _hparams)[0],) 273 | 274 | 275 | def _solarize_add_level_to_arg(level, _hparams): 276 | # range [0, 110] 277 | return (int((level / _MAX_LEVEL) * 110),) 278 | 279 | 280 | LEVEL_TO_ARG = { 281 | "AutoContrast": None, 282 | "Equalize": None, 283 | "Invert": None, 284 | "Rotate": _rotate_level_to_arg, 285 | # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers 286 | "Posterize": _posterize_level_to_arg, 287 | "PosterizeIncreasing": _posterize_increasing_level_to_arg, 288 | "PosterizeOriginal": _posterize_original_level_to_arg, 289 | "Solarize": _solarize_level_to_arg, 290 | "SolarizeIncreasing": _solarize_increasing_level_to_arg, 291 | "SolarizeAdd": _solarize_add_level_to_arg, 292 | "Color": _enhance_level_to_arg, 293 | "ColorIncreasing": _enhance_increasing_level_to_arg, 294 | "Contrast": _enhance_level_to_arg, 295 | "ContrastIncreasing": _enhance_increasing_level_to_arg, 296 | "Brightness": _enhance_level_to_arg, 297 | "BrightnessIncreasing": _enhance_increasing_level_to_arg, 298 | "Sharpness": _enhance_level_to_arg, 299 | "SharpnessIncreasing": _enhance_increasing_level_to_arg, 300 | "ShearX": _shear_level_to_arg, 301 | "ShearY": _shear_level_to_arg, 302 | "TranslateX": _translate_abs_level_to_arg, 303 | "TranslateY": _translate_abs_level_to_arg, 304 | "TranslateXRel": _translate_rel_level_to_arg, 305 | "TranslateYRel": _translate_rel_level_to_arg, 306 | } 307 | 308 | 309 | NAME_TO_OP = { 310 | "AutoContrast": auto_contrast, 311 | "Equalize": equalize, 312 | "Invert": invert, 313 | "Rotate": rotate, 314 | "Posterize": posterize, 315 | "PosterizeIncreasing": posterize, 316 | "PosterizeOriginal": posterize, 317 | "Solarize": solarize, 318 | "SolarizeIncreasing": solarize, 319 | "SolarizeAdd": solarize_add, 320 | "Color": color, 321 | "ColorIncreasing": color, 322 | "Contrast": contrast, 323 | "ContrastIncreasing": contrast, 324 | "Brightness": brightness, 325 | "BrightnessIncreasing": brightness, 326 | "Sharpness": sharpness, 327 | "SharpnessIncreasing": sharpness, 328 | "ShearX": shear_x, 329 | "ShearY": shear_y, 330 | "TranslateX": translate_x_abs, 331 | "TranslateY": translate_y_abs, 332 | "TranslateXRel": translate_x_rel, 333 | "TranslateYRel": translate_y_rel, 334 | } 335 | 336 | 337 | class AugmentOp: 338 | """ 339 | Apply for video. 340 | """ 341 | 342 | def __init__(self, name, prob=0.5, magnitude=10, hparams=None): 343 | hparams = hparams or _HPARAMS_DEFAULT 344 | self.aug_fn = NAME_TO_OP[name] 345 | self.level_fn = LEVEL_TO_ARG[name] 346 | self.prob = prob 347 | self.magnitude = magnitude 348 | self.hparams = hparams.copy() 349 | self.kwargs = { 350 | "fillcolor": hparams["img_mean"] 351 | if "img_mean" in hparams 352 | else _FILL, 353 | "resample": hparams["interpolation"] 354 | if "interpolation" in hparams 355 | else _RANDOM_INTERPOLATION, 356 | } 357 | 358 | # If magnitude_std is > 0, we introduce some randomness 359 | # in the usually fixed policy and sample magnitude from a normal distribution 360 | # with mean `magnitude` and std-dev of `magnitude_std`. 361 | # NOTE This is my own hack, being tested, not in papers or reference impls. 362 | self.magnitude_std = self.hparams.get("magnitude_std", 0) 363 | 364 | def __call__(self, img_list): 365 | if self.prob < 1.0 and random.random() > self.prob: 366 | return img_list 367 | magnitude = self.magnitude 368 | if self.magnitude_std and self.magnitude_std > 0: 369 | magnitude = random.gauss(magnitude, self.magnitude_std) 370 | magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range 371 | level_args = ( 372 | self.level_fn(magnitude, self.hparams) 373 | if self.level_fn is not None 374 | else () 375 | ) 376 | 377 | if isinstance(img_list, list): 378 | return [ 379 | self.aug_fn(img, *level_args, **self.kwargs) for img in img_list 380 | ] 381 | else: 382 | return self.aug_fn(img_list, *level_args, **self.kwargs) 383 | 384 | 385 | _RAND_TRANSFORMS = [ 386 | "AutoContrast", 387 | "Equalize", 388 | "Invert", 389 | "Rotate", 390 | "Posterize", 391 | "Solarize", 392 | "SolarizeAdd", 393 | "Color", 394 | "Contrast", 395 | "Brightness", 396 | "Sharpness", 397 | "ShearX", 398 | "ShearY", 399 | "TranslateXRel", 400 | "TranslateYRel", 401 | ] 402 | 403 | 404 | _RAND_INCREASING_TRANSFORMS = [ 405 | "AutoContrast", 406 | "Equalize", 407 | "Invert", 408 | "Rotate", 409 | "PosterizeIncreasing", 410 | "SolarizeIncreasing", 411 | "SolarizeAdd", 412 | "ColorIncreasing", 413 | "ContrastIncreasing", 414 | "BrightnessIncreasing", 415 | "SharpnessIncreasing", 416 | "ShearX", 417 | "ShearY", 418 | "TranslateXRel", 419 | "TranslateYRel", 420 | ] 421 | 422 | 423 | # These experimental weights are based loosely on the relative improvements mentioned in paper. 424 | # They may not result in increased performance, but could likely be tuned to so. 425 | _RAND_CHOICE_WEIGHTS_0 = { 426 | "Rotate": 0.3, 427 | "ShearX": 0.2, 428 | "ShearY": 0.2, 429 | "TranslateXRel": 0.1, 430 | "TranslateYRel": 0.1, 431 | "Color": 0.025, 432 | "Sharpness": 0.025, 433 | "AutoContrast": 0.025, 434 | "Solarize": 0.005, 435 | "SolarizeAdd": 0.005, 436 | "Contrast": 0.005, 437 | "Brightness": 0.005, 438 | "Equalize": 0.005, 439 | "Posterize": 0, 440 | "Invert": 0, 441 | } 442 | 443 | 444 | def _select_rand_weights(weight_idx=0, transforms=None): 445 | transforms = transforms or _RAND_TRANSFORMS 446 | assert weight_idx == 0 # only one set of weights currently 447 | rand_weights = _RAND_CHOICE_WEIGHTS_0 448 | probs = [rand_weights[k] for k in transforms] 449 | probs /= np.sum(probs) 450 | return probs 451 | 452 | 453 | def rand_augment_ops(magnitude=10, hparams=None, transforms=None): 454 | hparams = hparams or _HPARAMS_DEFAULT 455 | transforms = transforms or _RAND_TRANSFORMS 456 | return [ 457 | AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) 458 | for name in transforms 459 | ] 460 | 461 | 462 | class RandAugment: 463 | def __init__(self, ops, num_layers=2, choice_weights=None): 464 | self.ops = ops 465 | self.num_layers = num_layers 466 | self.choice_weights = choice_weights 467 | 468 | def __call__(self, img): 469 | # no replacement when using weighted choice 470 | ops = np.random.choice( 471 | self.ops, 472 | self.num_layers, 473 | replace=self.choice_weights is None, 474 | p=self.choice_weights, 475 | ) 476 | for op in ops: 477 | img = op(img) 478 | return img 479 | 480 | 481 | def rand_augment_transform(config_str, hparams): 482 | """ 483 | RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 484 | 485 | Create a RandAugment transform 486 | :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by 487 | dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining 488 | sections, not order sepecific determine 489 | 'm' - integer magnitude of rand augment 490 | 'n' - integer num layers (number of transform ops selected per image) 491 | 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) 492 | 'mstd' - float std deviation of magnitude noise applied 493 | 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) 494 | Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 495 | 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 496 | :param hparams: Other hparams (kwargs) for the RandAugmentation scheme 497 | :return: A PyTorch compatible Transform 498 | """ 499 | magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) 500 | num_layers = 2 # default to 2 ops per image 501 | weight_idx = None # default to no probability weights for op choice 502 | transforms = _RAND_TRANSFORMS 503 | config = config_str.split("-") 504 | assert config[0] == "rand" 505 | config = config[1:] 506 | for c in config: 507 | cs = re.split(r"(\d.*)", c) 508 | if len(cs) < 2: 509 | continue 510 | key, val = cs[:2] 511 | if key == "mstd": 512 | # noise param injected via hparams for now 513 | hparams.setdefault("magnitude_std", float(val)) 514 | elif key == "inc": 515 | if bool(val): 516 | transforms = _RAND_INCREASING_TRANSFORMS 517 | elif key == "m": 518 | magnitude = int(val) 519 | elif key == "n": 520 | num_layers = int(val) 521 | elif key == "w": 522 | weight_idx = int(val) 523 | else: 524 | assert NotImplementedError 525 | ra_ops = rand_augment_ops( 526 | magnitude=magnitude, hparams=hparams, transforms=transforms 527 | ) 528 | choice_weights = ( 529 | None if weight_idx is None else _select_rand_weights(weight_idx) 530 | ) 531 | return ra_ops, num_layers, choice_weights 532 | # return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) 533 | -------------------------------------------------------------------------------- /labels/UCF_test.txt: -------------------------------------------------------------------------------- 1 | frames/Abuse/Abuse028_x264.mp4 1412 1 165 240 -1 -1 2 | frames/Abuse/Abuse030_x264.mp4 1544 1 1275 1360 -1 -1 3 | frames/Arrest/Arrest001_x264.mp4 2374 2 1185 1485 -1 -1 4 | frames/Arrest/Arrest007_x264.mp4 3144 2 1530 2160 -1 -1 5 | frames/Arrest/Arrest024_x264.mp4 3629 2 1005 3105 -1 -1 6 | frames/Arrest/Arrest030_x264.mp4 8642 2 5535 7200 -1 -1 7 | frames/Arrest/Arrest039_x264.mp4 15835 2 7215 10335 -1 -1 8 | frames/Arson/Arson007_x264.mp4 6252 3 2250 5700 -1 -1 9 | frames/Arson/Arson009_x264.mp4 743 3 220 315 -1 -1 10 | frames/Arson/Arson010_x264.mp4 3159 3 885 1230 2300 2500 11 | frames/Arson/Arson011_x264.mp4 1266 3 150 420 680 1267 12 | frames/Arson/Arson016_x264.mp4 1795 3 666 1796 -1 -1 13 | frames/Arson/Arson018_x264.mp4 842 3 270 600 -1 -1 14 | frames/Arson/Arson022_x264.mp4 8640 3 3500 4000 -1 -1 15 | frames/Arson/Arson035_x264.mp4 1437 3 600 900 -1 -1 16 | frames/Arson/Arson041_x264.mp4 3754 3 2130 3615 -1 -1 17 | frames/Assault/Assault006_x264.mp4 8096 4 1185 8096 -1 -1 18 | frames/Assault/Assault010_x264.mp4 16177 4 11330 11680 12260 12930 19 | frames/Assault/Assault011_x264.mp4 2288 4 375 1345 -1 -1 20 | frames/Burglary/Burglary005_x264.mp4 7729 5 4710 5040 -1 -1 21 | frames/Burglary/Burglary017_x264.mp4 2113 5 150 1000 1250 1820 22 | frames/Burglary/Burglary018_x264.mp4 1125 5 510 1050 -1 -1 23 | frames/Burglary/Burglary021_x264.mp4 1537 5 60 200 840 1340 24 | frames/Burglary/Burglary024_x264.mp4 3605 5 60 1230 -1 -1 25 | frames/Burglary/Burglary032_x264.mp4 15795 5 1290 3690 -1 -1 26 | frames/Burglary/Burglary033_x264.mp4 1259 5 60 330 -1 -1 27 | frames/Burglary/Burglary035_x264.mp4 4050 5 120 1740 -1 -1 28 | frames/Burglary/Burglary037_x264.mp4 1920 5 240 390 540 1800 29 | frames/Burglary/Burglary061_x264.mp4 8990 5 4200 5700 -1 -1 30 | frames/Burglary/Burglary076_x264.mp4 12923 5 8400 9400 10500 12600 31 | frames/Burglary/Burglary079_x264.mp4 14853 5 7750 10710 -1 -1 32 | frames/Burglary/Burglary092_x264.mp4 625 5 240 500 -1 -1 33 | frames/Explosion/Explosion002_x264.mp4 4013 6 1500 2100 -1 -1 34 | frames/Explosion/Explosion004_x264.mp4 1902 6 75 225 -1 -1 35 | frames/Explosion/Explosion007_x264.mp4 16289 6 1590 2280 -1 -1 36 | frames/Explosion/Explosion008_x264.mp4 1748 6 1005 1245 -1 -1 37 | frames/Explosion/Explosion010_x264.mp4 2498 6 285 1080 -1 -1 38 | frames/Explosion/Explosion011_x264.mp4 1571 6 795 945 -1 -1 39 | frames/Explosion/Explosion013_x264.mp4 3317 6 2520 2970 -1 -1 40 | frames/Explosion/Explosion016_x264.mp4 963 6 180 450 -1 -1 41 | frames/Explosion/Explosion017_x264.mp4 1643 6 990 1440 -1 -1 42 | frames/Explosion/Explosion020_x264.mp4 1291 6 60 680 -1 -1 43 | frames/Explosion/Explosion021_x264.mp4 782 6 135 600 -1 -1 44 | frames/Explosion/Explosion022_x264.mp4 3594 6 2230 2420 -1 -1 45 | frames/Explosion/Explosion025_x264.mp4 505 6 260 420 -1 -1 46 | frames/Explosion/Explosion027_x264.mp4 776 6 105 180 -1 -1 47 | frames/Explosion/Explosion028_x264.mp4 1705 6 280 700 -1 -1 48 | frames/Explosion/Explosion029_x264.mp4 2410 6 1830 2020 -1 -1 49 | frames/Explosion/Explosion033_x264.mp4 3154 6 970 1350 1550 3156 50 | frames/Explosion/Explosion035_x264.mp4 2865 6 250 350 -1 -1 51 | frames/Explosion/Explosion036_x264.mp4 5327 6 1950 2070 -1 -1 52 | frames/Explosion/Explosion039_x264.mp4 998 6 60 150 678 750 53 | frames/Explosion/Explosion043_x264.mp4 7646 6 4460 4600 -1 -1 54 | frames/Fighting/Fighting003_x264.mp4 3102 7 1820 3103 -1 -1 55 | frames/Fighting/Fighting018_x264.mp4 1389 7 80 420 -1 -1 56 | frames/Fighting/Fighting033_x264.mp4 1105 7 570 840 -1 -1 57 | frames/Fighting/Fighting042_x264.mp4 2237 7 290 1200 -1 -1 58 | frames/Fighting/Fighting047_x264.mp4 4459 7 200 1830 -1 -1 59 | frames/Normal/Normal_Videos_003_x264.mp4 2822 0 -1 -1 -1 -1 60 | frames/Normal/Normal_Videos_006_x264.mp4 450 0 -1 -1 -1 -1 61 | frames/Normal/Normal_Videos_010_x264.mp4 1053 0 -1 -1 -1 -1 62 | frames/Normal/Normal_Videos_014_x264.mp4 1499 0 -1 -1 -1 -1 63 | frames/Normal/Normal_Videos_015_x264.mp4 480 0 -1 -1 -1 -1 64 | frames/Normal/Normal_Videos_018_x264.mp4 1181 0 -1 -1 -1 -1 65 | frames/Normal/Normal_Videos_019_x264.mp4 2843 0 -1 -1 -1 -1 66 | frames/Normal/Normal_Videos_024_x264.mp4 1076 0 -1 -1 -1 -1 67 | frames/Normal/Normal_Videos_025_x264.mp4 602 0 -1 -1 -1 -1 68 | frames/Normal/Normal_Videos_027_x264.mp4 4922 0 -1 -1 -1 -1 69 | frames/Normal/Normal_Videos_033_x264.mp4 1680 0 -1 -1 -1 -1 70 | frames/Normal/Normal_Videos_034_x264.mp4 1318 0 -1 -1 -1 -1 71 | frames/Normal/Normal_Videos_041_x264.mp4 1269 0 -1 -1 -1 -1 72 | frames/Normal/Normal_Videos_042_x264.mp4 3154 0 -1 -1 -1 -1 73 | frames/Normal/Normal_Videos_048_x264.mp4 1650 0 -1 -1 -1 -1 74 | frames/Normal/Normal_Videos_050_x264.mp4 4198 0 -1 -1 -1 -1 75 | frames/Normal/Normal_Videos_051_x264.mp4 2358 0 -1 -1 -1 -1 76 | frames/Normal/Normal_Videos_056_x264.mp4 1572 0 -1 -1 -1 -1 77 | frames/Normal/Normal_Videos_059_x264.mp4 1835 0 -1 -1 -1 -1 78 | frames/Normal/Normal_Videos_063_x264.mp4 355 0 -1 -1 -1 -1 79 | frames/Normal/Normal_Videos_067_x264.mp4 1068 0 -1 -1 -1 -1 80 | frames/Normal/Normal_Videos_070_x264.mp4 993 0 -1 -1 -1 -1 81 | frames/Normal/Normal_Videos_100_x264.mp4 627 0 -1 -1 -1 -1 82 | frames/Normal/Normal_Videos_129_x264.mp4 467 0 -1 -1 -1 -1 83 | frames/Normal/Normal_Videos_150_x264.mp4 864 0 -1 -1 -1 -1 84 | frames/Normal/Normal_Videos_168_x264.mp4 1740 0 -1 -1 -1 -1 85 | frames/Normal/Normal_Videos_175_x264.mp4 8847 0 -1 -1 -1 -1 86 | frames/Normal/Normal_Videos_182_x264.mp4 4094 0 -1 -1 -1 -1 87 | frames/Normal/Normal_Videos_189_x264.mp4 737 0 -1 -1 -1 -1 88 | frames/Normal/Normal_Videos_196_x264.mp4 2004 0 -1 -1 -1 -1 89 | frames/Normal/Normal_Videos_203_x264.mp4 2571 0 -1 -1 -1 -1 90 | frames/Normal/Normal_Videos_210_x264.mp4 5408 0 -1 -1 -1 -1 91 | frames/Normal/Normal_Videos_217_x264.mp4 1817 0 -1 -1 -1 -1 92 | frames/Normal/Normal_Videos_224_x264.mp4 6958 0 -1 -1 -1 -1 93 | frames/Normal/Normal_Videos_246_x264.mp4 4993 0 -1 -1 -1 -1 94 | frames/Normal/Normal_Videos_247_x264.mp4 8211 0 -1 -1 -1 -1 95 | frames/Normal/Normal_Videos_248_x264.mp4 1140 0 -1 -1 -1 -1 96 | frames/Normal/Normal_Videos_251_x264.mp4 405 0 -1 -1 -1 -1 97 | frames/Normal/Normal_Videos_289_x264.mp4 863 0 -1 -1 -1 -1 98 | frames/Normal/Normal_Videos_310_x264.mp4 2518 0 -1 -1 -1 -1 99 | frames/Normal/Normal_Videos_312_x264.mp4 1261 0 -1 -1 -1 -1 100 | frames/Normal/Normal_Videos_317_x264.mp4 929 0 -1 -1 -1 -1 101 | frames/Normal/Normal_Videos_345_x264.mp4 209 0 -1 -1 -1 -1 102 | frames/Normal/Normal_Videos_352_x264.mp4 5403 0 -1 -1 -1 -1 103 | frames/Normal/Normal_Videos_360_x264.mp4 984 0 -1 -1 -1 -1 104 | frames/Normal/Normal_Videos_365_x264.mp4 6626 0 -1 -1 -1 -1 105 | frames/Normal/Normal_Videos_401_x264.mp4 1626 0 -1 -1 -1 -1 106 | frames/Normal/Normal_Videos_417_x264.mp4 1077 0 -1 -1 -1 -1 107 | frames/Normal/Normal_Videos_439_x264.mp4 4226 0 -1 -1 -1 -1 108 | frames/Normal/Normal_Videos_452_x264.mp4 443 0 -1 -1 -1 -1 109 | frames/Normal/Normal_Videos_453_x264.mp4 5322 0 -1 -1 -1 -1 110 | frames/Normal/Normal_Videos_478_x264.mp4 4502 0 -1 -1 -1 -1 111 | frames/Normal/Normal_Videos_576_x264.mp4 11275 0 -1 -1 -1 -1 112 | frames/Normal/Normal_Videos_597_x264.mp4 2229 0 -1 -1 -1 -1 113 | frames/Normal/Normal_Videos_603_x264.mp4 3277 0 -1 -1 -1 -1 114 | frames/Normal/Normal_Videos_606_x264.mp4 1233 0 -1 -1 -1 -1 115 | frames/Normal/Normal_Videos_621_x264.mp4 4802 0 -1 -1 -1 -1 116 | frames/Normal/Normal_Videos_634_x264.mp4 13459 0 -1 -1 -1 -1 117 | frames/Normal/Normal_Videos_641_x264.mp4 3600 0 -1 -1 -1 -1 118 | frames/Normal/Normal_Videos_656_x264.mp4 1815 0 -1 -1 -1 -1 119 | frames/Normal/Normal_Videos_686_x264.mp4 2410 0 -1 -1 -1 -1 120 | frames/Normal/Normal_Videos_696_x264.mp4 3625 0 -1 -1 -1 -1 121 | frames/Normal/Normal_Videos_702_x264.mp4 2523 0 -1 -1 -1 -1 122 | frames/Normal/Normal_Videos_704_x264.mp4 1694 0 -1 -1 -1 -1 123 | frames/Normal/Normal_Videos_710_x264.mp4 1797 0 -1 -1 -1 -1 124 | frames/Normal/Normal_Videos_717_x264.mp4 1255 0 -1 -1 -1 -1 125 | frames/Normal/Normal_Videos_722_x264.mp4 8731 0 -1 -1 -1 -1 126 | frames/Normal/Normal_Videos_725_x264.mp4 920 0 -1 -1 -1 -1 127 | frames/Normal/Normal_Videos_745_x264.mp4 305 0 -1 -1 -1 -1 128 | frames/Normal/Normal_Videos_758_x264.mp4 1589 0 -1 -1 -1 -1 129 | frames/Normal/Normal_Videos_778_x264.mp4 1262 0 -1 -1 -1 -1 130 | frames/Normal/Normal_Videos_780_x264.mp4 2021 0 -1 -1 -1 -1 131 | frames/Normal/Normal_Videos_781_x264.mp4 3975 0 -1 -1 -1 -1 132 | frames/Normal/Normal_Videos_782_x264.mp4 5543 0 -1 -1 -1 -1 133 | frames/Normal/Normal_Videos_783_x264.mp4 9590 0 -1 -1 -1 -1 134 | frames/Normal/Normal_Videos_798_x264.mp4 6001 0 -1 -1 -1 -1 135 | frames/Normal/Normal_Videos_801_x264.mp4 2744 0 -1 -1 -1 -1 136 | frames/Normal/Normal_Videos_828_x264.mp4 930 0 -1 -1 -1 -1 137 | frames/Normal/Normal_Videos_831_x264.mp4 448 0 -1 -1 -1 -1 138 | frames/Normal/Normal_Videos_866_x264.mp4 1198 0 -1 -1 -1 -1 139 | frames/Normal/Normal_Videos_867_x264.mp4 624 0 -1 -1 -1 -1 140 | frames/Normal/Normal_Videos_868_x264.mp4 2401 0 -1 -1 -1 -1 141 | frames/Normal/Normal_Videos_869_x264.mp4 2401 0 -1 -1 -1 -1 142 | frames/Normal/Normal_Videos_870_x264.mp4 601 0 -1 -1 -1 -1 143 | frames/Normal/Normal_Videos_871_x264.mp4 4358 0 -1 -1 -1 -1 144 | frames/Normal/Normal_Videos_872_x264.mp4 530 0 -1 -1 -1 -1 145 | frames/Normal/Normal_Videos_873_x264.mp4 1799 0 -1 -1 -1 -1 146 | frames/Normal/Normal_Videos_874_x264.mp4 4226 0 -1 -1 -1 -1 147 | frames/Normal/Normal_Videos_875_x264.mp4 2565 0 -1 -1 -1 -1 148 | frames/Normal/Normal_Videos_876_x264.mp4 351 0 -1 -1 -1 -1 149 | frames/Normal/Normal_Videos_877_x264.mp4 10025 0 -1 -1 -1 -1 150 | frames/Normal/Normal_Videos_878_x264.mp4 265 0 -1 -1 -1 -1 151 | frames/Normal/Normal_Videos_879_x264.mp4 1143 0 -1 -1 -1 -1 152 | frames/Normal/Normal_Videos_880_x264.mp4 18054 0 -1 -1 -1 -1 153 | frames/Normal/Normal_Videos_881_x264.mp4 224 0 -1 -1 -1 -1 154 | frames/Normal/Normal_Videos_882_x264.mp4 1654 0 -1 -1 -1 -1 155 | frames/Normal/Normal_Videos_883_x264.mp4 326 0 -1 -1 -1 -1 156 | frames/Normal/Normal_Videos_884_x264.mp4 9029 0 -1 -1 -1 -1 157 | frames/Normal/Normal_Videos_885_x264.mp4 474 0 -1 -1 -1 -1 158 | frames/Normal/Normal_Videos_886_x264.mp4 2747 0 -1 -1 -1 -1 159 | frames/Normal/Normal_Videos_887_x264.mp4 7630 0 -1 -1 -1 -1 160 | frames/Normal/Normal_Videos_888_x264.mp4 574 0 -1 -1 -1 -1 161 | frames/Normal/Normal_Videos_889_x264.mp4 314 0 -1 -1 -1 -1 162 | frames/Normal/Normal_Videos_890_x264.mp4 3579 0 -1 -1 -1 -1 163 | frames/Normal/Normal_Videos_891_x264.mp4 1800 0 -1 -1 -1 -1 164 | frames/Normal/Normal_Videos_892_x264.mp4 1770 0 -1 -1 -1 -1 165 | frames/Normal/Normal_Videos_893_x264.mp4 6374 0 -1 -1 -1 -1 166 | frames/Normal/Normal_Videos_894_x264.mp4 2575 0 -1 -1 -1 -1 167 | frames/Normal/Normal_Videos_895_x264.mp4 3030 0 -1 -1 -1 -1 168 | frames/Normal/Normal_Videos_896_x264.mp4 2303 0 -1 -1 -1 -1 169 | frames/Normal/Normal_Videos_897_x264.mp4 875 0 -1 -1 -1 -1 170 | frames/Normal/Normal_Videos_898_x264.mp4 1005 0 -1 -1 -1 -1 171 | frames/Normal/Normal_Videos_899_x264.mp4 1380 0 -1 -1 -1 -1 172 | frames/Normal/Normal_Videos_900_x264.mp4 1455 0 -1 -1 -1 -1 173 | frames/Normal/Normal_Videos_901_x264.mp4 1170 0 -1 -1 -1 -1 174 | frames/Normal/Normal_Videos_902_x264.mp4 1384 0 -1 -1 -1 -1 175 | frames/Normal/Normal_Videos_903_x264.mp4 790 0 -1 -1 -1 -1 176 | frames/Normal/Normal_Videos_904_x264.mp4 908 0 -1 -1 -1 -1 177 | frames/Normal/Normal_Videos_905_x264.mp4 1196 0 -1 -1 -1 -1 178 | frames/Normal/Normal_Videos_906_x264.mp4 676 0 -1 -1 -1 -1 179 | frames/Normal/Normal_Videos_907_x264.mp4 598 0 -1 -1 -1 -1 180 | frames/Normal/Normal_Videos_908_x264.mp4 889 0 -1 -1 -1 -1 181 | frames/Normal/Normal_Videos_909_x264.mp4 870 0 -1 -1 -1 -1 182 | frames/Normal/Normal_Videos_910_x264.mp4 567 0 -1 -1 -1 -1 183 | frames/Normal/Normal_Videos_911_x264.mp4 776 0 -1 -1 -1 -1 184 | frames/Normal/Normal_Videos_912_x264.mp4 746 0 -1 -1 -1 -1 185 | frames/Normal/Normal_Videos_913_x264.mp4 609 0 -1 -1 -1 -1 186 | frames/Normal/Normal_Videos_914_x264.mp4 880 0 -1 -1 -1 -1 187 | frames/Normal/Normal_Videos_915_x264.mp4 1245 0 -1 -1 -1 -1 188 | frames/Normal/Normal_Videos_923_x264.mp4 18224 0 -1 -1 -1 -1 189 | frames/Normal/Normal_Videos_924_x264.mp4 107997 0 -1 -1 -1 -1 190 | frames/Normal/Normal_Videos_925_x264.mp4 7726 0 -1 -1 -1 -1 191 | frames/Normal/Normal_Videos_926_x264.mp4 1796 0 -1 -1 -1 -1 192 | frames/Normal/Normal_Videos_927_x264.mp4 1633 0 -1 -1 -1 -1 193 | frames/Normal/Normal_Videos_928_x264.mp4 918 0 -1 -1 -1 -1 194 | frames/Normal/Normal_Videos_929_x264.mp4 924 0 -1 -1 -1 -1 195 | frames/Normal/Normal_Videos_930_x264.mp4 3187 0 -1 -1 -1 -1 196 | frames/Normal/Normal_Videos_931_x264.mp4 1765 0 -1 -1 -1 -1 197 | frames/Normal/Normal_Videos_932_x264.mp4 1784 0 -1 -1 -1 -1 198 | frames/Normal/Normal_Videos_933_x264.mp4 1771 0 -1 -1 -1 -1 199 | frames/Normal/Normal_Videos_934_x264.mp4 1765 0 -1 -1 -1 -1 200 | frames/Normal/Normal_Videos_935_x264.mp4 107994 0 -1 -1 -1 -1 201 | frames/Normal/Normal_Videos_936_x264.mp4 1150 0 -1 -1 -1 -1 202 | frames/Normal/Normal_Videos_937_x264.mp4 1150 0 -1 -1 -1 -1 203 | frames/Normal/Normal_Videos_938_x264.mp4 4827 0 -1 -1 -1 -1 204 | frames/Normal/Normal_Videos_939_x264.mp4 801 0 -1 -1 -1 -1 205 | frames/Normal/Normal_Videos_940_x264.mp4 36017 0 -1 -1 -1 -1 206 | frames/Normal/Normal_Videos_941_x264.mp4 2020 0 -1 -1 -1 -1 207 | frames/Normal/Normal_Videos_943_x264.mp4 1020 0 -1 -1 -1 -1 208 | frames/Normal/Normal_Videos_944_x264.mp4 7170 0 -1 -1 -1 -1 209 | frames/RoadAccidents/RoadAccidents001_x264.mp4 1366 8 210 300 -1 -1 210 | frames/RoadAccidents/RoadAccidents002_x264.mp4 347 8 240 300 -1 -1 211 | frames/RoadAccidents/RoadAccidents004_x264.mp4 389 8 140 189 -1 -1 212 | frames/RoadAccidents/RoadAccidents009_x264.mp4 918 8 210 240 -1 -1 213 | frames/RoadAccidents/RoadAccidents010_x264.mp4 528 8 230 270 -1 -1 214 | frames/RoadAccidents/RoadAccidents011_x264.mp4 2159 8 260 300 -1 -1 215 | frames/RoadAccidents/RoadAccidents012_x264.mp4 468 8 250 390 -1 -1 216 | frames/RoadAccidents/RoadAccidents016_x264.mp4 2192 8 530 720 -1 -1 217 | frames/RoadAccidents/RoadAccidents017_x264.mp4 243 8 60 130 -1 -1 218 | frames/RoadAccidents/RoadAccidents019_x264.mp4 1314 8 750 900 -1 -1 219 | frames/RoadAccidents/RoadAccidents020_x264.mp4 1773 8 610 730 -1 -1 220 | frames/RoadAccidents/RoadAccidents021_x264.mp4 155 8 30 90 -1 -1 221 | frames/RoadAccidents/RoadAccidents022_x264.mp4 716 8 120 220 490 560 222 | frames/RoadAccidents/RoadAccidents121_x264.mp4 1835 8 330 390 -1 -1 223 | frames/RoadAccidents/RoadAccidents122_x264.mp4 647 8 300 360 -1 -1 224 | frames/RoadAccidents/RoadAccidents123_x264.mp4 1005 8 130 210 -1 -1 225 | frames/RoadAccidents/RoadAccidents124_x264.mp4 1495 8 250 420 -1 -1 226 | frames/RoadAccidents/RoadAccidents125_x264.mp4 1771 8 490 600 -1 -1 227 | frames/RoadAccidents/RoadAccidents127_x264.mp4 2580 8 2160 2300 -1 -1 228 | frames/RoadAccidents/RoadAccidents128_x264.mp4 565 8 90 200 -1 -1 229 | frames/RoadAccidents/RoadAccidents131_x264.mp4 1524 8 180 240 -1 -1 230 | frames/RoadAccidents/RoadAccidents132_x264.mp4 1862 8 220 320 -1 -1 231 | frames/RoadAccidents/RoadAccidents133_x264.mp4 673 8 270 450 -1 -1 232 | frames/Robbery/Robbery048_x264.mp4 1409 9 450 930 -1 -1 233 | frames/Robbery/Robbery050_x264.mp4 1701 9 495 1410 -1 -1 234 | frames/Robbery/Robbery102_x264.mp4 1827 9 1080 1560 -1 -1 235 | frames/Robbery/Robbery106_x264.mp4 1197 9 480 600 -1 -1 236 | frames/Robbery/Robbery137_x264.mp4 2193 9 135 1950 -1 -1 237 | frames/Shooting/Shooting002_x264.mp4 1206 10 1020 1100 -1 -1 238 | frames/Shooting/Shooting004_x264.mp4 1793 10 500 660 -1 -1 239 | frames/Shooting/Shooting007_x264.mp4 1430 10 45 165 -1 -1 240 | frames/Shooting/Shooting008_x264.mp4 1625 10 75 315 -1 -1 241 | frames/Shooting/Shooting010_x264.mp4 2641 10 1095 1260 -1 -1 242 | frames/Shooting/Shooting011_x264.mp4 4003 10 1480 1750 -1 -1 243 | frames/Shooting/Shooting013_x264.mp4 1073 10 860 945 -1 -1 244 | frames/Shooting/Shooting015_x264.mp4 1713 10 855 1715 -1 -1 245 | frames/Shooting/Shooting018_x264.mp4 1799 10 315 480 -1 -1 246 | frames/Shooting/Shooting019_x264.mp4 2756 10 1020 1455 -1 -1 247 | frames/Shooting/Shooting021_x264.mp4 1275 10 480 630 -1 -1 248 | frames/Shooting/Shooting022_x264.mp4 4554 10 2850 3300 -1 -1 249 | frames/Shooting/Shooting024_x264.mp4 2003 10 720 1305 1720 1780 250 | frames/Shooting/Shooting026_x264.mp4 1403 10 195 600 -1 -1 251 | frames/Shooting/Shooting028_x264.mp4 1898 10 285 555 -1 -1 252 | frames/Shooting/Shooting032_x264.mp4 21681 10 7995 8205 -1 -1 253 | frames/Shooting/Shooting033_x264.mp4 3630 10 1680 2000 -1 -1 254 | frames/Shooting/Shooting034_x264.mp4 1409 10 960 1050 -1 -1 255 | frames/Shooting/Shooting037_x264.mp4 305 10 140 260 -1 -1 256 | frames/Shooting/Shooting043_x264.mp4 1874 10 945 1750 -1 -1 257 | frames/Shooting/Shooting046_x264.mp4 5088 10 4005 4230 4760 5088 258 | frames/Shooting/Shooting047_x264.mp4 8287 10 2160 3900 4860 6600 259 | frames/Shooting/Shooting048_x264.mp4 2741 10 1410 1730 -1 -1 260 | frames/Shoplifting/Shoplifting001_x264.mp4 4344 11 1550 2000 -1 -1 261 | frames/Shoplifting/Shoplifting004_x264.mp4 6673 11 2200 4900 -1 -1 262 | frames/Shoplifting/Shoplifting005_x264.mp4 1967 11 720 930 -1 -1 263 | frames/Shoplifting/Shoplifting007_x264.mp4 5124 11 550 760 4630 4920 264 | frames/Shoplifting/Shoplifting010_x264.mp4 2736 11 750 920 1550 1970 265 | frames/Shoplifting/Shoplifting015_x264.mp4 2256 11 2010 2160 -1 -1 266 | frames/Shoplifting/Shoplifting016_x264.mp4 1483 11 630 720 -1 -1 267 | frames/Shoplifting/Shoplifting017_x264.mp4 457 11 360 420 -1 -1 268 | frames/Shoplifting/Shoplifting020_x264.mp4 5770 11 2340 2460 -1 -1 269 | frames/Shoplifting/Shoplifting021_x264.mp4 3551 11 2070 2220 -1 -1 270 | frames/Shoplifting/Shoplifting022_x264.mp4 2191 11 270 420 1440 1560 271 | frames/Shoplifting/Shoplifting027_x264.mp4 1873 11 1080 1160 1470 1710 272 | frames/Shoplifting/Shoplifting028_x264.mp4 1357 11 570 840 -1 -1 273 | frames/Shoplifting/Shoplifting029_x264.mp4 2176 11 1020 1470 -1 -1 274 | frames/Shoplifting/Shoplifting031_x264.mp4 447 11 120 330 -1 -1 275 | frames/Shoplifting/Shoplifting033_x264.mp4 899 11 630 750 -1 -1 276 | frames/Shoplifting/Shoplifting034_x264.mp4 11937 11 7350 7470 -1 -1 277 | frames/Shoplifting/Shoplifting037_x264.mp4 1386 11 1140 1200 -1 -1 278 | frames/Shoplifting/Shoplifting039_x264.mp4 2803 11 2190 2340 -1 -1 279 | frames/Shoplifting/Shoplifting044_x264.mp4 14555 11 11070 11250 -1 -1 280 | frames/Shoplifting/Shoplifting049_x264.mp4 2149 11 1020 1350 -1 -1 281 | frames/Stealing/Stealing019_x264.mp4 4911 12 2730 2790 4170 4350 282 | frames/Stealing/Stealing036_x264.mp4 2503 12 1260 1590 -1 -1 283 | frames/Stealing/Stealing058_x264.mp4 4991 12 570 3660 -1 -1 284 | frames/Stealing/Stealing062_x264.mp4 1560 12 360 1050 -1 -1 285 | frames/Stealing/Stealing079_x264.mp4 5846 12 2550 3210 3510 4500 286 | frames/Vandalism/Vandalism007_x264.mp4 1146 13 240 750 -1 -1 287 | frames/Vandalism/Vandalism015_x264.mp4 2982 13 2010 2700 -1 -1 288 | frames/Vandalism/Vandalism017_x264.mp4 1011 13 270 330 780 840 289 | frames/Vandalism/Vandalism028_x264.mp4 4495 13 1830 1980 2400 2670 290 | frames/Vandalism/Vandalism036_x264.mp4 1443 13 540 780 990 1080 291 | -------------------------------------------------------------------------------- /datasets/build.py: -------------------------------------------------------------------------------- 1 | from logging import Logger 2 | from torch.utils.data import DataLoader 3 | import torch.distributed as dist 4 | import torch 5 | import numpy as np 6 | from functools import partial 7 | import random 8 | 9 | import io 10 | import os 11 | import os.path as osp 12 | import shutil 13 | import warnings 14 | from collections.abc import Mapping, Sequence 15 | from mmcv.utils import Registry, build_from_cfg 16 | from torch.utils.data import Dataset 17 | import copy 18 | import os.path as osp 19 | import warnings 20 | from abc import ABCMeta, abstractmethod 21 | from collections import OrderedDict, defaultdict 22 | import os.path as osp 23 | import mmcv 24 | import numpy as np 25 | import torch 26 | import math 27 | import tarfile 28 | from .pipeline import * 29 | from torch.utils.data import DataLoader 30 | from torch.utils.data.dataloader import default_collate 31 | from mmcv.parallel import collate 32 | import pandas as pd 33 | 34 | PIPELINES = Registry('pipeline') 35 | img_norm_cfg = dict( 36 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) 37 | 38 | 39 | class RawFramesTestRecord(object): 40 | def __init__(self, row, temp_label=None): 41 | self._data = row 42 | self.temp_label = temp_label 43 | 44 | @property 45 | def path(self): 46 | return self._data[0] 47 | 48 | @property 49 | def num_frames(self): 50 | if int(self._data[1])==0: 51 | return int(self._data[2]) 52 | else: 53 | return int(self._data[1]) 54 | 55 | @property 56 | def label(self): 57 | if len(self._data) == 15: 58 | label = np.zeros((32, 1)) 59 | class_onehot = [] 60 | for i in range(2, len(self._data)): 61 | class_onehot.append(int(self._data[i])) 62 | if self.temp_label is None: 63 | return label 64 | else: 65 | for j in range(len(self.temp_label)): 66 | start = round(self.temp_label[j][0] * 32.0 / self.num_frames) 67 | end = round(self.temp_label[j][1] * 32.0 / self.num_frames) 68 | label[start:end] += 1 69 | 70 | if start > 0: 71 | label[start - 1] = 2 72 | if end < 32: 73 | label[end] = 2 74 | if end < 31: 75 | label[end - 1] = 2 76 | label[start] = 2 77 | 78 | return label 79 | else: 80 | if len(self._data) > 2 and int(self._data[1])>0: 81 | return int(self._data[2]) 82 | elif len(self._data) > 2 and int(self._data[1])==0: 83 | return int(self._data[3]) 84 | else: 85 | return 0 86 | 87 | 88 | class BaseDataset(Dataset, metaclass=ABCMeta): 89 | def __init__(self, 90 | ann_file, 91 | pipeline, 92 | repeat = 1, 93 | pipeline_=None, 94 | data_prefix=None, 95 | test_mode=False, 96 | multi_class=False, 97 | num_classes=None, 98 | start_index=1, 99 | modality='RGB', 100 | sample_by_class=False, 101 | filename_tmpl='img_{:08}.jpg', 102 | seg_interval=30, 103 | power=0, 104 | dynamic_length=False,): 105 | super().__init__() 106 | self.use_tar_format = True if ".tar" in data_prefix else False 107 | data_prefix = data_prefix.replace(".tar", "") 108 | self.ann_file = ann_file 109 | self.repeat = repeat 110 | self.data_prefix = osp.realpath( 111 | data_prefix) if data_prefix is not None and osp.isdir( 112 | data_prefix) else data_prefix 113 | self.filename_tmpl = filename_tmpl 114 | self.test_mode = test_mode 115 | self.multi_class = multi_class 116 | self.num_classes = num_classes 117 | self.start_index = start_index 118 | self.modality = modality 119 | self.sample_by_class = sample_by_class 120 | self.power = power 121 | self.seg_interval = seg_interval 122 | self.dynamic_length = dynamic_length 123 | 124 | assert not (self.multi_class and self.sample_by_class) 125 | 126 | self.pipeline = Compose(pipeline) 127 | if pipeline_ is not None: 128 | self.pipeline_ = Compose(pipeline_) 129 | self.repeat += 1 130 | self.video_infos = self.load_annotations() 131 | if self.sample_by_class: 132 | self.video_infos_by_class = self.parse_by_class() 133 | 134 | class_prob = [] 135 | for _, samples in self.video_infos_by_class.items(): 136 | class_prob.append(len(samples) / len(self.video_infos)) 137 | class_prob = [x**self.power for x in class_prob] 138 | 139 | summ = sum(class_prob) 140 | class_prob = [x / summ for x in class_prob] 141 | 142 | self.class_prob = dict(zip(self.video_infos_by_class, class_prob)) 143 | 144 | @abstractmethod 145 | def load_annotations(self): 146 | """Load the annotation according to ann_file into video_infos.""" 147 | 148 | # json annotations already looks like video_infos, so for each dataset, 149 | # this func should be the same 150 | def load_json_annotations(self): 151 | """Load json annotation file to get video information.""" 152 | video_infos = mmcv.load(self.ann_file) 153 | num_videos = len(video_infos) 154 | path_key = 'frame_dir' if 'frame_dir' in video_infos[0] else 'filename' 155 | for i in range(num_videos): 156 | path_value = video_infos[i][path_key] 157 | if self.data_prefix is not None: 158 | path_value = osp.join(self.data_prefix, path_value) 159 | video_infos[i][path_key] = path_value 160 | if self.multi_class: 161 | assert self.num_classes is not None 162 | else: 163 | assert len(video_infos[i]['label']) == 1 164 | video_infos[i]['label'] = video_infos[i]['label'][0] 165 | return video_infos 166 | 167 | def parse_by_class(self): 168 | video_infos_by_class = defaultdict(list) 169 | for item in self.video_infos: 170 | label = item['label'] 171 | video_infos_by_class[label].append(item) 172 | return video_infos_by_class 173 | 174 | @staticmethod 175 | def label2array(num, label): 176 | arr = np.zeros(num, dtype=np.float32) 177 | arr[label] = 1. 178 | return arr 179 | 180 | @staticmethod 181 | def dump_results(results, out): 182 | """Dump data to json/yaml/pickle strings or files.""" 183 | return mmcv.dump(results, out) 184 | 185 | def prepare_train_frames(self, idx): 186 | """Prepare the frames for training given the index.""" 187 | results = copy.deepcopy(self.video_infos[idx]) 188 | results['modality'] = self.modality 189 | results['start_index'] = self.start_index 190 | results['filename_tmpl'] = self.filename_tmpl 191 | # prepare tensor in getitem 192 | # If HVU, type(results['label']) is dict 193 | if self.multi_class and isinstance(results['label'], list): 194 | onehot = torch.zeros(self.num_classes) 195 | onehot[results['label']] = 1. 196 | results['label'] = onehot 197 | 198 | aug1 = self.pipeline(results) 199 | # import pdb;pdb.set_trace() 200 | if self.repeat > 1: 201 | aug2 = self.pipeline_(results) 202 | ret = {"imgs": torch.stack((aug1['imgs'], aug2['imgs']), 0), 203 | "label": aug1['label'].repeat(2), 204 | "vid": aug1['vid'], 205 | 'frame_inds': aug1['frame_inds'], 206 | 'total_frames': aug1['total_frames'], 207 | } 208 | return ret 209 | else: 210 | return aug1 211 | 212 | def prepare_test_frames(self, idx): 213 | """Prepare the frames for testing given the index.""" 214 | results = copy.deepcopy(self.video_infos[idx]) 215 | results['modality'] = self.modality 216 | results['start_index'] = self.start_index 217 | 218 | # prepare tensor in getitem 219 | # If HVU, type(results['label']) is dict 220 | if self.multi_class and isinstance(results['label'], list): 221 | onehot = torch.zeros(self.num_classes) 222 | onehot[results['label']] = 1. 223 | results['label'] = onehot 224 | 225 | return self.pipeline(results) 226 | 227 | def __len__(self): 228 | """Get the size of the dataset.""" 229 | return len(self.video_infos) 230 | 231 | def __getitem__(self, idx): 232 | """Get the sample for either training or testing given index.""" 233 | if self.test_mode: 234 | return self.prepare_test_frames(idx) 235 | 236 | return self.prepare_train_frames(idx) 237 | 238 | 239 | class FrameDataset(BaseDataset): 240 | def __init__(self, ann_file, pipeline, labels_file, start_index=0, **kwargs): 241 | super().__init__(ann_file, pipeline, start_index=start_index, **kwargs) 242 | self.labels_file = labels_file 243 | 244 | @property 245 | def classes(self): 246 | classes_all = pd.read_csv(self.labels_file) 247 | return classes_all.values.tolist() 248 | 249 | def load_annotations(self): 250 | """Load annotation file to get video information.""" 251 | if self.ann_file.endswith('.json'): 252 | return self.load_json_annotations() 253 | vid = 0 254 | video_infos = [] 255 | with open(self.ann_file, 'r') as fin: 256 | for line in fin: 257 | line_split = line.strip().split() 258 | if self.multi_class: 259 | assert self.num_classes is not None 260 | filename, label = line_split[0], line_split[1:] 261 | label = list(map(int, label)) 262 | else: 263 | if len(line_split) == 4: 264 | filename, start, end, label = line_split 265 | elif len(line_split) == 5: 266 | filename, end, label, _, _ = line_split 267 | start = 0 268 | else: 269 | filename, end, label = line_split[:3] 270 | start = 0 271 | label = int(label) 272 | if self.data_prefix is not None and self.data_prefix not in filename: 273 | filename = osp.join(self.data_prefix, filename) 274 | video_infos.append(dict(frame_dir=filename, label=label, total_frames=int(end)-int(start), tar=self.use_tar_format, vid=vid)) 275 | vid += 1 276 | return video_infos 277 | 278 | 279 | class RawFramesTestDataset(BaseDataset): 280 | def __init__(self, ann_file, pipeline, labels_file, start_index=0, **kwargs): 281 | super().__init__(ann_file, pipeline, start_index=start_index, **kwargs) 282 | self.labels_file = labels_file 283 | 284 | @property 285 | def classes(self): 286 | classes_all = pd.read_csv(self.labels_file) 287 | return classes_all.values.tolist() 288 | 289 | def load_annotations(self): 290 | segs = [] 291 | path_list = [] 292 | vid = 0 293 | for x in open(self.ann_file): 294 | video_info = RawFramesTestRecord(x.strip().split(' ')) 295 | path_list.append(video_info.path) 296 | num_segs = math.ceil(video_info.num_frames / self.seg_interval) 297 | for i in range(num_segs): 298 | start = self.seg_interval * i 299 | end = min(self.seg_interval * (i + 1), video_info.num_frames) 300 | if end - start < 5: 301 | continue 302 | filename = video_info.path 303 | if self.data_prefix is not None: 304 | filename = osp.join(self.data_prefix, filename) 305 | seg = dict(frame_dir=filename, label=video_info.label, start=int(start), total_frames=int(end)-int(start), tar=self.use_tar_format, vid=vid) 306 | segs.append(seg) 307 | vid += 1 308 | return segs 309 | 310 | 311 | class VideoDataset(BaseDataset): 312 | def __init__(self, ann_file, pipeline, labels_file, start_index=0, **kwargs): 313 | super().__init__(ann_file, pipeline, start_index=start_index, **kwargs) 314 | self.labels_file = labels_file 315 | 316 | @property 317 | def classes(self): 318 | classes_all = pd.read_csv(self.labels_file) 319 | return classes_all.values.tolist() 320 | 321 | def load_annotations(self): 322 | """Load annotation file to get video information.""" 323 | if self.ann_file.endswith('.json'): 324 | return self.load_json_annotations() 325 | 326 | video_infos = [] 327 | with open(self.ann_file, 'r') as fin: 328 | for line in fin: 329 | line_split = line.strip().split() 330 | import pdb; 331 | pdb.set_trace() 332 | if self.multi_class: 333 | assert self.num_classes is not None 334 | filename, label = line_split[0], line_split[1:] 335 | label = list(map(int, label)) 336 | else: 337 | filename, label = line_split 338 | label = int(label) 339 | if self.data_prefix is not None: 340 | filename = osp.join(self.data_prefix, filename) 341 | video_infos.append(dict(filename=filename, label=label, tar=self.use_tar_format)) 342 | 343 | return video_infos 344 | 345 | 346 | class SubsetRandomSampler(torch.utils.data.Sampler): 347 | r"""Samples elements randomly from a given list of indices, without replacement. 348 | 349 | Arguments: 350 | indices (sequence): a sequence of indices 351 | """ 352 | 353 | def __init__(self, indices): 354 | self.epoch = 0 355 | self.indices = indices 356 | 357 | def __iter__(self): 358 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 359 | 360 | def __len__(self): 361 | return len(self.indices) 362 | 363 | def set_epoch(self, epoch): 364 | self.epoch = epoch 365 | 366 | 367 | def mmcv_collate(batch, samples_per_gpu=1): 368 | if not isinstance(batch, Sequence): 369 | raise TypeError(f'{batch.dtype} is not supported.') 370 | if isinstance(batch[0], Sequence): 371 | transposed = zip(*batch) 372 | return [collate(samples, samples_per_gpu) for samples in transposed] 373 | elif isinstance(batch[0], Mapping): 374 | return { 375 | key: mmcv_collate([d[key] for d in batch], samples_per_gpu) 376 | for key in batch[0] 377 | } 378 | else: 379 | return default_collate(batch) 380 | 381 | 382 | def build_dataloader(logger, config): 383 | scale_resize = int(256 / 224 * config.DATA.INPUT_SIZE) 384 | 385 | train_pipeline = [ 386 | dict(type='SampleFrames', clip_len=config.DATA.NUM_FRAMES, frame_interval=config.DATA.FRAME_INTERVAL, num_clips=config.DATA.NUM_CLIPS), 387 | dict(type='RawFrameDecode'), 388 | dict(type='Resize', scale=(-1, scale_resize)), 389 | dict(type='CenterCrop', crop_size=config.DATA.INPUT_SIZE), 390 | dict(type='Normalize', **img_norm_cfg), 391 | dict(type='FormatShape', input_format='NCTHW'), 392 | dict(type='Collect', keys=['imgs', 'label', 'vid', 'frame_inds', 'total_frames'], meta_keys=[]), 393 | dict(type='ToTensor', keys=['imgs', 'label']), 394 | ] 395 | 396 | train_pipeline_S = [ 397 | dict(type='SampleFrames', clip_len=config.DATA.NUM_FRAMES, frame_interval=config.DATA.FRAME_INTERVAL, 398 | num_clips=config.DATA.NUM_CLIPS), 399 | dict(type='RawFrameDecode'), 400 | dict(type='Resize', scale=(-1, scale_resize)), 401 | dict( 402 | type='MultiScaleCrop', 403 | input_size=config.DATA.INPUT_SIZE, 404 | scales=(1, 0.875, 0.75, 0.66), 405 | random_crop=False, 406 | max_wh_scale_gap=1), 407 | dict(type='Resize', scale=(config.DATA.INPUT_SIZE, config.DATA.INPUT_SIZE), keep_ratio=False), 408 | dict(type='ColorJitter', p=config.AUG.COLOR_JITTER), 409 | dict(type='GrayScale', p=config.AUG.GRAY_SCALE), 410 | dict(type='RandAugment', auto_augment='rand-n{}-m{}-mstd0.5'.format(2, 10)), 411 | dict(type='Normalize', **img_norm_cfg), 412 | dict(type='FormatShape', input_format='NCTHW'), 413 | dict(type='Collect', keys=['imgs', 'label', 'vid', 'frame_inds', 'total_frames'], meta_keys=[]), 414 | dict(type='ToTensor', keys=['imgs', 'label']), 415 | ] 416 | 417 | 418 | train_data = FrameDataset(ann_file=config.DATA.TRAIN_FILE, data_prefix=config.DATA.ROOT, 419 | filename_tmpl=config.DATA.FILENAME_TMPL, labels_file=config.DATA.LABEL_LIST, 420 | pipeline=train_pipeline, pipeline_=train_pipeline_S) 421 | num_tasks = dist.get_world_size() 422 | global_rank = dist.get_rank() 423 | sampler_train = torch.utils.data.DistributedSampler( 424 | train_data, num_replicas=num_tasks, rank=global_rank, shuffle=True 425 | ) 426 | train_loader = DataLoader( 427 | train_data, sampler=sampler_train, 428 | batch_size=config.TRAIN.BATCH_SIZE, 429 | num_workers=8, 430 | pin_memory=True, 431 | drop_last=True, 432 | collate_fn=partial(mmcv_collate, samples_per_gpu=config.TRAIN.BATCH_SIZE), 433 | ) 434 | train_loader_umil = DataLoader( 435 | train_data, sampler=sampler_train, 436 | batch_size=config.TRAIN.BATCH_SIZE_UMIL, 437 | num_workers=8, 438 | pin_memory=True, 439 | drop_last=True, 440 | collate_fn=partial(mmcv_collate, samples_per_gpu=config.TRAIN.BATCH_SIZE_UMIL), 441 | ) 442 | 443 | val_pipeline = [ 444 | dict(type='SampleFrames', clip_len=config.DATA.NUM_FRAMES, frame_interval=config.DATA.FRAME_INTERVAL, num_clips=config.DATA.NUM_CLIPS, test_mode=True), 445 | dict(type='RawFrameDecode'), 446 | dict(type='Resize', scale=(-1, scale_resize)), 447 | dict(type='CenterCrop', crop_size=config.DATA.INPUT_SIZE), 448 | dict(type='Normalize', **img_norm_cfg), 449 | dict(type='FormatShape', input_format='NCTHW'), 450 | dict(type='Collect', keys=['imgs', 'label', 'vid'], meta_keys=[]), 451 | dict(type='ToTensor', keys=['imgs']) 452 | ] 453 | 454 | if config.TEST.NUM_CROP == 3: 455 | val_pipeline[3] = dict(type='Resize', scale=(-1, config.DATA.INPUT_SIZE)) 456 | val_pipeline[4] = dict(type='ThreeCrop', crop_size=config.DATA.INPUT_SIZE) 457 | if config.TEST.NUM_CLIP > 1: 458 | val_pipeline[0] = dict(type='SampleFrames', clip_len=config.DATA.NUM_FRAMES, frame_interval=config.DATA.FRAME_INTERVAL, num_clips=config.DATA.NUM_CLIPS, multiview=config.TEST.NUM_CLIP) 459 | 460 | val_data = FrameDataset(ann_file=config.DATA.VAL_FILE, data_prefix=config.DATA.ROOT, labels_file=config.DATA.LABEL_LIST, filename_tmpl=config.DATA.FILENAME_TMPL, pipeline=val_pipeline) 461 | 462 | sampler_val = torch.utils.data.SequentialSampler(val_data) 463 | val_loader = DataLoader( 464 | val_data, sampler=sampler_val, 465 | batch_size=2, 466 | num_workers=16, 467 | pin_memory=True, 468 | drop_last=False, 469 | collate_fn=partial(mmcv_collate, samples_per_gpu=2), 470 | ) 471 | 472 | test_pipeline = [ 473 | dict(type='SampleFrames', clip_len=config.DATA.NUM_FRAMES, frame_interval=config.DATA.FRAME_INTERVAL, 474 | num_clips=1, test_mode=True), 475 | dict(type='RawFrameDecode'), 476 | dict(type='Resize', scale=(-1, scale_resize)), 477 | dict(type='CenterCrop', crop_size=config.DATA.INPUT_SIZE), 478 | dict(type='Normalize', **img_norm_cfg), 479 | dict(type='FormatShape', input_format='NCTHW'), 480 | dict(type='Collect', keys=['imgs', 'label', 'vid'], meta_keys=[]), 481 | dict(type='ToTensor', keys=['imgs']) 482 | ] 483 | test_data = RawFramesTestDataset(ann_file=config.DATA.VAL_FILE, data_prefix=config.DATA.ROOT, 484 | labels_file=config.DATA.LABEL_LIST, filename_tmpl=config.DATA.FILENAME_TMPL, 485 | pipeline=test_pipeline, seg_interval=config.DATA.NUM_FRAMES*config.DATA.FRAME_INTERVAL) 486 | 487 | sampler_test = torch.utils.data.SequentialSampler(test_data) 488 | test_loader = DataLoader( 489 | test_data, sampler=sampler_test, 490 | batch_size=64, 491 | num_workers=8, 492 | pin_memory=True, 493 | drop_last=False, 494 | collate_fn=partial(mmcv_collate, samples_per_gpu=2), 495 | ) 496 | 497 | train_pipeline_test = [ 498 | dict(type='SampleFrames', clip_len=config.DATA.NUM_FRAMES, frame_interval=config.DATA.FRAME_INTERVAL, 499 | num_clips=1, test_mode=True), 500 | dict(type='RawFrameDecode'), 501 | dict(type='Resize', scale=(-1, scale_resize)), 502 | dict(type='CenterCrop', crop_size=config.DATA.INPUT_SIZE), 503 | dict(type='Normalize', **img_norm_cfg), 504 | dict(type='FormatShape', input_format='NCTHW'), 505 | dict(type='Collect', keys=['imgs', 'label', 'vid'], meta_keys=[]), 506 | dict(type='ToTensor', keys=['imgs']) 507 | ] 508 | train_data_test = RawFramesTestDataset(ann_file=config.DATA.TRAIN_FILE, data_prefix=config.DATA.ROOT, 509 | labels_file=config.DATA.LABEL_LIST, filename_tmpl=config.DATA.FILENAME_TMPL, 510 | pipeline=train_pipeline_test, 511 | seg_interval=config.DATA.NUM_FRAMES * config.DATA.FRAME_INTERVAL) 512 | 513 | train_sampler_test = torch.utils.data.SequentialSampler(train_data_test) 514 | train_loader_test = DataLoader( 515 | train_data_test, sampler=train_sampler_test, 516 | batch_size=64, 517 | num_workers=16, 518 | pin_memory=True, 519 | drop_last=False, 520 | collate_fn=partial(mmcv_collate, samples_per_gpu=2), 521 | ) 522 | 523 | return train_data, val_data, test_data, train_loader, val_loader, test_loader, train_loader_test, train_loader_umil -------------------------------------------------------------------------------- /main_umil.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.backends.cudnn as cudnn 5 | import torch.distributed as dist 6 | from torch.autograd import Variable 7 | import argparse 8 | import datetime 9 | import shutil 10 | from pathlib import Path 11 | from utils.config import get_config 12 | from utils.optimizer import build_optimizer, build_scheduler 13 | from utils.tools import AverageMeter, reduce_tensor, epoch_saving, load_checkpoint, generate_text, auto_resume_helper, evaluate_result, match 14 | from utils.cluster import ClusterLoss, Normalize, BCE, NCLMemory, PairEnum 15 | from datasets.build import build_dataloader 16 | from utils.logger import create_logger 17 | import time 18 | import numpy as np 19 | import random 20 | import mmcv 21 | from apex import amp 22 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 23 | from datasets.blending import CutmixMixupBlending 24 | from utils.config import get_config 25 | from models import xclip 26 | from einops import rearrange 27 | import glob 28 | import torch.nn.functional as F 29 | import pdb 30 | 31 | def parse_option(): 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--config', '-cfg', required=True, type=str, default='configs/k400/32_8.yaml') 34 | parser.add_argument( 35 | "--opts", 36 | help="Modify config options by adding 'KEY VALUE' pairs. ", 37 | default=None, 38 | nargs='+', 39 | ) 40 | parser.add_argument('--output', type=str, default="exp") 41 | parser.add_argument('--resume', type=str) 42 | parser.add_argument('--pretrained', type=str) 43 | parser.add_argument('--only_test', action='store_true') 44 | parser.add_argument('--batch-size', type=int) 45 | parser.add_argument('--batch-size-umil', type=int) 46 | parser.add_argument('--accumulation-steps', type=int) 47 | # model parameters 48 | parser.add_argument('--umil-epoch', default=30, type=float) 49 | parser.add_argument('--threshold', default=0.8, type=float) 50 | parser.add_argument('--cluster-threshold', default=0.8, type=float) 51 | parser.add_argument("--local_rank", type=int, default=-1, help='local rank for DistributedDataParallel') 52 | parser.add_argument("--bce-type", type=str, default='cos', help="Type of clustering techniques: cos or RK") 53 | parser.add_argument('--cosine-threshold', default=0.7, type=float, help='cosine similarity threshold for clustering') 54 | parser.add_argument('--topk', default=2, type=int, help='rank statistics threshold for clustering') 55 | parser.add_argument('--confidence-threshold', default=0.3, type=float, help='threshold for high-confident instance selection') 56 | 57 | parser.add_argument('--w-smooth', default=0.01, type=float, help='weight of smooth loss') 58 | parser.add_argument('--w-sparse', default=0.001, type=float, help='weight of sparse loss') 59 | parser.add_argument('--w-compat-u', default=1.0, type=float, help='weight of u2l loss compared to l2u') 60 | parser.add_argument('--w-compat', default=1.0, type=float, help='weight of compatibility loss') 61 | parser.add_argument('--w-cluster', default=0.01, type=float, help='weight of cluster loss') 62 | parser.add_argument('--w-con', default=1.0, type=float, help='weight of consistency loss') 63 | parser.add_argument('--w-con1', default=1.0, type=float, help='weight of clustering all consistency loss') 64 | parser.add_argument('--w-mil', default=1.0, type=float, help='weight of mil loss') 65 | parser.add_argument('--w-ce', default=0.0, type=float, help='weight of ce loss') 66 | 67 | parser.add_argument('--w-cls', default=0, type=float, help='weight of cluster anomaly score') 68 | args = parser.parse_args() 69 | 70 | config = get_config(args) 71 | 72 | return args, config 73 | 74 | 75 | def main(config): 76 | 77 | train_data, val_data, test_data, train_loader, val_loader, test_loader, val_loader_train, train_loader_umil = build_dataloader(logger, config) 78 | model, _ = xclip.load(config.MODEL.PRETRAINED, config.MODEL.ARCH, 79 | device="cpu", jit=False, 80 | T=config.DATA.NUM_FRAMES, 81 | droppath=config.MODEL.DROP_PATH_RATE, 82 | use_checkpoint=config.TRAIN.USE_CHECKPOINT, 83 | use_cache=config.MODEL.FIX_TEXT, 84 | logger=logger, 85 | ) 86 | model = model.cuda() 87 | 88 | criterion = ClusterLoss(config.DATA.NUM_CLASSES, args.bce_type, 89 | args.cosine_threshold, args.topk 90 | ) 91 | 92 | optimizer, optimizer_umil = build_optimizer(config, model) 93 | lr_scheduler = build_scheduler(config, optimizer, len(train_loader)) 94 | lr_scheduler_umil = build_scheduler(config, optimizer_umil, len(train_loader_umil)) 95 | 96 | if config.TRAIN.OPT_LEVEL != 'O0': 97 | model, [optimizer, optimizer_umil] = amp.initialize(models=model, optimizers=[optimizer, optimizer_umil], opt_level=config.TRAIN.OPT_LEVEL) 98 | 99 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False, find_unused_parameters=False) 100 | 101 | start_epoch, best_epoch, max_auc = 0, 0, 0.0 102 | 103 | if config.TRAIN.AUTO_RESUME: 104 | resume_file = auto_resume_helper(config.OUTPUT) 105 | if resume_file: 106 | config.defrost() 107 | config.MODEL.RESUME = resume_file 108 | config.freeze() 109 | logger.info(f'auto resuming from {resume_file}') 110 | else: 111 | logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') 112 | 113 | if config.MODEL.RESUME: 114 | start_epoch, _ = load_checkpoint(config, model.module, optimizer, lr_scheduler, logger) 115 | 116 | text_labels = generate_text(train_data) 117 | 118 | if config.TEST.ONLY_TEST: 119 | if not os.path.isdir(config.MODEL.PRETRAINED): 120 | # evaluate on val set 121 | out_path = config.MODEL.PRETRAINED.replace('pth','pkl') 122 | if os.path.exists(out_path): 123 | scores_dict = mmcv.load(out_path) 124 | else: 125 | scores_dict = validate(test_loader, text_labels, model, config, out_path) 126 | 127 | tmp_dict = {} 128 | for v_name in scores_dict["cls"].keys(): 129 | p_scores = np.array(scores_dict["prd"][v_name]).copy() 130 | c_scores = np.array(scores_dict["cls"][v_name]).copy() 131 | 132 | if p_scores.shape[0] == 1: 133 | # 1,32,2 134 | tmp_dict[v_name] = [p_scores[0, :, 1] + args.w_cls * c_scores[0, :, 1]] 135 | else: 136 | # T,1,2 137 | tmp_dict[v_name] = [p_scores[:, 0, 1] + args.w_cls * c_scores[:, 0, 1]] 138 | 139 | auc_all, auc_ano = evaluate_result(tmp_dict, config.DATA.VAL_FILE, os.path.dirname(out_path)) 140 | 141 | logger.info(f"AUC@all/ano of version {out_path.split('/')[-2]} on epoch {out_path.split('/')[-1].split('_')[-1][:-4]} : {auc_all:.4f}({auc_ano:.4f})") 142 | return 143 | 144 | data_dict = {} 145 | data_dict['mask'] = {} 146 | data_dict['label'] = {} 147 | data_dict['prd'] = {} 148 | data_dict['cls'] = {} 149 | data_dict['length'] = 0 150 | 151 | anno_file = config.DATA.TRAIN_FILE 152 | vid_list = [] 153 | with open(anno_file, 'r') as fin: 154 | for line in fin: 155 | line_split = line.strip().split() 156 | filename = line_split[0].split('/')[-1] 157 | vid_list.append(filename) 158 | 159 | for epoch in range(start_epoch, config.TRAIN.EPOCHS): 160 | train_loader.sampler.set_epoch(epoch) 161 | mil_one_epoch(epoch, model, criterion, optimizer, lr_scheduler, train_loader, text_labels, config, data_dict, vid_list) 162 | 163 | # calculate training statics 164 | if epoch % 1 == 0 and epoch >= (args.umil_epoch-5): 165 | out_path = os.path.join(config.OUTPUT, 'train_data_' + str(epoch) + '.pkl') 166 | epoch_data_dict = validate(val_loader_train, text_labels, model, config, out_path) 167 | 168 | for key in epoch_data_dict['prd'].keys(): 169 | if data_dict['length'] == 0: 170 | data_dict['prd'][key] = np.stack(epoch_data_dict['prd'][key], 1) 171 | data_dict['cls'][key] = np.stack(epoch_data_dict['cls'][key], 1) 172 | else: 173 | data_dict['prd'][key] = np.concatenate([data_dict['prd'][key], np.stack(epoch_data_dict['prd'][key], 1)], 0) 174 | data_dict['cls'][key] = np.concatenate([data_dict['cls'][key], np.stack(epoch_data_dict['cls'][key], 1)], 0) 175 | 176 | data_dict['length'] = data_dict['length'] + 1 177 | if data_dict['length'] > 1: 178 | history_scores = [] 179 | flags = [] 180 | for key in data_dict['prd'].keys(): 181 | history_scores.append(data_dict['prd'][key]) 182 | if 'Normal' in key: 183 | flag = np.zeros(data_dict['prd'][key].shape[1]) 184 | else: 185 | flag = np.ones(data_dict['prd'][key].shape[1]) 186 | flags.append(flag) 187 | his_scores = np.concatenate(history_scores, 1) 188 | flags = np.concatenate(flags, 0) 189 | his_scores = his_scores[:, flags == 1] 190 | pseudo_label = np.argmax(his_scores.mean(0), -1) 191 | scores_var = his_scores[:, :, 1].var(0) 192 | ano_idx = pseudo_label == 1 193 | ano_scores_var = scores_var[ano_idx] 194 | nor_idx = pseudo_label == 0 195 | nor_scores_var = scores_var[nor_idx] 196 | 197 | var_sort_idx = np.argsort(ano_scores_var) 198 | pseudo_threshold_ano = ano_scores_var[ 199 | var_sort_idx[int(ano_scores_var.shape[0] * args.confidence_threshold)]] 200 | var_sort_idx = np.argsort(nor_scores_var) 201 | pseudo_threshold_nor = nor_scores_var[ 202 | var_sort_idx[int(nor_scores_var.shape[0] * args.confidence_threshold)]] 203 | data_dict['mask'] = {} 204 | data_dict['label'] = {} 205 | for key in data_dict['prd'].keys(): 206 | score_var = data_dict['prd'][key][:, :, 1].var(0) 207 | pseudo_label = np.argmax(data_dict['prd'][key].mean(0), -1) 208 | if 'Normal' in key: 209 | data_dict['mask'][key] = np.ones_like(score_var) 210 | data_dict['label'][key] = np.zeros_like(score_var) 211 | else: 212 | data_dict['mask'][key] = (pseudo_label == 0) * (score_var < pseudo_threshold_nor) + \ 213 | (pseudo_label == 1) * (score_var < pseudo_threshold_ano) 214 | data_dict['label'][key] = pseudo_label 215 | # val 216 | out_path = os.path.join(config.OUTPUT, 'mil_epoch_'+str(epoch)+'.pkl') 217 | scores_dict = validate(val_loader, text_labels, model, config, out_path) 218 | 219 | tmp_dict = {} 220 | for v_name in scores_dict["cls"].keys(): 221 | p_scores = np.array(scores_dict["prd"][v_name]).copy() 222 | c_scores = np.array(scores_dict["cls"][v_name]).copy() 223 | 224 | if p_scores.shape[0] == 1: 225 | # 1,32,2 226 | tmp_dict[v_name] = [p_scores[0, :, 1] + args.w_cls * c_scores[0, :, 1]] 227 | else: 228 | # T,1,2 229 | tmp_dict[v_name] = [p_scores[:, 0, 1] + args.w_cls * c_scores[:, 0, 1]] 230 | auc_all, auc_ano = evaluate_result(tmp_dict, config.DATA.VAL_FILE) 231 | is_best = auc_all > max_auc 232 | max_auc = max(max_auc, auc_all) 233 | logger.info(f"Auc of MIL on epoch {epoch}: {auc_all:.4f}({auc_ano:.4f})") 234 | logger.info(f'Max AUC@all epoch {epoch} : {max_auc:.4f}') 235 | 236 | if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): 237 | epoch_saving(config, epoch, model.module, max_auc, optimizer, optimizer_umil, lr_scheduler, lr_scheduler_umil, logger, config.OUTPUT, is_best) 238 | 239 | if epoch >= args.umil_epoch: 240 | train_loader_umil.sampler.set_epoch(epoch) 241 | umil_one_epoch(epoch, model, criterion, optimizer_umil, lr_scheduler_umil, train_loader_umil, text_labels, config, data_dict, 242 | vid_list) 243 | #val 244 | out_path = os.path.join(config.OUTPUT, 'umil_epoch_' + str(epoch) + '.pkl') 245 | scores_dict = validate(val_loader, text_labels, model, config, out_path) 246 | tmp_dict = {} 247 | for v_name in scores_dict["cls"].keys(): 248 | p_scores = np.array(scores_dict["prd"][v_name]).copy() 249 | c_scores = np.array(scores_dict["cls"][v_name]).copy() 250 | 251 | if p_scores.shape[0] == 1: 252 | # 1,32,2 253 | tmp_dict[v_name] = [p_scores[0, :, 1] + args.w_cls * c_scores[0, :, 1]] 254 | else: 255 | # T,1,2 256 | tmp_dict[v_name] = [p_scores[:, 0, 1] + args.w_cls * c_scores[:, 0, 1]] 257 | auc_all_u, auc_ano_u = evaluate_result(tmp_dict, config.DATA.VAL_FILE) 258 | is_best = auc_all_u > max_auc 259 | max_auc = max(max_auc, auc_all_u) 260 | logger.info(f"Auc of UMIL on epoch {epoch}: {auc_all_u:.4f}({auc_ano_u:.4f})") 261 | logger.info(f'Max AUC@all epoch {epoch} : {max_auc:.4f}') 262 | 263 | if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)) and auc_all_u>auc_all: 264 | epoch_saving(config, epoch, model.module, max_auc, optimizer, optimizer_umil, lr_scheduler, lr_scheduler_umil, logger, config.OUTPUT, is_best) 265 | 266 | def mil_one_epoch(epoch, model, criterion, optimizer, lr_scheduler, train_loader, text_labels, config, data_dict, vid_list): 267 | model.train() 268 | 269 | optimizer.zero_grad() 270 | 271 | num_steps = len(train_loader) 272 | batch_time = AverageMeter() 273 | tot_loss_meter = AverageMeter() 274 | mil_loss_meter = AverageMeter() 275 | mar_loss_meter = AverageMeter() 276 | cst_loss_meter = AverageMeter() 277 | bce_loss_meter = AverageMeter() 278 | sm_loss_meter = AverageMeter() 279 | sp_loss_meter = AverageMeter() 280 | 281 | start = time.time() 282 | end = time.time() 283 | 284 | l2norm = Normalize(2) 285 | 286 | texts = text_labels.cuda(non_blocking=True) 287 | 288 | for idx, batch_data in enumerate(train_loader): 289 | images = batch_data["imgs"].cuda(non_blocking=True) 290 | label_id = batch_data["label"].cuda(non_blocking=True) 291 | label_id = label_id.reshape(-1) 292 | bz = images.shape[0] 293 | a_aug = images.shape[1] 294 | n_clips = images.shape[2] 295 | images = rearrange(images, 'b a k c t h w -> (b a k) t c h w')# bz*num_aug*num_clips,num_frames,h,w 296 | 297 | if texts.shape[0] == 1: 298 | texts = texts.view(1, -1) 299 | 300 | output = model(images, texts) 301 | # mil loss on max scores among bags, view instance of max scores as labeled data 302 | logits = rearrange(output['y'], '(b a k) c -> (b a) k c', b=bz, a=a_aug, ) 303 | scores = F.softmax(logits, dim=-1) 304 | 305 | scores_ano = scores[:, :, 1] 306 | scores_nor = scores[:, :, 0] 307 | max_prob_ano, max_ind = torch.max(scores_ano, dim=-1) 308 | max_prob_nor, _ = torch.max(scores_nor, dim=-1) 309 | 310 | logits_video = torch.gather(logits, 1, max_ind[:, None, None].repeat((1, 1, 2))).squeeze(1) 311 | margin_video = scores_ano.max(-1)[0] - scores_ano.min(-1)[0] 312 | max_prob_video, _ = torch.max(torch.gather(scores, 1, max_ind[:, None, None].repeat((1, 1, 2))).squeeze(1), 313 | dim=-1) 314 | labels_binary = label_id > 0 315 | loss_mil = F.cross_entropy(logits_video, labels_binary.long(), reduction='none') 316 | loss_mil = loss_mil * max_prob_video 317 | loss_mil = loss_mil.mean() 318 | loss_mar = F.binary_cross_entropy(margin_video, labels_binary.float(), reduction='none') 319 | loss_mar = loss_mar * max_prob_video 320 | loss_mar = loss_mar.mean() 321 | # pseudo loss 322 | logits = rearrange(logits, '(b a) k c -> b a k c', b=bz, a=a_aug, ) 323 | logits_alt = rearrange(output['y_cluster_all'], '(b a k) c -> b a k c', b=bz, a=a_aug, ) 324 | 325 | scores = F.softmax(logits, dim=-1) 326 | scores_alt = F.softmax(logits_alt, dim=-1) 327 | 328 | if data_dict['length'] < 0: 329 | vids = np.array(vid_list)[batch_data["vid"]] 330 | pseudo_labels = [] 331 | masks = [] 332 | if bz==1: 333 | vids = [vids] 334 | for ind in range(bz): 335 | tmp_label = data_dict['label'][vids[ind]].copy() 336 | tmp_mask = data_dict['mask'][vids[ind]].copy() 337 | tmp_ind = batch_data["frame_inds"][ind] 338 | tmp_ind = tmp_ind*tmp_label.shape[0]//batch_data['total_frames'][ind] 339 | tmp_ind = tmp_ind.reshape(n_clips, -1)[:, config.DATA.NUM_FRAMES // 2] 340 | pseudo_labels.append(tmp_label[tmp_ind].copy()) 341 | masks.append(tmp_mask[tmp_ind].copy()) 342 | 343 | pseudo_labels_np = np.stack(pseudo_labels,0) 344 | mask_source_np = np.stack(masks,0) 345 | pseudo_labels.clear() 346 | masks.clear() 347 | with torch.no_grad(): 348 | pseudo_labels_source = torch.from_numpy(pseudo_labels_np).cuda() 349 | mask_source = torch.from_numpy(mask_source_np).cuda() 350 | pseudo_labels_source = pseudo_labels_source.view(bz,1,-1).tile((1,a_aug,1)) 351 | mask_source = (mask_source==1).view(bz,1,-1).tile((1,a_aug,1)) 352 | mask_target = ~mask_source 353 | else: 354 | pseudo_labels_source = torch.max(scores,-1)[0] 355 | mask_target = pseudo_labels_source>=0 356 | 357 | # generate target pseudo-labels, [:bz]=weak aug;[bz:]=strong aug; 358 | max_prob, pseudo_labels = torch.max(scores[:, 0], dim=-1) 359 | max_prob_alt, pseudo_labels_alt = torch.max(scores_alt[:, 0], dim=-1) 360 | 361 | consistency_loss = (F.cross_entropy(logits[:,1].contiguous().view(-1,2), pseudo_labels.view(-1), reduction='none') \ 362 | * max_prob.view(-1).ge(args.threshold).float().detach()).mean() 363 | 364 | # Cluster consistency loss 365 | consistency_loss_alt = (F.cross_entropy(logits_alt[:,1].contiguous().view(-1,2), pseudo_labels_alt.view(-1), reduction='none') \ 366 | * max_prob_alt.view(-1).ge(args.cluster_threshold).float().detach()).mean() # 367 | 368 | if mask_target.sum() > 2: 369 | bk_feat = rearrange(output['feature_v'], '(b a k) c -> b a k c', b=bz, a=a_aug, ) 370 | cls_nograd = rearrange(output['y_cluster_all'], '(b a k) c -> b a k c', b=bz, a=a_aug, ) 371 | inputs = { 372 | "x1": bk_feat[:, 0].reshape(-1, bk_feat.shape[-1]), 373 | "x1_norm": l2norm(bk_feat[:, 0].reshape(-1, bk_feat.shape[-1])), 374 | "preds1_u": cls_nograd[:, 0].reshape(-1, 2), 375 | "x2": bk_feat[:, 1].reshape(-1, bk_feat.shape[-1]), 376 | "x2_norm": l2norm(bk_feat[:, 1].reshape(-1, bk_feat.shape[-1])), 377 | "preds2_u": cls_nograd[:, 1].reshape(-1, 2), 378 | "labels": pseudo_labels_source[:, 0].reshape(-1), 379 | "labels_": label_id, 380 | "mask": ~mask_target[:, 0].reshape(-1), 381 | } 382 | bce_loss, _ = criterion.compute_losses(inputs) 383 | else: 384 | bce_loss = torch.zeros_like(loss_mil) 385 | 386 | scores_all = scores + args.w_cls * scores_alt 387 | smoothed_scores = (scores_all[:,:,1:,1] - scores_all[:,:,:-1,1]) 388 | smoothed_loss = smoothed_scores.pow(2).sum(dim=-1).mean() 389 | 390 | sparsity_loss = scores_all[:,:,:,1].sum(dim=-1).mean() 391 | 392 | if epoch >= args.umil_epoch: 393 | w_mil = args.w_mil 394 | w_con = args.w_con 395 | w_con1 = args.w_con1 396 | w_cluster = args.w_cluster 397 | w_smooth = args.w_smooth 398 | w_sparse = args.w_sparse 399 | else: 400 | w_mil = args.w_mil 401 | w_con = args.w_con 402 | w_con1 = args.w_con1 403 | w_cluster = args.w_cluster 404 | w_smooth = args.w_smooth 405 | w_sparse = args.w_sparse 406 | 407 | 408 | total_loss = (loss_mil + loss_mar) * w_mil + \ 409 | consistency_loss * w_con + consistency_loss_alt * w_con1 + \ 410 | smoothed_loss * w_smooth + sparsity_loss * w_sparse + \ 411 | bce_loss * w_cluster 412 | 413 | 414 | total_loss = total_loss / config.TRAIN.ACCUMULATION_STEPS 415 | # print(idx,total_loss) 416 | 417 | if config.TRAIN.ACCUMULATION_STEPS == 1: 418 | optimizer.zero_grad() 419 | if config.TRAIN.OPT_LEVEL != 'O0': 420 | with amp.scale_loss(total_loss, optimizer) as scaled_loss: 421 | scaled_loss.backward() 422 | else: 423 | total_loss.backward() 424 | if config.TRAIN.ACCUMULATION_STEPS > 1: 425 | if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: 426 | optimizer.step() 427 | optimizer.zero_grad() 428 | lr_scheduler.step_update(epoch * num_steps + idx) 429 | else: 430 | optimizer.step() 431 | lr_scheduler.step_update(epoch * num_steps + idx) 432 | 433 | torch.cuda.synchronize() 434 | 435 | tot_loss_meter.update(total_loss.item(), len(label_id)) 436 | mil_loss_meter.update((loss_mil * w_mil).item(), len(label_id)) 437 | mar_loss_meter.update((loss_mar * w_mil).item(), len(label_id)) 438 | cst_loss_meter.update((consistency_loss * w_con + consistency_loss_alt * w_con1).item(), len(label_id)) 439 | bce_loss_meter.update((bce_loss * w_cluster).item(), len(label_id)) 440 | sp_loss_meter.update((sparsity_loss * w_sparse).item(), len(label_id)) 441 | sm_loss_meter.update((smoothed_loss * w_smooth).item(), len(label_id)) 442 | batch_time.update(time.time() - end) 443 | end = time.time() 444 | 445 | if idx % config.PRINT_FREQ == 0: 446 | lr = optimizer.param_groups[0]['lr'] 447 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 448 | etas = batch_time.avg * (num_steps - idx) 449 | logger.info( 450 | f'MIL: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' 451 | f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.9f}\t' 452 | f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' 453 | f'tot {tot_loss_meter.val:.4f} ({tot_loss_meter.avg:.4f})\t' 454 | f'mil {mil_loss_meter.val:.4f} ({mil_loss_meter.avg:.4f})\t' 455 | f'mar {mar_loss_meter.val:.4f} ({mar_loss_meter.avg:.4f})\t' 456 | f'cst {cst_loss_meter.val:.4f} ({cst_loss_meter.avg:.4f})\t' 457 | # f'bce {bce_loss_meter.val:.4f} ({bce_loss_meter.avg:.4f})\t' 458 | f'sm {sm_loss_meter.val:.4f} ({sm_loss_meter.avg:.4f})\t' 459 | f'sp {sp_loss_meter.val:.4f} ({sp_loss_meter.avg:.4f})\t' 460 | f'mem {memory_used:.0f}MB') 461 | 462 | epoch_time = time.time() - start 463 | logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") 464 | 465 | 466 | def umil_one_epoch(epoch, model, criterion, optimizer_umil, lr_scheduler_umil, train_loader, text_labels, config, data_dict, 467 | vid_list): 468 | assert data_dict['length'] > 1 469 | model.train() 470 | 471 | optimizer_umil.zero_grad() 472 | 473 | num_steps = len(train_loader) 474 | batch_time = AverageMeter() 475 | tot_loss_meter = AverageMeter() 476 | bce_loss_meter = AverageMeter() 477 | cmp_loss_meter = AverageMeter() 478 | 479 | start = time.time() 480 | end = time.time() 481 | 482 | l2norm = Normalize(2) 483 | bce = BCE() 484 | 485 | texts = text_labels.cuda(non_blocking=True) 486 | 487 | for idx, batch_data in enumerate(train_loader): 488 | images = batch_data["imgs"].cuda(non_blocking=True) 489 | label_id = batch_data["label"].cuda(non_blocking=True) 490 | label_id = label_id.reshape(-1) 491 | if label_id.sum()==0: 492 | continue 493 | bz = images.shape[0] 494 | a_aug = images.shape[1] 495 | n_clips = images.shape[2] 496 | images = rearrange(images, 'b a k c t h w -> (b a k) t c h w') # bz*num_aug*num_clips,num_frames,h,w 497 | 498 | if texts.shape[0] == 1: 499 | texts = texts.view(1, -1) 500 | 501 | # mil loss on max scores among bags, view instance of max scores as labeled data 502 | 503 | vids = np.array(vid_list)[batch_data["vid"]] 504 | vids = [vids] if bz == 1 else vids 505 | pseudo_labels = [] 506 | masks = [] 507 | 508 | for ind in range(bz): 509 | tmp_label = data_dict['label'][vids[ind]].copy() 510 | tmp_mask = data_dict['mask'][vids[ind]].copy() 511 | tmp_ind = batch_data["frame_inds"][ind] 512 | tmp_ind = tmp_ind * tmp_label.shape[0] // batch_data['total_frames'][ind] 513 | tmp_ind = tmp_ind.reshape(n_clips, -1)[:, config.DATA.NUM_FRAMES // 2] 514 | pseudo_labels.append(tmp_label[tmp_ind].copy()) 515 | masks.append(tmp_mask[tmp_ind].copy()) 516 | 517 | pseudo_labels_np = np.stack(pseudo_labels, 0) 518 | mask_source_np = np.stack(masks, 0) 519 | pseudo_labels.clear() 520 | masks.clear() 521 | 522 | with torch.no_grad(): 523 | pseudo_labels_source = torch.from_numpy(pseudo_labels_np).cuda() 524 | mask_source = torch.from_numpy(mask_source_np).cuda() 525 | pseudo_labels_source = pseudo_labels_source.view(bz, 1, -1).tile((1, a_aug, 1)) 526 | mask_source = (mask_source == 1).view(bz, 1, -1).tile((1, a_aug, 1)) 527 | mask_target = ~mask_source 528 | 529 | if mask_target.sum() > 2: 530 | output = model(images, texts) 531 | bk_feat = rearrange(output['feature_v'], '(b a k) c -> b a k c', b=bz, a=a_aug, ) 532 | cls_nograd = rearrange(output['y_cluster_all_nograd'], '(b a k) c -> b a k c', b=bz, a=a_aug, ) 533 | inputs = { 534 | "x1": bk_feat[:, 0].reshape(-1, bk_feat.shape[-1]), 535 | "x1_norm": l2norm(bk_feat[:, 0].reshape(-1, bk_feat.shape[-1])), 536 | "preds1_u": cls_nograd[:, 0].reshape(-1, 2), 537 | "x2": bk_feat[:, 1].reshape(-1, bk_feat.shape[-1]), 538 | "x2_norm": l2norm(bk_feat[:, 1].reshape(-1, bk_feat.shape[-1])), 539 | "preds2_u": cls_nograd[:, 1].reshape(-1, 2), 540 | "labels": pseudo_labels_source[:, 0].reshape(-1), 541 | "labels_": label_id, 542 | "mask": ~mask_target[:, 0].reshape(-1), 543 | } 544 | 545 | bce_loss, sim_matrix_all = criterion.compute_losses(inputs) 546 | # refine unlabel similarity matrix 547 | # l head compat with u with target sample 548 | logits_nograd = rearrange(output['y'], '(b a k) c -> b a k c', b=bz, a=a_aug, ) 549 | p_nograd = F.softmax(logits_nograd, dim=-1) 550 | pairs1, _ = PairEnum(p_nograd[:, 0][mask_target[:, 0]].reshape(-1, 2)) 551 | _, pairs2 = PairEnum(p_nograd[:, 1][mask_target[:, 0]].reshape(-1, 2)) 552 | lu_compatibility_loss = bce(pairs1, pairs2, sim_matrix_all) 553 | compatibility_loss = lu_compatibility_loss 554 | else: 555 | continue 556 | 557 | w_cluster = 1.0 558 | w_compat = args.w_compat 559 | 560 | total_loss = bce_loss * w_cluster + compatibility_loss * w_compat 561 | 562 | total_loss = total_loss / config.TRAIN.ACCUMULATION_STEPS 563 | 564 | if config.TRAIN.ACCUMULATION_STEPS == 1: 565 | optimizer_umil.zero_grad() 566 | if config.TRAIN.OPT_LEVEL != 'O0': 567 | with amp.scale_loss(total_loss, optimizer_umil) as scaled_loss: 568 | scaled_loss.backward() 569 | else: 570 | total_loss.backward() 571 | if config.TRAIN.ACCUMULATION_STEPS > 1: 572 | if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: 573 | optimizer_umil.step() 574 | optimizer_umil.zero_grad() 575 | lr_scheduler_umil.step_update(epoch * num_steps + idx) 576 | else: 577 | optimize_umilr.step() 578 | lr_scheduler_umil.step_update(epoch * num_steps + idx) 579 | 580 | torch.cuda.synchronize() 581 | 582 | tot_loss_meter.update(total_loss.item(), len(label_id)) 583 | bce_loss_meter.update((bce_loss * w_cluster).item(), len(label_id)) 584 | cmp_loss_meter.update((compatibility_loss * w_compat).item(), len(label_id)) 585 | batch_time.update(time.time() - end) 586 | end = time.time() 587 | 588 | if idx % config.PRINT_FREQ == 0: 589 | lr = optimizer_umil.param_groups[0]['lr'] 590 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 591 | etas = batch_time.avg * (num_steps - idx) 592 | logger.info( 593 | f'UMIL: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' 594 | f'tot {tot_loss_meter.val:.4f} ({tot_loss_meter.avg:.4f})\t' 595 | f'bce {bce_loss_meter.val:.4f} ({bce_loss_meter.avg:.4f})\t' 596 | f'cmp {cmp_loss_meter.val:.4f} ({cmp_loss_meter.avg:.4f})\t' 597 | f'mem {memory_used:.0f}MB') 598 | 599 | epoch_time = time.time() - start 600 | logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") 601 | 602 | 603 | @torch.no_grad() 604 | def validate(data_loader, text_labels, model, config, out_path): 605 | model.eval() 606 | vid_list = [] 607 | if 'train' in out_path: 608 | anno_file = config.DATA.TRAIN_FILE 609 | else: 610 | anno_file = config.DATA.VAL_FILE 611 | 612 | with open(anno_file, 'r') as fin: 613 | for line in fin: 614 | line_split = line.strip().split() 615 | filename = line_split[0].split('/')[-1] 616 | vid_list.append(filename) 617 | 618 | scores_dict = dict() 619 | scores_dict['cls'] = dict() 620 | scores_dict['prd'] = dict() 621 | scores_dict['fea'] = dict() 622 | 623 | with torch.no_grad(): 624 | text_inputs = text_labels.cuda() 625 | logger.info(f"{config.TEST.NUM_CLIP * config.TEST.NUM_CROP} views inference") 626 | for idx, batch_data in enumerate(data_loader): 627 | _image = batch_data["imgs"] 628 | label_id = batch_data["label"] 629 | label_id = label_id.reshape(-1) 630 | b, n, c, t, h, w = _image.size() 631 | _image = rearrange(_image, 'b n c t h w -> (b n) t c h w') 632 | 633 | output = model(_image, text_inputs) 634 | 635 | scores_prd = F.softmax(output['y'], dim=-1) 636 | scores_cls = F.softmax(output['y_cluster_all'], dim=-1) 637 | 638 | scores_prd = rearrange(scores_prd, '(b n) c -> b n c', b=b) 639 | scores_np_prd = scores_prd.cpu().data.numpy().copy() 640 | scores_cls = rearrange(scores_cls, '(b n) c -> b n c', b=b) 641 | scores_np_cls = scores_cls.cpu().data.numpy().copy() 642 | 643 | for ind in range(scores_np_prd.shape[0]): 644 | v_name = vid_list[batch_data["vid"][ind]] 645 | if v_name not in scores_dict['prd']: 646 | scores_dict['prd'][v_name] = [] 647 | scores_dict['cls'][v_name] = [] 648 | scores_dict['prd'][v_name].append(scores_np_prd[ind]) 649 | scores_dict['cls'][v_name].append(scores_np_cls[ind]) 650 | if idx % 500 == 0 and len(data_loader) >= 500 and 'train' not in out_path: 651 | logger.info( 652 | f'Test: [{idx}/{len(data_loader)}]\t' 653 | ) 654 | elif idx % 1000 == 0 and len(data_loader) >= 1000 and 'train' in out_path: 655 | logger.info( 656 | f'Train: [{idx}/{len(data_loader)}]\t' 657 | f'Vid: {v_name}\t' 658 | ) 659 | if 'train' not in out_path: 660 | tmp_dict = {} 661 | for v_name in scores_dict["cls"].keys(): 662 | p_scores = np.array(scores_dict["prd"][v_name]).copy() 663 | c_scores = np.array(scores_dict["cls"][v_name]).copy() 664 | if p_scores.shape[0] == 1: 665 | # 1,T,2 666 | tmp_dict[v_name] = [p_scores[0, :, 1] + args.w_cls * c_scores[0, :, 1]] 667 | else: 668 | # T,1,2 669 | tmp_dict[v_name] = [p_scores[:, 0, 1] + args.w_cls * c_scores[:, 0, 1]] 670 | 671 | auc_all, auc_ano = evaluate_result(tmp_dict, config.DATA.VAL_FILE) 672 | logger.info( 673 | f'AUC: [{auc_all:.3f}/{auc_ano:.3f}]\t' 674 | ) 675 | logger.info(f'writing results to {out_path}') 676 | 677 | mmcv.dump(scores_dict, out_path) 678 | 679 | return scores_dict 680 | 681 | 682 | if __name__ == '__main__': 683 | # prepare config 684 | args, config = parse_option() 685 | 686 | # init_distributed 687 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 688 | rank = int(os.environ["RANK"]) 689 | world_size = int(os.environ['WORLD_SIZE']) 690 | print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") 691 | else: 692 | rank = -1 693 | world_size = -1 694 | torch.cuda.set_device(args.local_rank) 695 | torch.distributed.init_process_group(backend='nccl', init_method='env://', timeout=datetime.timedelta(seconds=7200), 696 | world_size=world_size, rank=rank) 697 | torch.distributed.barrier(device_ids=[args.local_rank]) 698 | 699 | seed = config.SEED + dist.get_rank() 700 | torch.manual_seed(seed) 701 | np.random.seed(seed) 702 | random.seed(seed) 703 | cudnn.benchmark = True 704 | 705 | # create working_dir 706 | Path(config.OUTPUT).mkdir(parents=True, exist_ok=True) 707 | 708 | # logger 709 | logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.ARCH}") 710 | logger.info(f"working dir: {config.OUTPUT}") 711 | 712 | # save config 713 | if dist.get_rank() == 0: 714 | logger.info(config) 715 | shutil.copy(args.config, config.OUTPUT) 716 | 717 | main(config) --------------------------------------------------------------------------------