├── README.md ├── baselines ├── __pycache__ │ ├── dua.cpython-36.pyc │ ├── dua_utils.cpython-36.pyc │ ├── norm.cpython-36.pyc │ ├── setup_baseline.cpython-36.pyc │ ├── shot.cpython-36.pyc │ ├── shot_utils.cpython-36.pyc │ ├── t3a.cpython-36.pyc │ └── tent.cpython-36.pyc ├── dua.py ├── dua_utils.py ├── norm.py ├── pseudo_label.py ├── setup_baseline.py ├── shot.py ├── shot_utils.py ├── t3a.py └── tent.py ├── compute_stats ├── compute_spatiotemp_stats_clean_train_swin.py └── compute_spatiotemp_stats_clean_train_tanet.py ├── corpus ├── __pycache__ │ ├── basics.cpython-36.pyc │ ├── main_eval.cpython-36.pyc │ └── main_train.cpython-36.pyc ├── basics.py ├── main_eval.py └── main_train.py ├── datasets_ ├── __pycache__ │ ├── dataset_deprecated.cpython-36.pyc │ └── video_dataset.cpython-36.pyc ├── dataset_deprecated.py └── video_dataset.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── i3d.cpython-36.pyc │ ├── i3d_incep.cpython-36.pyc │ └── r2plus1d.cpython-36.pyc ├── backbones │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── resnet3d.cpython-36.pyc │ └── resnet3d.py ├── i3d.py ├── i3d_incep.py ├── r2plus1d.py ├── tanet_models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── basic_ops.cpython-36.pyc │ │ ├── tanet.cpython-36.pyc │ │ ├── temporal_module.cpython-36.pyc │ │ ├── transforms.cpython-36.pyc │ │ └── video_dataset.cpython-36.pyc │ ├── basic_ops.py │ ├── tanet.py │ ├── temporal_module.py │ ├── transforms.py │ └── video_dataset.py ├── videomae_models │ └── modeling_finetune.py └── videoswintransformer_models │ ├── __pycache__ │ ├── i3d_head.cpython-36.pyc │ ├── recognizer3d.cpython-36.pyc │ ├── swin_transformer.cpython-36.pyc │ ├── transforms_backup.cpython-36.pyc │ └── video_dataset.cpython-36.pyc │ ├── base.py │ ├── i3d_head.py │ ├── recognizer3d.py │ ├── swin_transformer.py │ ├── transforms.py │ ├── transforms_backup.py │ └── video_dataset.py ├── requirements.txt ├── sourceonly_swin_ucf101_corr.py ├── sourceonly_tanet_ucf101_corr.py ├── tta_swin_ucf101.py ├── tta_tanet_ucf101.py └── utils ├── BNS_utils.py ├── __pycache__ ├── BNS_utils.cpython-36.pyc ├── norm_stats_utils.cpython-36.pyc ├── opts.cpython-36.pyc ├── pred_consistency_utils.cpython-36.pyc ├── transforms.cpython-36.pyc └── utils_.cpython-36.pyc ├── norm_stats_utils.py ├── opts.py ├── pred_consistency_utils.py ├── relation_map_utils.py ├── transforms.py └── utils_.py /README.md: -------------------------------------------------------------------------------- 1 | # Video Test-Time Adaptation for Action Recognition (CVPR 2023) 2 | [ProjectPage](https://wlin-at.github.io/vitta) 3 | 4 | **ViTTA** is the first approach of test-time adaptation of video action recognition models against common distribution shifts. ViTTA is tailored to saptio-temporal models and capable of adaptation on a single video sample at a step. It consists in a feature distribution alignment technique that aligns online estimates of test set statistics towards the training statistics. It further enforces prediction consistency over temporally augmented views of the same test video sample. 5 | 6 | Official implementation of ViTTA [[`arXiv`](https://arxiv.org/abs/2211.15393)] 7 | Author [HomePage](https://wlin-at.github.io/) 8 | [🤗 Dataset](https://huggingface.co/datasets/wlin21at/ViTTA) (12 corruption types of Kinetics 400 and Something-Something v2, and UCF101 data) 9 | 10 | ## Requirements 11 | * Our experiments run on Python 3.6, PyTorch 1.7, mmcv-full 1.3.12. Other versions should work but are not tested. 12 | * Dependency of mmaction2 (for [Video Swin Transformer](https://github.com/SwinTransformer/Video-Swin-Transformer)): 13 | $ `pip install mmcv-full==1.3.12` 14 | $ `git clone https://github.com/SwinTransformer/Video-Swin-Transformer.git && cd Video-Swin-Transformer` 15 | $ `pip install -v -e . --user` 16 | * Other relevant dependencies can be found in `requirements.txt` 17 | 18 | --- 19 | ## Data Preparation 20 | * Download 21 | Download required data for Experiments on UCF101 from [here](https://files.icg.tugraz.at/d/3551df694e3d4d6b89da/) 22 | `list_video_perturbations_ucf`: list of files for corrupted videos of UCF101 validation set (in 12 corruption types) 23 | `model_swin_ucf`: Video Swin Transformer trained on UCF101 training set 24 | `model_tanet_ucf`: TANet trained on UCF101 training set 25 | `model_tanet_ucf`: TANet trained on UCF101 training set 26 | `source_statistics_tanet_ucf`: precomputed source (UCF101 training set) statistics on TANet 27 | `source_statistics_swin_ucf`: precomputed source (UCF101 training set) statistics on Video Swin Transformer 28 | `ucf_corrupted_videos`: a folder of 12 compressed files containing videos of UCF validation set (in 12 corruption types) 29 | `ucf_corrupted_videos.zip`: a single compressed file (83.8GB) containing videos of UCF validation set (in 12 corruption types) 30 | * Data structure 31 | lines in file list are in format 32 | `video_path n_frames class_id` 33 | video dataset structure 34 | ``` 35 | level_5_ucf_val_split_1_/ 36 | gauss/ 37 | ApplyEyeMakeup/ 38 | v_ApplyEyeMakeup_g01_c01.mp4 39 | v_ApplyEyeMakeup_g01_c02.mp4 40 | ... 41 | ApplyLipstick/ 42 | ... 43 | contrast/ 44 | ApplyEyeMakeup/ 45 | v_ApplyEyeMakeup_g01_c01.mp4 46 | v_ApplyEyeMakeup_g01_c02.mp4 47 | ... 48 | ApplyLipstick/ 49 | ... 50 | ``` 51 | 52 | --- 53 | ## Usage 54 | Specify the data paths in the scripts correspondingly (see comments in scripts) 55 | * Precompute source statistics on training set 56 | precompute source (UCF101 training set) statistics on TANet: 57 | $ `python compute_stats/compute_spatiotemp_stats_clean_train_tanet.py` 58 | precompute source (UCF101 training set) statistics on Video Swin Transformer: 59 | $ `python compute_stats/compute_spatiotemp_stats_clean_train_swin.py` 60 | * Test-time adaptation 61 | $ `python tta_tanet_ucf101.py` test-time adaptation on TANet UCF101 62 | $ `python tta_swin_ucf101.py` test-time adaptation on Video Swin Transformer UCF101 63 | * Source-only evaluation on corrupted validation data 64 | $ `python tta_tanet_ucf101.py` 65 | $ `python tta_swin_ucf101.py` 66 | --- 67 | ## Citation 68 | Thanks for citing our paper: 69 | ```bibtex 70 | @inproceedings{lin2023video, 71 | title={Video Test-Time Adaptation for Action Recognition}, 72 | author={Lin, Wei and Mirza, Muhammad Jehanzeb and Kozinski, Mateusz and Possegger, Horst and Kuehne, Hilde and Bischof, Horst}, 73 | booktitle={CVPR}, 74 | year={2023}, 75 | } 76 | ``` 77 | -------------------------------------------------------------------------------- /baselines/__pycache__/dua.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/baselines/__pycache__/dua.cpython-36.pyc -------------------------------------------------------------------------------- /baselines/__pycache__/dua_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/baselines/__pycache__/dua_utils.cpython-36.pyc -------------------------------------------------------------------------------- /baselines/__pycache__/norm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/baselines/__pycache__/norm.cpython-36.pyc -------------------------------------------------------------------------------- /baselines/__pycache__/setup_baseline.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/baselines/__pycache__/setup_baseline.cpython-36.pyc -------------------------------------------------------------------------------- /baselines/__pycache__/shot.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/baselines/__pycache__/shot.cpython-36.pyc -------------------------------------------------------------------------------- /baselines/__pycache__/shot_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/baselines/__pycache__/shot_utils.cpython-36.pyc -------------------------------------------------------------------------------- /baselines/__pycache__/t3a.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/baselines/__pycache__/t3a.cpython-36.pyc -------------------------------------------------------------------------------- /baselines/__pycache__/tent.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/baselines/__pycache__/tent.cpython-36.pyc -------------------------------------------------------------------------------- /baselines/dua.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from utils.utils_ import * 3 | # from corpus.main_train import validate_brief 4 | from corpus.basics import validate_brief 5 | from baselines.dua_utils import rotate_batch 6 | 7 | 8 | def DUA(model): 9 | model = configure_model(model) 10 | return model 11 | 12 | 13 | def configure_model(model): 14 | """Configure model for adaptation by test-time normalization.""" 15 | for m in model.modules(): 16 | if isinstance(m, nn.modules.batchnorm._BatchNorm): 17 | m.train() 18 | return model 19 | 20 | 21 | def dua_adaptation(args, model, te_loader, adapt_loader, logger, batchsize, augmentations, no_vids): 22 | """ 23 | :param model: After configuring the DUA model 24 | :param te_loader: The test set for Test-Time-Training 25 | :param logger: Logger for logging results 26 | :param batchsize: Batchsize to use for adaptation 27 | :param augmentations: augmentations to form a batch from a single video 28 | :param no_vids: total number of videos for adaptation 29 | 30 | """ 31 | if args.arch == 'tanet': 32 | from models.tanet_models.transforms import ToTorchFormatTensor_TANet_dua, GroupNormalize_TANet_dua 33 | adapt_transforms = torchvision.transforms.Compose([ 34 | augmentations, # GroupMultiScaleCrop amd GroupRandomHorizontalFlip 35 | ToTorchFormatTensor_TANet_dua(div=True), 36 | GroupNormalize_TANet_dua(args.input_mean, args.input_std) 37 | ]) 38 | else: 39 | adapt_transforms = torchvision.transforms.Compose([ 40 | augmentations, # GroupMultiScaleCrop amd GroupRandomHorizontalFlip 41 | fromListToTorchFormatTensor(clip_len=args.clip_length, num_clips=args.num_clips), 42 | GroupNormalize(args.input_mean, args.input_std) 43 | # Normalize later in the DUA adaptation loop after making a batch 44 | ]) 45 | logger.debug('---- Starting adaptation for DUA ----') 46 | all_acc = [] 47 | for i, (inputs, target) in enumerate(adapt_loader): 48 | model.train() 49 | for m in model.modules(): 50 | if isinstance(m, nn.modules.batchnorm._BatchNorm): 51 | m.train() 52 | 53 | with torch.no_grad(): 54 | if args.arch == 'tanet': 55 | n_clips = int(args.sample_style.split("-")[-1]) 56 | inputs = inputs.cuda() 57 | actual_bz = inputs.shape[0] 58 | 59 | inputs = inputs.view(-1, 3, inputs.size(2), inputs.size(3)) 60 | inputs = inputs.view(actual_bz * args.test_crops * n_clips, 61 | args.clip_length, 3, inputs.size(2), inputs.size(3)) # [1, 16, 3, 224, 224] 62 | inputs = [(adapt_transforms([inputs, target])[0]) for _ in 63 | range(batchsize)] # pass image, label together 64 | inputs = torch.stack(inputs) # only stack images 65 | inputs = inputs.cuda() 66 | rot_img = rotate_batch(inputs) 67 | _ = model(rot_img.float()) 68 | else: 69 | inputs = [(adapt_transforms([inputs, target])[0]) for _ in 70 | range(batchsize)] # pass image, label together 71 | inputs = torch.stack(inputs) # only stack images 72 | inputs = inputs.cuda() 73 | inputs = inputs.reshape( 74 | (-1,) + inputs.shape[2:]) # [b, channel, frames, h, w] 75 | rot_img = rotate_batch(inputs) 76 | _ = model(rot_img) 77 | 78 | logger.debug(f'---- Starting evaluation for DUA after video {i} ----') 79 | 80 | if i % 1 == 0 or i == len(adapt_loader) - 1: 81 | top1 = validate_brief(eval_loader=te_loader, model=model, global_iter=i, args=args, 82 | logger=logger, writer=None, epoch=i) 83 | all_acc.append(top1) 84 | 85 | if len(all_acc) >= 3: 86 | if all(top1 < i for i in all_acc[-3:]): 87 | logger.debug('---- Model Performance Degrading Consistently ::: Quitting Now ----') 88 | return max(all_acc) 89 | 90 | if i == no_vids: 91 | logger.debug(f' --- Best Accuracy for {args.corruptions} --- {max(all_acc)}') 92 | logger.debug(f' --- Stopping DUA adaptation ---') 93 | return max(all_acc) 94 | 95 | return max(all_acc) 96 | 97 | 98 | -------------------------------------------------------------------------------- /baselines/dua_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def tensor_rot_90(x): 5 | return x.flip(3).transpose(2, 3) 6 | 7 | 8 | def tensor_rot_180(x): 9 | return x.flip(3).flip(2) 10 | 11 | 12 | def tensor_rot_270(x): 13 | return x.transpose(2, 3).flip(3) 14 | 15 | 16 | def rotate_batch_with_labels(batch, labels): 17 | images = [] 18 | for img, label in zip(batch, labels): 19 | if label == 1: 20 | img = tensor_rot_90(img) 21 | elif label == 2: 22 | for image in batch: 23 | img = tensor_rot_180(image) 24 | elif label == 3: 25 | for image in batch: 26 | img = tensor_rot_270(image) 27 | images.append(img) 28 | return torch.stack(images) 29 | 30 | 31 | def rotate_batch(batch): # input [b, channel, frames, h, w] if arch != tanet else [b, frames, channels, h, w] 32 | 33 | labels = torch.randint(4, (len(batch),), dtype=torch.long) 34 | list_batch = [batch[i, :, :, :, :] for i in range(len(batch))] # make a list of tensors 35 | return rotate_batch_with_labels(list_batch, labels) 36 | -------------------------------------------------------------------------------- /baselines/norm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def Norm(model, eps=1e-5, momentum=0.1, 5 | reset_stats=False, no_stats=False): 6 | model = model 7 | model = configure_model(model, eps, momentum, reset_stats, 8 | no_stats) 9 | return model 10 | 11 | 12 | def collect_stats(model): 13 | """Collect the normalization stats from batch norms. 14 | 15 | Walk the model's modules and collect all batch normalization stats. 16 | Return the stats and their names. 17 | """ 18 | stats = [] 19 | names = [] 20 | for nm, m in model.named_modules(): 21 | if isinstance(m, nn.modules.batchnorm._BatchNorm): 22 | state = m.state_dict() 23 | if m.affine: 24 | del state['weight'], state['bias'] 25 | for ns, s in state.items(): 26 | stats.append(s) 27 | names.append(f"{nm}.{ns}") 28 | return stats, names 29 | 30 | 31 | def configure_model(model, eps, momentum, reset_stats, no_stats): 32 | """Configure model for adaptation by test-time normalization.""" 33 | for m in model.modules(): 34 | if isinstance(m, nn.modules.batchnorm._BatchNorm): 35 | m.train() 36 | # configure epsilon for stability, and momentum for updates 37 | m.eps = eps 38 | m.momentum = momentum 39 | if reset_stats: 40 | # reset state to estimate test stats without train stats 41 | m.reset_running_stats() 42 | m.running_mean = None 43 | m.running_var = None 44 | if no_stats: 45 | # disable state entirely and use only batch stats 46 | m.track_running_stats = False 47 | m.running_mean = None 48 | m.running_var = None 49 | return model -------------------------------------------------------------------------------- /baselines/pseudo_label.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/baselines/pseudo_label.py -------------------------------------------------------------------------------- /baselines/setup_baseline.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | import baselines.tent as tent 3 | import baselines.norm as norm 4 | from utils.opts import parser 5 | import baselines.shot as shot 6 | import baselines.dua as dua 7 | 8 | # hard coding tent arguments in order not to incorporate configs (taken from their github) 9 | tent_args = parser.parse_args() 10 | tent_args.STEPS = 1 11 | tent_args.LR = 1e-5 # for i3d (batchsize 16) set it to 1e-3 --- (tanet batchsize is 12) 12 | tent_args.BETA = 0.9 13 | tent_args.WD = 0. 14 | tent_args.EPISODIC = False 15 | 16 | 17 | def setup_model(args, base_model, logger): 18 | """ 19 | :param args: argument from the main file 20 | :param base_model: model to set up for adaptation 21 | :param logger: logger for keeping the logs 22 | :return: returns the base_model after setting up adaptation baseline 23 | """ 24 | 25 | if args.baseline == "source": 26 | logger.info("test-time adaptation: NONE") 27 | model = setup_source(args, base_model, logger) # set model to eval() 28 | return model 29 | elif args.baseline == "norm": 30 | logger.info("test-time adaptation: NORM") 31 | model = setup_norm(args, base_model, logger) 32 | return model 33 | elif args.baseline == "tent": 34 | logger.info("test-time adaptation: TENT") 35 | model, optimizer = setup_tent(args, base_model, logger) 36 | return model, optimizer 37 | elif args.baseline == "shot": 38 | optimizer, classfier, ext = setup_shot(args, base_model, logger) 39 | return optimizer, classfier, ext 40 | elif args.baseline == "dua": 41 | model = setup_dua(args, base_model, logger) 42 | return model 43 | else: 44 | raise NotImplementedError('Baseline not implemented') 45 | 46 | 47 | def setup_source(args, model, logger): 48 | """Set up the baseline source model without adaptation.""" 49 | model.eval() 50 | # if args.verbose: 51 | # logger.info(f"model for adaptation: %s", model) 52 | return model 53 | 54 | 55 | def setup_dua(args, model, logger): 56 | """ 57 | Set up DUA model. 58 | Do not reset stats. Freeze entire model except the Batch Normalization layer. 59 | """ 60 | dua_model = dua.DUA(model) 61 | if args.verbose: 62 | logger.info(f"model for adaptation: %s", model) 63 | return dua_model 64 | 65 | 66 | def setup_shot(args, model, logger): 67 | """Set up test-time shot. 68 | 69 | Adapts the feature extractor by keeping source predictions as hypothesis and entropy minimization. 70 | """ 71 | optimizer, classifier, ext = shot.configure_shot(model, logger, args) 72 | if args.verbose: 73 | logger.info(f"model for adaptation: %s", model) 74 | return optimizer, classifier, ext 75 | 76 | 77 | def setup_norm(args, model, logger): 78 | """Set up test-time normalization adaptation. 79 | 80 | Adapt by normalizing features with test batch statistics. 81 | The statistics are measured independently for each batch; 82 | no running average or other cross-batch estimation is used. 83 | """ 84 | norm_model = norm.Norm(model) 85 | stats, stat_names = norm.collect_stats(model) 86 | if args.verbose: 87 | logger.info(f"model for adaptation: %s", model) 88 | logger.debug(f"stats for adaptation: %s", stat_names) 89 | return norm_model 90 | 91 | 92 | def setup_tent(args, model, logger): 93 | """Set up tent adaptation. 94 | 95 | Configure the model for training + feature modulation by batch statistics, 96 | collect the parameters for feature modulation by gradient optimization, 97 | set up the optimizer, and then tent the model. 98 | """ 99 | model = tent.configure_model(model) # set only Batchnorm3d layers to trainable, freeze all the other layers 100 | params, param_names = tent.collect_params(model) # collecting gamma and beta in all Batchnorm3d layers 101 | optimizer = setup_optimizer(params) # todo hyperparameters are hard-coded above 102 | if args.verbose: 103 | logger.debug(f"model for adaptation: %s", model) 104 | logger.debug(f"params for adaptation: %s", param_names) 105 | logger.debug(f"optimizer for adaptation: %s", optimizer) 106 | return model, optimizer 107 | 108 | 109 | def setup_optimizer(params): 110 | """Set up optimizer for tent adaptation. 111 | 112 | Tent needs an optimizer for test-time entropy minimization. 113 | In principle, tent could make use of any gradient optimizer. 114 | In practice, we advise choosing Adam or SGD+momentum. 115 | For optimization settings, we advise to use the settings from the end of 116 | trainig, if known, or start with a low learning rate (like 0.001) if not. 117 | 118 | For best results, try tuning the learning rate and batch size. 119 | """ 120 | return optim.Adam(params, 121 | lr=tent_args.LR, 122 | betas=(tent_args.BETA, 0.999), 123 | weight_decay=tent_args.WD) 124 | -------------------------------------------------------------------------------- /baselines/shot.py: -------------------------------------------------------------------------------- 1 | from baselines.shot_utils import * 2 | import torch.optim as optim 3 | import argparse 4 | from utils import utils_ as utils 5 | from corpus.main_train import validate 6 | 7 | parser = argparse.ArgumentParser() 8 | 9 | parser.add_argument('--lr', default=0.00005, type=float) 10 | parser.add_argument('--nepoch', default=1, type=int, help='maximum number of epoch for SHOT') 11 | parser.add_argument('--bnepoch', default=2, type=int, help='first few epochs to update bn stat') 12 | parser.add_argument('--delayepoch', default=0, type=int) 13 | parser.add_argument('--stopepoch', default=25, type=int) 14 | ######################################################################## 15 | parser.add_argument('--outf', default='.') 16 | ######################################################################## 17 | parser.add_argument('--level', default=5, type=int) 18 | parser.add_argument('--corruption', default='') 19 | parser.add_argument('--resume', default=None, help='directory of pretrained model') 20 | parser.add_argument('--ckpt', default=None, type=int) 21 | parser.add_argument('--fix_ssh', action='store_true') 22 | parser.add_argument('--batch_size', default=12, type=int) 23 | ######################################################################## 24 | parser.add_argument('--method', default='shot', choices=['shot']) 25 | ######################################################################## 26 | parser.add_argument('--model', default='resnet50', help='resnet50') 27 | parser.add_argument('--save_every', default=100, type=int) 28 | ######################################################################## 29 | parser.add_argument('--tsne', action='store_true') 30 | ######################################################################## 31 | parser.add_argument('--cls_par', type=float, default=0.001) 32 | parser.add_argument('--ent_par', type=float, default=1.0) 33 | parser.add_argument('--gent', type=bool, default=True) 34 | parser.add_argument('--ent', type=bool, default=True) 35 | ######################################################################## 36 | parser.add_argument('--seed', default=0, type=int) 37 | 38 | args_shot = parser.parse_args() 39 | 40 | 41 | def configure_shot(net, logger, args): 42 | logger.debug('---- Configuring SHOT ----') 43 | if args.arch == 'tanet': 44 | classifier = net.module.new_fc 45 | ext = net 46 | ext.module.new_fc = nn.Identity() 47 | 48 | for k, v in classifier.named_parameters(): 49 | v.requires_grad = False 50 | else: 51 | for k, v in net.named_parameters(): 52 | if 'logits' in k: 53 | v.requires_grad = False # freeze the classifier 54 | classifier = nn.Sequential(*list(net.module.logits.children())) 55 | ext = list(net.module.children())[3:] + list(net.module.children())[:2] 56 | ext = nn.Sequential(*ext) 57 | 58 | optimizer = optim.SGD(ext.parameters(), lr=args_shot.lr, momentum=0.9) 59 | return optimizer, classifier, ext 60 | 61 | 62 | def train(args, criterion, optimizer, classifier, ext, teloader, logger): 63 | logger.debug('---- Training SHOT ----') 64 | losses = utils.AverageMeter() 65 | shot_acc = list() 66 | if args.arch == 'tanet': 67 | n_clips = int(args.sample_style.split("-")[-1]) 68 | 69 | for epoch in range(1, args_shot.nepoch + 1): 70 | ext.eval() 71 | mem_label = obtain_shot_label(teloader, ext, classifier, n_clips=n_clips, args=args) # compute the pseudo label 72 | mem_label = torch.from_numpy(mem_label).cuda() 73 | ext.train() 74 | 75 | for batch_idx, (inputs, labels) in enumerate(teloader): 76 | 77 | optimizer.zero_grad() 78 | actual_bz = inputs.shape[0] 79 | inputs = inputs.cuda() 80 | labels = labels.cuda() 81 | 82 | if args.arch == 'tanet': 83 | classifier_loss = 0 84 | inputs = inputs.view(-1, 3, inputs.size(2), inputs.size(3)) 85 | inputs = inputs.view(actual_bz * args.test_crops * n_clips, 86 | args.clip_length, 3, inputs.size(2), inputs.size(3)) 87 | features_test = ext(inputs.cuda()) 88 | 89 | outputs_test = classifier(features_test) 90 | 91 | outputs_test = torch.squeeze(outputs_test) 92 | outputs_test = outputs_test.reshape(actual_bz, args.test_crops * n_clips, -1).mean(1) 93 | 94 | else: 95 | classifier_loss = 0 96 | inputs = inputs.reshape( 97 | (-1,) + inputs.shape[2:]) 98 | features_test = ext(inputs.cuda()) 99 | outputs_test = classifier(features_test) 100 | outputs_test = torch.squeeze(outputs_test) 101 | 102 | if args_shot.cls_par > 0: 103 | pred = mem_label[batch_idx * args_shot.batch_size:(batch_idx + 1) * args_shot.batch_size] 104 | classifier_loss = args_shot.cls_par * nn.CrossEntropyLoss()(outputs_test, 105 | pred) # CE loss using the pseudo labels 106 | else: 107 | classifier_loss = torch.tensor(0.0).cuda() 108 | 109 | if args_shot.ent: 110 | softmax_out = nn.Softmax(dim=1)(outputs_test) 111 | entropy_loss = torch.mean(Entropy(softmax_out)) 112 | if args_shot.gent: 113 | msoftmax = softmax_out.mean(dim=0) 114 | entropy_loss -= torch.sum(-msoftmax * torch.log(msoftmax + 1e-5)) 115 | 116 | im_loss = entropy_loss * args_shot.ent_par 117 | classifier_loss += im_loss 118 | 119 | classifier_loss.backward() 120 | optimizer.step() 121 | losses.update(classifier_loss.item(), labels.size(0)) 122 | if args.verbose: 123 | logger.debug(('SHOT Training: [{0}/{1}]\t' 124 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(batch_idx, len(teloader), loss=losses))) 125 | 126 | if args.arch == 'tanet': 127 | ext.module.new_fc = classifier # simply put the classifier back in place of nn.Identity() 128 | top_1_acc = validate(teloader, ext, criterion, 0, epoch=epoch, args=args, logger=logger) 129 | shot_acc.append(top_1_acc) 130 | ext.module.new_fc = nn.Identity() 131 | else: 132 | adapted_model = nn.Sequential(*(list(ext.children()) + list(classifier.children()))) 133 | adapted_model = adapted_model.cuda() 134 | top_1_acc = validate(teloader, adapted_model, criterion, 0, epoch=epoch, args=args, logger=logger) 135 | shot_acc.append(top_1_acc) 136 | return max(shot_acc) 137 | -------------------------------------------------------------------------------- /baselines/shot_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from scipy.spatial.distance import cdist 4 | import numpy as np 5 | 6 | 7 | def obtain_shot_label(loader, ext, task_head, args, n_clips=None, c=None): 8 | start_test = True 9 | with torch.no_grad(): 10 | # iter_test = iter(loader) 11 | for batch_idx, (inputs, labels) in enumerate(loader): 12 | inputs = inputs.cuda() 13 | if args.arch == 'tanet': 14 | actual_bz = inputs.shape[0] 15 | inputs = inputs.view(-1, 3, inputs.size(2), inputs.size(3)) 16 | 17 | inputs = inputs.view(actual_bz * args.test_crops * n_clips, 18 | args.clip_length, 3, inputs.size(2), inputs.size(3)) 19 | feas = ext(inputs) 20 | outputs = task_head(feas) 21 | outputs = torch.squeeze(outputs) 22 | outputs = outputs.reshape(actual_bz, args.test_crops * n_clips, -1).mean(1) 23 | else: 24 | inputs = inputs.reshape( 25 | (-1,) + inputs.shape[2:]) 26 | feas = ext(inputs) 27 | outputs = task_head(feas) 28 | outputs = torch.squeeze(outputs) 29 | 30 | if start_test: 31 | all_fea = feas.float().cpu() 32 | all_output = outputs.float().cpu() 33 | all_label = labels.float() 34 | start_test = False 35 | else: 36 | all_fea = torch.cat((all_fea, feas.float().cpu()), 0) 37 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 38 | all_label = torch.cat((all_label, labels.float()), 0) 39 | 40 | all_output = nn.Softmax(dim=1)(all_output) 41 | _, predict = torch.max(all_output, 1) 42 | 43 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 44 | a = torch.ones(all_fea.size(0), 1) 45 | all_fea = torch.squeeze(all_fea) 46 | all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1) # (bz, 1024 + 1) add one more dimension of ones 47 | all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t() # (bz, 1025), normalize along feature dimension 48 | all_fea = all_fea.float().cpu().numpy() 49 | 50 | K = all_output.size(1) 51 | aff = all_output.float().cpu().numpy() 52 | initc = aff.transpose().dot(all_fea) 53 | initc = initc / (1e-8 + aff.sum(axis=0)[:, None]) 54 | dd = cdist(all_fea, initc, 'cosine') 55 | pred_label = dd.argmin(axis=1) 56 | acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea) 57 | 58 | for round in range(1): # todo udpate the pseudo labels once 59 | aff = np.eye(K)[pred_label] 60 | initc = aff.transpose().dot(all_fea) 61 | initc = initc / (1e-8 + aff.sum(axis=0)[:, None]) 62 | dd = cdist(all_fea, initc, 'cosine') 63 | pred_label = dd.argmin(axis=1) 64 | acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea) 65 | 66 | log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100) 67 | print(log_str+'\n') 68 | return pred_label.astype('int') 69 | 70 | 71 | def Entropy(input_): 72 | bs = input_.size(0) 73 | entropy = -input_ * torch.log(input_ + 1e-5) 74 | entropy = torch.sum(entropy, dim=1) 75 | return entropy 76 | 77 | 78 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 79 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 80 | for param_group in optimizer.param_groups: 81 | param_group['lr'] = param_group['lr0'] * decay 82 | param_group['weight_decay'] = 1e-3 83 | param_group['momentum'] = 0.9 84 | param_group['nesterov'] = True 85 | return optimizer 86 | 87 | 88 | def op_copy(optimizer): 89 | for param_group in optimizer.param_groups: 90 | param_group['lr0'] = param_group['lr'] 91 | return optimizer 92 | -------------------------------------------------------------------------------- /baselines/t3a.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from utils.utils_ import * 4 | 5 | def softmax_entropy(x: torch.Tensor) -> torch.Tensor: 6 | """Entropy of softmax distribution from logits.""" 7 | return -(x.softmax(1) * x.log_softmax(1)).sum(1) 8 | 9 | 10 | def get_cls_ext(args, net): 11 | 12 | if args.arch == 'tanet': 13 | classifier = net.module.new_fc 14 | ext = net 15 | ext.module.new_fc = nn.Identity() 16 | for k, v in classifier.named_parameters(): 17 | v.requires_grad = False 18 | else: 19 | for k, v in net.named_parameters(): 20 | if 'logits' in k: 21 | v.requires_grad = False # freeze the classifier 22 | classifier = nn.Sequential(*list(net.module.logits.children())) 23 | ext = list(net.module.children())[3:] + list(net.module.children())[:2] 24 | ext = nn.Sequential(*ext) 25 | 26 | return ext, classifier 27 | 28 | 29 | class T3A(nn.Module): 30 | """ 31 | Test Time Template Adjustments (T3A) 32 | 33 | """ 34 | 35 | def __init__(self, args, ext, classifier): 36 | super().__init__() 37 | self.args = args 38 | self.model = ext 39 | self.classifier = classifier 40 | self.classifier.weight.requires_grad = False # To save memory ... 41 | self.classifier.bias.requires_grad = False # To save memory ... 42 | 43 | self.warmup_supports = self.classifier.weight.data 44 | warmup_prob = self.classifier(self.warmup_supports) 45 | self.warmup_ent = softmax_entropy(warmup_prob) 46 | self.warmup_labels = torch.nn.functional.one_hot(warmup_prob.argmax(1), num_classes=args.num_classes).float() 47 | 48 | self.supports = self.warmup_supports.data 49 | self.labels = self.warmup_labels.data 50 | self.ent = self.warmup_ent.data 51 | 52 | self.filter_K = args.t3a_filter_k 53 | self.num_classes = args.num_classes 54 | self.softmax = torch.nn.Softmax(-1) 55 | 56 | def forward(self, x): 57 | with torch.no_grad(): 58 | z = self.model(x) 59 | # online adaptation 60 | p = self.classifier(z) 61 | yhat = torch.nn.functional.one_hot(p.argmax(1), num_classes=self.num_classes).float() 62 | ent = softmax_entropy(p) 63 | 64 | # prediction 65 | self.supports = self.supports.to(z.device) 66 | self.labels = self.labels.to(z.device) 67 | self.ent = self.ent.to(z.device) 68 | self.supports = torch.cat([self.supports, z]) 69 | self.labels = torch.cat([self.labels, yhat]) 70 | self.ent = torch.cat([self.ent, ent]) 71 | 72 | supports, labels = self.select_supports() 73 | supports = torch.nn.functional.normalize(supports, dim=1) 74 | weights = (supports.T @ (labels)) 75 | return z @ torch.nn.functional.normalize(weights, dim=0) 76 | 77 | def select_supports(self): 78 | ent_s = self.ent 79 | y_hat = self.labels.argmax(dim=1).long() 80 | filter_K = self.filter_K 81 | if filter_K == -1: 82 | indices = torch.LongTensor(list(range(len(ent_s)))) 83 | 84 | indices = [] 85 | indices1 = torch.LongTensor(list(range(len(ent_s)))) 86 | for i in range(self.num_classes): 87 | _, indices2 = torch.sort(ent_s[y_hat == i]) 88 | indices.append(indices1[y_hat == i][indices2][:filter_K]) 89 | indices = torch.cat(indices) 90 | 91 | self.supports = self.supports[indices] 92 | self.labels = self.labels[indices] 93 | self.ent = self.ent[indices] 94 | 95 | return self.supports, self.labels 96 | 97 | 98 | def t3a_forward_and_adapt(args, ext, cls, val_loader): 99 | model = T3A(args, ext, cls) 100 | with torch.no_grad(): 101 | total = 0 102 | correct_list = [] 103 | top1 = AverageMeter() 104 | 105 | for i, (input, target) in enumerate(val_loader): # 106 | ext.eval() 107 | cls.eval() 108 | actual_bz = input.shape[0] 109 | input = input.cuda() 110 | target = target.cuda() 111 | if args.arch == 'tanet': 112 | n_clips = int(args.sample_style.split("-")[-1]) 113 | input = input.view(-1, 3, input.size(2), input.size(3)) 114 | input = input.view(actual_bz * args.test_crops * n_clips, 115 | args.clip_length, 3, input.size(2), input.size(3)) 116 | output = model(input) 117 | output = output.reshape(actual_bz, args.test_crops * n_clips, -1).mean(1) 118 | logits = torch.squeeze(output) 119 | prec1, prec5 = accuracy(logits.data, target, topk=(1, 5)) 120 | top1.update(prec1.item(), actual_bz) 121 | else: 122 | input = input.reshape((-1,) + input.shape[2:]) 123 | output = model(input) 124 | logits = torch.squeeze(output) 125 | prec1, prec5 = accuracy(logits.data, target, topk=(1, 5)) 126 | top1.update(prec1.item(), actual_bz) 127 | return top1.avg 128 | 129 | -------------------------------------------------------------------------------- /baselines/tent.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import torch.jit 3 | 4 | 5 | @torch.jit.script 6 | def softmax_entropy(x: torch.Tensor) -> torch.Tensor: 7 | """Entropy of softmax distribution from logits.""" 8 | return -(x.softmax(1) * x.log_softmax(1)).sum(1) 9 | 10 | 11 | @torch.enable_grad() # ensure grads in possible no grad context for testing 12 | def forward_and_adapt(x, model, optimizer, args=None, actual_bz=None, n_clips=None): 13 | """Forward and adapt model on batch of data. 14 | 15 | Measure entropy of the model prediction, take gradients, and update params. 16 | """ 17 | # forward 18 | outputs = model(x) # (batch * n_views, 3, T, 224,224 ) -> (batch * n_views, n_class ) todo clip-level prediction 19 | if args.arch == 'tanet': 20 | outputs = outputs.reshape(actual_bz, args.test_crops * n_clips, -1).mean(1) 21 | # adapt 22 | loss = softmax_entropy(outputs).mean(0) # todo compute the entropy for all clip-level predictions then take the averaga among all samples 23 | loss.backward() 24 | optimizer.step() 25 | optimizer.zero_grad() 26 | return outputs 27 | 28 | 29 | def collect_params(model): 30 | """Collect the affine scale + shift parameters from batch norms. 31 | 32 | Walk the model's modules and collect all batch normalization parameters. 33 | Return the parameters and their names. 34 | 35 | Note: other choices of parameterization are possible! 36 | """ 37 | params = [] 38 | names = [] 39 | for nm, m in model.named_modules(): 40 | if isinstance(m, torch.nn.modules.batchnorm._BatchNorm): 41 | for np, p in m.named_parameters(): 42 | if np in ['weight', 'bias']: # weight is scale gamma, bias is shift beta 43 | params.append(p) 44 | names.append(f"{nm}.{np}") 45 | return params, names 46 | 47 | 48 | def copy_model_and_optimizer(model, optimizer): 49 | """Copy the model and optimizer states for resetting after adaptation.""" 50 | model_state = deepcopy(model.state_dict()) 51 | optimizer_state = deepcopy(optimizer.state_dict()) 52 | return model_state, optimizer_state 53 | 54 | 55 | def load_model_and_optimizer(model, optimizer, model_state, optimizer_state): 56 | """Restore the model and optimizer states from copies.""" 57 | model.load_state_dict(model_state, strict=True) 58 | optimizer.load_state_dict(optimizer_state) 59 | 60 | 61 | def configure_model(model): 62 | """Configure model for use with tent.""" 63 | model.train() 64 | model.requires_grad_(False) 65 | for m in model.modules(): 66 | if isinstance(m, torch.nn.modules.batchnorm._BatchNorm): 67 | m.requires_grad_(True) 68 | #m.track_running_stats = True # for original implementation this is False 69 | #m.running_mean = None # for original implementation uncomment this 70 | #m.running_var = None # for original implementation uncomment this 71 | return model 72 | 73 | 74 | def check_model(model): 75 | """Check model for compatability with tent.""" 76 | is_training = model.training 77 | assert is_training, "tent needs train mode: call model.train()" 78 | param_grads = [p.requires_grad for p in model.parameters()] 79 | has_any_params = any(param_grads) 80 | has_all_params = all(param_grads) 81 | assert has_any_params, "tent needs params to update: " \ 82 | "check which require grad" 83 | assert not has_all_params, "tent should not update all params: " \ 84 | "check which require grad" 85 | 86 | has_bn = any([isinstance(m, torch.nn.modules.batchnorm._BatchNorm) for m in model.modules()]) 87 | assert has_bn, "tent needs normalization for its optimization" 88 | -------------------------------------------------------------------------------- /compute_stats/compute_spatiotemp_stats_clean_train_swin.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.abspath('..')) # last level 4 | # import os.path as osp 5 | # from utils.opts import parser 6 | from utils.opts import get_opts 7 | from corpus.main_eval import eval 8 | 9 | corruptions = ['clean' ] 10 | 11 | if __name__ == '__main__': 12 | global args 13 | args = get_opts() 14 | args.gpus = [0] 15 | args.arch = 'videoswintransformer' 16 | args.dataset = 'ucf101' 17 | # todo ========================= To Specify ========================== 18 | args.model_path = '.../swin_base_patch244_window877_pretrain_kinetics400_30epoch_lr3e-5.pth' 19 | args.video_data_dir = '...' # main directory of the video data, [args.video_data_dir] + [path in file list] should be complete absolute path for a video file 20 | args.val_vid_list = '...' # list of training data for computing statistics, with lines in format : file_path n_frames class_id 21 | args.result_dir = '.../{}_{}/compute_norm_{}stats_{}_bz{}' 22 | # todo ========================= To Specify ========================== 23 | 24 | 25 | args.batch_size = 32 # 12 26 | args.clip_length = 16 27 | args.num_clips = 1 # number of temporal clips 28 | args.test_crops = 1 # number of spatial crops 29 | args.frame_uniform = True 30 | args.frame_interval = 2 31 | args.scale_size = 224 32 | 33 | args.patch_size = (2, 4, 4) 34 | args.window_size =(8, 7, 7) 35 | 36 | args.tta = True 37 | args.evaluate_baselines = not args.tta 38 | args.baseline = 'source' 39 | 40 | 41 | args.n_augmented_views = None 42 | args.n_epoch_adapat = 1 43 | 44 | args.compute_stat = 'mean_var' 45 | args.stat_type = 'spatiotemp' 46 | args.corruptions = 'clean' 47 | args.result_dir = args.result_dir.format(args.arch, args.dataset, args.stat_type, args.corruptions, args.batch_size) 48 | eval(args=args, ) 49 | 50 | 51 | -------------------------------------------------------------------------------- /compute_stats/compute_spatiotemp_stats_clean_train_tanet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | print(os.getcwd()) 4 | sys.path.append(os.path.abspath('..')) 5 | from utils.opts import get_opts 6 | from corpus.main_eval import eval 7 | 8 | 9 | if __name__ == '__main__': 10 | global args 11 | args = get_opts() 12 | args.gpus = [0] 13 | args.arch = 'tanet' 14 | args.dataset = 'ucf101' 15 | # todo ========================= To Specify ========================== 16 | args.model_path = '.../tanet_ucf.pth.tar' 17 | args.video_data_dir = '...' # main directory of the video data, [args.video_data_dir] + [path in file list] should be complete absolute path for a video file 18 | args.val_vid_list = '...' # list of training data for computing statistics, with lines in format : file_path n_frames class_id 19 | args.result_dir = '.../{}_{}/compute_norm_{}stats_{}_bz{}' 20 | # todo ========================= To Specify ========================== 21 | 22 | args.clip_length = 16 23 | args.batch_size = 32 # 12 24 | args.sample_style = 'uniform-1' # number of temporal clips 25 | args.test_crops = 1 # number of spatial crops 26 | 27 | args.tta = True 28 | args.evaluate_baselines = not args.tta 29 | args.baseline = 'source' 30 | 31 | args.compute_stat = 'mean_var' 32 | args.stat_type = 'spatiotemp' 33 | 34 | args.corruptions = 'clean' 35 | args.result_dir = args.result_dir.format(args.arch, args.dataset, args.stat_type, args.corruptions, args.batch_size) 36 | eval(args=args, ) 37 | 38 | 39 | -------------------------------------------------------------------------------- /corpus/__pycache__/basics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/corpus/__pycache__/basics.cpython-36.pyc -------------------------------------------------------------------------------- /corpus/__pycache__/main_eval.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/corpus/__pycache__/main_eval.cpython-36.pyc -------------------------------------------------------------------------------- /corpus/__pycache__/main_train.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/corpus/__pycache__/main_train.cpython-36.pyc -------------------------------------------------------------------------------- /corpus/main_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch.utils.data.dataloader 4 | import torch.nn.parallel 5 | import torch.backends.cudnn as cudnn 6 | import torch.optim 7 | 8 | from utils.transforms import * 9 | from utils.utils_ import make_dir, path_logger, model_analysis 10 | # from models.r2plus1d import MyR2plus1d 11 | # from models import i3d 12 | 13 | from corpus.basics import validate, get_dataset, get_dataset_tanet, get_dataset_videoswin, \ 14 | get_dataset_tanet_dua, get_model, test_time_adapt, compute_statistics, compute_cos_similarity, \ 15 | tta_standard 16 | import os.path as osp 17 | from tensorboardX import SummaryWriter 18 | from baselines.setup_baseline import setup_model 19 | from baselines.shot import train as train_shot 20 | from baselines.dua import dua_adaptation as adapt_dua 21 | from baselines.t3a import get_cls_ext, t3a_forward_and_adapt 22 | # import torch.nn as nn 23 | 24 | 25 | # def compute_temp_statistics(args = None,): 26 | 27 | 28 | 29 | 30 | def eval(args=None, model = None ): 31 | log_time = time.strftime("%Y%m%d_%H%M%S") 32 | make_dir(args.result_dir) 33 | logger = path_logger(args.result_dir, log_time) 34 | # writer = SummaryWriter(log_dir=osp.join(result_dir, f'{log_time}_tb')) 35 | if args.verbose: 36 | for arg in dir(args): 37 | if arg[0] != '_': 38 | logger.debug(f'{arg} {getattr(args, arg)}') 39 | num_class_dict = { 40 | 'ucf101' : 101, 41 | 'hmdb51': 51, 42 | 'kinetics': 400, 43 | 'somethingv2': 174, 44 | 'kth': 6, 45 | 'u2h':12, 46 | 'h2u':12, 47 | } 48 | num_classes = num_class_dict[args.dataset] 49 | args.num_classes = num_classes 50 | 51 | if model is None: 52 | # todo initialize the model if the model is not provided 53 | model = get_model(args, num_classes, logger) 54 | # todo load model weights 55 | checkpoint = torch.load(args.model_path) 56 | logger.debug(f'Loading {args.model_path}') 57 | if args.arch == 'tanet': 58 | print("model epoch {} best prec@1: {}".format(checkpoint['epoch'], checkpoint['best_prec1'])) 59 | 60 | if 'module.' in list(checkpoint['state_dict'].keys())[0]: 61 | model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() 62 | model.load_state_dict(checkpoint['state_dict']) 63 | else: 64 | model.load_state_dict(checkpoint['state_dict']) 65 | model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() 66 | if args.verbose: 67 | model_analysis(model, logger) 68 | 69 | args.crop_size = args.input_size 70 | 71 | if args.modality == 'Flow': 72 | args.input_mean = [0.5] 73 | args.input_std = [np.mean(args.input_std)] 74 | 75 | # train_augmentation = get_augmentation(args.modality, args.input_size) # GroupMultiScaleCrop amd GroupRandomHorizontalFlip 76 | 77 | cudnn.benchmark = True 78 | 79 | if args.loss_type == 'nll': 80 | criterion = torch.nn.CrossEntropyLoss().cuda() 81 | else: 82 | raise ValueError("Unknown loss type") 83 | if args.tta: 84 | # TTA 85 | writer = SummaryWriter(log_dir=osp.join(args.result_dir, f'{log_time}_tb')) 86 | if args.compute_stat == 'mean_var': 87 | compute_statistics(model, args=args, log_time=log_time) 88 | epoch_result_list = None 89 | elif args.compute_stat == 'cossim': 90 | compute_cos_similarity(model, args=args, log_time=log_time) 91 | epoch_result_list = None 92 | elif args.compute_stat == False: 93 | if args.if_tta_standard: 94 | epoch_result_list = tta_standard(model, criterion, args=args, logger=logger, writer=writer) 95 | model = None 96 | else: 97 | # todo return the adapted model 98 | epoch_result_list, model = test_time_adapt(model, criterion, args=args, logger=logger, writer=writer) 99 | 100 | elif args.evaluate_baselines: 101 | # evaluate baselines 102 | if args.baseline == 'source': # source only evaluation 103 | if args.arch == 'tanet': 104 | val_loader = torch.utils.data.DataLoader( 105 | get_dataset_tanet(args, split='val', dataset_type='eval'), 106 | batch_size=args.batch_size, shuffle=False, 107 | num_workers=args.workers, pin_memory=True, ) 108 | elif args.arch == 'videoswintransformer': 109 | val_loader = torch.utils.data.DataLoader( 110 | get_dataset_videoswin(args, split='val', dataset_type='eval'), 111 | batch_size=args.batch_size, shuffle=False, 112 | num_workers=args.workers, pin_memory=True, ) 113 | else: 114 | # I3D 115 | val_loader = torch.utils.data.DataLoader( 116 | get_dataset(args, split='val'), 117 | batch_size=args.batch_size, shuffle=False, 118 | num_workers=args.workers, pin_memory=True) 119 | 120 | model_baseline = setup_model(args, base_model=model, logger=logger) # set model to eval mode 121 | 122 | 123 | top1_acc = validate(val_loader, model_baseline, criterion, 0, epoch=0, args=args, logger=logger) 124 | epoch_result_list = [top1_acc] 125 | 126 | elif args.baseline == 'norm': 127 | 128 | if args.arch == 'tanet': 129 | val_loader = torch.utils.data.DataLoader( 130 | get_dataset_tanet(args, split='val'), 131 | batch_size=args.batch_size, shuffle=False, 132 | num_workers=args.workers, pin_memory=True, ) 133 | else: 134 | val_loader = torch.utils.data.DataLoader( 135 | get_dataset(args, split='val'), 136 | batch_size=args.batch_size, shuffle=False, 137 | num_workers=args.workers, pin_memory=True) 138 | 139 | model_baseline = setup_model(args, base_model=model, logger=logger) 140 | top1_acc = validate(val_loader, model_baseline, criterion, 0, epoch=0, args=args, logger=logger) 141 | epoch_result_list = [top1_acc] 142 | 143 | elif args.baseline == 'tent': 144 | 145 | if args.arch == 'tanet': 146 | val_loader = torch.utils.data.DataLoader( 147 | get_dataset_tanet(args, split='val'), 148 | batch_size=args.batch_size, shuffle=False, 149 | num_workers=args.workers, pin_memory=True, ) 150 | else: 151 | val_loader = torch.utils.data.DataLoader( 152 | get_dataset(args, split='val'), 153 | batch_size=args.batch_size, shuffle=False, 154 | num_workers=args.workers, pin_memory=True) 155 | # set only Batchnorm3d layers to trainable, freeze all the other layers ; collecting gamma and beta in all Batchnorm3d layers 156 | model_baseline, optimizer = setup_model(args, base_model=model, logger=logger) 157 | top1_acc = validate(val_loader, model_baseline, criterion, 0, epoch=0, 158 | args=args, logger=logger, optimizer = optimizer) 159 | epoch_result_list = [top1_acc] 160 | 161 | elif args.baseline == 'shot': 162 | if args.arch == 'tanet': 163 | val_loader = torch.utils.data.DataLoader( 164 | get_dataset_tanet(args, split='val'), 165 | batch_size=args.batch_size, shuffle=False, 166 | num_workers=args.workers, pin_memory=True, ) 167 | else: 168 | val_loader = torch.utils.data.DataLoader( 169 | get_dataset(args, split='val'), 170 | batch_size=args.batch_size, shuffle=False, 171 | num_workers=args.workers, pin_memory=True) 172 | 173 | optimizer, classifier, ext = setup_model(args, base_model=model, logger=logger) 174 | top1_acc = train_shot(args, criterion, optimizer, classifier, ext, val_loader, logger) # train and validate 175 | epoch_result_list = [top1_acc] 176 | 177 | elif args.baseline == 'dua': 178 | from utils.utils_ import get_augmentation 179 | aug = get_augmentation(args, args.modality, 180 | args.input_size) 181 | if args.arch == 'tanet': 182 | val_loader_adapt = torch.utils.data.DataLoader( 183 | get_dataset_tanet_dua(args, tanet_model=model.module, split='val')[1], 184 | batch_size=1, shuffle=False, 185 | num_workers=args.workers, pin_memory=True, ) 186 | te_loader = torch.utils.data.DataLoader( 187 | get_dataset_tanet_dua(args, tanet_model=model.module, split='val')[0], 188 | batch_size=args.batch_size, shuffle=False, 189 | num_workers=args.workers, pin_memory=True, ) 190 | else: 191 | val_loader_adapt = torch.utils.data.DataLoader( 192 | get_dataset(args, split='val')[0], 193 | batch_size=1, shuffle=True, 194 | num_workers=args.workers, pin_memory=True 195 | ) 196 | te_loader = torch.utils.data.DataLoader( 197 | get_dataset(args, split='val')[1], 198 | batch_size=args.batch_size, shuffle=True, 199 | num_workers=args.workers, pin_memory=True 200 | ) 201 | 202 | dua_model = setup_model(args, base_model=model, logger=logger) 203 | top1_acc = adapt_dua(args=args, model=dua_model, batchsize=16, logger=logger, 204 | no_vids=int(len(val_loader_adapt) * 1 / 100), 205 | adapt_loader=val_loader_adapt, te_loader=te_loader, augmentations=aug) 206 | 207 | epoch_result_list = [top1_acc] 208 | 209 | elif args.baseline == 't3a': 210 | logger.debug(f'Baseline :::::: {args.baseline}') 211 | if args.arch == 'tanet': 212 | val_loader = torch.utils.data.DataLoader( 213 | get_dataset_tanet(args, split='val'), 214 | batch_size=args.batch_size, shuffle=False, 215 | num_workers=args.workers, pin_memory=True, ) 216 | else: 217 | val_loader = torch.utils.data.DataLoader( 218 | get_dataset(args, split='val'), 219 | batch_size=args.batch_size, shuffle=False, 220 | num_workers=args.workers, pin_memory=True) 221 | 222 | ext, classifier = get_cls_ext(args, model) 223 | top1_acc = t3a_forward_and_adapt(args, ext, classifier, val_loader) 224 | logger.debug(f'Top1 Accuracy After Adaptation ::: {args.corruptions} ::: {top1_acc}') 225 | epoch_result_list = [top1_acc] 226 | else: 227 | raise NotImplementedError('The Baseline is not Implemented') 228 | 229 | # validate(val_loader, model, criterion, iter, epoch = None, args = None, logger= None, writer = None) 230 | logger.handlers.clear() 231 | # todo return the adapted model, if no adaptation, returned model is None 232 | return epoch_result_list, model 233 | 234 | 235 | -------------------------------------------------------------------------------- /corpus/main_train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | # from torch.nn.utils import clip_grad_norm 4 | # import torch.nn as nn 5 | # from einops import rearrange 6 | import os.path as osp 7 | from tensorboardX import SummaryWriter 8 | import torch.backends.cudnn as cudnn 9 | # from datasets_.dataset import MyTSNDataset 10 | # from datasets_.video_dataset import MyTSNVideoDataset, MyVideoDataset 11 | 12 | # from models.r2plus1d import MyR2plus1d 13 | # from models import i3d 14 | # from models.i3d_incep import InceptionI3d 15 | # from models.tanet_models.tanet import TSN 16 | from utils.transforms import * 17 | from utils.utils_ import make_dir, path_logger, model_analysis, \ 18 | adjust_learning_rate, save_checkpoint 19 | # from utils.BNS_utils import BN3DFeatureHook, choose_BN_layers 20 | # import baselines.tent as tent 21 | from corpus.basics import train, validate, get_dataset, get_model 22 | 23 | def main_train(args=None, best_prec1=0, ): 24 | log_time = time.strftime("%Y%m%d_%H%M%S") 25 | 26 | make_dir(args.result_dir) 27 | logger = path_logger(args.result_dir, log_time) 28 | writer = SummaryWriter(log_dir=osp.join(args.result_dir, f'{log_time}_tb')) 29 | 30 | for arg in dir(args): 31 | logger.debug(f'{arg} {getattr(args, arg)}') 32 | 33 | if args.dataset == 'ucf101': 34 | num_classes = 101 35 | elif args.dataset == 'hmdb51': 36 | num_classes = 51 37 | elif args.dataset == 'kinetics': 38 | num_classes = 400 39 | elif args.dataset == 'kth': 40 | num_classes = 6 41 | elif args.dataset in ['u2h', 'h2u']: 42 | num_classes = 12 43 | else: 44 | raise ValueError('Unknown dataset ' + args.dataset) 45 | 46 | model = get_model(args, num_classes, logger) 47 | 48 | if args.verbose: 49 | model_analysis(model, logger) 50 | args.crop_size = args.input_size 51 | 52 | if args.modality == 'Flow': 53 | args.input_mean = [0.5] 54 | args.input_std = [np.mean(args.input_std)] 55 | 56 | model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() 57 | # for k, v in model.named_parameters(): 58 | # print(k) 59 | # quit() 60 | 61 | if args.resume: 62 | if os.path.isfile(args.resume): 63 | logger.debug("=> loading checkpoint '{}'".format(args.resume)) 64 | checkpoint = torch.load(args.resume) 65 | args.start_epoch = checkpoint['epoch'] 66 | best_prec1 = checkpoint['best_prec1'] 67 | model.load_state_dict(checkpoint['state_dict']) 68 | logger.debug("=> loaded checkpoint '{}' (epoch {})" 69 | .format(args.evaluate, checkpoint['epoch'])) 70 | else: 71 | logger.debug("=> no checkpoint found at '{}'".format(args.resume)) 72 | 73 | cudnn.benchmark = True 74 | 75 | # Data loading code 76 | train_loader = torch.utils.data.DataLoader( # TSN video dataset 77 | get_dataset(args, split='train'), 78 | batch_size=args.batch_size, shuffle=True, 79 | num_workers=args.workers, pin_memory=True, ) 80 | 81 | val_loader = torch.utils.data.DataLoader( 82 | get_dataset(args, split='val'), 83 | batch_size=args.batch_size, shuffle=False, 84 | num_workers=args.workers, pin_memory=True 85 | ) 86 | 87 | # define loss function (criterion) and optimizer 88 | if args.loss_type == 'nll': 89 | criterion = torch.nn.CrossEntropyLoss().cuda() 90 | else: 91 | raise ValueError("Unknown loss type") 92 | 93 | optimizer = torch.optim.SGD(params=model.parameters(), lr=args.lr, momentum=args.momentum, 94 | weight_decay=args.weight_decay) 95 | 96 | if args.evaluate: 97 | validate(val_loader, model, criterion, 0, epoch=0, args=args, logger=logger, writer=writer) 98 | return 99 | 100 | for epoch in range(args.start_epoch, args.epochs): 101 | adjust_learning_rate(optimizer, epoch, args.lr_steps, args=args) # learning rate decay is 0.1 102 | 103 | # train for one epoch 104 | train(train_loader, model, criterion, optimizer, epoch, args=args, logger=logger, writer=writer) 105 | 106 | # evaluate on validation set 107 | if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: 108 | prec1 = validate(val_loader, model, criterion, (epoch + 1) * len(train_loader), epoch=epoch, args=args, 109 | logger=logger, writer=writer) 110 | 111 | # remember best prec@1 and save checkpoint 112 | is_best = prec1 > best_prec1 113 | best_prec1 = max(prec1, best_prec1) 114 | 115 | if args.if_save_model: 116 | save_checkpoint({ 117 | 'epoch': epoch + 1, 118 | 'arch': args.arch, 119 | 'state_dict': model.state_dict(), 120 | 'best_prec1': best_prec1, 121 | }, is_best, result_dir=args.result_dir, log_time=log_time, logger=logger, args=args) 122 | logger.debug(f'Checkpoint epoch {epoch} saved!') 123 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /datasets_/__pycache__/dataset_deprecated.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/datasets_/__pycache__/dataset_deprecated.cpython-36.pyc -------------------------------------------------------------------------------- /datasets_/__pycache__/video_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/datasets_/__pycache__/video_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /datasets_/video_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import numpy as np 5 | import os 6 | import os.path as osp 7 | from numpy.random import randint 8 | import cv2 9 | import decord 10 | import random 11 | 12 | 13 | class VideoRecord(object): 14 | def __init__(self, row): 15 | self._data = row 16 | 17 | @property 18 | def path(self): 19 | return self._data[0] 20 | 21 | @property 22 | def num_frames(self): 23 | return int(self._data[1]) 24 | 25 | @property 26 | def label(self): 27 | return int(self._data[2]) 28 | 29 | 30 | class MyVideoDataset(data.Dataset): 31 | """ 32 | sample several clips of consecutive frames 33 | """ 34 | def __init__(self, root_path, list_file, clip_length=64, frame_interval=1, num_clips=1, frame_size=(320, 240), 35 | modality='RGB', vid_format='.mp4', 36 | transform=None, test_mode=False, video_data_dir=None, debug=False, debug_vid=50): 37 | self.root_path = root_path 38 | self.list_file = list_file 39 | self.clip_len = clip_length 40 | self.frame_interval = frame_interval 41 | self.num_clips = num_clips 42 | # self.frame_size = frame_size 43 | self.modality = modality 44 | # self.image_tmpl = image_tmpl 45 | self.vid_format = vid_format 46 | self.transform = transform 47 | # self.random_shift = random_shift 48 | self.test_mode = test_mode 49 | self.video_data_dir = video_data_dir 50 | 51 | self.debug = debug 52 | self.debug_vid = debug_vid 53 | 54 | self._parse_list() 55 | 56 | def _parse_list(self): 57 | self.video_list = [VideoRecord(x.strip().split(' ')) for x in 58 | open(os.path.join(self.list_file))] 59 | if self.debug: 60 | self.video_list = self.video_list[:self.debug_vid] 61 | 62 | def _sample_indices_old(self, record): 63 | if not self.test_mode and self.random_shift: 64 | average_duration = record.num_frames // self.clip_len 65 | if average_duration > 0: 66 | # uniformly divide and then randomly sample from each segment 67 | offsets = np.sort( 68 | np.multiply(list(range(self.clip_len)), average_duration) + randint(average_duration, 69 | size=self.clip_len)) 70 | else: 71 | # randomly sample with repetitions 72 | offsets = np.sort(randint(record.num_frames, size=self.clip_len)) 73 | else: # equi-distant sampling 74 | tick = record.num_frames / float(self.clip_len) 75 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.clip_len)]) 76 | return offsets + 1 77 | 78 | def _get_train_clips(self, num_frames): 79 | ori_clip_len = self.clip_len * self.frame_interval 80 | # the average interval between two clips (the clip can only be sampled within the video, last start position is (num_frames - ori_clip_len +1) 81 | avg_interval = (num_frames - ori_clip_len + 1) // self.num_clips 82 | if avg_interval > 0: # randomly sample the starting position of each clip 83 | base_offsets = np.arange(self.num_clips) * avg_interval # starting positions of all clips 84 | clip_offsets = base_offsets + np.random.randint( 85 | avg_interval, size=self.num_clips) # randomly shift each starting posiiton within one interval 86 | elif num_frames > max(self.num_clips, ori_clip_len): 87 | clip_offsets = np.sort( 88 | np.random.randint( 89 | num_frames - ori_clip_len + 1, 90 | size=self.num_clips)) # in the interval of (0, num_frames - ori_clip_len + 1), randomly choose 4 starting positions 91 | elif avg_interval == 0: 92 | ratio = (num_frames - ori_clip_len + 1.0) / self.num_clips 93 | clip_offsets = np.around(np.arange(self.num_clips) * ratio) 94 | else: 95 | clip_offsets = np.zeros((self.num_clips,), dtype=np.int) 96 | return clip_offsets 97 | 98 | def _get_test_clips(self, num_frames): # uniformly sample the starting position of each clip 99 | ori_clip_len = self.clip_len * self.frame_interval 100 | avg_interval = (num_frames - ori_clip_len + 1) / float(self.num_clips) 101 | if num_frames > ori_clip_len - 1: 102 | base_offsets = np.arange(self.num_clips) * avg_interval 103 | clip_offsets = (base_offsets + avg_interval / 2.0).astype(np.int) 104 | else: 105 | clip_offsets = np.zeros((self.num_clips,), dtype=np.int) 106 | return clip_offsets 107 | 108 | def _sample_clips(self, num_frames): 109 | if self.test_mode: 110 | clip_offsets = self._get_test_clips(num_frames) 111 | else: 112 | clip_offsets = self._get_train_clips(num_frames) 113 | return clip_offsets 114 | 115 | def _sample_indices(self, record): 116 | num_frames = record.num_frames 117 | clip_offsets = self._sample_clips(num_frames) 118 | frame_inds = clip_offsets[:, None] + np.arange( 119 | self.clip_len)[None, :] * self.frame_interval 120 | frame_inds = np.concatenate(frame_inds) 121 | frame_inds = frame_inds.reshape((-1, self.clip_len)) # (clip_len, ) -> (n_clips, clip_len ) 122 | frame_inds = np.mod(frame_inds, num_frames) 123 | return frame_inds 124 | 125 | def __getitem__(self, index): 126 | record = self.video_list[index] # video file 127 | indices = self._sample_indices(record) # get the frame indices 128 | return self.get(record, indices) 129 | 130 | def get(self, record, indices): # indices (n_clips, clip_len ) 131 | vid_path = os.path.join(self.video_data_dir, f'{record.path}{self.vid_format}') 132 | container = decord.VideoReader(vid_path) 133 | frame_indices = np.concatenate(indices) # flatten the frame indices 134 | # accurate mode 135 | images = container.get_batch(frame_indices).asnumpy() 136 | images = list(images) 137 | images = [Image.fromarray(image).convert('RGB') for image in images] 138 | 139 | # process_data = self.transform(images) 140 | # return process_data, record.label 141 | process_data, label = self.transform( (images, record.label) ) 142 | return process_data, label 143 | 144 | def __len__(self): 145 | return len(self.video_list) 146 | 147 | 148 | class MyTSNVideoDataset(data.Dataset): 149 | """ 150 | sample 1 clip in uniform sampling (TSN style ) 151 | """ 152 | def __init__(self, args, root_path, list_file, clip_length=64, frame_interval=1, num_clips=1, frame_size=(320, 240), 153 | modality='RGB', image_tmpl='img_{:05d}.jpg', vid_format='.mp4', 154 | transform=None, random_shift=True, test_mode=False, video_data_dir=None, debug=False, debug_vid=50): 155 | self.root_path = root_path 156 | self.list_file = list_file 157 | self.clip_len = clip_length 158 | self.frame_interval = frame_interval 159 | self.num_clips = num_clips 160 | # self.frame_size = frame_size 161 | self.modality = modality 162 | # self.image_tmpl = image_tmpl 163 | self.vid_format = vid_format 164 | self.transform = transform 165 | # self.random_shift = random_shift 166 | self.test_mode = test_mode 167 | self.video_data_dir = video_data_dir 168 | 169 | self.debug = debug 170 | self.debug_vid = debug_vid 171 | 172 | self._parse_list() 173 | self.args = args 174 | 175 | def _parse_list(self): 176 | self.video_list = [VideoRecord(x.strip().split(' ')) for x in 177 | open(os.path.join(self.list_file))] 178 | 179 | if self.debug: 180 | self.video_list = self.video_list[:self.debug_vid] 181 | 182 | def _sample_indices_old(self, record): 183 | if not self.test_mode and self.random_shift: 184 | average_duration = record.num_frames // self.clip_len 185 | if average_duration > 0: 186 | # uniformly divide and then randomly sample from each segment 187 | offsets = np.sort( 188 | np.multiply(list(range(self.clip_len)), average_duration) + randint(average_duration, 189 | size=self.clip_len)) 190 | else: 191 | # randomly sample with repetitions 192 | offsets = np.sort(randint(record.num_frames, size=self.clip_len)) 193 | else: # equi-distant sampling 194 | tick = record.num_frames / float(self.clip_len) 195 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.clip_len)]) 196 | return offsets + 1 197 | 198 | def _get_train_clips(self, num_frames): 199 | ori_clip_len = self.clip_len * self.frame_interval 200 | # the average interval between two clips (the clip can only be sampled within the video, last start position is (num_frames - ori_clip_len +1) 201 | avg_interval = (num_frames - ori_clip_len + 1) // self.num_clips 202 | if avg_interval > 0: # randomly sample the starting position of each clip 203 | base_offsets = np.arange(self.num_clips) * avg_interval # starting positions of all clips 204 | clip_offsets = base_offsets + np.random.randint( 205 | avg_interval, size=self.num_clips) # randomly shift each starting posiiton within one interval 206 | elif num_frames > max(self.num_clips, ori_clip_len): 207 | clip_offsets = np.sort( 208 | np.random.randint( 209 | num_frames - ori_clip_len + 1, 210 | size=self.num_clips)) # in the interval of (0, num_frames - ori_clip_len + 1), randomly choose 4 starting positions 211 | elif avg_interval == 0: 212 | ratio = (num_frames - ori_clip_len + 1.0) / self.num_clips 213 | clip_offsets = np.around(np.arange(self.num_clips) * ratio) 214 | else: 215 | clip_offsets = np.zeros((self.num_clips,), dtype=np.int) 216 | return clip_offsets 217 | 218 | def _get_test_clips(self, num_frames): # uniformly sample the starting position of each clip 219 | ori_clip_len = self.clip_len * self.frame_interval 220 | avg_interval = (num_frames - ori_clip_len + 1) / float(self.num_clips) 221 | if num_frames > ori_clip_len - 1: 222 | base_offsets = np.arange(self.num_clips) * avg_interval 223 | clip_offsets = (base_offsets + avg_interval / 2.0).astype(np.int) 224 | else: 225 | clip_offsets = np.zeros((self.num_clips,), dtype=np.int) 226 | return clip_offsets 227 | 228 | def _sample_clips(self, num_frames): 229 | if self.test_mode: 230 | clip_offsets = self._get_test_clips(num_frames) 231 | else: 232 | clip_offsets = self._get_train_clips(num_frames) 233 | return clip_offsets 234 | 235 | def uniform_divide_segment(self, record): 236 | vid_len = record.num_frames 237 | # n_segments = self.num_clips 238 | n_segments = self.clip_len 239 | seg_len = int(np.floor(float(vid_len) / n_segments)) 240 | seg_len_list = [seg_len] * n_segments 241 | for idx in range(vid_len - seg_len * n_segments): 242 | seg_len_list[idx] += 1 243 | return seg_len_list 244 | 245 | def _sample_indices_train(self, record): 246 | selected_frames_ = np.zeros((self.num_clips, self.clip_len)) 247 | 248 | if record.num_frames >= self.clip_len: 249 | seg_len_list = self.uniform_divide_segment(record) 250 | # selected_frames = [] 251 | 252 | for clip_id in range(self.num_clips): 253 | start = 0 254 | for seg_id, seg_len in enumerate(seg_len_list): 255 | end = start + seg_len - 1 # (0, 5) (6, 11), (12, 17) 256 | selected_frames_[clip_id, seg_id] = random.randint(start, 257 | end) # todo random.randint returns a random number, both borders are included 258 | # selected_frames_per_clip.append( random.randint( start, end ) ) 259 | start = end + 1 260 | else: # clip_len > num_frames 261 | selected_frames = list(range(record.num_frames)) + [record.num_frames - 1] * ( 262 | self.clip_len - record.num_frames) 263 | for clip_id in range(self.num_clips): 264 | selected_frames_[clip_id, :] = selected_frames 265 | return selected_frames_ 266 | 267 | def _sample_indices_test(self, record): 268 | # sample one clip for test 269 | selected_frames_ = np.zeros((1, self.clip_len)) 270 | if record.num_frames >= self.clip_len: 271 | seg_len = int(np.floor(float(record.num_frames) / self.clip_len)) 272 | half_seg_len = int(np.floor(seg_len / 2.0)) 273 | selected_frames = np.arange(self.clip_len) * seg_len + half_seg_len 274 | else: # clip_len > num_frames 275 | # repeat the last frame 276 | selected_frames = list(range(record.num_frames)) + [record.num_frames - 1] * ( 277 | self.clip_len - record.num_frames) 278 | selected_frames = np.array(selected_frames) 279 | selected_frames_[0, :] = selected_frames 280 | return selected_frames_ 281 | 282 | def _sample_indices(self, record): 283 | frame_inds = self._sample_indices_test(record) if self.test_mode else self._sample_indices_train(record) 284 | return frame_inds 285 | 286 | def __getitem__(self, index): 287 | record = self.video_list[index] # video file 288 | indices = self._sample_indices(record) # get the frame indices 289 | return self.get(record, indices) 290 | 291 | def get(self, record, indices): # indices (n_clips, clip_len ) 292 | vid_path = os.path.join(self.video_data_dir, f'{record.path}{self.vid_format}') 293 | container = decord.VideoReader(vid_path) # read video into a container 294 | frame_indices = np.concatenate(indices) # flatten the frame indices 295 | # accurate mode 296 | 297 | frame_indices = np.minimum(frame_indices, container._num_frame - 1) 298 | images = container.get_batch(frame_indices).asnumpy() 299 | # try: 300 | # images = container.get_batch(frame_indices).asnumpy() 301 | # except: 302 | # print(frame_indices) 303 | 304 | images = list(images) 305 | images = [Image.fromarray(image).convert('RGB') for image in images] 306 | 307 | # process_data = self.transform(images) 308 | process_data, label = self.transform( (images, record.label) ) 309 | return process_data, label 310 | 311 | def __len__(self): 312 | return len(self.video_list) 313 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbones import * 2 | from .i3d import * 3 | -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/i3d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/models/__pycache__/i3d.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/i3d_incep.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/models/__pycache__/i3d_incep.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/r2plus1d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/models/__pycache__/r2plus1d.cpython-36.pyc -------------------------------------------------------------------------------- /models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet3d import * 2 | -------------------------------------------------------------------------------- /models/backbones/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/models/backbones/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/backbones/__pycache__/resnet3d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/models/backbones/__pycache__/resnet3d.cpython-36.pyc -------------------------------------------------------------------------------- /models/backbones/resnet3d.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.hub import load_state_dict_from_url 3 | 4 | __all__ = ['resnet3d'] 5 | 6 | model_urls = { 7 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 8 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 9 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 10 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 11 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 12 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 13 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 14 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 15 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 16 | } 17 | 18 | 19 | def conv3x3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 20 | """3x3x3 convolution with padding""" 21 | if isinstance(stride, int): 22 | stride = (1, stride, stride) 23 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, 24 | padding=dilation, dilation=dilation, groups=groups, bias=False) 25 | 26 | 27 | def conv1x1x1(in_planes, out_planes, stride=1): 28 | """1x1x1 convolution""" 29 | if isinstance(stride, int): 30 | stride = (1, stride, stride) 31 | return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 32 | 33 | 34 | class BasicBlock3d(nn.Module): 35 | expansion = 1 36 | 37 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 38 | base_width=64, dilation=1, norm_layer=None): 39 | super(BasicBlock3d, self).__init__() 40 | if norm_layer is None: 41 | norm_layer = nn.BatchNorm3d 42 | if groups != 1 or base_width != 64: 43 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 44 | if dilation > 1: 45 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 46 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 47 | self.conv1 = conv3x3x3(inplanes, planes, stride) 48 | self.bn1 = norm_layer(planes) 49 | self.relu = nn.ReLU(inplace=True) 50 | self.conv2 = conv3x3x3(planes, planes) 51 | self.bn2 = norm_layer(planes) 52 | self.downsample = downsample 53 | self.stride = stride 54 | 55 | def forward(self, x): 56 | identity = x 57 | 58 | out = self.conv1(x) 59 | out = self.bn1(out) 60 | out = self.relu(out) 61 | 62 | out = self.conv2(out) 63 | out = self.bn2(out) 64 | 65 | if self.downsample is not None: 66 | identity = self.downsample(x) 67 | 68 | out += identity 69 | out = self.relu(out) 70 | 71 | return out 72 | 73 | 74 | class Bottleneck3d(nn.Module): 75 | expansion = 4 76 | 77 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 78 | base_width=64, dilation=1, norm_layer=None): 79 | super(Bottleneck3d, self).__init__() 80 | if norm_layer is None: 81 | norm_layer = nn.BatchNorm3d 82 | width = int(planes * (base_width / 64.)) * groups 83 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 84 | self.conv1 = conv1x1x1(inplanes, width) 85 | self.bn1 = norm_layer(width) 86 | self.conv2 = conv3x3x3(width, width, stride, groups, dilation) 87 | self.bn2 = norm_layer(width) 88 | self.conv3 = conv1x1x1(width, planes * self.expansion) 89 | self.bn3 = norm_layer(planes * self.expansion) 90 | self.relu = nn.ReLU(inplace=True) 91 | self.downsample = downsample 92 | self.stride = stride 93 | 94 | def forward(self, x): 95 | identity = x 96 | 97 | out = self.conv1(x) 98 | out = self.bn1(out) 99 | out = self.relu(out) 100 | 101 | out = self.conv2(out) 102 | out = self.bn2(out) 103 | out = self.relu(out) 104 | 105 | out = self.conv3(out) 106 | out = self.bn3(out) 107 | 108 | if self.downsample is not None: 109 | identity = self.downsample(x) 110 | 111 | out += identity 112 | out = self.relu(out) 113 | 114 | return out 115 | 116 | 117 | class ResNet3d(nn.Module): 118 | 119 | def __init__(self, block, layers, zero_init_residual=False, 120 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 121 | norm_layer=None, modality='RGB'): 122 | super(ResNet3d, self).__init__() 123 | if norm_layer is None: 124 | norm_layer = nn.BatchNorm3d 125 | self._norm_layer = norm_layer 126 | 127 | self.modality = modality 128 | self.inplanes = 64 129 | self.dilation = 1 130 | if replace_stride_with_dilation is None: 131 | # each element in the tuple indicates if we should replace 132 | # the 2x2x2 stride with a dilated convolution instead 133 | replace_stride_with_dilation = [False, False, False] 134 | if len(replace_stride_with_dilation) != 3: 135 | raise ValueError("replace_stride_with_dilation should be None " 136 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 137 | self.groups = groups 138 | self.base_width = width_per_group 139 | 140 | self._make_stem_layer() 141 | 142 | self.layer1 = self._make_layer(block, 64, layers[0]) 143 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 144 | dilate=replace_stride_with_dilation[0]) 145 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 146 | dilate=replace_stride_with_dilation[1]) 147 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 148 | dilate=replace_stride_with_dilation[2]) 149 | 150 | for m in self.modules(): # self.modules() --> Depth-First-Search the Net 151 | if isinstance(m, nn.Conv3d): 152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 153 | elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)): 154 | nn.init.constant_(m.weight, 1) 155 | nn.init.constant_(m.bias, 0) 156 | 157 | # Zero-initialize the last BN in each residual branch, 158 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 159 | if zero_init_residual: 160 | for m in self.modules(): 161 | if isinstance(m, Bottleneck3d): 162 | nn.init.constant_(m.bn3.weight, 0) 163 | elif isinstance(m, BasicBlock3d): 164 | nn.init.constant_(m.bn2.weight, 0) 165 | 166 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 167 | norm_layer = self._norm_layer 168 | downsample = None 169 | previous_dilation = self.dilation 170 | if dilate: 171 | self.dilation *= stride 172 | stride = 1 173 | if stride != 1 or self.inplanes != planes * block.expansion: 174 | downsample = nn.Sequential( 175 | conv1x1x1(self.inplanes, planes * block.expansion, stride), 176 | norm_layer(planes * block.expansion), 177 | ) 178 | 179 | layers = [] 180 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 181 | self.base_width, previous_dilation, norm_layer)) 182 | self.inplanes = planes * block.expansion 183 | for _ in range(1, blocks): 184 | layers.append(block(self.inplanes, planes, groups=self.groups, 185 | base_width=self.base_width, dilation=self.dilation, 186 | norm_layer=norm_layer)) 187 | 188 | return nn.Sequential(*layers) 189 | 190 | def _make_stem_layer(self): 191 | """Construct the stem layers consists of a conv+norm+act module and a 192 | pooling layer.""" 193 | if self.modality == 'RGB': 194 | inchannels = 3 195 | elif self.modality == 'Flow': 196 | inchannels = 2 197 | else: 198 | raise ValueError('Unknown modality: {}'.format(self.modality)) 199 | self.conv1 = nn.Conv3d(inchannels, self.inplanes, kernel_size=(5, 7, 7), 200 | stride=2, padding=(2, 3, 3), bias=False) 201 | self.bn1 = self._norm_layer(self.inplanes) 202 | self.relu = nn.ReLU(inplace=True) 203 | self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=2, 204 | padding=(0, 1, 1)) # kernel_size=(2, 3, 3) 205 | 206 | def _forward_impl(self, x): 207 | # See note [TorchScript super()] 208 | x = self.conv1(x) 209 | x = self.bn1(x) 210 | x = self.relu(x) 211 | x = self.maxpool(x) 212 | 213 | x = self.layer1(x) 214 | x = self.layer2(x) 215 | x = self.layer3(x) 216 | x = self.layer4(x) 217 | 218 | return x 219 | 220 | def forward(self, x): 221 | return self._forward_impl(x) 222 | 223 | def _inflate_conv_params(self, conv3d, state_dict_2d, module_name_2d, 224 | inflated_param_names): 225 | """Inflate a conv module from 2d to 3d. 226 | 227 | Args: 228 | conv3d (nn.Module): The destination conv3d module. 229 | state_dict_2d (OrderedDict): The state dict of pretrained 2d models. 230 | module_name_2d (str): The name of corresponding conv module in the 231 | 2d models. 232 | inflated_param_names (list[str]): List of parameters that have been 233 | inflated. 234 | """ 235 | weight_2d_name = module_name_2d + '.weight' 236 | 237 | conv2d_weight = state_dict_2d[weight_2d_name] 238 | kernel_t = conv3d.weight.data.shape[2] 239 | 240 | new_weight = conv2d_weight.data.unsqueeze(2).expand_as(conv3d.weight) / kernel_t 241 | conv3d.weight.data.copy_(new_weight) 242 | inflated_param_names.append(weight_2d_name) 243 | 244 | if getattr(conv3d, 'bias') is not None: 245 | bias_2d_name = module_name_2d + '.bias' 246 | conv3d.bias.data.copy_(state_dict_2d[bias_2d_name]) 247 | inflated_param_names.append(bias_2d_name) 248 | 249 | def _inflate_bn_params(self, bn3d, state_dict_2d, module_name_2d, 250 | inflated_param_names): 251 | """Inflate a norm module from 2d to 3d. 252 | 253 | Args: 254 | bn3d (nn.Module): The destination bn3d module. 255 | state_dict_2d (OrderedDict): The state dict of pretrained 2d models. 256 | module_name_2d (str): The name of corresponding bn module in the 257 | 2d models. 258 | inflated_param_names (list[str]): List of parameters that have been 259 | inflated. 260 | """ 261 | for param_name, param in bn3d.named_parameters(): 262 | param_2d_name = f'{module_name_2d}.{param_name}' 263 | param_2d = state_dict_2d[param_2d_name] 264 | param.data.copy_(param_2d) 265 | inflated_param_names.append(param_2d_name) 266 | 267 | for param_name, param in bn3d.named_buffers(): 268 | param_2d_name = f'{module_name_2d}.{param_name}' 269 | # some buffers like num_batches_tracked may not exist in old 270 | # checkpoints 271 | if param_2d_name in state_dict_2d: 272 | param_2d = state_dict_2d[param_2d_name] 273 | param.data.copy_(param_2d) 274 | inflated_param_names.append(param_2d_name) 275 | 276 | def inflate_weights(self, state_dict_r2d): 277 | """Inflate the resnet2d parameters to resnet3d. 278 | 279 | The differences between resnet3d and resnet2d mainly lie in an extra 280 | axis of conv kernel. To utilize the pretrained parameters in 2d models, 281 | the weight of conv2d models should be inflated to fit in the shapes of 282 | the 3d counterpart. 283 | 284 | """ 285 | 286 | inflated_param_names = [] 287 | for name, module in self.named_modules(): 288 | if isinstance(module, nn.Conv3d) or isinstance(module, nn.BatchNorm3d): 289 | if name + '.weight' not in state_dict_r2d: 290 | print(f'Module not exist in the state_dict_r2d: {name}') 291 | else: 292 | shape_2d = state_dict_r2d[name + '.weight'].shape 293 | shape_3d = module.weight.data.shape 294 | if shape_2d != shape_3d[:2] + shape_3d[3:]: 295 | print(f'Weight shape mismatch for: {name}' 296 | f'3d weight shape: {shape_3d}; ' 297 | f'2d weight shape: {shape_2d}. ') 298 | else: 299 | if isinstance(module, nn.Conv3d): 300 | self._inflate_conv_params(module, state_dict_r2d, name, inflated_param_names) 301 | else: 302 | self._inflate_bn_params(module, state_dict_r2d, name, inflated_param_names) 303 | 304 | # check if any parameters in the 2d checkpoint are not loaded 305 | remaining_names = set(state_dict_r2d.keys()) - set(inflated_param_names) 306 | if remaining_names: 307 | print(f'These parameters in the 2d checkpoint are not loaded: {remaining_names}') 308 | 309 | 310 | def resnet3d(arch, progress=True, modality='RGB', pretrained2d=True, **kwargs): 311 | r""" 312 | Args: 313 | arch (str): The architecture of resnet 314 | modality (str): The modality of input, 'RGB' or 'Flow' 315 | progress (bool): If True, displays a progress bar of the download to stderr 316 | pretrained2d (bool): If True, utilize the pretrained parameters in 2d models 317 | """ 318 | 319 | arch_settings = { 320 | 'resnet18': (BasicBlock3d, (2, 2, 2, 2)), 321 | 'resnet34': (BasicBlock3d, (3, 4, 6, 3)), 322 | 'resnet50': (Bottleneck3d, (3, 4, 6, 3)), 323 | 'resnet101': (Bottleneck3d, (3, 4, 23, 3)), 324 | 'resnet152': (Bottleneck3d, (3, 8, 36, 3)) 325 | } 326 | 327 | model = ResNet3d(*arch_settings[arch], modality=modality, **kwargs) 328 | if pretrained2d: 329 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 330 | model.inflate_weights(state_dict) 331 | return model 332 | -------------------------------------------------------------------------------- /models/i3d.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .backbones import resnet3d 3 | 4 | __all__ = ['i3d_resnet18', 'i3d_resnet34', 'i3d_resnet50', 'i3d_resnet101', 'i3d_resnet152'] 5 | 6 | 7 | class I3D(nn.Module): 8 | """ 9 | Implements a I3D Network for action recognition. 10 | 11 | Arguments: 12 | backbone (nn.Module): the network used to compute the features for the model. 13 | classifier (nn.Module): module that takes the features returned from the 14 | backbone and returns classification scores. 15 | """ 16 | 17 | def __init__(self, backbone, classifier): 18 | super(I3D, self).__init__() 19 | self.backbone = backbone 20 | self.classifier = classifier 21 | 22 | def forward(self, x): 23 | x = self.backbone(x) 24 | x = self.classifier(x) 25 | return x 26 | 27 | 28 | class I3DHead(nn.Module): 29 | """Classification head for I3D. 30 | 31 | Args: 32 | num_classes (int): Number of classes to be classified. 33 | in_channels (int): Number of channels in input feature. 34 | dropout_ratio (float): Probability of dropout layer. Default: 0.5. 35 | """ 36 | 37 | def __init__(self, num_classes, in_channels, dropout_ratio=0.5): 38 | super(I3DHead, self).__init__() 39 | self.num_classes = num_classes 40 | self.in_channels = in_channels 41 | self.dropout_ratio = dropout_ratio 42 | # use `nn.AdaptiveAvgPool3d` to adaptively match the in_channels. 43 | self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) 44 | if self.dropout_ratio != 0: 45 | self.dropout = nn.Dropout(p=self.dropout_ratio) 46 | else: 47 | self.dropout = None 48 | self.fc_cls = nn.Linear(self.in_channels, self.num_classes) 49 | 50 | def forward(self, x): 51 | """Defines the computation performed at every call. 52 | 53 | Args: 54 | x (torch.Tensor): The input data. 55 | 56 | Returns: 57 | torch.Tensor: The classification scores for input samples. 58 | """ 59 | # [N, in_channels, 4, 7, 7] 60 | x = self.avg_pool(x) 61 | # [N, in_channels, 1, 1, 1] 62 | if self.dropout is not None: 63 | x = self.dropout(x) 64 | # [N, in_channels, 1, 1, 1] 65 | x = x.view(x.shape[0], -1) 66 | # [N, in_channels] 67 | cls_score = self.fc_cls(x) # i3d_r18 (bz, 512) 33209676; i3d_r34 (bz, 512) 63519308; i3d_r50 (bz, 2048), 46204748 ; i3d_r101 (bz, 2048), 85250892; i3d_r101 (bz, 2048), 85250892; i3d_r152 (bz, 2048), 117409612 68 | # [N, num_classes] 69 | return cls_score 70 | 71 | 72 | def _load_model(backbone_name, progress, modality, pretrained2d, num_classes, in_channel, **kwargs): 73 | backbone = resnet3d(arch=backbone_name, progress=progress, modality=modality, pretrained2d=pretrained2d) 74 | classifier = I3DHead(num_classes=num_classes, in_channels=in_channel, **kwargs) 75 | model = I3D(backbone, classifier) 76 | return model 77 | 78 | 79 | def i3d_resnet18(modality='RGB', pretrained2d=True, progress=True, num_classes=21, in_channel = 2048, **kwargs): 80 | """Constructs a I3D model with a ResNet3d-18 backbone. 81 | 82 | Args: 83 | modality (str): The modality of input data (RGB or Flow). If 'RGB', the first Conv 84 | accept a 3-channels input. (Default: RGB) 85 | pretrained2d (bool): If True, the backbone utilize the pretrained parameters in 2d 86 | models. (Default: True) 87 | progress (bool): If True, displays a progress bar of the download to stderr. 88 | (Default: True) 89 | num_classes (int): Number of dataset classes. (Default: 21) 90 | """ 91 | return _load_model('resnet18', progress, modality, pretrained2d, num_classes, in_channel = in_channel, **kwargs) 92 | 93 | 94 | def i3d_resnet34(modality='RGB', pretrained2d=True, progress=True, num_classes=21, in_channel = 2048, **kwargs): 95 | """Constructs a I3D model with a ResNet3d-34 backbone. 96 | 97 | Args: 98 | modality (str): The modality of input data (RGB or Flow). If 'RGB', the first Conv 99 | accept a 3-channels input. (Default: RGB) 100 | pretrained2d (bool): If True, the backbone utilize the pretrained parameters in 2d 101 | models. (Default: True) 102 | progress (bool): If True, displays a progress bar of the download to stderr. 103 | (Default: True) 104 | num_classes (int): Number of dataset classes. (Default: 21) 105 | """ 106 | return _load_model('resnet34', progress, modality, pretrained2d, num_classes, in_channel = in_channel, **kwargs) 107 | 108 | 109 | def i3d_resnet50(modality='RGB', pretrained2d=True, progress=True, num_classes=21, in_channel = 2048, **kwargs): 110 | """Constructs a I3D model with a ResNet3d-50 backbone. 111 | 112 | Args: 113 | modality (str): The modality of input data (RGB or Flow). If 'RGB', the first Conv 114 | accept a 3-channels input. (Default: RGB) 115 | pretrained2d (bool): If True, the backbone utilize the pretrained parameters in 2d 116 | models. (Default: True) 117 | progress (bool): If True, displays a progress bar of the download to stderr. 118 | (Default: True) 119 | num_classes (int): Number of dataset classes. (Default: 21) 120 | """ 121 | return _load_model('resnet50', progress, modality, pretrained2d, num_classes, in_channel = in_channel, **kwargs) 122 | 123 | 124 | def i3d_resnet101(modality='RGB', pretrained2d=True, progress=True, num_classes=21, in_channel = 2048, **kwargs): 125 | """Constructs a I3D model with a ResNet3d-101 backbone. 126 | 127 | Args: 128 | modality (str): The modality of input data (RGB or Flow). If 'RGB', the first Conv 129 | accept a 3-channels input. (Default: RGB) 130 | pretrained2d (bool): If True, the backbone utilize the pretrained parameters in 2d 131 | models. (Default: True) 132 | progress (bool): If True, displays a progress bar of the download to stderr. 133 | (Default: True) 134 | num_classes (int): Number of dataset classes. (Default: 21) 135 | """ 136 | return _load_model('resnet101', progress, modality, pretrained2d, num_classes, in_channel = in_channel, **kwargs) 137 | 138 | 139 | def i3d_resnet152(modality='RGB', pretrained2d=True, progress=True, num_classes=21, in_channel = 2048, **kwargs): 140 | """Constructs a I3D model with a ResNet3d-152 backbone. 141 | 142 | Args: 143 | modality (str): The modality of input data (RGB or Flow). If 'RGB', the first Conv 144 | accept a 3-channels input. (Default: RGB) 145 | pretrained2d (bool): If True, the backbone utilize the pretrained parameters in 2d 146 | models. (Default: True) 147 | progress (bool): If True, displays a progress bar of the download to stderr. 148 | (Default: True) 149 | num_classes (int): Number of dataset classes. (Default: 21) 150 | """ 151 | return _load_model('resnet152', progress, modality, pretrained2d, num_classes, in_channel = in_channel, **kwargs) 152 | -------------------------------------------------------------------------------- /models/i3d_incep.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | import numpy as np 7 | 8 | import os 9 | import sys 10 | from collections import OrderedDict 11 | 12 | 13 | class MaxPool3dSamePadding(nn.MaxPool3d): 14 | 15 | def compute_pad(self, dim, s): 16 | if s % self.stride[dim] == 0: 17 | return max(self.kernel_size[dim] - self.stride[dim], 0) 18 | else: 19 | return max(self.kernel_size[dim] - (s % self.stride[dim]), 0) 20 | 21 | def forward(self, x): 22 | # compute 'same' padding 23 | (batch, channel, t, h, w) = x.size() 24 | #print t,h,w 25 | out_t = np.ceil(float(t) / float(self.stride[0])) 26 | out_h = np.ceil(float(h) / float(self.stride[1])) 27 | out_w = np.ceil(float(w) / float(self.stride[2])) 28 | #print out_t, out_h, out_w 29 | pad_t = self.compute_pad(0, t) 30 | pad_h = self.compute_pad(1, h) 31 | pad_w = self.compute_pad(2, w) 32 | #print pad_t, pad_h, pad_w 33 | 34 | pad_t_f = pad_t // 2 35 | pad_t_b = pad_t - pad_t_f 36 | pad_h_f = pad_h // 2 37 | pad_h_b = pad_h - pad_h_f 38 | pad_w_f = pad_w // 2 39 | pad_w_b = pad_w - pad_w_f 40 | 41 | pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) 42 | #print x.size() 43 | #print pad 44 | x = F.pad(x, pad) 45 | return super(MaxPool3dSamePadding, self).forward(x) 46 | 47 | 48 | class Unit3D(nn.Module): 49 | 50 | def __init__(self, in_channels, 51 | output_channels, 52 | kernel_shape=(1, 1, 1), 53 | stride=(1, 1, 1), 54 | padding=0, 55 | activation_fn=F.relu, 56 | use_batch_norm=True, 57 | use_bias=False, 58 | name='unit_3d'): 59 | 60 | """Initializes Unit3D module.""" 61 | super(Unit3D, self).__init__() 62 | 63 | self._output_channels = output_channels 64 | self._kernel_shape = kernel_shape 65 | self._stride = stride 66 | self._use_batch_norm = use_batch_norm 67 | self._activation_fn = activation_fn 68 | self._use_bias = use_bias 69 | self.name = name 70 | self.padding = padding 71 | 72 | self.conv3d = nn.Conv3d(in_channels=in_channels, 73 | out_channels=self._output_channels, 74 | kernel_size=self._kernel_shape, 75 | stride=self._stride, 76 | padding=0, # we always want padding to be 0 here. We will dynamically pad based on input size in forward function 77 | bias=self._use_bias) 78 | 79 | if self._use_batch_norm: 80 | self.bn = nn.BatchNorm3d(self._output_channels, eps=0.001, momentum=0.01) 81 | 82 | def compute_pad(self, dim, s): 83 | if s % self._stride[dim] == 0: 84 | return max(self._kernel_shape[dim] - self._stride[dim], 0) 85 | else: 86 | return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0) 87 | 88 | 89 | def forward(self, x): 90 | # compute 'same' padding 91 | (batch, channel, t, h, w) = x.size() 92 | #print t,h,w 93 | out_t = np.ceil(float(t) / float(self._stride[0])) 94 | out_h = np.ceil(float(h) / float(self._stride[1])) 95 | out_w = np.ceil(float(w) / float(self._stride[2])) 96 | #print out_t, out_h, out_w 97 | pad_t = self.compute_pad(0, t) 98 | pad_h = self.compute_pad(1, h) 99 | pad_w = self.compute_pad(2, w) 100 | #print pad_t, pad_h, pad_w 101 | 102 | pad_t_f = pad_t // 2 103 | pad_t_b = pad_t - pad_t_f 104 | pad_h_f = pad_h // 2 105 | pad_h_b = pad_h - pad_h_f 106 | pad_w_f = pad_w // 2 107 | pad_w_b = pad_w - pad_w_f 108 | 109 | pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) 110 | #print x.size() 111 | #print pad 112 | x = F.pad(x, pad) 113 | #print x.size() 114 | 115 | x = self.conv3d(x) 116 | if self._use_batch_norm: 117 | x = self.bn(x) 118 | if self._activation_fn is not None: 119 | x = self._activation_fn(x) 120 | return x 121 | 122 | 123 | 124 | class InceptionModule(nn.Module): 125 | def __init__(self, in_channels, out_channels, name): 126 | super(InceptionModule, self).__init__() 127 | 128 | self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0, 129 | name=name+'/Branch_0/Conv3d_0a_1x1') 130 | self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0, 131 | name=name+'/Branch_1/Conv3d_0a_1x1') 132 | self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 3, 3], 133 | name=name+'/Branch_1/Conv3d_0b_3x3') 134 | self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0, 135 | name=name+'/Branch_2/Conv3d_0a_1x1') 136 | self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 3, 3], 137 | name=name+'/Branch_2/Conv3d_0b_3x3') 138 | self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3], 139 | stride=(1, 1, 1), padding=0) 140 | self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0, 141 | name=name+'/Branch_3/Conv3d_0b_1x1') 142 | self.name = name 143 | 144 | def forward(self, x): 145 | b0 = self.b0(x) 146 | b1 = self.b1b(self.b1a(x)) 147 | b2 = self.b2b(self.b2a(x)) 148 | b3 = self.b3b(self.b3a(x)) 149 | return torch.cat([b0,b1,b2,b3], dim=1) 150 | 151 | 152 | class InceptionI3d(nn.Module): 153 | """Inception-v1 I3D architecture. 154 | The model is introduced in: 155 | Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset 156 | Joao Carreira, Andrew Zisserman 157 | https://arxiv.org/pdf/1705.07750v1.pdf. 158 | See also the Inception architecture, introduced in: 159 | Going deeper with convolutions 160 | Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, 161 | Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich. 162 | http://arxiv.org/pdf/1409.4842v1.pdf. 163 | """ 164 | 165 | # Endpoints of the model in order. During construction, all the endpoints up 166 | # to a designated `final_endpoint` are returned in a dictionary as the 167 | # second return value. 168 | VALID_ENDPOINTS = ( 169 | 'Conv3d_1a_7x7', 170 | 'MaxPool3d_2a_3x3', 171 | 'Conv3d_2b_1x1', 172 | 'Conv3d_2c_3x3', 173 | 'MaxPool3d_3a_3x3', 174 | 'Mixed_3b', 175 | 'Mixed_3c', 176 | 'MaxPool3d_4a_3x3', 177 | 'Mixed_4b', 178 | 'Mixed_4c', 179 | 'Mixed_4d', 180 | 'Mixed_4e', 181 | 'Mixed_4f', 182 | 'MaxPool3d_5a_2x2', 183 | 'Mixed_5b', 184 | 'Mixed_5c', 185 | 'Logits', 186 | 'Predictions', 187 | ) 188 | 189 | def __init__(self, num_classes=400, spatial_squeeze=True, 190 | final_endpoint='Logits', name='inception_i3d', in_channels=3, dropout_keep_prob=0.5): 191 | """Initializes I3D model instance. 192 | Args: 193 | num_classes: The number of outputs in the logit layer (default 400, which 194 | matches the Kinetics datasets_). 195 | spatial_squeeze: Whether to squeeze the spatial dimensions for the logits 196 | before returning (default True). 197 | final_endpoint: The model contains many possible endpoints. 198 | `final_endpoint` specifies the last endpoint for the model to be built 199 | up to. In addition to the output at `final_endpoint`, all the outputs 200 | at endpoints up to `final_endpoint` will also be returned, in a 201 | dictionary. `final_endpoint` must be one of 202 | InceptionI3d.VALID_ENDPOINTS (default 'Logits'). 203 | name: A string (optional). The name of this module. 204 | Raises: 205 | ValueError: if `final_endpoint` is not recognized. 206 | """ 207 | 208 | if final_endpoint not in self.VALID_ENDPOINTS: 209 | raise ValueError('Unknown final endpoint %s' % final_endpoint) 210 | 211 | super(InceptionI3d, self).__init__() 212 | self._num_classes = num_classes 213 | self._spatial_squeeze = spatial_squeeze 214 | self._final_endpoint = final_endpoint 215 | self.logits = None 216 | 217 | if self._final_endpoint not in self.VALID_ENDPOINTS: 218 | raise ValueError('Unknown final endpoint %s' % self._final_endpoint) 219 | 220 | self.end_points = {} 221 | end_point = 'Conv3d_1a_7x7' 222 | self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7], 223 | stride=(2, 2, 2), padding=(3,3,3), name=name+end_point) 224 | if self._final_endpoint == end_point: return 225 | 226 | end_point = 'MaxPool3d_2a_3x3' 227 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), 228 | padding=0) 229 | if self._final_endpoint == end_point: return 230 | 231 | end_point = 'Conv3d_2b_1x1' 232 | self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0, 233 | name=name+end_point) 234 | if self._final_endpoint == end_point: return 235 | 236 | end_point = 'Conv3d_2c_3x3' 237 | self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1, 238 | name=name+end_point) 239 | if self._final_endpoint == end_point: return 240 | 241 | end_point = 'MaxPool3d_3a_3x3' 242 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), 243 | padding=0) 244 | if self._final_endpoint == end_point: return 245 | 246 | end_point = 'Mixed_3b' 247 | self.end_points[end_point] = InceptionModule(192, [64,96,128,16,32,32], name+end_point) 248 | if self._final_endpoint == end_point: return 249 | 250 | end_point = 'Mixed_3c' 251 | self.end_points[end_point] = InceptionModule(256, [128,128,192,32,96,64], name+end_point) 252 | if self._final_endpoint == end_point: return 253 | 254 | end_point = 'MaxPool3d_4a_3x3' 255 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2), 256 | padding=0) 257 | if self._final_endpoint == end_point: return 258 | 259 | end_point = 'Mixed_4b' 260 | self.end_points[end_point] = InceptionModule(128+192+96+64, [192,96,208,16,48,64], name+end_point) 261 | if self._final_endpoint == end_point: return 262 | 263 | end_point = 'Mixed_4c' 264 | self.end_points[end_point] = InceptionModule(192+208+48+64, [160,112,224,24,64,64], name+end_point) 265 | if self._final_endpoint == end_point: return 266 | 267 | end_point = 'Mixed_4d' 268 | self.end_points[end_point] = InceptionModule(160+224+64+64, [128,128,256,24,64,64], name+end_point) 269 | if self._final_endpoint == end_point: return 270 | 271 | end_point = 'Mixed_4e' 272 | self.end_points[end_point] = InceptionModule(128+256+64+64, [112,144,288,32,64,64], name+end_point) 273 | if self._final_endpoint == end_point: return 274 | 275 | end_point = 'Mixed_4f' 276 | self.end_points[end_point] = InceptionModule(112+288+64+64, [256,160,320,32,128,128], name+end_point) 277 | if self._final_endpoint == end_point: return 278 | 279 | end_point = 'MaxPool3d_5a_2x2' 280 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(2, 2, 2), 281 | padding=0) 282 | if self._final_endpoint == end_point: return 283 | 284 | end_point = 'Mixed_5b' 285 | self.end_points[end_point] = InceptionModule(256+320+128+128, [256,160,320,32,128,128], name+end_point) 286 | if self._final_endpoint == end_point: return 287 | 288 | end_point = 'Mixed_5c' 289 | self.end_points[end_point] = InceptionModule(256+320+128+128, [384,192,384,48,128,128], name+end_point) 290 | if self._final_endpoint == end_point: return 291 | 292 | end_point = 'Logits' 293 | # self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7], 294 | # stride=(1, 1, 1)) 295 | # self.avg_pool = nn.AvgPool3d(kernel_size=[1, 7, 7], 296 | # stride=(1, 1, 1)) 297 | self.avg_pool = nn.AdaptiveAvgPool3d((1,1,1 )) 298 | self.dropout = nn.Dropout(dropout_keep_prob) 299 | self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes, 300 | kernel_shape=[1, 1, 1], 301 | padding=0, 302 | activation_fn=None, 303 | use_batch_norm=False, 304 | use_bias=True, 305 | name='logits') # todo the last layer is a 3D conv layer 306 | 307 | self.build() 308 | 309 | 310 | def replace_logits(self, num_classes): 311 | self._num_classes = num_classes 312 | self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes, 313 | kernel_shape=[1, 1, 1], 314 | padding=0, 315 | activation_fn=None, 316 | use_batch_norm=False, 317 | use_bias=True, 318 | name='logits') # todo the last layer is a 3D conv layer 319 | 320 | 321 | def build(self): 322 | for k in self.end_points.keys(): 323 | self.add_module(k, self.end_points[k]) 324 | 325 | def forward(self, x): 326 | for end_point in self.VALID_ENDPOINTS: 327 | if end_point in self.end_points: 328 | x = self._modules[end_point](x) # use _modules to work with dataparallel 329 | 330 | x = self.logits(self.dropout(self.avg_pool(x))) 331 | if self._spatial_squeeze: 332 | logits = x.squeeze(3).squeeze(3).squeeze(2) 333 | # logits is batch X time X classes, which is what we want to work with 334 | return logits 335 | 336 | 337 | def extract_features(self, x): 338 | for end_point in self.VALID_ENDPOINTS: 339 | if end_point in self.end_points: 340 | x = self._modules[end_point](x) 341 | return self.avg_pool(x) 342 | 343 | 344 | 345 | # class MyI3DIncep(nn.Module): 346 | # def __init__(self, num_classes, use_pretrained = True, pretrained_model = None, init_std=0.01,): 347 | # super(MyI3DIncep, self).__init__() 348 | # self.model_ft = InceptionI3d(num_classes=400, in_channels=3) 349 | # if use_pretrained: 350 | # self.model_ft.load_state_dict( torch.load(pretrained_model) ) 351 | # print( f'Loaded pretrained I3D Inception model {pretrained_model}' ) 352 | 353 | -------------------------------------------------------------------------------- /models/r2plus1d.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torchvision import models 3 | import torch.nn as nn 4 | 5 | def normal_init(module, mean=0, std=1, bias=0): 6 | if hasattr(module, 'weight') and module.weight is not None: 7 | nn.init.normal_(module.weight, mean, std) 8 | if hasattr(module, 'bias') and module.bias is not None: 9 | nn.init.constant_(module.bias, bias) 10 | 11 | # def initialize_model(arch_name, num_classes) 12 | 13 | class MyR2plus1d(nn.Module): 14 | def __init__(self, num_classes, use_pretrained = True, init_std=0.01, model_name = 'r2plus1d' ): 15 | super(MyR2plus1d, self).__init__() 16 | # self.model_ft = models.__dict__[model_name](pretrained=use_pretrained) 17 | self.model_ft = models.video.r2plus1d_18( pretrained=use_pretrained) 18 | num_ftrs = self.model_ft.fc.in_features 19 | self.init_std = init_std 20 | modules = list(self.model_ft.children())[:-1] 21 | self.model_ft = nn.Sequential(*modules) 22 | self.clsfr = nn.Linear( num_ftrs, num_classes ) 23 | normal_init(self.clsfr, std=self.init_std) 24 | def forward(self, x): 25 | feat= self.model_ft(x).squeeze() 26 | if len(feat.size()) == 1: 27 | feat = feat.unsqueeze(0) 28 | pred_cls = self.clsfr(feat) 29 | return pred_cls 30 | 31 | -------------------------------------------------------------------------------- /models/tanet_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/models/tanet_models/__init__.py -------------------------------------------------------------------------------- /models/tanet_models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/models/tanet_models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/tanet_models/__pycache__/basic_ops.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/models/tanet_models/__pycache__/basic_ops.cpython-36.pyc -------------------------------------------------------------------------------- /models/tanet_models/__pycache__/tanet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/models/tanet_models/__pycache__/tanet.cpython-36.pyc -------------------------------------------------------------------------------- /models/tanet_models/__pycache__/temporal_module.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/models/tanet_models/__pycache__/temporal_module.cpython-36.pyc -------------------------------------------------------------------------------- /models/tanet_models/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/models/tanet_models/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /models/tanet_models/__pycache__/video_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/models/tanet_models/__pycache__/video_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /models/tanet_models/basic_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Identity(torch.nn.Module): 5 | def forward(self, input): 6 | return input 7 | 8 | 9 | class SegmentConsensus(torch.autograd.Function): 10 | 11 | def __init__(self, consensus_type, dim=1): 12 | self.consensus_type = consensus_type 13 | self.dim = dim 14 | self.shape = None 15 | # @staticmethod 16 | def forward(self, input_tensor): # input_tensor (bz, T, n_class) 17 | self.shape = input_tensor.size() 18 | if self.consensus_type == 'avg': # take the average of all frame-level prediction, which is video-level prediction 19 | output = input_tensor.mean(dim=self.dim, keepdim=True) 20 | elif self.consensus_type == 'identity': 21 | output = input_tensor 22 | else: 23 | output = None 24 | 25 | return output 26 | # @staticmethod 27 | def backward(self, grad_output): 28 | if self.consensus_type == 'avg': 29 | grad_in = grad_output.expand(self.shape) / float(self.shape[self.dim]) 30 | elif self.consensus_type == 'identity': 31 | grad_in = grad_output 32 | else: 33 | grad_in = None 34 | 35 | return grad_in 36 | 37 | 38 | class SegmentAvg_static(torch.autograd.Function): 39 | @staticmethod 40 | def forward(self, input_tensor, ): # input_tensor (bz, T, n_class) 41 | dim_ = 1 42 | self.save_for_backward(input_tensor) 43 | output = input_tensor.mean(dim=dim_, keepdim= True) 44 | return output 45 | @staticmethod 46 | def backward(self, grad_output ): 47 | dim_ = 1 48 | input_tensor, = self.saved_tensors 49 | shape_ = input_tensor.size() 50 | grad_in = grad_output.expand(shape_) / float(shape_[dim_]) 51 | return grad_in 52 | 53 | class SegmentIdentity_static(torch.autograd.Function): 54 | @staticmethod 55 | def forward(self, input_tensor, ): # input_tensor (bz, T, n_class) 56 | # dim_ = 1 57 | # self.save_for_backward(input_tensor) 58 | output = input_tensor 59 | return output 60 | @staticmethod 61 | def backward(self, grad_output ): 62 | # dim_ = 1 63 | # input_tensor, = self.saved_tensors 64 | # shape_ = input_tensor.size() 65 | grad_in = grad_output 66 | return grad_in 67 | 68 | 69 | 70 | 71 | class ConsensusModule(torch.nn.Module): # contains no parameters 72 | 73 | def __init__(self, consensus_type, dim=1): 74 | super(ConsensusModule, self).__init__() 75 | self.consensus_type = consensus_type if consensus_type != 'rnn' else 'identity' 76 | self.dim = dim 77 | assert self.dim == 1 78 | 79 | def forward(self, input): # input_tensor (bz, T, n_class) 80 | 81 | if self.consensus_type == 'avg': 82 | return SegmentAvg_static.apply(input) 83 | # return input.mean(dim=self.dim, keepdim=True) 84 | elif self.consensus_type == 'identity': 85 | return SegmentIdentity_static.apply(input) 86 | 87 | # return SegmentConsensus(self.consensus_type, self.dim)(input) -------------------------------------------------------------------------------- /models/tanet_models/temporal_module.py: -------------------------------------------------------------------------------- 1 | # Code for "TAM: Temporal Adaptive Module for Video Recognition" 2 | # arXiv: 2005.06803 3 | # Zhaoyang liu*, Limin Wang, Wayne Wu, Chen Qian, Tong Lu 4 | # zyliumy@gmail.com 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision 10 | 11 | 12 | class TAM(nn.Module): # TAM is concatenate after conv1,bn1,relu 13 | def __init__(self, 14 | in_channels, # the feature dim after conv1 15 | n_segment, 16 | kernel_size=3, 17 | stride=1, 18 | padding=1): 19 | super(TAM, self).__init__() 20 | self.in_channels = in_channels 21 | self.n_segment = n_segment 22 | self.kernel_size = kernel_size 23 | self.stride = stride 24 | self.padding = padding 25 | print('TAM with kernel_size {}.'.format(kernel_size)) 26 | 27 | self.G = nn.Sequential( # todo global branch: to learn the dynamic/adaptive kernels that are used in convolution to aggregate the temporal info 28 | nn.Linear(n_segment, n_segment * 2, bias=False), 29 | nn.BatchNorm1d(n_segment * 2), nn.ReLU(inplace=True), 30 | nn.Linear(n_segment * 2, kernel_size, bias=False), nn.Softmax(-1)) 31 | 32 | self.L = nn.Sequential( # todo local branch: compute local importance map (attention) in shape C x T, using 2 temporal convolutions 33 | nn.Conv1d(in_channels, 34 | in_channels // 4, 35 | kernel_size, 36 | stride=1, 37 | padding=kernel_size // 2, 38 | bias=False), nn.BatchNorm1d(in_channels // 4), 39 | nn.ReLU(inplace=True), 40 | nn.Conv1d(in_channels // 4, in_channels, 1, bias=False), 41 | nn.Sigmoid()) 42 | 43 | def forward(self, x): # todo input is of shape batch*T, C, H, W, output is of the same shape batch*T, C, H, W 44 | # x.size = N*C*T*(H*W) 45 | nt, c, h, w = x.size() # x size : N*T, C, H, W 46 | t = self.n_segment 47 | n_batch = nt // t 48 | new_x = x.view(n_batch, t, c, h, w).permute(0, 2, 1, 3, 49 | 4).contiguous() # ( N*T, C, H, W) -> (N, C, T, H, W) 50 | out = F.adaptive_avg_pool2d(new_x.view(n_batch * c, t, h, w), (1, 1)) # (N, C, T, H, W) -> (N * C, T, 1, 1) 51 | out = out.view(-1, t) # (N * C, T, 1, 1) -> (N * C, T ) 52 | conv_kernel = self.G(out.view(-1, t)).view(n_batch * c, 1, -1, 1) # (N * C, T ) -> (N * C, 1, T, 1 ) 53 | local_activation = self.L(out.view(n_batch, c, 54 | t)).view(n_batch, c, t, 1, 1) # (N * C, T ) -> (N, C, T) -> (N , C, T, 1, 1 ) 55 | new_x = new_x * local_activation # (N, C, T, H, W) 56 | out = F.conv2d(new_x.view(1, n_batch * c, t, h * w), 57 | conv_kernel, 58 | bias=None, 59 | stride=(self.stride, 1), 60 | padding=(self.padding, 0), 61 | groups=n_batch * c) # (N, C, T, H, W) -> (1, N * C, T, H * W) -> (1, N * C, T, H * W) 62 | out = out.view(n_batch, c, t, h, w) # (1, N * C, T, H * W) -> (N, C, T, H, W) 63 | out = out.permute(0, 2, 1, 3, 4).contiguous().view(nt, c, h, w) # (N * T, C, H, W) 64 | 65 | return out 66 | 67 | 68 | class TemporalBottleneck(nn.Module): 69 | def __init__(self, 70 | net, 71 | n_segment=8, 72 | t_kernel_size=3, 73 | t_stride=1, 74 | t_padding=1): 75 | super(TemporalBottleneck, self).__init__() 76 | self.net = net # here net is a Bottleneck module 77 | assert isinstance(net, torchvision.models.resnet.Bottleneck) 78 | self.n_segment = n_segment 79 | self.tam = TAM(in_channels=net.conv1.out_channels, 80 | n_segment=n_segment, 81 | kernel_size=t_kernel_size, 82 | stride=t_stride, 83 | padding=t_padding) 84 | 85 | def forward(self, x): 86 | identity = x 87 | 88 | out = self.net.conv1(x) 89 | out = self.net.bn1(out) 90 | out = self.net.relu(out) 91 | out = self.tam(out) 92 | 93 | out = self.net.conv2(out) 94 | out = self.net.bn2(out) 95 | out = self.net.relu(out) 96 | 97 | out = self.net.conv3(out) 98 | out = self.net.bn3(out) 99 | 100 | if self.net.downsample is not None: 101 | identity = self.net.downsample(x) 102 | 103 | out += identity 104 | out = self.net.relu(out) 105 | 106 | return out 107 | 108 | 109 | def make_temporal_modeling(net, 110 | n_segment=8, 111 | t_kernel_size=3, 112 | t_stride=1, 113 | t_padding=1): 114 | if isinstance(net, torchvision.models.ResNet): 115 | n_round = 1 116 | 117 | def make_block_temporal(stage, # stage is layer1/2/3/4 in ResNet, each stage has 3/4/6/3 TemporalBottleneck module 118 | this_segment, 119 | t_kernel_size=3, 120 | t_stride=1, 121 | t_padding=1): 122 | blocks = list(stage.children()) 123 | print('=> Processing this stage with {} blocks residual'.format( 124 | len(blocks))) 125 | for i, b in enumerate(blocks): 126 | # if i >= len(blocks)//2: 127 | if i % n_round == 0: # todo turn each Bottleneck module into a TemporalBottleneck, TemporalBottleneck = (Bottleneck + TAM ) 128 | blocks[i] = TemporalBottleneck(b, this_segment, 129 | t_kernel_size, t_stride, 130 | t_padding) 131 | return nn.Sequential(*blocks) 132 | # add temporal modules to layer1, layer2, layer3, layer4 in ResNet, layer1/2/3/4 are 4 stages, each stage has 3/4/6/3 TemporalBottleneck module, 133 | net.layer1 = make_block_temporal(net.layer1, n_segment, t_kernel_size, 134 | t_stride, t_padding) 135 | net.layer2 = make_block_temporal(net.layer2, n_segment, t_kernel_size, 136 | t_stride, t_padding) 137 | net.layer3 = make_block_temporal(net.layer3, n_segment, t_kernel_size, 138 | t_stride, t_padding) 139 | net.layer4 = make_block_temporal(net.layer4, n_segment, t_kernel_size, 140 | t_stride, t_padding) 141 | 142 | 143 | if __name__ == '__main__': 144 | # test 145 | pass 146 | -------------------------------------------------------------------------------- /models/videomae_models/modeling_finetune.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_ 7 | from timm.models.registry import register_model 8 | 9 | 10 | def _cfg(url='', **kwargs): 11 | return { 12 | 'url': url, 13 | 'num_classes': 400, 'input_size': (3, 224, 224), 'pool_size': None, 14 | 'crop_pct': .9, 'interpolation': 'bicubic', 15 | 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 16 | **kwargs 17 | } 18 | 19 | 20 | class DropPath(nn.Module): 21 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 22 | """ 23 | 24 | def __init__(self, drop_prob=None): 25 | super(DropPath, self).__init__() 26 | self.drop_prob = drop_prob 27 | 28 | def forward(self, x): 29 | return drop_path(x, self.drop_prob, self.training) 30 | 31 | def extra_repr(self) -> str: 32 | return 'p={}'.format(self.drop_prob) 33 | 34 | 35 | class Mlp(nn.Module): 36 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 37 | super().__init__() 38 | out_features = out_features or in_features 39 | hidden_features = hidden_features or in_features 40 | self.fc1 = nn.Linear(in_features, hidden_features) 41 | self.act = act_layer() 42 | self.fc2 = nn.Linear(hidden_features, out_features) 43 | self.drop = nn.Dropout(drop) 44 | 45 | def forward(self, x): 46 | x = self.fc1(x) 47 | x = self.act(x) 48 | # x = self.drop(x) 49 | # commit this for the orignal BERT implement 50 | x = self.fc2(x) 51 | x = self.drop(x) 52 | return x 53 | 54 | 55 | class Attention(nn.Module): 56 | def __init__( 57 | self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., 58 | proj_drop=0., attn_head_dim=None): 59 | super().__init__() 60 | self.num_heads = num_heads 61 | head_dim = dim // num_heads 62 | if attn_head_dim is not None: 63 | head_dim = attn_head_dim 64 | all_head_dim = head_dim * self.num_heads 65 | self.scale = qk_scale or head_dim ** -0.5 66 | 67 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) # () 68 | if qkv_bias: 69 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) 70 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) 71 | else: 72 | self.q_bias = None 73 | self.v_bias = None 74 | 75 | self.attn_drop = nn.Dropout(attn_drop) 76 | self.proj = nn.Linear(all_head_dim, dim) 77 | self.proj_drop = nn.Dropout(proj_drop) 78 | 79 | def forward(self, x): 80 | B, N, C = x.shape # (bz, 1568, 768) 81 | qkv_bias = None 82 | if self.q_bias is not None: 83 | qkv_bias = torch.cat( 84 | (self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) # (2304, ) 85 | # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 86 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) # (bz, 1568, 768) -> (bz, 1568, 2304) 87 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 88 | 4) # (bz, 1568, 2304) -> (bz, 1568, 3, 12, 64) -> (3, bz, 12, 1568, 64) todo 2304 = 3x12x64 3 is for query, key and value, 12 is number of heads, 64 is hidden dim 89 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 90 | 91 | q = q * self.scale 92 | attn = (q @ k.transpose(-2, -1)) # (bz, 12, 1568, 1568) 93 | 94 | attn = attn.softmax(dim=-1) 95 | attn = self.attn_drop(attn) 96 | 97 | x = (attn @ v).transpose(1, 2).reshape(B, N, 98 | -1) # (bz, 12, 1568, 64) -> (bz, 1568, 12, 64 ) -> (bz, 1568, 768 ) 99 | x = self.proj(x) # (bz, 1568, 768 ) -> (bz, 1568, 768 ) 100 | x = self.proj_drop(x) 101 | return x 102 | 103 | 104 | class Block(nn.Module): 105 | 106 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 107 | drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, 108 | attn_head_dim=None): 109 | super().__init__() 110 | self.norm1 = norm_layer(dim) 111 | self.attn = Attention( 112 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 113 | attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim) 114 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 115 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 116 | self.norm2 = norm_layer(dim) 117 | mlp_hidden_dim = int(dim * mlp_ratio) 118 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) # 768 -> 3072 119 | 120 | if init_values > 0: 121 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) 122 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) 123 | else: 124 | self.gamma_1, self.gamma_2 = None, None 125 | 126 | def forward(self, x): 127 | if self.gamma_1 is None: 128 | x = x + self.drop_path(self.attn(self.norm1(x))) # (bz, 1568, 768) -> (bz, 1568, 768) 129 | x = x + self.drop_path(self.mlp(self.norm2(x))) # (bz, 1568, 768) -> (bz, 1568, 3072) -> (bz, 1568, 768) 130 | else: 131 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) 132 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) 133 | return x 134 | 135 | 136 | class PatchEmbed(nn.Module): 137 | """ Image to Patch Embedding 138 | """ 139 | 140 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2): 141 | super().__init__() 142 | img_size = to_2tuple(img_size) 143 | patch_size = to_2tuple(patch_size) 144 | self.tubelet_size = int(tubelet_size) 145 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * ( 146 | num_frames // self.tubelet_size) 147 | self.img_size = img_size 148 | self.patch_size = patch_size 149 | self.num_patches = num_patches 150 | self.proj = nn.Conv3d(in_channels=in_chans, out_channels=embed_dim, 151 | kernel_size=(self.tubelet_size, patch_size[0], patch_size[1]), 152 | stride=(self.tubelet_size, patch_size[0], patch_size[1])) 153 | 154 | def forward(self, x, **kwargs): 155 | B, C, T, H, W = x.shape 156 | # FIXME look at relaxing size constraints 157 | assert H == self.img_size[0] and W == self.img_size[1], \ 158 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 159 | x = self.proj(x).flatten(2).transpose(1, 160 | 2) # (bz, 3, clip_len, 224, 224) -> (bz, 768, 8, 14, 14) -> (bz, 768, 1568) -> (bz, 1568, 768) 161 | return x 162 | 163 | 164 | # sin-cos position encoding 165 | # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31 166 | def get_sinusoid_encoding_table(n_position, d_hid): 167 | ''' Sinusoid position encoding table ''' 168 | 169 | # TODO: make it with torch instead of numpy 170 | def get_position_angle_vec(position): 171 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 172 | 173 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 174 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 175 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 176 | 177 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) 178 | 179 | 180 | class VisionTransformer(nn.Module): 181 | """ Vision Transformer with support for patch or hybrid CNN input stage 182 | """ 183 | 184 | def __init__(self, 185 | img_size=224, 186 | patch_size=16, 187 | in_chans=3, 188 | num_classes=1000, 189 | embed_dim=768, 190 | depth=12, # number of blocks 191 | num_heads=12, 192 | mlp_ratio=4., 193 | qkv_bias=False, 194 | qk_scale=None, 195 | drop_rate=0., 196 | attn_drop_rate=0., 197 | drop_path_rate=0., 198 | norm_layer=nn.LayerNorm, 199 | init_values=0., 200 | use_learnable_pos_emb=False, 201 | init_scale=0., 202 | all_frames=16, 203 | tubelet_size=2, 204 | use_mean_pooling=True): 205 | super().__init__() 206 | self.num_classes = num_classes 207 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 208 | self.tubelet_size = tubelet_size 209 | self.patch_embed = PatchEmbed( 210 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, num_frames=all_frames, 211 | tubelet_size=self.tubelet_size) 212 | num_patches = self.patch_embed.num_patches # num_patches 1568 213 | 214 | if use_learnable_pos_emb: 215 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 216 | else: 217 | # sine-cosine positional embeddings is on the way 218 | self.pos_embed = get_sinusoid_encoding_table(num_patches, 219 | embed_dim) # num_patches 1568, embed_dim 768 -> pos_embed (1, 1568, 768) 220 | 221 | self.pos_drop = nn.Dropout(p=drop_rate) 222 | 223 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 224 | self.blocks = nn.ModuleList([ 225 | Block( 226 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 227 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 228 | init_values=init_values) 229 | for i in range(depth)]) 230 | self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) 231 | self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None 232 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 233 | 234 | if use_learnable_pos_emb: 235 | trunc_normal_(self.pos_embed, std=.02) 236 | 237 | trunc_normal_(self.head.weight, std=.02) 238 | self.apply(self._init_weights) 239 | 240 | self.head.weight.data.mul_(init_scale) 241 | self.head.bias.data.mul_(init_scale) 242 | 243 | def _init_weights(self, m): 244 | if isinstance(m, nn.Linear): 245 | trunc_normal_(m.weight, std=.02) 246 | if isinstance(m, nn.Linear) and m.bias is not None: 247 | nn.init.constant_(m.bias, 0) 248 | elif isinstance(m, nn.LayerNorm): 249 | nn.init.constant_(m.bias, 0) 250 | nn.init.constant_(m.weight, 1.0) 251 | 252 | def get_num_layers(self): 253 | return len(self.blocks) 254 | 255 | @torch.jit.ignore 256 | def no_weight_decay(self): 257 | return {'pos_embed', 'cls_token'} 258 | 259 | def get_classifier(self): 260 | return self.head 261 | 262 | def reset_classifier(self, num_classes, global_pool=''): 263 | self.num_classes = num_classes 264 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 265 | 266 | def forward_features(self, x): 267 | x = self.patch_embed( 268 | x) ## (bz, 3, clip_len, 224, 224) ->3D Conv-> (bz, 768, 8, 14, 14) ->Flatten-> (bz, 768, 1568) -> (bz, 1568, 768) 269 | B, _, _ = x.size() 270 | 271 | if self.pos_embed is not None: 272 | x = x + self.pos_embed.expand(B, -1, -1).type_as(x).to( 273 | x.device).clone().detach() # todo add the positional embedding 274 | x = self.pos_drop(x) 275 | 276 | for blk in self.blocks: # (bz, 1568, 768) 277 | x = blk(x) 278 | 279 | x = self.norm(x) 280 | if self.fc_norm is not None: 281 | return self.fc_norm(x.mean(1)) # todo take the average along the dimension of 1568 THW 282 | else: 283 | return x[:, 0] 284 | 285 | def forward(self, x): 286 | x = self.forward_features(x) # (bz, 3, clip_len, 224, 224) -> (bz, 1568, 768) -> (bz, 768) 287 | x = self.head(x) # (bz, 768) -> (bz, n_class) 288 | return x 289 | 290 | 291 | @register_model 292 | def vit_small_patch16_224(pretrained=False, **kwargs): 293 | model = VisionTransformer( 294 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 295 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 296 | model.default_cfg = _cfg() 297 | return model 298 | 299 | 300 | @register_model 301 | def vit_base_patch16_224(pretrained=False, **kwargs): 302 | model = VisionTransformer( 303 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 304 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 305 | model.default_cfg = _cfg() 306 | return model 307 | 308 | 309 | @register_model 310 | def vit_base_patch16_384(pretrained=False, **kwargs): 311 | model = VisionTransformer( 312 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 313 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 314 | model.default_cfg = _cfg() 315 | return model 316 | 317 | 318 | @register_model 319 | def vit_large_patch16_224(pretrained=False, **kwargs): 320 | model = VisionTransformer( 321 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 322 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 323 | model.default_cfg = _cfg() 324 | return model 325 | 326 | 327 | @register_model 328 | def vit_large_patch16_384(pretrained=False, **kwargs): 329 | model = VisionTransformer( 330 | img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 331 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 332 | model.default_cfg = _cfg() 333 | return model 334 | 335 | 336 | @register_model 337 | def vit_large_patch16_512(pretrained=False, **kwargs): 338 | model = VisionTransformer( 339 | img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 340 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 341 | model.default_cfg = _cfg() 342 | return model 343 | -------------------------------------------------------------------------------- /models/videoswintransformer_models/__pycache__/i3d_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/models/videoswintransformer_models/__pycache__/i3d_head.cpython-36.pyc -------------------------------------------------------------------------------- /models/videoswintransformer_models/__pycache__/recognizer3d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/models/videoswintransformer_models/__pycache__/recognizer3d.cpython-36.pyc -------------------------------------------------------------------------------- /models/videoswintransformer_models/__pycache__/swin_transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/models/videoswintransformer_models/__pycache__/swin_transformer.cpython-36.pyc -------------------------------------------------------------------------------- /models/videoswintransformer_models/__pycache__/transforms_backup.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/models/videoswintransformer_models/__pycache__/transforms_backup.cpython-36.pyc -------------------------------------------------------------------------------- /models/videoswintransformer_models/__pycache__/video_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/models/videoswintransformer_models/__pycache__/video_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /models/videoswintransformer_models/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | # from ...core import top_k_accuracy 7 | # from ..builder import build_loss 8 | 9 | 10 | class AvgConsensus(nn.Module): 11 | """Average consensus module. 12 | 13 | Args: 14 | dim (int): Decide which dim consensus function to apply. 15 | Default: 1. 16 | """ 17 | 18 | def __init__(self, dim=1): 19 | super().__init__() 20 | self.dim = dim 21 | 22 | def forward(self, x): 23 | """Defines the computation performed at every call.""" 24 | return x.mean(dim=self.dim, keepdim=True) 25 | 26 | 27 | class BaseHead(nn.Module, metaclass=ABCMeta): 28 | """Base class for head. 29 | 30 | All Head should subclass it. 31 | All subclass should overwrite: 32 | - Methods:``init_weights``, initializing weights in some modules. 33 | - Methods:``forward``, supporting to forward both for training and testing. 34 | 35 | Args: 36 | num_classes (int): Number of classes to be classified. 37 | in_channels (int): Number of channels in input feature. 38 | loss_cls (dict): Config for building loss. 39 | Default: dict(type='CrossEntropyLoss', loss_weight=1.0). 40 | multi_class (bool): Determines whether it is a multi-class 41 | recognition task. Default: False. 42 | label_smooth_eps (float): Epsilon used in label smooth. 43 | Reference: arxiv.org/abs/1906.02629. Default: 0. 44 | """ 45 | 46 | def __init__(self, 47 | num_classes, 48 | in_channels, 49 | loss_cls=dict(type='CrossEntropyLoss', loss_weight=1.0), 50 | multi_class=False, 51 | label_smooth_eps=0.0): 52 | super().__init__() 53 | self.num_classes = num_classes 54 | self.in_channels = in_channels 55 | self.loss_cls = build_loss(loss_cls) 56 | self.multi_class = multi_class 57 | self.label_smooth_eps = label_smooth_eps 58 | 59 | @abstractmethod 60 | def init_weights(self): 61 | """Initiate the parameters either from existing checkpoint or from 62 | scratch.""" 63 | 64 | @abstractmethod 65 | def forward(self, x): 66 | """Defines the computation performed at every call.""" 67 | 68 | def loss(self, cls_score, labels, **kwargs): 69 | """Calculate the loss given output ``cls_score``, target ``labels``. 70 | 71 | Args: 72 | cls_score (torch.Tensor): The output of the model. 73 | labels (torch.Tensor): The target output of the model. 74 | 75 | Returns: 76 | dict: A dict containing field 'loss_cls'(mandatory) 77 | and 'top1_acc', 'top5_acc'(optional). 78 | """ 79 | losses = dict() 80 | if labels.shape == torch.Size([]): 81 | labels = labels.unsqueeze(0) 82 | elif labels.dim() == 1 and labels.size()[0] == self.num_classes \ 83 | and cls_score.size()[0] == 1: 84 | # Fix a bug when training with soft labels and batch size is 1. 85 | # When using soft labels, `labels` and `cls_socre` share the same 86 | # shape. 87 | labels = labels.unsqueeze(0) 88 | 89 | if not self.multi_class and cls_score.size() != labels.size(): 90 | top_k_acc = top_k_accuracy(cls_score.detach().cpu().numpy(), 91 | labels.detach().cpu().numpy(), (1, 5)) 92 | losses['top1_acc'] = torch.tensor( 93 | top_k_acc[0], device=cls_score.device) 94 | losses['top5_acc'] = torch.tensor( 95 | top_k_acc[1], device=cls_score.device) 96 | 97 | elif self.multi_class and self.label_smooth_eps != 0: 98 | labels = ((1 - self.label_smooth_eps) * labels + 99 | self.label_smooth_eps / self.num_classes) 100 | 101 | loss_cls = self.loss_cls(cls_score, labels, **kwargs) 102 | # loss_cls may be dictionary or single tensor 103 | if isinstance(loss_cls, dict): 104 | losses.update(loss_cls) 105 | else: 106 | losses['loss_cls'] = loss_cls 107 | 108 | return losses 109 | -------------------------------------------------------------------------------- /models/videoswintransformer_models/i3d_head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from mmcv.cnn import normal_init 3 | 4 | # from ..builder import HEADS 5 | # from .base import BaseHead 6 | 7 | 8 | # @HEADS.register_module() 9 | # class I3DHead(BaseHead): 10 | class I3DHead(nn.Module): 11 | """Classification head for I3D. 12 | 13 | Args: 14 | num_classes (int): Number of classes to be classified. 15 | in_channels (int): Number of channels in input feature. 16 | loss_cls (dict): Config for building loss. 17 | Default: dict(type='CrossEntropyLoss') 18 | spatial_type (str): Pooling type in spatial dimension. Default: 'avg'. 19 | dropout_ratio (float): Probability of dropout layer. Default: 0.5. 20 | init_std (float): Std value for Initiation. Default: 0.01. 21 | kwargs (dict, optional): Any keyword argument to be used to initialize 22 | the head. 23 | """ 24 | 25 | def __init__(self, 26 | num_classes, 27 | in_channels, 28 | # loss_cls=dict(type='CrossEntropyLoss'), 29 | spatial_type='avg', 30 | dropout_ratio=0.5, 31 | init_std=0.01, 32 | # **kwargs 33 | ): 34 | # super().__init__(num_classes, in_channels, loss_cls, **kwargs) 35 | super(I3DHead, self).__init__() 36 | self.num_classes = num_classes 37 | self.in_channels = in_channels 38 | self.spatial_type = spatial_type 39 | self.dropout_ratio = dropout_ratio 40 | self.init_std = init_std 41 | if self.dropout_ratio != 0: 42 | self.dropout = nn.Dropout(p=self.dropout_ratio) 43 | else: 44 | self.dropout = None 45 | self.fc_cls = nn.Linear(self.in_channels, self.num_classes) 46 | 47 | if self.spatial_type == 'avg': 48 | # use `nn.AdaptiveAvgPool3d` to adaptively match the in_channels. 49 | self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) 50 | else: 51 | self.avg_pool = None 52 | 53 | def init_weights(self): 54 | """Initiate the parameters from scratch.""" 55 | normal_init(self.fc_cls, std=self.init_std) 56 | 57 | def forward(self, x): 58 | """Defines the computation performed at every call. 59 | 60 | Args: 61 | x (torch.Tensor): The input data. 62 | 63 | Returns: 64 | torch.Tensor: The classification scores for input samples. 65 | """ 66 | # [N, in_channels, 4, 7, 7] 67 | if self.avg_pool is not None: 68 | x = self.avg_pool(x) 69 | # [N, in_channels, 1, 1, 1] 70 | if self.dropout is not None: 71 | x = self.dropout(x) 72 | # [N, in_channels, 1, 1, 1] 73 | x = x.view(x.shape[0], -1) 74 | # [N, in_channels] 75 | cls_score = self.fc_cls(x) 76 | # [N, num_classes] 77 | return cls_score 78 | -------------------------------------------------------------------------------- /models/videoswintransformer_models/recognizer3d.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | from models.videoswintransformer_models.swin_transformer import SwinTransformer3D 4 | from models.videoswintransformer_models.i3d_head import I3DHead 5 | import torch.nn.functional as F 6 | 7 | # : (2, 4, 4) 8 | # : [2, 2, 18, 2] 9 | # : [4, 8, 16, 32] 10 | # : (8, 7, 7) 11 | # todo swin base 12 | # model = dict( 13 | # type='Recognizer3D', 14 | # backbone=dict( 15 | # type='SwinTransformer3D', 16 | # patch_size=(4,4,4), 17 | # embed_dim=128, 18 | # depths=[2, 2, 18, 2], 19 | # num_heads=[4, 8, 16, 32]), 20 | # window_size=(8,7,7), 21 | # mlp_ratio=4., 22 | # qkv_bias=True, 23 | # qk_scale=None, 24 | # drop_rate=0., 25 | # attn_drop_rate=0., 26 | # drop_path_rate=0.2, 27 | # patch_norm=True), 28 | # cls_head=dict( 29 | # type='I3DHead', 30 | # in_channels=1024, 31 | # num_classes=400, 32 | # spatial_type='avg', 33 | # dropout_ratio=0.5), 34 | # test_cfg = dict(average_clips='prob')) 35 | 36 | # todo swin_base_patch244_window1677_sthv2.py 37 | # model=dict(backbone=dict(patch_size=(2,4,4), window_size=(16,7,7), drop_path_rate=0.4), 38 | # cls_head=dict(num_classes=174), 39 | # test_cfg=dict(max_testing_views=2), 40 | # train_cfg=dict(blending=dict(type='LabelSmoothing', num_classes=174, smoothing=0.1))) 41 | 42 | # todo swin_base_patch244_window877_kinetics400_1k.py 43 | # model=dict(backbone=dict(patch_size=(2,4,4), drop_path_rate=0.3), test_cfg=dict(max_testing_views=4)) 44 | 45 | class Recognizer3D(nn.Module): 46 | def __init__(self, num_classes = None, patch_size = None, window_size = None, drop_path_rate = None, ): 47 | super(Recognizer3D, self).__init__() 48 | # backbone params 49 | self.pretrained = None 50 | self.pretrained2d = True 51 | self.patch_size = patch_size 52 | self.in_chans = 3 53 | self.embed_dim = 128 54 | self.depths = [2, 2, 18, 2] 55 | self.num_heads = [4, 8, 16, 32] 56 | self.window_size = window_size 57 | self.mlp_ratio = 4.0 58 | self.qkv_bias = True 59 | self.qk_scale = None 60 | self.drop_rate = 0. 61 | self.attn_drop_rate = 0. 62 | self.drop_path_rate = drop_path_rate 63 | self.patch_norm = True 64 | 65 | # head params 66 | self.num_classes = num_classes 67 | self.in_channels = 1024 68 | self.spatial_type = 'avg' 69 | self.dropout_ratio = 0.5 70 | 71 | self.score_type = 'score' # 72 | 73 | self.backbone = SwinTransformer3D(pretrained= self.pretrained, 74 | pretrained2d= self.pretrained2d, 75 | patch_size= self.patch_size, 76 | in_chans= self.in_chans, 77 | embed_dim= self.embed_dim, 78 | depths= self.depths , 79 | num_heads= self.num_heads , 80 | window_size= self.window_size, 81 | mlp_ratio= self.mlp_ratio, 82 | qkv_bias= self.qkv_bias, 83 | qk_scale= self.qk_scale, 84 | drop_rate= self.drop_rate, 85 | attn_drop_rate= self.attn_drop_rate, 86 | drop_path_rate= self.drop_path_rate, 87 | patch_norm= self.patch_norm, 88 | ) 89 | self.cls_head = I3DHead(num_classes=self.num_classes, 90 | in_channels=self.in_channels, 91 | spatial_type=self.spatial_type, 92 | dropout_ratio=self.dropout_ratio 93 | ) 94 | 95 | def forward(self, x): # x (batch, n_views, C, T, H, W) 96 | n = x.shape[0] 97 | n_views = x.shape[1] # n_views n_spatial_crops * n_temporal clips 98 | x = x.reshape((-1,) + x.shape[2:]) # (N, n_views, C, T, H, W) -> (N * n_views, C, T, H, W) 99 | feat = self.backbone(x) # (N * n_views, C, T, H, W) -> (N * n_views, C, T/2, H/32, W/32) 100 | cls_score = self.cls_head(feat) # (N * n_views, 1024, 16, 7, 7) -> (N * n_views, 600) 101 | # cls_score = self.average_clips(cls_score, num_segs= n_views) # todo average over n_views 102 | vid_cls_score, view_cls_score = self.average_clips(cls_score, num_segs= n_views) 103 | return vid_cls_score, view_cls_score 104 | 105 | 106 | def average_clips(self, cls_score, num_segs = 1): 107 | bz = cls_score.shape[0] 108 | cls_score = cls_score.view(bz // num_segs, num_segs, -1) # (bz, n_views, n_class) 109 | if self.score_type == 'prob': 110 | cls_score = F.softmax(cls_score, dim=2).mean(dim=2) 111 | return cls_score 112 | elif self.score_type == 'score': 113 | 114 | vid_cls_score = cls_score.mean(dim=1) 115 | return vid_cls_score, cls_score 116 | 117 | 118 | -------------------------------------------------------------------------------- /models/videoswintransformer_models/video_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import os.path as osp 3 | from models.tanet_models.video_dataset import VideoRecord 4 | import torch.utils.data as data 5 | from models.videoswintransformer_models.transforms_backup import DecordInit, SampleFrames, DecordDecode, Resize, CenterCrop, Flip, Normalize, FormatShape, Collect, ToTensor, \ 6 | RandomResizedCrop 7 | # import decord 8 | class Video_SwinDataset(data.Dataset): 9 | def __init__(self, list_file, 10 | num_segments=3, # clip_length 11 | frame_interval = 2, 12 | num_clips = 1, 13 | frame_uniform = True, 14 | test_mode = False, 15 | flip_ratio = None, 16 | scale_size = None, 17 | input_size=None, 18 | img_norm_cfg = None, 19 | 20 | vid_format='.mp4', 21 | video_data_dir = None, 22 | remove_missing = False, 23 | if_sample_tta_aug_views=None, 24 | tta_view_sample_style_list=None, 25 | n_augmented_views = None, 26 | debug = False, debug_vid = 50, 27 | ): 28 | self.list_file = list_file 29 | self.num_segments = num_segments 30 | self.frame_interval = frame_interval 31 | self.num_clips = num_clips 32 | self.frame_uniform = frame_uniform 33 | self.test_mode = test_mode 34 | self.flip_ratio = flip_ratio 35 | if self.test_mode: 36 | assert self.flip_ratio == 0 37 | self.scale_size = scale_size 38 | self.input_size = input_size 39 | self.img_norm_cfg = img_norm_cfg 40 | 41 | self.vid_format = vid_format 42 | self.video_data_dir = video_data_dir 43 | self.remove_missing = remove_missing 44 | self.debug = debug 45 | self.debug_vid = debug_vid 46 | self.if_sample_tta_aug_views = if_sample_tta_aug_views 47 | self.tta_view_sample_style_list = tta_view_sample_style_list 48 | self.n_augmented_views = n_augmented_views 49 | self.__parse_list() 50 | def __parse_list(self): 51 | tmp = [x.strip().split(' ') for x in open(self.list_file)] 52 | if not self.test_mode or self.remove_missing: 53 | tmp = [item for item in tmp if int(item[1]) >= 3] 54 | self.video_list = [VideoRecord(item) for item in tmp] 55 | 56 | if self.debug: 57 | self.video_list = self.video_list[:self.debug_vid] 58 | 59 | def __getitem__(self, index): 60 | record = self.video_list[index] 61 | vid_path = osp.join(self.video_data_dir, f'{record.path}{self.vid_format}') 62 | results = { 'filename': vid_path, 'start_index':0, 'modality': 'RGB'} 63 | if self.test_mode: 64 | if self.if_sample_tta_aug_views: 65 | func_list = [DecordInit(), 66 | SampleFrames(clip_len=self.num_segments, frame_interval=self.frame_interval, # todo frame_interval is only used for dense sampling 67 | num_clips=self.num_clips, 68 | frame_uniform=self.frame_uniform, test_mode=self.test_mode, 69 | if_sample_tta_aug_views= self.if_sample_tta_aug_views, tta_view_sample_style_list= self.tta_view_sample_style_list, 70 | n_augmented_views= self.n_augmented_views 71 | ), # todo uniform sampling (instead of dense sampling) 72 | DecordDecode(), 73 | Resize( scale=(-1, self.scale_size)), # todo always resize the height to 224 74 | RandomResizedCrop(), 75 | Resize(scale= (self.input_size, self.input_size), keep_ratio= False ), 76 | # CenterCrop(crop_size=( self.input_size)), 77 | Flip(flip_ratio= self.flip_ratio), 78 | Normalize(**self.img_norm_cfg), 79 | FormatShape(input_format='NCTHW' ), # , collapse=True # todo collapse = False default, (n_clips, 3, T, H, W ) 80 | Collect(keys=['imgs'], meta_keys=[] ), 81 | ToTensor(keys=['imgs']) 82 | ] 83 | else: 84 | func_list = [DecordInit(), 85 | SampleFrames(clip_len=self.num_segments, frame_interval=self.frame_interval, 86 | # todo frame_interval is only used for dense sampling 87 | num_clips=self.num_clips, 88 | frame_uniform=self.frame_uniform, test_mode=self.test_mode, 89 | if_sample_tta_aug_views=self.if_sample_tta_aug_views, 90 | tta_view_sample_style_list=self.tta_view_sample_style_list, 91 | n_augmented_views=self.n_augmented_views 92 | ), # todo uniform sampling (instead of dense sampling) 93 | DecordDecode(), 94 | Resize(scale=(-1, self.scale_size)), # todo always resize the height to 224 95 | CenterCrop(crop_size=(self.input_size)), 96 | Flip(flip_ratio=self.flip_ratio), 97 | Normalize(**self.img_norm_cfg), 98 | FormatShape(input_format='NCTHW'), 99 | # , collapse=True # todo collapse = False default, (n_clips, 3, T, H, W ) 100 | Collect(keys=['imgs'], meta_keys=[]), 101 | ToTensor(keys=['imgs']) 102 | ] 103 | else: 104 | raise NotImplementedError('Transformation for training not implemented ') 105 | for func_ in func_list: 106 | results = func_(results) 107 | return results['imgs'], record.label 108 | def __len__(self): 109 | return len(self.video_list) 110 | 111 | def get(self, record, indices): 112 | pass -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | addict==2.4.0 3 | blessings==1.7 4 | cachetools==4.2.4 5 | certifi==2021.5.30 6 | charset-normalizer==2.0.12 7 | click==7.1.2 8 | colorama==0.4.4 9 | commonmark==0.9.1 10 | cycler==0.11.0 11 | dataclasses==0.8 12 | decord==0.6.0 13 | einops==0.4.1 14 | future==0.18.2 15 | google-auth==2.6.6 16 | google-auth-oauthlib==0.4.6 17 | gpustat==0.6.0 18 | grpcio==1.46.3 19 | idna==3.3 20 | importlib-metadata==4.8.3 21 | joblib==1.1.0 22 | kiwisolver==1.3.1 23 | Markdown==3.3.7 24 | matplotlib==3.3.4 25 | -e git+https://github.com/SwinTransformer/Video-Swin-Transformer@db018fb8896251711791386bbd2127562fd8d6a6#egg=mmaction2 26 | mmcv-full==1.3.12 27 | model-index==0.1.11 28 | numpy==1.19.5 29 | nvidia-ml-py3==7.352.0 30 | oauthlib==3.2.0 31 | opencv-contrib-python==4.5.5.64 32 | opencv-python==4.5.5.64 33 | openmim @ git+https://github.com/open-mmlab/mim.git@ffaff8a4ece8a833b20dc7c6c1a7f8647d9eec15 34 | ordered-set==4.0.2 35 | packaging==21.3 36 | pandas==1.1.5 37 | Pillow==8.4.0 38 | protobuf==3.19.4 39 | psutil==5.9.1 40 | pyasn1==0.4.8 41 | pyasn1-modules==0.2.8 42 | Pygments==2.13.0 43 | pyparsing==3.0.9 44 | python-dateutil==2.8.2 45 | pytz==2022.1 46 | PyYAML==6.0 47 | requests==2.27.1 48 | requests-oauthlib==1.3.1 49 | rich==12.5.1 50 | rsa==4.8 51 | scikit-learn==0.24.2 52 | scipy==1.5.4 53 | six==1.16.0 54 | tabulate==0.8.9 55 | tensorboard==2.9.0 56 | tensorboard-data-server==0.6.1 57 | tensorboard-plugin-wit==1.8.1 58 | tensorboardX==2.5 59 | threadpoolctl==3.1.0 60 | timm==0.6.7 61 | torch==1.7.1 62 | torchaudio==0.7.2 63 | torchvision==0.8.2 64 | typing_extensions==4.1.1 65 | urllib3==1.26.9 66 | Werkzeug==2.0.3 67 | yapf==0.32.0 68 | zipp==3.6.0 69 | -------------------------------------------------------------------------------- /sourceonly_swin_ucf101_corr.py: -------------------------------------------------------------------------------- 1 | # import os.path as osp 2 | from utils.opts import get_opts 3 | from utils.utils_ import get_env_id, get_writer_to_all_result 4 | from corpus.main_eval import eval 5 | 6 | # best_prec1 = 0 7 | 8 | corruptions = ['gauss', 'pepper', 'salt','shot', 9 | 'zoom', 'impulse', 'defocus', 'motion', 10 | 'jpeg', 'contrast', 'rain', 'h265_abr' ] 11 | 12 | if __name__ == '__main__': 13 | global args 14 | args = get_opts() 15 | args.gpus = [0] 16 | args.arch = 'videoswintransformer' 17 | args.dataset = 'ucf101' 18 | # todo ========================= To Specify ========================== 19 | args.model_path = '.../swin_base_patch244_window877_pretrain_kinetics400_30epoch_lr3e-5.pth' 20 | args.video_data_dir = '.../level_5_ucf_val_split_1' # main directory of the video data, [args.video_data_dir] + [path in file list] should be complete absolute path for a video file 21 | args.val_vid_list = '.../list_video_perturbations_ucf/{}.txt' 22 | args.result_dir = '.../{}_{}/tta_{}' 23 | # todo ========================= To Specify ========================== 24 | 25 | 26 | args.batch_size = 32 # 12 27 | args.clip_length = 16 # 32 28 | args.num_clips = 1 # number of temporal clips 29 | args.test_crops = 1 # number of spatial crops 30 | args.frame_uniform = True # todo uniform sampling (should be better than dense sampling when using only 1 clip ) 31 | args.frame_interval = 2 32 | args.scale_size = 224 33 | 34 | 35 | args.patch_size = (2,4,4) 36 | args.window_size = (8, 7, 7) 37 | 38 | args.tta = False 39 | args.tta_view_sample_style_list = None 40 | args.evaluate_baselines = not args.tta 41 | args.baseline = 'source' 42 | 43 | for corr_id, args.corruptions in enumerate(corruptions): 44 | print(f'####Starting Evaluation for ::: {args.corruptions} corruption####') 45 | args.val_vid_list = args.val_vid_list.format(args.corruptions) 46 | args.result_dir = args.result_dir.format( args.arch, args.dataset, args.corruptions ) 47 | epoch_result_list = eval(args=args, ) 48 | if corr_id == 0: 49 | f_write = get_writer_to_all_result(args) 50 | f_write.write(' '.join([str(round(float(xx), 3)) for xx in epoch_result_list]) + '\n') 51 | 52 | f_write.flush() 53 | if corr_id == len(corruptions) - 1: 54 | f_write.close() 55 | -------------------------------------------------------------------------------- /sourceonly_tanet_ucf101_corr.py: -------------------------------------------------------------------------------- 1 | # import os.path as osp 2 | # from utils.opts import parser 3 | from utils.opts import get_opts 4 | from utils.utils_ import get_writer_to_all_result 5 | from corpus.main_eval import eval 6 | 7 | # best_prec1 = 0 8 | 9 | corruptions = ['gauss', 'pepper', 'salt','shot', 10 | 'zoom', 'impulse', 'defocus', 'motion', 11 | 'jpeg', 'contrast', 'rain', 'h265_abr' ] 12 | 13 | if __name__ == '__main__': 14 | global args 15 | args = get_opts() 16 | args.gpus = [0] 17 | args.arch = 'tanet' 18 | args.dataset = 'ucf101' 19 | # todo ========================= To Specify ========================== 20 | args.model_path = '.../tanet_ucf.pth.tar' 21 | args.video_data_dir = '.../level_5_ucf_val_split_1' # main directory of the video data, [args.video_data_dir] + [path in file list] should be complete absolute path for a video file 22 | args.val_vid_list = '.../list_video_perturbations_ucf/{}.txt' 23 | args.result_dir = '.../{}_{}/eval_{}' 24 | # todo ========================= To Specify ========================== 25 | 26 | args.batch_size = 32 # 12 27 | args.clip_length = 16 # 32 28 | args.sample_style = 'uniform-1' # number of temporal clips 29 | args.test_crops = 1 # number of spatial crops 30 | 31 | args.tta = False 32 | args.evaluate_baselines = not args.tta 33 | args.baseline = 'source' 34 | 35 | for corr_id, args.corruptions in enumerate(corruptions): 36 | print(f'####Starting Evaluation for ::: {args.corruptions} corruption####') 37 | args.val_vid_list = args.val_vid_list.format(args.corruptions) 38 | args.result_dir = args.result_dir.format( args.arch, args.dataset, args.corruptions ) 39 | 40 | epoch_result_list = eval(args=args, ) 41 | 42 | if corr_id == 0: 43 | f_write = get_writer_to_all_result(args) 44 | f_write.write(' '.join([str(round(float(xx), 3)) for xx in epoch_result_list]) + '\n') 45 | 46 | f_write.flush() 47 | if corr_id == len(corruptions) - 1: 48 | f_write.close() 49 | -------------------------------------------------------------------------------- /tta_swin_ucf101.py: -------------------------------------------------------------------------------- 1 | # import os.path as osp 2 | from utils.opts import get_opts 3 | from utils.utils_ import get_writer_to_all_result 4 | from corpus.main_eval import eval 5 | 6 | corruptions = ['gauss_shuffled', 'pepper_shuffled', 'salt_shuffled', 'shot_shuffled', 7 | 'zoom_shuffled', 'impulse_shuffled', 'defocus_shuffled', 'motion_shuffled', 8 | 'jpeg_shuffled', 'contrast_shuffled', 'rain_shuffled', 'h265_abr_shuffled', ] 9 | 10 | if __name__ == '__main__': 11 | global args 12 | args = get_opts() 13 | args.gpus = [0] 14 | args.arch = 'videoswintransformer' 15 | args.dataset = 'ucf101' 16 | # todo ========================= To Specify ========================== 17 | args.model_path = '.../swin_base_patch244_window877_pretrain_kinetics400_30epoch_lr3e-5.pth' 18 | args.video_data_dir = '.../level_5_ucf_val_split_1' # main directory of the video data, [args.video_data_dir] + [path in file list] should be complete absolute path for a video file 19 | args.spatiotemp_mean_clean_file = '.../source_statistics_tanet_ucf/list_spatiotemp_mean_20221004_192722.npy' 20 | args.spatiotemp_var_clean_file = '.../source_statistics_tanet_ucf/list_spatiotemp_var_20221004_192722.npy' 21 | args.val_vid_list = '.../list_video_perturbations_ucf/{}.txt' 22 | args.result_dir = '.../{}_{}/tta_{}' 23 | # todo ========================= To Specify ========================== 24 | 25 | 26 | 27 | args.clip_length = 16 28 | args.num_clips = 1 # number of temporal clips 29 | args.test_crops = 1 # number of spatial crops 30 | args.frame_uniform = True 31 | args.frame_interval = 2 32 | args.scale_size = 224 # different than TANet 33 | 34 | args.patch_size = (2,4,4) 35 | args.window_size = (8, 7, 7) 36 | 37 | args.lr = 0.00001 38 | args.lambda_pred_consis = 0.05 39 | args.momentum_mvg = 0.05 40 | args.chosen_blocks = ['module.backbone.layers.2', 'module.backbone.layers.3', 'module.backbone.norm'] 41 | 42 | 43 | for corr_id, args.corruptions in enumerate(corruptions): 44 | print(f'####Starting Evaluation for ::: {args.corruptions} corruption####') 45 | args.val_vid_list = args.val_vid_list.format(args.corruptions) 46 | args.result_dir = args.result_dir.format( args.arch, args.dataset, args.corruptions ) 47 | 48 | epoch_result_list, _ = eval(args=args, ) 49 | if corr_id == 0: 50 | f_write = get_writer_to_all_result(args) 51 | f_write.write(' '.join([str(round(float(xx), 3)) for xx in epoch_result_list]) + '\n') 52 | 53 | f_write.flush() 54 | if corr_id == len(corruptions) - 1: 55 | f_write.close() 56 | -------------------------------------------------------------------------------- /tta_tanet_ucf101.py: -------------------------------------------------------------------------------- 1 | # import os.path as osp 2 | # from utils.opts import parser 3 | # import argparse 4 | from utils.opts import get_opts 5 | from utils.utils_ import get_writer_to_all_result 6 | from corpus.main_eval import eval 7 | 8 | 9 | corruptions = ['gauss_shuffled', 'pepper_shuffled', 'salt_shuffled', 'shot_shuffled', 10 | 'zoom_shuffled', 'impulse_shuffled', 'defocus_shuffled', 'motion_shuffled', 11 | 'jpeg_shuffled', 'contrast_shuffled', 'rain_shuffled', 'h265_abr_shuffled', ] 12 | 13 | if __name__ == '__main__': 14 | global args 15 | args = get_opts() 16 | args.gpus = [0] 17 | args.arch = 'tanet' 18 | args.dataset = 'ucf101' 19 | # todo ========================= To Specify ========================== 20 | args.model_path = '.../tanet_ucf.pth.tar' 21 | args.video_data_dir = '.../level_5_ucf_val_split_1' # main directory of the video data, [args.video_data_dir] + [path in file list] should be complete absolute path for a video file 22 | args.spatiotemp_mean_clean_file = '.../source_statistics_tanet_swin/list_spatiotemp_mean_20220908_235138.npy' 23 | args.spatiotemp_var_clean_file = '.../source_statistics_tanet_swin/list_spatiotemp_var_20220908_235138.npy' 24 | args.val_vid_list = '.../list_video_perturbations_ucf/{}.txt' 25 | args.result_dir = '.../{}_{}/tta_{}' 26 | # todo ========================= To Specify ========================== 27 | 28 | 29 | 30 | 31 | for corr_id, args.corruptions in enumerate(corruptions): 32 | print(f'####Starting Evaluation for ::: {args.corruptions} corruption####') 33 | args.val_vid_list = args.val_vid_list.format(args.corruptions) 34 | args.result_dir = args.result_dir.format( args.arch, args.dataset, args.corruptions ) 35 | 36 | 37 | epoch_result_list, _ = eval(args=args, ) 38 | 39 | if corr_id == 0: 40 | f_write = get_writer_to_all_result(args) 41 | f_write.write(' '.join([str(round(float(xx), 3)) for xx in epoch_result_list]) + '\n') 42 | 43 | f_write.flush() 44 | if corr_id == len(corruptions) - 1: 45 | f_write.close() 46 | -------------------------------------------------------------------------------- /utils/BNS_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils.utils_ import AverageMeter, AverageMeterTensor 4 | from utils.norm_stats_utils import compute_regularization 5 | 6 | l1_loss = nn.L1Loss(reduction='mean') 7 | 8 | def compute_kld(mean_true, mean_pred, var_true, var_pred): 9 | # mean1 and std1 are for true distribution 10 | # mean2 and std2 are for pred distribution 11 | # kld_mv = torch.log(std_pred / std_true) + (std_true ** 2 + (mean_true - mean_pred) ** 2) / (2 * std_pred ** 2) - 0.5 12 | 13 | kld_mv = 0.5 * torch.log(torch.div(var_pred, var_true)) + (var_true + (mean_true - mean_pred) ** 2) / \ 14 | (2 * var_pred) - 0.5 15 | kld_mv = torch.sum(kld_mv) 16 | return kld_mv 17 | 18 | 19 | class BNFeatureHook(): 20 | def __init__(self, module, reg_type='l2norm', running_manner = False, use_src_stat_in_reg = True, momentum = 0.1): 21 | 22 | self.hook = module.register_forward_hook(self.hook_fn) # register a hook func to a module 23 | self.reg_type = reg_type 24 | self.running_manner = running_manner 25 | self.use_src_stat_in_reg = use_src_stat_in_reg # whether to use the source statistics in regularization loss 26 | # todo keep the initial module.running_xx.data (the statistics of source model) 27 | # if BN layer is not set to eval, these statistics will change 28 | if self.use_src_stat_in_reg: 29 | self.source_mean = module.running_mean.data 30 | self.source_var = module.running_var.data 31 | if self.running_manner: 32 | # initialize the statistics of computation in running manner 33 | self.mean = torch.zeros_like( module.running_mean) 34 | self.var = torch.zeros_like(module.running_var) 35 | self.momentum = momentum 36 | 37 | def hook_fn(self, module, input, output): # input in shape (B, C, T, H, W) 38 | 39 | nch = input[0].shape[1] 40 | if isinstance(module, nn.BatchNorm1d): 41 | # input in shape (B, C) or (B, C, T) 42 | if len(input[0].shape) == 2: # todo BatchNorm1d in TAM G branch input is (N*C, T ) 43 | batch_mean = input[0].mean([0,]) 44 | batch_var = input[0].permute(1, 0,).contiguous().view([nch, -1]).var(1, unbiased=False) # compute the variance along each channel 45 | elif len(input[0].shape) == 3: # todo BatchNorm1d in TAM L branch input is (N, C, T) 46 | batch_mean = input[0].mean([0,2]) 47 | batch_var = input[0].permute(1, 0, 2).contiguous().view([nch, -1]).var(1, unbiased=False) # compute the variance along each channel 48 | elif isinstance(module, nn.BatchNorm2d): 49 | # input in shape (B, C, H, W) 50 | batch_mean = input[0].mean([0, 2, 3]) 51 | batch_var = input[0].permute(1, 0, 2, 3).contiguous().view([nch, -1]).var(1, unbiased=False) # compute the variance along each channel 52 | elif isinstance(module, nn.BatchNorm3d): 53 | # input in shape (B, C, T, H, W) 54 | batch_mean = input[0].mean([0, 2, 3, 4]) 55 | batch_var = input[0].permute(1, 0, 2, 3, 4).contiguous().view([nch, -1]).var(1, unbiased=False) # compute the variance along each channel 56 | 57 | self.mean = self.momentum * batch_mean + (1.0 - self.momentum) * self.mean.detach() if self.running_manner else batch_mean 58 | self.var = self.momentum * batch_var + (1.0 - self.momentum) * self.var.detach() if self.running_manner else batch_var 59 | # todo if BN layer is set to eval, these two are the same; otherwise, module.running_xx.data keeps changing 60 | self.mean_true = self.source_mean if self.use_src_stat_in_reg else module.running_mean.data 61 | self.var_true = self.source_var if self.use_src_stat_in_reg else module.running_var.data 62 | self.r_feature = compute_regularization(mean_true = self.mean_true, mean_pred = self.mean, var_true=self.var_true, var_pred = self.var, reg_type = self.reg_type) 63 | 64 | 65 | # if self.reg_type == 'l2norm': 66 | # self.r_feature = torch.norm(self.var_true - self.var, 2) + torch.norm(self.mean_true - self.mean,2) 67 | # if self.reg_type == 'l1_loss': 68 | # self.r_feature = torch.norm(self.var_true - self.var, 1) + torch.norm(self.mean_true - self.mean, 1) 69 | # elif self.reg_type == 'kld': 70 | # self.r_feature = compute_kld(mean_true=self.mean_true, mean_pred= self.mean, 71 | # var_true= self.var_true, var_pred= self.var) 72 | 73 | def add_hook_back(self, module): 74 | self.hook = module.register_forward_hook(self.hook_fn) # register a hook func to a module 75 | 76 | def close(self): 77 | self.hook.remove() 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | class TempStatsRegHook(): 86 | def __init__(self, module, clip_len = None, temp_stats_clean_tuple = None, reg_type='l2norm', ): 87 | 88 | self.hook = module.register_forward_hook(self.hook_fn) # register a hook func to a module 89 | self.clip_len = clip_len 90 | # self.temp_mean_clean, self.temp_var_clean = temp_stats_clean_tuple 91 | 92 | self.reg_type = reg_type 93 | # self.running_manner = running_manner 94 | # self.use_src_stat_in_reg = use_src_stat_in_reg # whether to use the source statistics in regularization loss 95 | # todo keep the initial module.running_xx.data (the statistics of source model) 96 | # if BN layer is not set to eval, these statistics will change 97 | # if self.use_src_stat_in_reg: 98 | # self.source_mean = module.running_mean.data 99 | # self.source_var = module.running_var.data 100 | self.source_mean, self.source_var = temp_stats_clean_tuple 101 | 102 | self.source_mean = torch.tensor(self.source_mean).cuda() 103 | self.source_var = torch.tensor(self.source_var).cuda() 104 | 105 | # self.source_mean = self.source_mean.mean((1,2)) 106 | # self.source_var = self.source_var.mean((1,2 )) 107 | 108 | # if self.running_manner: 109 | # # initialize the statistics of computation in running manner 110 | # self.mean = torch.zeros_like( self.source_mean) 111 | # self.var = torch.zeros_like( self.source_var) 112 | 113 | self.mean_avgmeter = AverageMeterTensor() 114 | self.var_avgmeter = AverageMeterTensor() 115 | 116 | # self.momentum = momentum 117 | 118 | def hook_fn(self, module, input, output): 119 | 120 | if isinstance(module, nn.Conv1d) or isinstance(module, nn.Linear): 121 | # output is in shape (N, C, T) or (N*C, T ) 122 | raise NotImplementedError('Temporal statistics computation for nn.Conv1d not implemented!') 123 | elif isinstance(module, nn.Conv2d): 124 | # output is in shape (N*T, C, H, W) 125 | nt, c, h, w = output.size() 126 | t = self.clip_len 127 | bz = nt // t 128 | 129 | output = output.view(bz, t, c, h, w).permute(0, 2, 1, 3, 4).contiguous() # # ( N*T, C, H, W) -> (N, C, T, H, W) 130 | elif isinstance(module, nn.Conv3d): 131 | # output is in shape (N, C, T, H, W) 132 | bz, c, t, h, w = output.size() 133 | output = output 134 | else: 135 | raise Exception(f'undefined module {module}') 136 | # spatial_dim = h * w 137 | # todo compute the statistics only along the temporal dimension T, then take the average for all samples N 138 | # the statistics are in shape (C, H, W), 139 | batch_mean = output.mean(2).mean(0) # (N, C, T, H, W) -> (N, C, H, W) -> (C, H, W) 140 | # temp_var = new_output.permute(1, 3, 4, 0, 2).contiguous().view([c, t, -1]).var(2, unbiased = False ) 141 | batch_var = output.permute(0, 1, 3, 4, 2).contiguous().var(-1, unbiased=False).mean(0) # (N, C, T, H, W) -> # (N, C, H, W, T) -> (N, C, H, W) -> (C, H, W) 142 | 143 | # batch_mean = output.mean(2).mean((0, 2,3)) # (N, C, T, H, W) -> (N, C, H, W) -> (C,) 144 | # batch_var = output.permute(0, 1, 3, 4, 2).contiguous().var(-1, unbiased=False).mean((0, 2,3)) # (N, C, T, H, W) -> # (N, C, H, W, T) -> (N, C, H, W) -> (C,) 145 | 146 | 147 | self.mean_avgmeter.update(batch_mean, n= bz) 148 | self.var_avgmeter.update(batch_var, n= bz) 149 | 150 | if self.reg_type == 'l2norm': 151 | # # todo sum of squared difference, averaged over h * w 152 | # self.r_feature = torch.sum(( self.source_var - self.var_avgmeter.avg )**2 ) / spatial_dim + torch.sum(( self.source_mean - self.mean_avgmeter.avg )**2 ) / spatial_dim 153 | self.r_feature = torch.norm(self.source_var - self.var_avgmeter.avg, 2) + torch.norm(self.source_mean - self.mean_avgmeter.avg, 2) 154 | else: 155 | raise NotImplementedError 156 | 157 | def close(self): 158 | self.hook.remove() 159 | 160 | 161 | 162 | 163 | class ComputeSpatioTemporalStatisticsHook(): 164 | def __init__(self, module, clip_len = None,): 165 | 166 | self.hook = module.register_forward_hook(self.hook_fn) # register a hook func to a module 167 | self.clip_len = clip_len 168 | 169 | def hook_fn(self, module, input, output): 170 | 171 | if isinstance(module, nn.Conv1d) or isinstance(module, nn.Linear): 172 | # output is in shape (N, C, T) or (N*C, T ) 173 | raise NotImplementedError('Temporal statistics computation for nn.Conv1d not implemented!') 174 | elif isinstance(module, nn.Conv2d): 175 | # output is in shape (N*T, C, H, W) 176 | nt, c, h, w = output.size() 177 | t = self.clip_len 178 | bz = nt // t 179 | output = output.view(bz, t, c, h, w).permute(0, 2, 1, 3, 4).contiguous() # # ( N*T, C, H, W) -> (N, C, T, H, W) 180 | elif isinstance(module, nn.Conv3d): 181 | # output is in shape (N, C, T, H, W) 182 | bz, c, t, h, w = output.size() 183 | output = output 184 | else: 185 | raise Exception(f'undefined module {module}') 186 | 187 | # todo compute the statistics only along the temporal dimension T, then take the average for all samples N 188 | # the statistics are in shape (C, H, W), 189 | self.temp_mean = output.mean((0, 2,3,4)).mean(0) # (N, C, T, H, W) -> (C, ) 190 | self.temp_var = output.permute(1, 0, 2, 3, 4).contiguous().view([c, -1]).var(1, unbiased=False) # (N, C, T, H, W) -> (C, N, T, H, W) -> (C, ) 191 | 192 | # batch_mean = input[0].mean([0, 2, 3, 4]) 193 | # batch_var = input[0].permute(1, 0, 2, 3, 4).contiguous().view([nch, -1]).var(1, unbiased=False) # compute the variance along each channel 194 | 195 | self.temp_mean = output.mean(2).mean(0) # (N, C, T, H, W) -> (N, C, H, W) -> (C, H, W) 196 | # temp_var = new_output.permute(1, 3, 4, 0, 2).contiguous().view([c, t, -1]).var(2, unbiased = False ) 197 | self.temp_var = output.permute(0, 1, 3, 4, 2).contiguous().var(-1, unbiased=False).mean(0) # (N, C, T, H, W) -> # (N, C, H, W, T) -> (N, C, H, W) -> (C, H, W) 198 | 199 | # self.temp_mean = output.mean(2).mean((0, 2, 3)) # (N, C, T, H, W) -> (N, C, H, W) -> (C,) 200 | # self.temp_var = output.permute(0, 1, 3, 4, 2).contiguous().var(-1, unbiased=False).mean((0, 2, 3) ) # (N, C, T, H, W) -> # (N, C, H, W, T) -> (N, C, H, W) -> (C,) 201 | 202 | 203 | def close(self): 204 | self.hook.remove() 205 | 206 | 207 | class ComputeTemporalStatisticsHook(): 208 | def __init__(self, module, clip_len = None,): 209 | 210 | self.hook = module.register_forward_hook(self.hook_fn) # register a hook func to a module 211 | self.clip_len = clip_len 212 | 213 | def hook_fn(self, module, input, output): 214 | 215 | if isinstance(module, nn.Conv1d) or isinstance(module, nn.Linear): 216 | # output is in shape (N, C, T) or (N*C, T ) 217 | raise NotImplementedError('Temporal statistics computation for nn.Conv1d not implemented!') 218 | elif isinstance(module, nn.Conv2d): 219 | # output is in shape (N*T, C, H, W) 220 | nt, c, h, w = output.size() 221 | t = self.clip_len 222 | bz = nt // t 223 | output = output.view(bz, t, c, h, w).permute(0, 2, 1, 3, 4).contiguous() # # ( N*T, C, H, W) -> (N, C, T, H, W) 224 | elif isinstance(module, nn.Conv3d): 225 | # output is in shape (N, C, T, H, W) 226 | bz, c, t, h, w = output.size() 227 | output = output 228 | else: 229 | raise Exception(f'undefined module {module}') 230 | 231 | # todo compute the statistics only along the temporal dimension T, then take the average for all samples N 232 | # the statistics are in shape (C, H, W), 233 | self.temp_mean = output.mean(2).mean(0) # (N, C, T, H, W) -> (N, C, H, W) -> (C, H, W) 234 | # temp_var = new_output.permute(1, 3, 4, 0, 2).contiguous().view([c, t, -1]).var(2, unbiased = False ) 235 | self.temp_var = output.permute(0, 1, 3, 4, 2).contiguous().var(-1, unbiased=False).mean(0) # (N, C, T, H, W) -> # (N, C, H, W, T) -> (N, C, H, W) -> (C, H, W) 236 | 237 | # self.temp_mean = output.mean(2).mean((0, 2, 3)) # (N, C, T, H, W) -> (N, C, H, W) -> (C,) 238 | # self.temp_var = output.permute(0, 1, 3, 4, 2).contiguous().var(-1, unbiased=False).mean((0, 2, 3) ) # (N, C, T, H, W) -> # (N, C, H, W, T) -> (N, C, H, W) -> (C,) 239 | 240 | 241 | def close(self): 242 | self.hook.remove() 243 | 244 | 245 | def choose_layers(model, candidate_layers): 246 | 247 | chosen_layers = [] 248 | # choose all the BN layers 249 | # candidate_layers = [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d ] 250 | counter = [0] * len(candidate_layers) 251 | # for m in model.modules(): 252 | for nm, m in model.named_modules(): 253 | for candidate_idx, candidate in enumerate(candidate_layers): 254 | if isinstance(m, candidate): 255 | counter[candidate_idx] += 1 256 | chosen_layers.append((nm, m)) 257 | # for idx in range(len(candidate_layers)): 258 | # print(f'Number of {candidate_layers[idx]} : {counter[idx]}') 259 | return chosen_layers 260 | 261 | 262 | def freeze_except_bn(model, bn_condidiate_layers, ): 263 | """ 264 | freeze the model, except the BN layers 265 | :param model: 266 | :param bn_condidiate_layers: 267 | :return: 268 | """ 269 | 270 | model.train() # 271 | model.requires_grad_(False) 272 | for m in model.modules(): 273 | for candidate in bn_condidiate_layers: 274 | if isinstance(m, candidate): 275 | m.requires_grad_(True) 276 | return model 277 | 278 | def collect_bn_params(model, bn_candidate_layers): 279 | params = [] 280 | names = [] 281 | for nm, m in model.named_modules(): 282 | for candidate in bn_candidate_layers: 283 | if isinstance(m, candidate): 284 | for np, p in m.named_parameters(): 285 | if np in ['weight', 'bias']: # weight is scale gamma, bias is shift beta 286 | params.append(p) 287 | names.append( f"{nm}.{np}") 288 | return params, names -------------------------------------------------------------------------------- /utils/__pycache__/BNS_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/utils/__pycache__/BNS_utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/norm_stats_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/utils/__pycache__/norm_stats_utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/opts.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/utils/__pycache__/opts.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/pred_consistency_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/utils/__pycache__/pred_consistency_utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/utils/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlin-at/ViTTA/c8e01fa63f8a821a2ebdf1f1272872a867b78cdb/utils/__pycache__/utils_.cpython-36.pyc -------------------------------------------------------------------------------- /utils/opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | # for TANet 4 | input_mean = [0.485, 0.456, 0.406] 5 | input_std = [0.229, 0.224, 0.225] 6 | 7 | # for Video Swin Transformer 8 | img_norm_cfg = dict( 9 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) 10 | 11 | parser = argparse.ArgumentParser(description="ViTTA") 12 | 13 | # ========================= Data Configs ========================== 14 | # parser.add_argument('--data_dir', default='/media/data_8T', type=str, help='main data directory') 15 | parser.add_argument('--dataset', type=str, default='ucf101', choices=['ucf101', 'somethingv2', 'kinetics']) 16 | parser.add_argument('--modality', type=str, default='RGB') 17 | parser.add_argument('--root_path', default='None', type=str) 18 | parser.add_argument('--video_data_dir', 19 | default='/home/ivanl/data/UCF-HMDB/video_pertubations/UCF101/level_5_ucf_val_split_1', type=str, 20 | help='directory of the corrupted videos') # to specify 21 | parser.add_argument('--vid_format', default='', type=str, 22 | help='if video format is not given in the filenames in the list, the video format can be specified here') 23 | parser.add_argument('--datatype', default='vid', type=str, choices=['vid', 'frame']) 24 | 25 | 26 | parser.add_argument('--spatiotemp_mean_clean_file', type=str, 27 | default='/home/ivanl/data/UCF-HMDB/UCF-HMDB_all/corruptions_results/source/tanet_ucf101/compute_norm_spatiotempstats_clean_train_bn2d/list_spatiotemp_mean_20220908_235138.npy', 28 | help='spatiotemporal statistics - mean') # to specify 29 | parser.add_argument('--spatiotemp_var_clean_file', type=str, 30 | default='/home/ivanl/data/UCF-HMDB/UCF-HMDB_all/corruptions_results/source/tanet_ucf101/compute_norm_spatiotempstats_clean_train_bn2d/list_spatiotemp_var_20220908_235138.npy', 31 | help='spatiotemporal statistics - variance') # to specify 32 | 33 | parser.add_argument('--val_vid_list', type=str, 34 | default='/home/ivanl/data/UCF-HMDB/video_pertubations/UCF101/list_video_perturbations/{}.txt', 35 | help='list of corrupted videos to adapt to, list is named after the corruption type name') # to specify 36 | 37 | parser.add_argument('--result_dir', type=str, 38 | default='/home/ivanl/data/UCF-HMDB/UCF-HMDB_all/corruptions_results/source/{}_{}/tta_{}', 39 | help='result directory') # to specify 40 | 41 | 42 | # ========================= Model Configs ========================== 43 | parser.add_argument('--arch', type=str, default='tanet', choices=['tanet', 'videoswintransformer'], 44 | help='network architecture') 45 | parser.add_argument('--model_path', type=str, 46 | default='/home/ivanl/data/DeepInversion_results/train_models/models/UCF/tanet/20220815_122340_ckpt.pth.tar') # to specify 47 | parser.add_argument('--img_feature_dim', type=int, default=256, help='dimension of image feature on ResNet50') 48 | parser.add_argument('--partial_bn', action='store_true', ) 49 | 50 | # ========================= Model Configs for Video Swin Transformer ========================== 51 | parser.add_argument('--num_clips', type=int, default=1, help='number of temporal clips') 52 | parser.add_argument('--frame_uniform', type=bool, default=True, help='whether uniform sampling or dense sampling') # uniform sampling is better than dense sampling when using only 1 clip 53 | parser.add_argument('--frame_interval', type=int, default=2) 54 | parser.add_argument('--flip_ratio', type=int, default=0) 55 | parser.add_argument('--img_norm_cfg', default=img_norm_cfg) 56 | parser.add_argument('--patch_size', default=(2,4,4)) 57 | parser.add_argument('--window_size', default=(8, 7, 7)) 58 | parser.add_argument('--drop_path_rate', default=0.2) 59 | 60 | 61 | # ========================= Runtime Configs ========================== 62 | parser.add_argument('--gpus', nargs='+', type=int, default=None) 63 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 64 | help='number of data loading workers (default: 4)') 65 | parser.add_argument('--norm', action='store_true') 66 | parser.add_argument('--debug', action='store_true', help='if debug, loading only the first 50 videos in the list') 67 | parser.add_argument('--verbose', type=bool, default=True, help='more details in the logging file') 68 | parser.add_argument('--print-freq', '-p', default=20, type=int, metavar='N', help='print frequency (default: 10)') 69 | 70 | 71 | # ========================= Learning Configs ========================== 72 | parser.add_argument('--tta', type=bool, default=True, help='perform test-time adaptation') 73 | parser.add_argument('--use_src_stat_in_reg', type=bool, default=True, help='whether to use source statistics in the regularization loss') 74 | parser.add_argument('--fix_BNS', type=bool, default=True, help='whether fix the BNS of target model during forward pass') 75 | parser.add_argument('--running_manner', type=bool, default=True, help='whether to manually compute the target statistics in running manner') 76 | parser.add_argument('--momentum_bns', type=float, default=0.1) 77 | parser.add_argument('--update_only_bn_affine', action='store_true') 78 | parser.add_argument('--compute_stat', action='store_true') 79 | parser.add_argument('--momentum_mvg', type=float, default=0.1) 80 | parser.add_argument('--stat_reg', type=str, default='mean_var', help='statistics regularization') 81 | parser.add_argument('--if_tta_standard', type=str, default='tta_online') 82 | parser.add_argument('--loss_type', type=str, default="nll", choices=['nll']) 83 | 84 | parser.add_argument('--if_sample_tta_aug_views', type=bool, default=True) 85 | parser.add_argument('--if_spatial_rand_cropping', type=bool, default=True) 86 | parser.add_argument('--if_pred_consistency', type=bool, default=True) 87 | parser.add_argument('--lambda_pred_consis', type=float, default=0.1) 88 | parser.add_argument('--lambda_feature_reg', type=int, default=1) 89 | parser.add_argument('--n_augmented_views', type=int, default=2) 90 | parser.add_argument('--tta_view_sample_style_list', default=['uniform_equidist']) 91 | parser.add_argument('--stat_type', default=['spatiotemp']) 92 | parser.add_argument('--before_norm', action='store_true') 93 | parser.add_argument('--reduce_dim', type=bool, default=True) 94 | parser.add_argument('--reg_type', type=str, default='l1_loss') 95 | 96 | parser.add_argument('--chosen_blocks', default=['layer3', 'layer4'] ) 97 | parser.add_argument('--moving_avg', type=bool, default=True ) 98 | 99 | parser.add_argument('--n_gradient_steps', type=int, default=1, help='number of gradient steps per sample') 100 | 101 | 102 | 103 | 104 | 105 | parser.add_argument('--full_res', action='store_true') 106 | parser.add_argument('--input_size', type=int, default=224) 107 | parser.add_argument('--scale_size', type=int, default=256) 108 | parser.add_argument('--batch_size', type=int, default=1) 109 | parser.add_argument('--clip_length', type=int, default=16) 110 | parser.add_argument('--sample_style', type=str, default='uniform-1', 111 | help="either 'dense-xx' (dense sampling, sample from 64 consecutive frames) or 'uniform-xx' (uniform sampling, TSN style), last number is the number of temporal clips") 112 | parser.add_argument('--test_crops', type=int, default=1, help="number of spatial crops") 113 | parser.add_argument('--use_pretrained', action='store_true', 114 | help='whether to use pretrained model for training, set to False during evaluation') 115 | parser.add_argument('--input_mean', default=input_mean) 116 | parser.add_argument('--input_std', default=input_std) 117 | 118 | parser.add_argument('--lr', default=0.00005) 119 | parser.add_argument('--n_epoch_adapat', default=1) 120 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',help='momentum') 121 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, metavar='W', help='weight decay (default: 5e-4)') 122 | 123 | 124 | 125 | 126 | def get_opts(): 127 | args = parser.parse_args() 128 | 129 | args.evaluate_baselines = not args.tta 130 | args.baseline = 'source' 131 | 132 | return args 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | # parser = argparse.ArgumentParser(description="Implementation of ViTTA") 142 | # parser.add_argument('--dataset', type=str, default='ucf101', choices=['ucf101', 'somethingv2', 'kinetics']) 143 | # parser.add_argument('--modality', type=str, default='RGB', choices=['RGB', 'Flow']) 144 | # parser.add_argument('--root_path', default='None', type=str) 145 | # parser.add_argument('--train_list', default='None', type=str) 146 | # parser.add_argument('--val_list', default='None', type=str) 147 | # 148 | # # ========================= Model Configs ========================== 149 | # parser.add_argument('--arch', type=str, default="i3d_resnet50") 150 | # parser.add_argument('--dropout', '--do', default=0.5, type=float, 151 | # metavar='DO', help='dropout ratio (default: 0.5)') # used in i3d 152 | # parser.add_argument('--clip_length', default=64, type=int, metavar='N', 153 | # help='length of sequential frames (default: 64)') 154 | # parser.add_argument('--input_size', default=224, type=int, metavar='N', 155 | # help='size of input (default: 224)') 156 | # parser.add_argument('--loss_type', type=str, default="nll", 157 | # choices=['nll']) 158 | # 159 | # # ========================= Learning Configs ========================== 160 | # parser.add_argument('--epochs', default=80, type=int, metavar='N', 161 | # help='number of total epochs to run') 162 | # parser.add_argument('-b', '--batch-size', default=16, type=int, 163 | # metavar='N', help='mini-batch size (default: 16)') 164 | # parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, 165 | # metavar='LR', help='initial learning rate') 166 | # parser.add_argument('--lr_steps', default=[30, 60], type=float, nargs="+", 167 | # metavar='LRSteps', help='epochs to decay learning rate by 10') 168 | # parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 169 | # help='momentum') 170 | # parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 171 | # metavar='W', help='weight decay (default: 5e-4)') 172 | # parser.add_argument('--clip-gradient', '--gd', default=None, type=float, 173 | # metavar='W', help='gradient norm clipping (default: disabled)') 174 | # 175 | # # ========================= Monitor Configs ========================== 176 | # parser.add_argument('--print-freq', '-p', default=20, type=int, 177 | # metavar='N', help='print frequency (default: 10)') 178 | # parser.add_argument('--eval-freq', '-ef', default=5, type=int, 179 | # metavar='N', help='evaluation frequency (default: 5)') 180 | # 181 | # # ========================= Runtime Configs ========================== 182 | # parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 183 | # help='number of data loading workers (default: 4)') 184 | # parser.add_argument('--resume', default='', type=str, metavar='PATH', 185 | # help='path to latest checkpoint (default: none)') 186 | # parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 187 | # help='evaluate model on validation set') 188 | # parser.add_argument('--snapshot_pref', type=str, default="") 189 | # parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 190 | # help='manual epoch number (useful on restarts)') 191 | # parser.add_argument('--gpus', nargs='+', type=int, default=None) 192 | # parser.add_argument('--flow_prefix', default="flow_", type=str) -------------------------------------------------------------------------------- /utils/pred_consistency_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | # def kl_div(input, targets): 6 | # return F.kl_div(F.log_softmax(input, dim=1), targets, reduction='none').sum(1) 7 | 8 | l1_loss = nn.L1Loss(reduction='sum') 9 | 10 | 11 | def kl_div(input, targets): 12 | # return F.kl_div( input.log(), targets, reduction='none').sum(1) 13 | return F.kl_div( F.log_softmax(input, dim=1), targets, reduction='none' ).sum(1) 14 | 15 | def compute_pred_consis( preds ): 16 | """ 17 | :param preds: in shape (batch_size, n_views, n_class) before softmax 18 | :return: 19 | """ 20 | bz, n_views, n_class = preds.size() 21 | softmaxs = [] 22 | for view_id in range(n_views): 23 | softmaxs += [F.softmax( preds[:, view_id, :], dim=1)] 24 | 25 | # avg_softmax = torch.stack(softmaxs, dim=0).mean(0).detach() 26 | avg_softmax = torch.stack(softmaxs, dim=0).mean(0) 27 | 28 | loss_consis = [ l1_loss( softmaxs[view_id] , avg_softmax) for view_id in range(n_views) ] 29 | # loss_consis = [ kl_div( preds[:, view_id, :] , avg_softmax) for view_id in range(n_views) ] 30 | loss_consis = sum(loss_consis) / n_views 31 | return loss_consis 32 | 33 | # avg_softmax = sum(softmaxs) / n_views 34 | 35 | 36 | 37 | # softmaxs = [ F.softmax() for logit in preds[:, ]] -------------------------------------------------------------------------------- /utils/utils_.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append(".") # last level 4 | import os 5 | import getpass 6 | import logging 7 | import numpy as np 8 | import torchvision 9 | # import torchvision.transforms 10 | 11 | 12 | from utils.transforms import * 13 | # from transforms import * 14 | import os.path as osp 15 | import csv 16 | import shutil 17 | import time 18 | 19 | def make_dir(dir_): 20 | if not os.path.exists(dir_): 21 | os.makedirs(dir_) 22 | 23 | 24 | def get_class_dict(file_): 25 | class_to_id_dict = dict() 26 | id_to_class_dict = dict() 27 | for line in open(file_): 28 | items = line.strip('\n').split(' ') 29 | class_id, class_name = int(items[0]), items[1] 30 | class_to_id_dict.update({class_name: class_id}) 31 | id_to_class_dict.update({class_id: class_name}) 32 | return class_to_id_dict, id_to_class_dict 33 | 34 | 35 | def read_mapping(file_): 36 | class_to_id_dict = dict() 37 | id_to_class_dict = dict() 38 | for line_id, line in enumerate(open(file_)): 39 | class_id = line_id 40 | class_name = line.strip('\n') 41 | # items = line.strip('\n').split(' ') 42 | # class_id, class_name = int(items[0]), items[1] 43 | class_to_id_dict.update({class_name: class_id}) 44 | id_to_class_dict.update({class_id: class_name}) 45 | return class_to_id_dict, id_to_class_dict 46 | 47 | 48 | def read_csv(csv_file, class_to_id_dict): 49 | vid_label_dict = dict() 50 | with open(csv_file, 'r') as csvfile: 51 | csvreader = csv.reader(csvfile) 52 | fields = next(csvreader) 53 | # rows = [] 54 | for row in csvreader: 55 | action = row[0] 56 | youtube_id = row[1] 57 | class_id = class_to_id_dict[action] 58 | vid_label_dict.update({youtube_id: (class_id, action)}) 59 | return vid_label_dict 60 | 61 | 62 | def get_env_id(): 63 | if getpass.getuser() == 'mirza': 64 | env_id = 0 65 | elif getpass.getuser() == 'jmie01': 66 | env_id = 1 67 | elif getpass.getuser() == 'lin': 68 | env_id = 2 69 | elif getpass.getuser() == 'ivanl': 70 | env_id = 3 71 | elif getpass.getuser() == 'eicg': 72 | env_id = 4 73 | elif getpass.getuser() == 'wlin': 74 | env_id = 5 75 | else: 76 | raise Exception("Unknown username!") 77 | return env_id 78 | 79 | 80 | def get_list_files_da(dataset_da, debug, data_dir): 81 | dummy_str = '_dummy' if debug else '' 82 | if dataset_da == 'u2h': 83 | train_list = osp.join('UCF-HMDB/UCF-HMDB12/list_nframes_label', f'list_ucf12_train_nframes{dummy_str}.txt') 84 | val_list = osp.join('UCF-HMDB/UCF-HMDB12/list_nframes_label', f'list_hmdb12_val_nframes{dummy_str}.txt') 85 | elif dataset_da == 'h2u': 86 | train_list = osp.join('UCF-HMDB/UCF-HMDB12/list_nframes_label', f'list_hmdb12_train_nframes{dummy_str}.txt') 87 | val_list = osp.join('UCF-HMDB/UCF-HMDB12/list_nframes_label', f'list_ucf12_val_nframes{dummy_str}.txt') 88 | train_list, val_list = osp.join(data_dir, train_list), osp.join(data_dir, val_list) 89 | return train_list, val_list 90 | 91 | 92 | def path_logger(result_dir, log_time): 93 | streamHandler = logging.StreamHandler() 94 | streamHandler.setLevel(logging.DEBUG) 95 | global logger 96 | 97 | logger = logging.getLogger('basic') 98 | logger.setLevel(logging.DEBUG) 99 | 100 | path_logging = os.path.join(result_dir, f'{log_time}') 101 | 102 | fileHandler = logging.FileHandler(path_logging, mode='w') 103 | fileHandler.setLevel(logging.DEBUG) 104 | 105 | formatter = logging.Formatter('%(asctime)s - %(levelno)s - %(filename)s - %(funcName)s - %(message)s') 106 | streamHandler.setFormatter(formatter) 107 | fileHandler.setFormatter(formatter) 108 | logger.addHandler(streamHandler) 109 | logger.addHandler(fileHandler) 110 | return logger 111 | 112 | 113 | def model_analysis(model, logger): 114 | print("Model Structure") 115 | print(model) 116 | 117 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 118 | params = sum([np.prod(p.size()) for p in model_parameters]) 119 | logger.debug('#################################################') 120 | logger.debug(f'Number of trainable parameters: {params}') 121 | logger.debug('#################################################') 122 | 123 | 124 | def get_augmentation(args, modality, input_size): 125 | if modality == 'Flow': 126 | raise NotImplementedError('Flow not implemented!') 127 | # return torchvision.transforms.Compose([GroupMultiScaleCrop(input_size, [1, .875, .75]), 128 | # GroupRandomHorizontalFlip(is_flow=True)]) 129 | 130 | 131 | # todo for SSv2, labels for some classes are modified when GroupRandomHorizontalFlip is performed 132 | # label_transforms is hard coded to swap the labels for 3 groups of classes: 86 and 87, 93 and 94, 166 and 167 133 | # because after horizontal flip, "left to right" becomes "right to left" 134 | if args.dataset == 'somethingv2': 135 | label_transforms = { 136 | 86: 87, 137 | 87: 86, 138 | 93: 94, 139 | 94: 93, 140 | 166: 167, 141 | 167: 166 142 | } 143 | else: 144 | label_transforms = None 145 | 146 | if args.evaluate_baselines: 147 | if modality == 'RGB' and args.baseline != 'dua': 148 | return torchvision.transforms.Compose([GroupMultiScaleCrop(input_size, [1, .875, .75, .66]), 149 | GroupRandomHorizontalFlip(is_flow=False, label_transforms=label_transforms)]) 150 | elif modality == 'RGB' and args.baseline == 'dua': 151 | 152 | from models.tanet_models.transforms import GroupMultiScaleCrop_TANet_tensor, GroupRandomHorizontalFlip_TANet 153 | 154 | if args.arch == 'tanet': 155 | return torchvision.transforms.Compose([ 156 | GroupMultiScaleCrop_TANet_tensor(input_size, [1, .875, .75, .66]), 157 | GroupRandomHorizontalFlip_TANet(is_flow=False, label_transforms=label_transforms)]) 158 | else: 159 | return torchvision.transforms.Compose([ 160 | GroupMultiScaleCrop_tensors(input_size, [1, .875, .75, .66]), 161 | GroupRandomHorizontalFlip(is_flow=False, label_transforms=label_transforms)]) 162 | # elif modality == 'RGB' and args.baseline == 'dua' and args.arch == 'tanet': 163 | # return torchvision.transforms.Compose([]) 164 | else: 165 | # todo pure evaluation or TTA 166 | if modality == 'RGB': 167 | return torchvision.transforms.Compose([GroupMultiScaleCrop(input_size, [1, .875, .75, .66]), 168 | GroupRandomHorizontalFlip(is_flow=False, label_transforms=label_transforms)]) 169 | 170 | 171 | class AverageMeter(object): 172 | """Computes and stores the average and current value""" 173 | 174 | def __init__(self): 175 | self.reset() 176 | 177 | def reset(self): 178 | self.val = 0 179 | self.avg = 0 180 | self.sum = 0 181 | self.count = 0 182 | 183 | def update(self, val, n=1): 184 | self.val = val 185 | self.sum += val * n 186 | self.count += n 187 | self.avg = self.sum / self.count 188 | 189 | 190 | class AverageMeterTensor(object): 191 | def __init__(self): 192 | self.reset() 193 | def reset(self): 194 | self.val = torch.tensor(0).float().cuda() 195 | self.avg = torch.tensor(0).float().cuda() 196 | self.sum = torch.tensor(0).float().cuda() 197 | self.count = 0 198 | def update(self, val, n=1): 199 | self.val = val 200 | self.sum = self.sum.detach() + val * n 201 | self.count += n 202 | self.avg = self.sum / self.count 203 | 204 | class MovingAverageTensor(object): 205 | def __init__(self, momentum=0.1): 206 | self.momentum = momentum 207 | self.reset() 208 | def reset(self): 209 | self.avg = torch.tensor(0).float().cuda() 210 | def update(self, val ): 211 | self.avg = self.momentum * val + (1.0 - self.momentum) * self.avg.detach().to(val.device) 212 | 213 | 214 | def adjust_learning_rate(optimizer, epoch, lr_steps, args=None): 215 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 216 | decay = 0.1 ** (sum(epoch >= np.array(lr_steps))) 217 | lr = args.lr * decay 218 | decay = args.weight_decay 219 | for param_group in optimizer.param_groups: 220 | param_group['lr'] = lr 221 | param_group['weight_decay'] = decay 222 | 223 | 224 | def accuracy(output, target, topk=(1,)): 225 | """Computes the precision@k for the specified values of k""" 226 | maxk = max(topk) 227 | batch_size = target.size(0) 228 | 229 | _, pred = output.topk(maxk, 1, True, True) 230 | pred = pred.t() 231 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 232 | 233 | res = [] 234 | for k in topk: 235 | correct_k = correct[:k].reshape(-1).float().sum(0) 236 | res.append(correct_k.mul_(100.0 / batch_size)) 237 | return res 238 | 239 | 240 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', result_dir=None, log_time=None, logger=None, 241 | args=None): 242 | filename = '_'.join((log_time, args.snapshot_pref, args.modality.lower(), filename)) 243 | file_path = osp.join(result_dir, filename) 244 | torch.save(state, file_path) 245 | if is_best: 246 | best_name = '_'.join((log_time, args.snapshot_pref, args.modality.lower(), 'model_best.pth.tar')) 247 | best_model_path = osp.join(result_dir, best_name) 248 | shutil.copyfile(file_path, best_model_path) 249 | logger.debug(f'Best Checkpoint saved!') 250 | 251 | 252 | def get_writer_to_all_result(args, custom_path=None): 253 | log_time = time.strftime("%Y%m%d_%H%M%S") 254 | 255 | if custom_path is None: 256 | f_write = open(osp.join(args.result_dir, f'{log_time}_all_result'), 'w+') 257 | else: 258 | f_write = open(osp.join(custom_path, f'{args.baseline}_{log_time}_all_result'), 'w+') 259 | 260 | for arg in dir(args): 261 | if arg[0] != '_': 262 | f_write.write(f'{arg} {getattr(args, arg)}\n') 263 | f_write.write(f'#############################\n') 264 | f_write.write(f'#############################\n') 265 | f_write.write('\n') 266 | f_write.write('\n') 267 | return f_write 268 | --------------------------------------------------------------------------------