├── .DS_Store ├── LICENSE ├── README.md ├── eval_sth.sh ├── fig ├── intro.jpeg ├── neu.png └── smile.png ├── main.py ├── ops ├── __init__.py ├── backbone │ ├── AF_MobileNetv3.py │ ├── AF_ResNet.py │ ├── __init__.py │ ├── gumbel_softmax.py │ └── temporal_shift.py ├── basic_ops.py ├── dataset.py ├── dataset_config.py ├── models.py ├── models_mobilenet.py ├── transforms.py └── utils.py ├── opts.py └── train_sth.sh /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeSpontaneous/AFNet-pytorch/f20b92e6430f7978ef15e537d932381be575bdad/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 BeSpontaneous 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Look More but Care Less in Video Recognition (NeurIPS 2022) 2 | 3 |
4 | 5 | 6 |
7 | 8 | [arXiv](https://arxiv.org/abs/2211.09992) | Primary contact: [Yitian Zhang](mailto:markcheung9248@gmail.com) 9 | 10 |
11 | 12 |
13 | 14 | Comparisons between existing methods and our proposed Ample and Focal Network (AFNet). Most existing works reduce the redundancy in data at the beginning of the deep networks which leads to the loss of information. We propose a two-branch design which processes frames with different computational resources within the network and preserves all input information as well. 15 | 16 | 17 | ## Requirements 18 | - python 3.7 19 | - pytorch 1.7.0 20 | - torchvision 0.9.0 21 | 22 | 23 | ## Datasets 24 | Please follow the instruction of [TSM](https://github.com/mit-han-lab/temporal-shift-module#data-preparation) to prepare the Something-Something V1/V2 dataset. 25 | 26 | 27 | ## Pretrained Models 28 | Here we provide the pretrained AF-MobileNetv3, AF-ResNet50, AF-ResNet101 on ImageNet and all the pretrained models on Something-Something V1 dataset. 29 | 30 | ### Results on ImageNet 31 | Checkpoints are available through the [link](https://drive.google.com/drive/folders/1UzSckmKnwmgwWObF2_YxpkAIZ2k2mcHL?usp=share_link). 32 | | Model | Top-1 Acc. | GFLOPs | 33 | | --------------- | ------------- | ------------- | 34 | | AF-MobileNetv3 | 72.09% | 0.2 | 35 | | AF-ResNet50 | 77.24% | 2.9 | 36 | | AF-ResNet101 | 78.36% | 5.0 | 37 | 38 | ### Results on Something-Something V1 39 | Checkpoints and logs are available through the [link](https://drive.google.com/drive/folders/1-xmE6T6OADmDkkzJr4iM1vCJbA4ofcSO?usp=share_link). 40 | 41 | **Less is More**: 42 | | Model | Frame | Top-1 Acc. | GFLOPs | 43 | | --------------- | --------------- | ------------- | ------------- | 44 | | TSN | 8 | 18.6% | 32.7 | 45 | | AFNet(RT=0.50) | 8 | 26.8% | 19.5 | 46 | | AFNet(RT=0.25) | 8 | 27.7% | 18.3 | 47 | 48 | 49 | **More is Less**: 50 | | Model | Backbone | Frame | Top-1 Acc. | GFLOPs | 51 | | --------------- | --------------- | ------------- |------------- | ------------- | 52 | | TSM | ResNet50 | 8 | 45.6% | 32.7 | 53 | | AFNet-TSM(RT=0.4) | AF-ResNet50 | 12 | 49.0% | 27.9 | 54 | | AFNet-TSM(RT=0.8) | AF-ResNet50 | 12 |49.9% | 31.7 | 55 | | AFNet-TSM(RT=0.4) | AF-MobileNetv3 | 12 | 45.3% | 2.2 | 56 | | AFNet-TSM(RT=0.8) | AF-MobileNetv3 | 12 | 45.9% | 2.3 | 57 | | AFNet-TSM(RT=0.4) | AF-ResNet101 | 12 | 49.8% | 42.1 | 58 | | AFNet-TSM(RT=0.4) | AF-ResNet101 | 12 | 50.1% | 48.9 | 59 | 60 | 61 | ## Training AFNet on Something-Something V1 62 | 1. Specify the directory of datasets with `root_dataset` in `train_sth.sh`. 63 | 2. Please download pretrained backbone on ImageNet from [Google Drive](https://drive.google.com/drive/folders/1UzSckmKnwmgwWObF2_YxpkAIZ2k2mcHL?usp=share_link). 64 | 3. Specify the directory of the downloaded backbone with `path_backbone` in `train_sth.sh`. 65 | 4. Specify the ratio of selected frames with `rt` and run `bash train_sth.sh`. 66 | 67 | 68 | 69 | ## Evaluate pretrained models on Something-Something V1 70 | **Note that there is a small variance during evaluation because of Gumbel-Softmax and the testing results may not align with the numbers in our paper. We provide the logs in Tab 2 for verification.** 71 | 1. Specify the directory of datasets with `root_dataset` in `eval_sth.sh`. 72 | 2. Please download pretrained models from [Google Drive](https://drive.google.com/drive/folders/1-xmE6T6OADmDkkzJr4iM1vCJbA4ofcSO?usp=share_link). 73 | 3. Specify the directory of the pretrained model with `resume` in `eval_sth.sh`. 74 | 4. Run `bash eval_sth.sh`. 75 | 76 | 77 | 78 | ## Reference 79 | If you find our code or paper useful for your research, please cite: 80 | ``` 81 | @article{zhang2022look, 82 | title={Look More but Care Less in Video Recognition}, 83 | author={Zhang, Yitian and Bai, Yue and Wang, Huan and Xu, Yi and Fu, Yun}, 84 | journal={arXiv preprint arXiv:2211.09992}, 85 | year={2022} 86 | } 87 | ``` -------------------------------------------------------------------------------- /eval_sth.sh: -------------------------------------------------------------------------------- 1 | ### evaluate AF-ResNet 2 | CUDA_VISIBLE_DEVICES=0,1 python main.py something RGB \ 3 | --arch_file AF_ResNet \ 4 | --arch AF_resnet50 --num_segments 12 \ 5 | --root_dataset 'path_dataset' \ 6 | --path_backbone 'path_backbone' \ 7 | --batch-size 32 --lr 0.01 --lr_steps 25 45 --epochs 55 \ 8 | --gd 20 -j 12 --dropout 0.5 --consensus_type=avg --eval-freq=1 --npb \ 9 | --rt_begin 10 --rt_end 20 --t0 1 --t_end 50 --lambda_rt 0.5 \ 10 | --model_path 'models' \ 11 | --rt 0.5 --round test \ 12 | --resume 'path_pretrained_model' \ 13 | --evaluate; 14 | 15 | 16 | 17 | ### evaluate AF-ResNet-TSM 18 | CUDA_VISIBLE_DEVICES=0,1 python main.py something RGB \ 19 | --arch_file AF_ResNet \ 20 | --arch AF_resnet50 --num_segments 12 \ 21 | --root_dataset 'path_dataset' \ 22 | --path_backbone 'path_backbone' \ 23 | --batch-size 32 --lr 0.01 --lr_steps 25 45 --epochs 55 \ 24 | --gd 20 -j 12 --dropout 0.5 --consensus_type=avg --eval-freq=1 --npb \ 25 | --rt_begin 10 --rt_end 20 --t0 1 --t_end 50 --lambda_rt 0.5 \ 26 | --model_path 'models' \ 27 | --shift \ 28 | --rt 0.5 --round test \ 29 | --resume 'path_pretrained_model' \ 30 | --evaluate; 31 | 32 | 33 | 34 | ### evaluate AF-MobileNetv3-TSM 35 | CUDA_VISIBLE_DEVICES=0,1 python main.py something RGB \ 36 | --arch_file AF_MobileNetv3 \ 37 | --arch AF_mobilenetv3 --num_segments 12 \ 38 | --root_dataset 'path_dataset' \ 39 | --path_backbone 'path_backbone' \ 40 | --batch-size 32 --lr 0.01 --lr_steps 25 45 --epochs 55 \ 41 | --gd 20 -j 12 --dropout 0.5 --consensus_type=avg --eval-freq=1 --npb \ 42 | --rt_begin 10 --rt_end 20 --t0 1 --t_end 50 --lambda_rt 0.5 \ 43 | --model_path 'models_mobilenet' \ 44 | --shift \ 45 | --rt 0.5 --round test \ 46 | --resume 'path_pretrained_model' \ 47 | --evaluate; -------------------------------------------------------------------------------- /fig/intro.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeSpontaneous/AFNet-pytorch/f20b92e6430f7978ef15e537d932381be575bdad/fig/intro.jpeg -------------------------------------------------------------------------------- /fig/neu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeSpontaneous/AFNet-pytorch/f20b92e6430f7978ef15e537d932381be575bdad/fig/neu.png -------------------------------------------------------------------------------- /fig/smile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeSpontaneous/AFNet-pytorch/f20b92e6430f7978ef15e537d932381be575bdad/fig/smile.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | 6 | import os 7 | import time 8 | import shutil 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim 12 | 13 | import torch.multiprocessing as mp 14 | import torch.utils.data 15 | import torch.utils.data.distributed 16 | import torch.distributed as dist 17 | from torch.cuda.amp import autocast, GradScaler 18 | 19 | from torch.nn.utils import clip_grad_norm_ 20 | import pandas as pd 21 | from ops.dataset import TSNDataSet 22 | import importlib 23 | from ops.transforms import * 24 | from opts import parser 25 | from ops import dataset_config 26 | from ops.utils import AverageMeter, accuracy 27 | from ops.backbone.temporal_shift import make_temporal_pool 28 | from tensorboardX import SummaryWriter 29 | 30 | 31 | best_prec1 = 0 32 | val_acc_top1 = [] 33 | val_acc_top5 = [] 34 | val_FLOPs = [] 35 | 36 | tr_big_rate = [] 37 | val_big_rate = [] 38 | train_loss_ls = [] 39 | 40 | tr_acc_top1 = [] 41 | tr_acc_top5 = [] 42 | train_loss = [] 43 | train_loss_cls = [] 44 | valid_loss = [] 45 | epoch_log = [] 46 | 47 | 48 | def main(): 49 | global args, best_prec1 50 | global val_acc_top1 51 | global val_acc_top5 52 | global tr_acc_top1 53 | global tr_acc_top5 54 | global train_loss 55 | global train_loss_cls 56 | global valid_loss 57 | global epoch_log 58 | global tr_big_rate 59 | global val_big_rate 60 | global train_loss_ls 61 | global val_FLOPs 62 | args = parser.parse_args() 63 | 64 | if args.distributed: 65 | dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:8888', 66 | world_size=args.world_size, rank=args.local_rank) 67 | torch.cuda.set_device(args.local_rank) 68 | device = torch.device(f'cuda:{args.local_rank}') 69 | 70 | if not args.distributed or (args.distributed and torch.distributed.get_rank() == 0): 71 | num_class, args.train_list, args.val_list, args.root_path, prefix \ 72 | = dataset_config.return_dataset(args.root_dataset, args.dataset, args.modality) 73 | str_round = str(args.round) 74 | args.store_name = f'{args.dataset}/{args.arch_file}/{args.arch}/frame{args.num_segments}/round{str_round}/' 75 | print('storing name: ' + args.store_name) 76 | check_rootfolders() 77 | 78 | path = str('ops.'+args.model_path) 79 | file = importlib.import_module(path) 80 | model = file.TSN(args.arch_file, num_class, args.num_segments, args.modality, args.path_backbone, 81 | base_model=args.arch, 82 | consensus_type=args.consensus_type, 83 | dropout=args.dropout, 84 | img_feature_dim=args.img_feature_dim, 85 | partial_bn=not args.no_partialbn, 86 | pretrain=args.pretrain, 87 | is_shift=args.shift, 88 | fc_lr5=not (args.tune_from and args.dataset in args.tune_from), 89 | temporal_pool=args.temporal_pool, 90 | non_local=args.non_local) 91 | 92 | crop_size = model.crop_size 93 | scale_size = model.scale_size 94 | input_mean = model.input_mean 95 | input_std = model.input_std 96 | policies = model.get_optim_policies() 97 | train_augmentation = model.get_augmentation(flip=False if 'something' in args.dataset or 'jester' in args.dataset else True) 98 | 99 | if args.distributed: 100 | model.to(device) 101 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 102 | output_device=args.local_rank, find_unused_parameters=True) 103 | else: 104 | model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() 105 | 106 | optimizer = torch.optim.SGD(policies, 107 | args.lr, 108 | momentum=args.momentum, 109 | weight_decay=args.weight_decay) 110 | 111 | if args.resume: 112 | if args.temporal_pool: # early temporal pool so that we can load the state_dict 113 | make_temporal_pool(model.module.base_model, args.num_segments) 114 | if os.path.isfile(args.resume): 115 | print(("=> loading checkpoint '{}'".format(args.resume))) 116 | checkpoint = torch.load(args.resume) 117 | args.start_epoch = checkpoint['epoch'] 118 | best_prec1 = checkpoint['best_prec1'] 119 | model.load_state_dict(checkpoint['state_dict']) 120 | optimizer.load_state_dict(checkpoint['optimizer']) 121 | 122 | val_acc_top1 = checkpoint['val_acc_top1'] 123 | val_acc_top5 = checkpoint['val_acc_top5'] 124 | val_big_rate = checkpoint['val_big_rate'] 125 | val_FLOPs = checkpoint['val_FLOPs'] 126 | tr_acc_top1 = checkpoint['tr_acc_top1'] 127 | tr_acc_top5 = checkpoint['tr_acc_top5'] 128 | train_loss = checkpoint['train_loss'] 129 | tr_big_rate = checkpoint['tr_big_rate'] 130 | train_loss_cls = checkpoint['train_loss_cls'] 131 | train_loss_ls = checkpoint['train_loss_ls'] 132 | valid_loss = checkpoint['valid_loss'] 133 | epoch_log = checkpoint['epoch_log'] 134 | 135 | print(("=> loaded checkpoint '{}' (epoch {})" 136 | .format(args.evaluate, checkpoint['epoch']))) 137 | else: 138 | print(("=> no checkpoint found at '{}'".format(args.resume))) 139 | 140 | if args.tune_from: 141 | print(("=> fine-tuning from '{}'".format(args.tune_from))) 142 | sd = torch.load(args.tune_from) 143 | sd = sd['state_dict'] 144 | model_dict = model.state_dict() 145 | replace_dict = [] 146 | for k, v in sd.items(): 147 | if k not in model_dict and k.replace('.net', '') in model_dict: 148 | print('=> Load after remove .net: ', k) 149 | replace_dict.append((k, k.replace('.net', ''))) 150 | for k, v in model_dict.items(): 151 | if k not in sd and k.replace('.net', '') in sd: 152 | print('=> Load after adding .net: ', k) 153 | replace_dict.append((k.replace('.net', ''), k)) 154 | 155 | for k, k_new in replace_dict: 156 | sd[k_new] = sd.pop(k) 157 | keys1 = set(list(sd.keys())) 158 | keys2 = set(list(model_dict.keys())) 159 | set_diff = (keys1 - keys2) | (keys2 - keys1) 160 | print('#### Notice: keys that failed to load: {}'.format(set_diff)) 161 | if args.dataset not in args.tune_from: # new dataset 162 | print('=> New dataset, do not load fc weights') 163 | sd = {k: v for k, v in sd.items() if 'fc' not in k} 164 | if args.modality == 'Flow' and 'Flow' not in args.tune_from: 165 | sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k} 166 | model_dict.update(sd) 167 | model.load_state_dict(model_dict) 168 | 169 | if args.temporal_pool and not args.resume: 170 | make_temporal_pool(model.module.base_model, args.num_segments) 171 | 172 | cudnn.benchmark = True 173 | 174 | # Data loading code 175 | if args.modality != 'RGBDiff': 176 | normalize = GroupNormalize(input_mean, input_std) 177 | else: 178 | normalize = IdentityTransform() 179 | 180 | if args.modality == 'RGB': 181 | data_length = 1 182 | elif args.modality in ['Flow', 'RGBDiff']: 183 | data_length = 5 184 | 185 | train_dataset = TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments, 186 | new_length=data_length, 187 | modality=args.modality, 188 | image_tmpl=prefix, 189 | transform=torchvision.transforms.Compose([ 190 | GroupScale((240,320)), 191 | train_augmentation, 192 | Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), 193 | ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])), 194 | normalize, 195 | ]), dense_sample=args.dense_sample) 196 | 197 | val_dataset = TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments, 198 | new_length=data_length, 199 | modality=args.modality, 200 | image_tmpl=prefix, 201 | random_shift=False, 202 | transform=torchvision.transforms.Compose([ 203 | GroupScale((240,320)), 204 | GroupCenterCrop(crop_size), 205 | Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), 206 | ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])), 207 | normalize, 208 | ]), dense_sample=args.dense_sample) 209 | 210 | if args.distributed: 211 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 212 | else: 213 | train_sampler = None 214 | 215 | train_loader = torch.utils.data.DataLoader( 216 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 217 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, 218 | drop_last=True) # prevent something not % n_GPU 219 | 220 | val_loader = torch.utils.data.DataLoader( 221 | val_dataset, batch_size=args.batch_size, shuffle=False, 222 | num_workers=args.workers, pin_memory=True) 223 | 224 | # define loss function (criterion) and optimizer 225 | if args.loss_type == 'nll': 226 | criterion = torch.nn.CrossEntropyLoss().cuda() 227 | else: 228 | raise ValueError("Unknown loss type") 229 | 230 | for group in policies: 231 | print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format( 232 | group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) 233 | 234 | if args.evaluate: 235 | validate(val_loader, model, criterion, 0, args.rt) 236 | return 237 | 238 | if not args.distributed or (args.distributed and torch.distributed.get_rank() == 0): 239 | log_training = open(os.path.join(args.root_log, args.store_name, 'log.csv'), 'w') 240 | with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f: 241 | f.write(str(args)) 242 | tf_writer = SummaryWriter(log_dir=os.path.join(args.root_log, args.store_name)) 243 | 244 | if args.amp: 245 | scaler = GradScaler() 246 | else: 247 | scaler = None 248 | 249 | for epoch in range(args.start_epoch, args.epochs): 250 | adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps) 251 | rt = adjust_ratio(epoch, args) 252 | # train for one epoch 253 | tr_acc1, tr_acc5, tr_loss, tr_loss_cls, tr_loss_rt, tr_ratios = train(train_loader, model, criterion, optimizer, epoch, rt, log_training, tf_writer, scaler) 254 | 255 | # evaluate on validation set 256 | if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: 257 | val_acc1, val_acc5, val_loss, val_ratios, val_flops = validate(val_loader, model, criterion, epoch, rt, log_training, tf_writer) 258 | 259 | if not args.distributed or (args.distributed and torch.distributed.get_rank() == 0): 260 | # remember best prec@1 and save checkpoint 261 | is_best = val_acc1 > best_prec1 262 | best_prec1 = max(val_acc1, best_prec1) 263 | tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch) 264 | 265 | output_best = 'Best Prec@1: %.3f\n' % (best_prec1) 266 | print(output_best) 267 | log_training.write(output_best + '\n') 268 | log_training.flush() 269 | 270 | val_acc_top1.append(val_acc1) 271 | val_acc_top5.append(val_acc5) 272 | val_big_rate.append(val_ratios) 273 | tr_big_rate.append(tr_ratios) 274 | val_FLOPs.append(val_flops) 275 | tr_acc_top1.append(tr_acc1) 276 | tr_acc_top5.append(tr_acc5) 277 | train_loss.append(tr_loss) 278 | train_loss_cls.append(tr_loss_cls) 279 | train_loss_ls.append(tr_loss_rt) 280 | valid_loss.append(val_loss) 281 | epoch_log.append(epoch) 282 | 283 | df = pd.DataFrame({'val_acc_top1': val_acc_top1, 'val_acc_top5': val_acc_top5, 284 | 'val_big_rate': val_big_rate, 'val_FLOPs': val_FLOPs, 285 | 'tr_big_rate': tr_big_rate, 'tr_acc_top1': tr_acc_top1, 'tr_acc_top5': tr_acc_top5, 286 | 'train_loss': train_loss, 'train_loss_cls': train_loss_cls, 'train_loss_ls': train_loss_ls, 287 | 'valid_loss': valid_loss, 'epoch_log': epoch_log}) 288 | 289 | log_file = os.path.join(args.root_log, args.store_name, 'log_epoch.txt') 290 | with open(log_file, "w") as f: 291 | df.to_csv(f) 292 | 293 | save_checkpoint({ 294 | 'epoch': epoch + 1, 295 | 'arch': args.arch, 296 | 'state_dict': model.state_dict(), 297 | 'optimizer': optimizer.state_dict(), 298 | 'best_prec1': best_prec1, 299 | 'val_acc_top1': val_acc_top1, 300 | 'val_acc_top5': val_acc_top5, 301 | 'val_big_rate': val_big_rate, 302 | 'val_FLOPs': val_FLOPs, 303 | 'tr_big_rate': tr_big_rate, 304 | 'tr_acc_top1': tr_acc_top1, 305 | 'tr_acc_top5': tr_acc_top5, 306 | 'train_loss': train_loss, 307 | 'train_loss_cls': train_loss_cls, 308 | 'train_loss_ls': train_loss_ls, 309 | 'valid_loss': valid_loss, 310 | 'epoch_log': epoch_log, 311 | }, is_best, epoch) 312 | 313 | if not args.distributed or (args.distributed and torch.distributed.get_rank() == 0): 314 | file1 = pd.read_csv(log_file) 315 | acc1 = np.array(file1['val_acc_top1']) 316 | flops1 = np.array(file1['val_FLOPs']) 317 | loc = np.argmax(acc1) 318 | max_acc = acc1[loc] 319 | acc_flops = flops1[loc] 320 | fout = open(os.path.join(args.root_log, args.store_name, 'log_epoch.txt'), mode='a', encoding='utf-8') 321 | fout.write("%.6f\t%.6f" % (max_acc, acc_flops)) 322 | 323 | 324 | def train(train_loader, model, criterion, optimizer, epoch, rt, log, tf_writer, scaler=None): 325 | batch_time = AverageMeter() 326 | data_time = AverageMeter() 327 | losses = AverageMeter() 328 | losses_cls = AverageMeter() 329 | losses_rt = AverageMeter() 330 | top1 = AverageMeter() 331 | top5 = AverageMeter() 332 | real_ratios = AverageMeter() 333 | train_batches_num = len(train_loader) 334 | 335 | if args.no_partialbn: 336 | model.module.partialBN(False) 337 | else: 338 | model.module.partialBN(True) 339 | 340 | # switch to train mode 341 | model.train() 342 | 343 | end = time.time() 344 | 345 | if args.amp: 346 | assert scaler is not None 347 | 348 | for i, (input, target) in enumerate(train_loader): 349 | # measure data loading time 350 | data_time.update(time.time() - end) 351 | 352 | target = target.cuda() 353 | input_var = torch.autograd.Variable(input) 354 | target_var = torch.autograd.Variable(target) 355 | 356 | adjust_temperature(epoch, i, train_batches_num, args) 357 | optimizer.zero_grad() 358 | 359 | if args.amp: 360 | with autocast(): 361 | # compute output 362 | output, temporal_mask_ls = model(input_var, args.temp) 363 | loss_cls = criterion(output, target_var) 364 | 365 | real_ratio = 0.0 366 | loss_real_ratio = 0.0 367 | for temporal_mask in temporal_mask_ls: 368 | real_ratio += torch.mean(temporal_mask) 369 | loss_real_ratio += torch.pow(rt-torch.mean(temporal_mask), 2) 370 | real_ratio = torch.mean(real_ratio/len(temporal_mask_ls)) 371 | loss_real_ratio = torch.mean(loss_real_ratio/len(temporal_mask_ls)) 372 | loss_real_ratio = args.lambda_rt * loss_real_ratio 373 | loss = loss_cls + loss_real_ratio 374 | 375 | scaler.scale(loss).backward() 376 | scaler.step(optimizer) 377 | scaler.update() 378 | else: 379 | output, temporal_mask_ls = model(input_var, args.temp) 380 | loss_cls = criterion(output, target_var) 381 | 382 | real_ratio = 0.0 383 | loss_real_ratio = 0.0 384 | for temporal_mask in temporal_mask_ls: 385 | real_ratio += torch.mean(temporal_mask) 386 | loss_real_ratio += torch.pow(rt-torch.mean(temporal_mask), 2) 387 | real_ratio = torch.mean(real_ratio/len(temporal_mask_ls)) 388 | loss_real_ratio = torch.mean(loss_real_ratio/len(temporal_mask_ls)) 389 | loss_real_ratio = args.lambda_rt * loss_real_ratio 390 | loss = loss_cls + loss_real_ratio 391 | 392 | loss.backward() 393 | if args.clip_gradient is not None: 394 | total_norm = clip_grad_norm_(model.parameters(), args.clip_gradient) 395 | optimizer.step() 396 | 397 | 398 | # measure accuracy and record loss 399 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 400 | real_ratios.update(real_ratio.item(), input.size(0)) 401 | losses_cls.update(loss_cls.item(), input.size(0)) 402 | losses_rt.update(loss_real_ratio.item(), input.size(0)) 403 | losses.update(loss.item(), input.size(0)) 404 | top1.update(prec1.item(), input.size(0)) 405 | top5.update(prec5.item(), input.size(0)) 406 | 407 | # measure elapsed time 408 | batch_time.update(time.time() - end) 409 | end = time.time() 410 | 411 | if not args.distributed or (args.distributed and torch.distributed.get_rank() == 0): 412 | if i % args.print_freq == 0: 413 | output = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t' 414 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 415 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 416 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 417 | 'Loss_cls {loss_cls.val:.4f} ({loss_cls.avg:.4f})\t' 418 | 'Loss_ls {loss_ls.val:.4f} ({loss_ls.avg:.4f})\t' 419 | 'Real_ratio {real_ratio.val:.4f} ({real_ratio.avg:.4f})\t' 420 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 421 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 422 | epoch, i, len(train_loader), batch_time=batch_time, 423 | data_time=data_time, loss=losses, loss_cls=losses_cls, loss_ls=losses_rt, real_ratio=real_ratios, top1=top1, top5=top5, lr=optimizer.param_groups[-1]['lr'] * 0.1)) # TODO 424 | print(output) 425 | log.write(output + '\n') 426 | log.flush() 427 | 428 | if not args.distributed or (args.distributed and torch.distributed.get_rank() == 0): 429 | tf_writer.add_scalar('loss/train', losses.avg, epoch) 430 | tf_writer.add_scalar('acc/train_top1', top1.avg, epoch) 431 | tf_writer.add_scalar('acc/train_top5', top5.avg, epoch) 432 | tf_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch) 433 | 434 | return top1.avg, top5.avg, losses.avg, losses_cls.avg, losses_rt.avg, real_ratios.avg 435 | 436 | 437 | def validate(val_loader, model, criterion, epoch, rt, log=None, tf_writer=None): 438 | batch_time = AverageMeter() 439 | losses = AverageMeter() 440 | top1 = AverageMeter() 441 | top5 = AverageMeter() 442 | real_ratios = AverageMeter() 443 | FLOPs = AverageMeter() 444 | 445 | # switch to evaluate mode 446 | model.eval() 447 | 448 | end = time.time() 449 | with torch.no_grad(): 450 | for i, (input, target) in enumerate(val_loader): 451 | input = input.cuda() 452 | target = target.cuda() 453 | 454 | # compute output 455 | output, temporal_mask_ls, flops = model.module.forward_calc_flops(input, args.t1) 456 | flops /= 1e9 457 | loss_cls = criterion(output, target) 458 | 459 | real_ratio = 0.0 460 | loss_real_ratio = 0.0 461 | for temporal_mask in temporal_mask_ls: 462 | real_ratio += torch.mean(temporal_mask) 463 | loss_real_ratio += torch.pow(rt-torch.mean(temporal_mask), 2) 464 | real_ratio = torch.mean(real_ratio/len(temporal_mask_ls)) 465 | loss_real_ratio = torch.mean(loss_real_ratio/len(temporal_mask_ls)) 466 | loss_real_ratio = args.lambda_rt * loss_real_ratio 467 | 468 | loss = loss_cls + loss_real_ratio 469 | 470 | # measure accuracy and record loss 471 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 472 | 473 | FLOPs.update(flops.item(), input.size(0)) 474 | real_ratios.update(real_ratio.item(), input.size(0)) 475 | losses.update(loss.item(), input.size(0)) 476 | top1.update(prec1.item(), input.size(0)) 477 | top5.update(prec5.item(), input.size(0)) 478 | 479 | # measure elapsed time 480 | batch_time.update(time.time() - end) 481 | end = time.time() 482 | 483 | if not args.distributed or (args.distributed and torch.distributed.get_rank() == 0): 484 | if i % args.print_freq == 0: 485 | output = ('Test: [{0}/{1}]\t' 486 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 487 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 488 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 489 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 490 | i, len(val_loader), batch_time=batch_time, loss=losses, 491 | top1=top1, top5=top5)) 492 | print(output) 493 | if log is not None: 494 | log.write(output + '\n') 495 | log.flush() 496 | 497 | if not args.distributed or (args.distributed and torch.distributed.get_rank() == 0): 498 | output = ('Testing Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}' 499 | .format(top1=top1, top5=top5, loss=losses)) 500 | print(output) 501 | if log is not None: 502 | log.write(output + '\n') 503 | log.flush() 504 | 505 | if tf_writer is not None: 506 | tf_writer.add_scalar('loss/test', losses.avg, epoch) 507 | tf_writer.add_scalar('acc/test_top1', top1.avg, epoch) 508 | tf_writer.add_scalar('acc/test_top5', top5.avg, epoch) 509 | 510 | return top1.avg, top5.avg, losses.avg, real_ratios.avg, FLOPs.avg 511 | 512 | 513 | def save_checkpoint(state, is_best, epoch): 514 | filename = '%s/%s/ckpt.pth.tar' % (args.root_log, args.store_name) 515 | torch.save(state, filename) 516 | if is_best: 517 | shutil.copyfile(filename, filename.replace('pth.tar', 'best.pth.tar')) 518 | 519 | 520 | def adjust_learning_rate(optimizer, epoch, lr_type, lr_steps): 521 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 522 | if lr_type == 'step': 523 | decay = 0.1 ** (sum(epoch >= np.array(lr_steps))) 524 | lr = args.lr * decay 525 | decay = args.weight_decay 526 | elif lr_type == 'cos': 527 | import math 528 | lr = 0.5 * args.lr * (1 + math.cos(math.pi * epoch / args.epochs)) 529 | decay = args.weight_decay 530 | else: 531 | raise NotImplementedError 532 | for param_group in optimizer.param_groups: 533 | param_group['lr'] = lr * param_group['lr_mult'] 534 | param_group['weight_decay'] = decay * param_group['decay_mult'] 535 | 536 | 537 | def check_rootfolders(): 538 | """Create log and model folder""" 539 | folders_util = [args.root_log, os.path.join(args.root_log, args.store_name)] 540 | for folder in folders_util: 541 | if not os.path.exists(folder): 542 | print('creating folder ' + folder) 543 | os.makedirs(folder) 544 | 545 | 546 | 547 | def adjust_temperature(epoch, step, len_epoch, args): 548 | if epoch >= args.t_end: 549 | return args.t1 550 | else: 551 | T_total = args.t_end * len_epoch 552 | T_cur = epoch * len_epoch + step 553 | alpha = math.pow(args.t1 / args.t0, 1 / T_total) 554 | args.temp = math.pow(alpha, T_cur) * args.t0 555 | 556 | 557 | def adjust_ratio(epoch, args): 558 | if epoch < args.rt_begin : 559 | rt = 1.0 560 | elif epoch < args.rt_begin + (args.rt_end-args.rt_begin)//2: 561 | rt = args.rt + (1.0 - args.rt)/3*2 562 | elif epoch < args.rt_end: 563 | rt = args.rt + (1.0 - args.rt)/3 564 | else: 565 | rt = args.rt 566 | return rt 567 | 568 | 569 | if __name__ == '__main__': 570 | main() 571 | -------------------------------------------------------------------------------- /ops/__init__.py: -------------------------------------------------------------------------------- 1 | from ops.basic_ops import * -------------------------------------------------------------------------------- /ops/backbone/AF_MobileNetv3.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn as nn 3 | import math 4 | import torch 5 | from .gumbel_softmax import GumbleSoftmax 6 | 7 | 8 | __all__ = ['AF_mobilenetv3'] 9 | 10 | 11 | 12 | class TSM(nn.Module): 13 | def __init__(self): 14 | super(TSM, self).__init__() 15 | self.fold_div = 8 16 | 17 | def forward(self, x, n_segment): 18 | x = self.shift(x, n_segment, fold_div=self.fold_div) 19 | return x 20 | 21 | @staticmethod 22 | def shift(x, n_segment, fold_div=3): 23 | if type(n_segment) is int: 24 | nt, c, h, w = x.size() 25 | n_batch = nt // n_segment 26 | x = x.view(n_batch, n_segment, c, h, w) 27 | 28 | fold = c // fold_div 29 | out = torch.zeros_like(x) 30 | out[:, :-1, :fold] = x[:, 1:, :fold] # shift left 31 | out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right 32 | out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift 33 | shift_out = out.view(nt, c, h, w) 34 | else: 35 | num_segment = int(n_segment.sum()) 36 | ls = n_segment 37 | bool_list = ls > 0 38 | bool_list = bool_list.view(-1) 39 | 40 | shift_out = torch.zeros_like(x) 41 | x = x[bool_list] 42 | nt, c, h, w = x.size() 43 | x = x.view(-1, num_segment, c, h, w) 44 | 45 | fold = c // fold_div 46 | out = torch.zeros_like(x) 47 | out[:, :-1, :fold] = x[:, 1:, :fold] # shift left 48 | out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right 49 | out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift 50 | out = out.view(-1, c, h, w) 51 | shift_out[bool_list] = out 52 | 53 | return shift_out 54 | 55 | 56 | class dynamic_fusion(nn.Module): 57 | def __init__(self, channel, reduction=16): 58 | super(dynamic_fusion, self).__init__() 59 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 60 | self.reduction = reduction 61 | self.fc = nn.Sequential( 62 | nn.Linear(channel, int(channel // reduction)), 63 | nn.ReLU(inplace=True), 64 | nn.Linear(int(channel // reduction), channel), 65 | nn.Sigmoid() 66 | ) 67 | for m in self.modules(): 68 | if isinstance(m, nn.Linear): 69 | nn.init.normal_(m.weight, 0, 0.01) 70 | 71 | def forward(self, x): 72 | b, c, h, w = x.size() 73 | y = self.avg_pool(x).view(b,c) 74 | attention = self.fc(y) 75 | return attention.view(b,c,1,1) 76 | 77 | def forward_calc_flops(self, x): 78 | b, c, h, w = x.size() 79 | flops = c*h*w 80 | y = self.avg_pool(x).view(b,c) 81 | attention = self.fc(y) 82 | flops += c*c//self.reduction*2 + c 83 | return attention.view(b,c,1,1), flops 84 | 85 | 86 | def _make_divisible(v, divisor, min_value=None): 87 | """ 88 | This function is taken from the original tf repo. 89 | It ensures that all layers have a channel number that is divisible by 8 90 | It can be seen here: 91 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 92 | :param v: 93 | :param divisor: 94 | :param min_value: 95 | :return: 96 | """ 97 | if min_value is None: 98 | min_value = divisor 99 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 100 | # Make sure that round down does not go down by more than 10%. 101 | if new_v < 0.9 * v: 102 | new_v += divisor 103 | return new_v 104 | 105 | 106 | class h_sigmoid(nn.Module): 107 | def __init__(self, inplace=True): 108 | super(h_sigmoid, self).__init__() 109 | self.relu = nn.ReLU6(inplace=inplace) 110 | 111 | def forward(self, x): 112 | return self.relu(x + 3) / 6 113 | 114 | 115 | class h_swish(nn.Module): 116 | def __init__(self, inplace=True): 117 | super(h_swish, self).__init__() 118 | self.sigmoid = h_sigmoid(inplace=inplace) 119 | 120 | def forward(self, x): 121 | return x * self.sigmoid(x) 122 | 123 | 124 | class SELayer(nn.Module): 125 | def __init__(self, channel, reduction=4): 126 | super(SELayer, self).__init__() 127 | self.reduction = reduction 128 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 129 | self.fc = nn.Sequential( 130 | nn.Linear(channel, _make_divisible(channel // reduction, 8)), 131 | nn.ReLU(inplace=True), 132 | nn.Linear(_make_divisible(channel // reduction, 8), channel), 133 | h_sigmoid() 134 | ) 135 | 136 | def forward(self, x): 137 | b, c, _, _ = x.size() 138 | y = self.avg_pool(x).view(b, c) 139 | y = self.fc(y).view(b, c, 1, 1) 140 | return x * y 141 | 142 | def forward_calc_flops(self, x): 143 | b, c, h, w = x.size() 144 | flops = c*h*w 145 | y = self.avg_pool(x).view(b,c) 146 | y = self.fc(y).view(b, c, 1, 1) 147 | flops += c*c//self.reduction*2 + c 148 | return x * y, flops 149 | 150 | 151 | def conv_3x3_bn(inp, oup, stride): 152 | return nn.Sequential( 153 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 154 | nn.BatchNorm2d(oup), 155 | h_swish() 156 | ) 157 | 158 | 159 | def conv_1x1_bn(inp, oup): 160 | return nn.Sequential( 161 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 162 | nn.BatchNorm2d(oup), 163 | h_swish() 164 | ) 165 | 166 | 167 | class InvertedResidual_ample(nn.Module): 168 | def __init__(self, n_segment, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs): 169 | super(InvertedResidual_ample, self).__init__() 170 | assert stride in [1, 2] 171 | 172 | self.identity = stride == 1 and inp == oup 173 | self.tsm = TSM() 174 | self.inp = inp 175 | self.hidden_dim = hidden_dim 176 | self.use_se = use_se 177 | 178 | if inp == hidden_dim: 179 | # dw 180 | self.conv1 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False) 181 | self.bn1 = nn.BatchNorm2d(hidden_dim) 182 | self.act1 = h_swish() if use_hs else nn.ReLU(inplace=True) 183 | # Squeeze-and-Excite 184 | self.se = SELayer(hidden_dim) if use_se else nn.Identity() 185 | # pw-linear 186 | self.conv2 = nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False) 187 | self.bn2 = nn.BatchNorm2d(oup) 188 | else: 189 | # pw 190 | self.conv1 = nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False) 191 | self.bn1 = nn.BatchNorm2d(hidden_dim) 192 | self.act1 = h_swish() if use_hs else nn.ReLU(inplace=True) 193 | # dw 194 | self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False) 195 | self.bn2 = nn.BatchNorm2d(hidden_dim) 196 | # Squeeze-and-Excite 197 | self.se = SELayer(hidden_dim) if use_se else nn.Identity() 198 | self.act2 = h_swish() if use_hs else nn.ReLU(inplace=True) 199 | # pw-linear 200 | self.conv3 = nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False) 201 | self.bn3 = nn.BatchNorm2d(oup) 202 | 203 | def forward(self, x, list_little): 204 | residual = x 205 | if self.inp == self.hidden_dim: 206 | x = self.tsm(x, list_little) 207 | x = self.conv1(x) 208 | x = self.bn1(x) 209 | x = self.act1(x) 210 | x = self.se(x) 211 | x = self.conv2(x) 212 | x = self.bn2(x) 213 | else: 214 | x = self.tsm(x, list_little) 215 | x = self.conv1(x) 216 | x = self.bn1(x) 217 | x = self.act1(x) 218 | x = self.conv2(x) 219 | x = self.bn2(x) 220 | x = self.se(x) 221 | x = self.act2(x) 222 | x = self.conv3(x) 223 | x = self.bn3(x) 224 | 225 | if self.identity: 226 | return x + residual 227 | else: 228 | return x 229 | 230 | def forward_calc_flops(self, x, list_little): 231 | flops = 0 232 | residual = x 233 | if self.inp == self.hidden_dim: 234 | x = self.tsm(x, list_little) 235 | 236 | c_in = x.shape[1] 237 | x = self.conv1(x) 238 | flops += c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv1.kernel_size[0] * self.conv1.kernel_size[1] / self.conv1.groups 239 | 240 | x = self.bn1(x) 241 | x = self.act1(x) 242 | if self.use_se == True: 243 | x, _flops = self.se.forward_calc_flops(x) 244 | flops += _flops 245 | else: 246 | x = self.se(x) 247 | 248 | c_in = x.shape[1] 249 | x = self.conv2(x) 250 | flops += c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv2.kernel_size[0] * self.conv2.kernel_size[1] / self.conv2.groups 251 | x = self.bn2(x) 252 | else: 253 | x = self.tsm(x, list_little) 254 | 255 | c_in = x.shape[1] 256 | x = self.conv1(x) 257 | flops += c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv1.kernel_size[0] * self.conv1.kernel_size[1] / self.conv1.groups 258 | x = self.bn1(x) 259 | x = self.act1(x) 260 | 261 | c_in = x.shape[1] 262 | x = self.conv2(x) 263 | flops += c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv2.kernel_size[0] * self.conv2.kernel_size[1] / self.conv2.groups 264 | x = self.bn2(x) 265 | if self.use_se == True: 266 | x, _flops = self.se.forward_calc_flops(x) 267 | flops += _flops 268 | else: 269 | x = self.se(x) 270 | x = self.act2(x) 271 | 272 | c_in = x.shape[1] 273 | x = self.conv3(x) 274 | flops += c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv3.kernel_size[0] * self.conv3.kernel_size[1] / self.conv3.groups 275 | x = self.bn3(x) 276 | if self.identity: 277 | return x + residual, flops 278 | else: 279 | return x, flops 280 | 281 | 282 | class InvertedResidual_focal(nn.Module): 283 | def __init__(self, n_segment, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs): 284 | super(InvertedResidual_focal, self).__init__() 285 | assert stride in [1, 2] 286 | self.n_segment = n_segment 287 | self.identity = stride == 1 and inp == oup 288 | self.tsm = TSM() 289 | if stride != 1 or inp != oup: 290 | self.res_connect = nn.Sequential( 291 | nn.Conv2d(inp, oup, kernel_size=1, stride=stride, padding=0, bias=False, groups=2), 292 | nn.BatchNorm2d(oup) 293 | ) 294 | 295 | self.inp = inp 296 | self.hidden_dim = hidden_dim 297 | self.use_se = use_se 298 | 299 | if inp == hidden_dim: 300 | # dw 301 | self.conv1 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False) 302 | self.bn1 = nn.BatchNorm2d(hidden_dim) 303 | self.act1 = h_swish() if use_hs else nn.ReLU(inplace=True) 304 | # Squeeze-and-Excite 305 | self.se = SELayer(hidden_dim) if use_se else nn.Identity() 306 | # pw-linear 307 | self.conv2 = nn.Conv2d(hidden_dim, oup, 1, 1, 0, groups=2, bias=False) 308 | self.bn2 = nn.BatchNorm2d(oup) 309 | else: 310 | # pw 311 | self.conv1 = nn.Conv2d(inp, hidden_dim, 1, 1, 0, groups=2, bias=False) 312 | self.bn1 = nn.BatchNorm2d(hidden_dim) 313 | self.act1 = h_swish() if use_hs else nn.ReLU(inplace=True) 314 | # dw 315 | self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False) 316 | self.bn2 = nn.BatchNorm2d(hidden_dim) 317 | # Squeeze-and-Excite 318 | self.se = SELayer(hidden_dim) if use_se else nn.Identity() 319 | self.act2 = h_swish() if use_hs else nn.ReLU(inplace=True) 320 | # pw-linear 321 | self.conv3 = nn.Conv2d(hidden_dim, oup, 1, 1, 0, groups=2, bias=False) 322 | self.bn3 = nn.BatchNorm2d(oup) 323 | 324 | def forward(self, x, list_big): 325 | if self.identity: 326 | residual = x 327 | else: 328 | residual = self.res_connect(x) 329 | 330 | if self.inp == self.hidden_dim: 331 | x = self.tsm(x, self.n_segment) 332 | x = x * list_big 333 | x = self.conv1(x) 334 | x = self.bn1(x) 335 | x = self.act1(x) 336 | x = self.se(x) 337 | x = self.conv2(x) 338 | x = self.bn2(x) 339 | else: 340 | x = self.tsm(x, self.n_segment) 341 | x = x * list_big 342 | x = self.conv1(x) 343 | x = self.bn1(x) 344 | x = self.act1(x) 345 | x = self.conv2(x) 346 | x = self.bn2(x) 347 | x = self.se(x) 348 | x = self.act2(x) 349 | x = self.conv3(x) 350 | x = self.bn3(x) 351 | 352 | return x + residual 353 | 354 | 355 | def forward_calc_flops(self, x, list_big): 356 | flops = 0 357 | 358 | if self.identity: 359 | residual = x 360 | else: 361 | c_in = x.shape[1] 362 | residual = self.res_connect(x) 363 | flops += c_in * residual.shape[1] * residual.shape[2] * residual.shape[3] / self.res_connect[0].groups 364 | 365 | if self.inp == self.hidden_dim: 366 | x = self.tsm(x, self.n_segment) 367 | x = x * list_big 368 | select_ratio = torch.mean(list_big) 369 | # select_ratio = 1 370 | 371 | c_in = x.shape[1] 372 | x = self.conv1(x) 373 | flops += select_ratio * c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv1.kernel_size[0] * self.conv1.kernel_size[1] / self.conv1.groups 374 | 375 | x = self.bn1(x) 376 | x = self.act1(x) 377 | if self.use_se == True: 378 | x, _flops = self.se.forward_calc_flops(x) 379 | flops += select_ratio * _flops 380 | else: 381 | x = self.se(x) 382 | 383 | c_in = x.shape[1] 384 | x = self.conv2(x) 385 | flops += select_ratio * c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv2.kernel_size[0] * self.conv2.kernel_size[1] / self.conv2.groups 386 | x = self.bn2(x) 387 | else: 388 | x = self.tsm(x, self.n_segment) 389 | x = x * list_big 390 | select_ratio = torch.mean(list_big) 391 | # select_ratio = 1 392 | 393 | c_in = x.shape[1] 394 | x = self.conv1(x) 395 | flops += select_ratio * c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv1.kernel_size[0] * self.conv1.kernel_size[1] / self.conv1.groups 396 | x = self.bn1(x) 397 | x = self.act1(x) 398 | 399 | c_in = x.shape[1] 400 | x = self.conv2(x) 401 | flops += select_ratio * c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv2.kernel_size[0] * self.conv2.kernel_size[1] / self.conv2.groups 402 | x = self.bn2(x) 403 | if self.use_se == True: 404 | x, _flops = self.se.forward_calc_flops(x) 405 | flops += select_ratio * _flops 406 | else: 407 | x = self.se(x) 408 | x = self.act2(x) 409 | 410 | c_in = x.shape[1] 411 | x = self.conv3(x) 412 | flops += select_ratio * c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv3.kernel_size[0] * self.conv3.kernel_size[1] / self.conv3.groups 413 | x = self.bn3(x) 414 | 415 | return x + residual, flops 416 | 417 | 418 | 419 | class navigation(nn.Module): 420 | def __init__(self, inplanes=64, num_segments=8): 421 | super(navigation,self).__init__() 422 | self.num_segments = num_segments 423 | self.conv_pool = nn.Conv2d(inplanes, 2, kernel_size=1, padding=0, stride=1, bias=False) 424 | self.bn = nn.BatchNorm2d(2) 425 | self.relu = nn.ReLU(inplace=True) 426 | self.pool = nn.AdaptiveAvgPool2d((1,1)) 427 | self.conv_gs = nn.Conv2d(2*num_segments, 2*num_segments, kernel_size=1, padding=0, stride=1, bias=True, groups=num_segments) 428 | self.conv_gs.bias.data[:2*num_segments:2] = 1.0 429 | self.conv_gs.bias.data[1:2*num_segments+1:2] = 10.0 430 | self.gs = GumbleSoftmax() 431 | 432 | def forward(self, x, temperature=1.0): 433 | gates = self.pool(x) 434 | gates = self.conv_pool(gates) 435 | gates = self.bn(gates) 436 | gates = self.relu(gates) 437 | 438 | batch = x.shape[0] // self.num_segments 439 | 440 | gates = gates.view(batch, self.num_segments*2,1,1) 441 | gates = self.conv_gs(gates) 442 | 443 | gates = gates.view(batch, self.num_segments, 2, 1, 1) 444 | gates = self.gs(gates, temp=temperature, force_hard=True) 445 | list_big = gates[:, :, 1, :, :] 446 | list_big = list_big.view(x.shape[0],1,1,1) 447 | 448 | return list_big 449 | 450 | def forward_calc_flops(self, x, temperature=1.0): 451 | flops = 0 452 | 453 | flops += x.shape[1] * x.shape[2] * x.shape[3] 454 | gates = self.pool(x) 455 | 456 | c_in = gates.shape[1] 457 | gates = self.conv_pool(gates) 458 | flops += c_in * gates.shape[1] * gates.shape[2] * gates.shape[3] 459 | gates = self.bn(gates) 460 | gates = self.relu(gates) 461 | 462 | batch = x.shape[0] // self.num_segments 463 | 464 | gates = gates.view(batch, self.num_segments*2,1,1) 465 | gates = self.conv_gs(gates) 466 | flops += self.num_segments * 2 * gates.shape[1] * gates.shape[2] * gates.shape[3] / self.conv_gs.groups 467 | 468 | gates = gates.view(batch, self.num_segments, 2, 1, 1) 469 | gates = self.gs(gates, temp=temperature, force_hard=True) 470 | list_big = gates[:, :, 1, :, :] 471 | list_big = list_big.view(x.shape[0],1,1,1) 472 | 473 | return list_big, flops 474 | 475 | 476 | 477 | class AFMobileNetV3(nn.Module): 478 | def __init__(self, num_segments, num_class, cfgs_head, cfgs_stage1, cfgs_stage2_ample, 479 | cfgs_stage2_focal, cfgs_stage2_fuse, cfgs_stage3_ample, cfgs_stage3_focal, 480 | cfgs_stage3_fuse, cfgs_stage4, cfgs_stage5, mode, width_mult=1.): 481 | super(AFMobileNetV3, self).__init__() 482 | # setting of inverted residual blocks 483 | self.num_segments = num_segments 484 | self.cfgs_head = cfgs_head 485 | self.cfgs_stage1 = cfgs_stage1 486 | self.cfgs_stage2_ample = cfgs_stage2_ample 487 | self.cfgs_stage2_focal = cfgs_stage2_focal 488 | self.cfgs_stage2_fuse = cfgs_stage2_fuse 489 | self.cfgs_stage3_ample = cfgs_stage3_ample 490 | self.cfgs_stage3_focal = cfgs_stage3_focal 491 | self.cfgs_stage3_fuse = cfgs_stage3_fuse 492 | self.cfgs_stage4 = cfgs_stage4 493 | self.cfgs_stage5 = cfgs_stage5 494 | assert mode in ['large', 'small'] 495 | 496 | # building first layer 497 | input_channel = _make_divisible(16 * width_mult, 8) 498 | self.conv = nn.Conv2d(3, input_channel, 3, 2, 1, bias=False) 499 | self.bn = nn.BatchNorm2d(input_channel) 500 | self.act = h_swish() 501 | # building inverted residual blocks 502 | block_base = InvertedResidual_ample 503 | block_refine = InvertedResidual_focal 504 | 505 | layers = [] 506 | for k, t, c, use_se, use_hs, s in self.cfgs_head: 507 | output_channel = _make_divisible(c * width_mult, 8) 508 | exp_size = _make_divisible(input_channel * t, 8) 509 | layers.append(block_base(num_segments, input_channel, exp_size, output_channel, k, s, use_se, use_hs)) 510 | input_channel = output_channel 511 | self.features_head = nn.Sequential(*layers) 512 | 513 | 514 | ###### stage 1 515 | layers_stage1 = [] 516 | for k, t, c, use_se, use_hs, s in self.cfgs_stage1: 517 | output_channel = _make_divisible(c * width_mult, 8) 518 | exp_size = _make_divisible(input_channel * t, 8) 519 | layers_stage1.append(block_base(num_segments, input_channel, exp_size, output_channel, k, s, use_se, use_hs)) 520 | input_channel = output_channel 521 | self.features_stage1 = nn.Sequential(*layers_stage1) 522 | 523 | 524 | ###### stage 2 525 | input_channel_before = input_channel 526 | layers_stage2_ample = [] 527 | frame_gen_list_stage2 = [] 528 | for k, t, c, use_se, use_hs, s in self.cfgs_stage2_ample: 529 | output_channel = _make_divisible(c * width_mult, 8) 530 | exp_size = _make_divisible(input_channel * t, 8) 531 | layers_stage2_ample.append(block_base(num_segments, input_channel, exp_size, output_channel, k, s, use_se, use_hs)) 532 | frame_gen_list_stage2.append(navigation(inplanes=input_channel,num_segments=num_segments)) 533 | input_channel = output_channel 534 | self.list_gen2 = nn.ModuleList(frame_gen_list_stage2) 535 | self.features_stage2_base = nn.Sequential(*layers_stage2_ample) 536 | 537 | layers_stage2_focal = [] 538 | for k, t, c, use_se, use_hs, s in self.cfgs_stage2_focal: 539 | output_channel = _make_divisible(c * width_mult, 8) 540 | exp_size = _make_divisible(input_channel_before * t, 8) 541 | layers_stage2_focal.append(block_refine(num_segments, input_channel_before, exp_size, output_channel, k, s, use_se, use_hs)) 542 | input_channel_before = output_channel 543 | input_channel = input_channel_before 544 | self.features_stage2_refine = nn.Sequential(*layers_stage2_focal) 545 | 546 | layers_stage2_fuse = [] 547 | for k, t, c, use_se, use_hs, s in self.cfgs_stage2_fuse: 548 | output_channel = _make_divisible(c * width_mult, 8) 549 | exp_size = _make_divisible(input_channel * t, 8) 550 | layers_stage2_fuse.append(block_base(num_segments, input_channel, exp_size, output_channel, k, s, use_se, use_hs)) 551 | input_channel = output_channel 552 | self.features_stage2_fuse = nn.Sequential(*layers_stage2_fuse) 553 | self.att_gen2 = dynamic_fusion(channel=input_channel, reduction=16) 554 | 555 | 556 | ###### stage 3 557 | input_channel_before = input_channel 558 | layers_stage3_ample = [] 559 | frame_gen_list_stage3 = [] 560 | for k, t, c, use_se, use_hs, s in self.cfgs_stage3_ample: 561 | output_channel = _make_divisible(c * width_mult, 8) 562 | exp_size = _make_divisible(input_channel * t, 8) 563 | layers_stage3_ample.append(block_base(num_segments, input_channel, exp_size, output_channel, k, s, use_se, use_hs)) 564 | frame_gen_list_stage3.append(navigation(inplanes=input_channel,num_segments=num_segments)) 565 | input_channel = output_channel 566 | self.list_gen3 = nn.ModuleList(frame_gen_list_stage3) 567 | self.features_stage3_base = nn.Sequential(*layers_stage3_ample) 568 | 569 | layers_stage3_focal = [] 570 | for k, t, c, use_se, use_hs, s in self.cfgs_stage3_focal: 571 | output_channel = _make_divisible(c * width_mult, 8) 572 | exp_size = _make_divisible(input_channel_before * t, 8) 573 | layers_stage3_focal.append(block_refine(num_segments, input_channel_before, exp_size, output_channel, k, s, use_se, use_hs)) 574 | input_channel_before = output_channel 575 | input_channel = input_channel_before 576 | self.features_stage3_refine = nn.Sequential(*layers_stage3_focal) 577 | 578 | layers_stage3_fuse = [] 579 | for k, t, c, use_se, use_hs, s in self.cfgs_stage3_fuse: 580 | output_channel = _make_divisible(c * width_mult, 8) 581 | exp_size = _make_divisible(input_channel * t, 8) 582 | layers_stage3_fuse.append(block_base(num_segments, input_channel, exp_size, output_channel, k, s, use_se, use_hs)) 583 | input_channel = output_channel 584 | self.features_stage3_fuse = nn.Sequential(*layers_stage3_fuse) 585 | self.att_gen3 = dynamic_fusion(channel=input_channel, reduction=16) 586 | 587 | 588 | ###### stage 4 589 | layers_stage4 = [] 590 | for k, t, c, use_se, use_hs, s in self.cfgs_stage4: 591 | output_channel = _make_divisible(c * width_mult, 8) 592 | exp_size = _make_divisible(input_channel * t, 8) 593 | layers_stage4.append(block_base(num_segments, input_channel, exp_size, output_channel, k, s, use_se, use_hs)) 594 | input_channel = output_channel 595 | self.features_stage4 = nn.Sequential(*layers_stage4) 596 | 597 | ###### stage 5 598 | layers_stage5 = [] 599 | for k, t, c, use_se, use_hs, s in self.cfgs_stage5: 600 | output_channel = _make_divisible(c * width_mult, 8) 601 | exp_size = _make_divisible(input_channel * t, 8) 602 | layers_stage5.append(block_base(num_segments, input_channel, exp_size, output_channel, k, s, use_se, use_hs)) 603 | input_channel = output_channel 604 | self.features_stage5 = nn.Sequential(*layers_stage5) 605 | 606 | # building last several layers 607 | # self.conv_last = conv_1x1_bn(input_channel, exp_size) 608 | self.conv_last = nn.Conv2d(input_channel, exp_size, 1, 1, 0, bias=False) 609 | self.bn_last = nn.BatchNorm2d(exp_size) 610 | self.act_last = h_swish() 611 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 612 | output_channel = {'large': 1280, 'small': 1024} 613 | self.output_channel_num = output_channel[mode] 614 | output_channel = _make_divisible(output_channel[mode] * width_mult, 8) if width_mult > 1.0 else output_channel[mode] 615 | self.fc = nn.Sequential( 616 | nn.Linear(exp_size, output_channel), 617 | h_swish(), 618 | nn.Dropout(0.5), 619 | nn.Linear(output_channel, num_class), 620 | ) 621 | 622 | self._initialize_weights() 623 | 624 | def forward(self, x, temperature=1e-8): 625 | _lists = [] 626 | x = self.conv(x) 627 | x = self.bn(x) 628 | x = self.act(x) 629 | 630 | x = self.features_head[0](x, self.num_segments) 631 | 632 | 633 | for i in range(len(self.features_stage1)): 634 | x = self.features_stage1[i](x, self.num_segments) 635 | 636 | 637 | x_base = x 638 | x_refine = x 639 | for i in range(len(self.features_stage2_base)): 640 | list_big = self.list_gen2[i](x_base, temperature=temperature) 641 | _lists.append(list_big) 642 | x_base = self.features_stage2_base[i](x_base, self.num_segments) 643 | x_refine = self.features_stage2_refine[i](x_refine, list_big) 644 | _,_,h,w = x_refine.shape 645 | x_base = F.interpolate(x_base, size = (h,w)) 646 | att = self.att_gen2(x_base+x_refine) 647 | x = self.features_stage2_fuse[0](att*x_base + (1-att)*x_refine, self.num_segments) 648 | 649 | 650 | x_base = x 651 | x_refine = x 652 | for i in range(len(self.features_stage3_base)): 653 | list_big = self.list_gen3[i](x_base, temperature=temperature) 654 | _lists.append(list_big) 655 | x_base = self.features_stage3_base[i](x_base, self.num_segments) 656 | x_refine = self.features_stage3_refine[i](x_refine, list_big) 657 | _,_,h,w = x_refine.shape 658 | x_base = F.interpolate(x_base, size = (h,w)) 659 | att = self.att_gen3(x_base+x_refine) 660 | x = self.features_stage3_fuse[0](att*x_base + (1-att)*x_refine, self.num_segments) 661 | 662 | 663 | for i in range(len(self.features_stage4)): 664 | x = self.features_stage4[i](x, self.num_segments) 665 | for i in range(len(self.features_stage5)): 666 | x = self.features_stage5[i](x, self.num_segments) 667 | 668 | 669 | x = self.conv_last(x) 670 | x = self.bn_last(x) 671 | x = self.act_last(x) 672 | x = self.avgpool(x) 673 | x = x.view(x.size(0), -1) 674 | 675 | 676 | x = self.fc(x) 677 | 678 | return x, _lists 679 | 680 | def forward_calc_flops(self, x, temperature=1e-8): 681 | flops = 0 682 | _lists = [] 683 | 684 | c_in = x.shape[1] 685 | x = self.conv(x) 686 | flops += c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv.kernel_size[0] * self.conv.kernel_size[1] / self.conv.groups 687 | x = self.bn(x) 688 | x = self.act(x) 689 | 690 | x, _flops = self.features_head[0].forward_calc_flops(x, self.num_segments) 691 | flops += _flops 692 | 693 | 694 | for i in range(len(self.features_stage1)): 695 | x, _flops = self.features_stage1[i].forward_calc_flops(x, self.num_segments) 696 | flops += _flops 697 | 698 | 699 | x_base = x 700 | x_refine = x 701 | for i in range(len(self.features_stage2_base)): 702 | list_big, _flops = self.list_gen2[i].forward_calc_flops(x_base, temperature=temperature) 703 | _lists.append(list_big) 704 | flops += _flops 705 | x_base, _flops = self.features_stage2_base[i].forward_calc_flops(x_base, self.num_segments) 706 | flops += _flops 707 | x_refine, _flops = self.features_stage2_refine[i].forward_calc_flops(x_refine, list_big) 708 | flops += _flops 709 | _,_,h,w = x_refine.shape 710 | x_base = F.interpolate(x_base, size = (h,w)) 711 | att, _flops = self.att_gen2.forward_calc_flops(x_base+x_refine) 712 | flops += _flops 713 | x, _flops = self.features_stage2_fuse[0].forward_calc_flops(att*x_base + (1-att)*x_refine, self.num_segments) 714 | flops += _flops 715 | 716 | 717 | x_base = x 718 | x_refine = x 719 | for i in range(len(self.features_stage3_base)): 720 | list_big, _flops = self.list_gen3[i].forward_calc_flops(x_base, temperature=temperature) 721 | _lists.append(list_big) 722 | flops += _flops 723 | x_base, _flops = self.features_stage3_base[i].forward_calc_flops(x_base, self.num_segments) 724 | flops += _flops 725 | x_refine, _flops = self.features_stage3_refine[i].forward_calc_flops(x_refine, list_big) 726 | flops += _flops 727 | _,_,h,w = x_refine.shape 728 | x_base = F.interpolate(x_base, size = (h,w)) 729 | att, _flops = self.att_gen3.forward_calc_flops(x_base+x_refine) 730 | flops += _flops 731 | x, _flops = self.features_stage3_fuse[0].forward_calc_flops(att*x_base + (1-att)*x_refine, self.num_segments) 732 | flops += _flops 733 | 734 | 735 | for i in range(len(self.features_stage4)): 736 | x, _flops = self.features_stage4[i].forward_calc_flops(x, self.num_segments) 737 | flops += _flops 738 | 739 | for i in range(len(self.features_stage5)): 740 | x, _flops = self.features_stage5[i].forward_calc_flops(x, self.num_segments) 741 | flops += _flops 742 | 743 | c_in = x.shape[1] 744 | x = self.conv_last(x) 745 | flops += c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv_last.kernel_size[0] * self.conv_last.kernel_size[1] / self.conv_last.groups 746 | x = self.bn_last(x) 747 | x = self.act_last(x) 748 | x = self.avgpool(x) 749 | x = x.view(x.size(0), -1) 750 | 751 | 752 | c_in = x.shape[1] 753 | x = self.fc(x) 754 | c_out = x.shape[1] 755 | flops += c_in * self.output_channel_num + self.output_channel_num * c_out 756 | 757 | return x, _lists, self.num_segments * flops 758 | 759 | def _initialize_weights(self): 760 | for m in self.modules(): 761 | if isinstance(m, nn.Conv2d): 762 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 763 | m.weight.data.normal_(0, math.sqrt(2. / n)) 764 | # if m.bias is not None: 765 | # m.bias.data.zero_() 766 | elif isinstance(m, nn.BatchNorm2d): 767 | m.weight.data.fill_(1) 768 | m.bias.data.zero_() 769 | elif isinstance(m, nn.Linear): 770 | m.weight.data.normal_(0, 0.01) 771 | m.bias.data.zero_() 772 | 773 | 774 | def af_mobilenetv3(num_segments, num_class, **kwargs): 775 | """ 776 | Constructs a MobileNetV3-Large model 777 | """ 778 | cfgs_head = [ 779 | # k, t, c, SE, HS, s 780 | [3, 1, 16, 0, 0, 1]] 781 | cfgs_stage1 = [ 782 | [3, 4, 24, 0, 0, 2], 783 | [3, 3, 24, 0, 0, 2]] 784 | cfgs_stage2_ample = [ 785 | [5, 3, 40, 1, 0, 2], 786 | [5, 3, 40, 1, 0, 1]] 787 | cfgs_stage2_focal = [ 788 | [5, 3, 40, 1, 0, 1], 789 | [5, 3, 40, 1, 0, 1]] 790 | cfgs_stage2_fuse = [ 791 | [5, 3, 40, 1, 0, 2]] 792 | cfgs_stage3_ample = [ 793 | [3, 6, 80, 0, 1, 2], 794 | [3, 2.5, 80, 0, 1, 1], 795 | [3, 2.3, 80, 0, 1, 1]] 796 | cfgs_stage3_focal = [ 797 | [3, 6, 80, 0, 1, 1], 798 | [3, 2.5, 80, 0, 1, 1], 799 | [3, 2.3, 80, 0, 1, 1]] 800 | cfgs_stage3_fuse = [ 801 | [3, 2.3, 80, 0, 1, 1]] 802 | cfgs_stage4 = [ 803 | [3, 6, 112, 1, 1, 1], 804 | [3, 6, 112, 1, 1, 1]] 805 | cfgs_stage5 = [ 806 | [5, 6, 160, 1, 1, 2], 807 | [5, 6, 160, 1, 1, 1], 808 | [5, 6, 160, 1, 1, 1]] 809 | return AFMobileNetV3(num_segments, num_class, cfgs_head, cfgs_stage1, cfgs_stage2_ample, 810 | cfgs_stage2_focal, cfgs_stage2_fuse, cfgs_stage3_ample, cfgs_stage3_focal, 811 | cfgs_stage3_fuse, cfgs_stage4, cfgs_stage5, mode='large', **kwargs) 812 | 813 | 814 | 815 | def AF_mobilenetv3(pretrained=False, path_backbone = '.../.../checkpoint/ImageNet/AF-MobileNetv3.pth.tar', shift=False, num_segments=8, num_class=174, **kwargs): 816 | model = af_mobilenetv3(num_segments, num_class) 817 | if pretrained: 818 | checkpoint = torch.load(path_backbone, map_location='cpu') 819 | pretrained_dict = checkpoint['state_dict'] 820 | new_state_dict = model.state_dict() 821 | for k, v in pretrained_dict.items(): 822 | if (k[7:] in new_state_dict): 823 | new_state_dict.update({k[7:]:v}) 824 | model.load_state_dict(new_state_dict) 825 | return model -------------------------------------------------------------------------------- /ops/backbone/AF_ResNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | # from torch.hub import load_state_dict_from_url 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import torchvision.transforms as transforms 7 | import math 8 | from .gumbel_softmax import GumbleSoftmax 9 | 10 | 11 | 12 | class dynamic_fusion(nn.Module): 13 | def __init__(self, channel, reduction=16): 14 | super(dynamic_fusion, self).__init__() 15 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 16 | self.reduction = reduction 17 | self.fc = nn.Sequential( 18 | nn.Linear(channel, int(channel // reduction), bias=False), 19 | nn.ReLU(inplace=True), 20 | nn.Linear(int(channel // reduction), channel, bias=False), 21 | nn.Sigmoid() 22 | ) 23 | for m in self.modules(): 24 | if isinstance(m, nn.Linear): 25 | nn.init.normal_(m.weight, 0, 0.01) 26 | 27 | def forward(self, x): 28 | b, c, h, w = x.size() 29 | y = self.avg_pool(x).view(b,c) 30 | attention = self.fc(y) 31 | return attention.view(b,c,1,1) 32 | 33 | def forward_calc_flops(self, x): 34 | b, c, h, w = x.size() 35 | flops = c*h*w 36 | y = self.avg_pool(x).view(b,c) 37 | attention = self.fc(y) 38 | flops += c*c//self.reduction*2 + c 39 | return attention.view(b,c,1,1), flops 40 | 41 | 42 | class TSM(nn.Module): 43 | def __init__(self): 44 | super(TSM, self).__init__() 45 | self.fold_div = 8 46 | 47 | def forward(self, x, n_segment): 48 | x = self.shift(x, n_segment, fold_div=self.fold_div) 49 | return x 50 | 51 | @staticmethod 52 | def shift(x, n_segment, fold_div=3): 53 | if type(n_segment) is int: 54 | nt, c, h, w = x.size() 55 | n_batch = nt // n_segment 56 | x = x.view(n_batch, n_segment, c, h, w) 57 | 58 | fold = c // fold_div 59 | out = torch.zeros_like(x) 60 | out[:, :-1, :fold] = x[:, 1:, :fold] # shift left 61 | out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right 62 | out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift 63 | shift_out = out.view(nt, c, h, w) 64 | else: 65 | num_segment = int(n_segment.sum()) 66 | ls = n_segment 67 | bool_list = ls > 0 68 | bool_list = bool_list.view(-1) 69 | 70 | shift_out = torch.zeros_like(x) 71 | x = x[bool_list] 72 | nt, c, h, w = x.size() 73 | x = x.view(-1, num_segment, c, h, w) 74 | 75 | fold = c // fold_div 76 | out = torch.zeros_like(x) 77 | out[:, :-1, :fold] = x[:, 1:, :fold] # shift left 78 | out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right 79 | out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift 80 | out = out.view(-1, c, h, w) 81 | shift_out[bool_list] = out 82 | 83 | return shift_out 84 | 85 | 86 | class Bottleneck_ample(nn.Module): 87 | expansion = 4 88 | 89 | def __init__(self, inplanes, planes, num_segments, stride=1, downsample=None, last_relu=True, patch_groups=1, 90 | base_scale=2, is_first=False): 91 | super(Bottleneck_ample, self).__init__() 92 | self.num_segments = num_segments 93 | self.conv1 = nn.Conv2d(inplanes, planes // self.expansion, kernel_size=1, bias=False) 94 | self.bn1 = nn.BatchNorm2d(planes // self.expansion) 95 | self.conv2 = nn.Conv2d(planes // self.expansion, planes // self.expansion, kernel_size=3, stride=stride, 96 | padding=1, bias=False, groups=1) 97 | self.bn2 = nn.BatchNorm2d(planes // self.expansion) 98 | self.conv3 = nn.Conv2d(planes // self.expansion, planes, kernel_size=1, bias=False) 99 | self.bn3 = nn.BatchNorm2d(planes) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.tsm = TSM() 102 | 103 | self.downsample = downsample 104 | self.have_pool = False 105 | self.have_1x1conv2d = False 106 | 107 | self.first_downsample = nn.AvgPool2d(3, stride=2, padding=1) if (base_scale == 4 and is_first) else None 108 | 109 | if self.downsample is not None: 110 | self.have_pool = True 111 | if len(self.downsample) > 1: 112 | self.have_1x1conv2d = True 113 | 114 | self.stride = stride 115 | self.last_relu = last_relu 116 | 117 | def forward(self, x, list_little, activate_tsm=False): 118 | if self.first_downsample is not None: 119 | x = self.first_downsample(x) 120 | residual = x 121 | if self.downsample is not None: 122 | residual = self.downsample(x) 123 | 124 | if activate_tsm: 125 | out = self.tsm(x, list_little) 126 | else: 127 | out = x 128 | 129 | out = self.conv1(out) 130 | out = self.bn1(out) 131 | out = self.relu(out) 132 | 133 | out = self.conv2(out) 134 | out = self.bn2(out) 135 | out = self.relu(out) 136 | 137 | out = self.conv3(out) 138 | out = self.bn3(out) 139 | 140 | out += residual 141 | if self.last_relu: 142 | out = self.relu(out) 143 | return out 144 | 145 | def forward_calc_flops(self, x, list_little, activate_tsm=False): 146 | flops = 0 147 | if self.first_downsample is not None: 148 | x = self.first_downsample(x) 149 | _, c, h, w = x.shape 150 | flops += 9 * c * h * w 151 | 152 | residual = x 153 | if self.downsample is not None: 154 | c_in = x.shape[1] 155 | residual = self.downsample(x) 156 | _, c, h, w = residual.shape 157 | if self.have_pool: 158 | flops += 9 * c_in * h * w 159 | if self.have_1x1conv2d: 160 | flops += c_in * c * h * w 161 | 162 | if activate_tsm: 163 | out = self.tsm(x, list_little) 164 | else: 165 | out = x 166 | 167 | c_in = out.shape[1] 168 | out = self.conv1(out) 169 | _,c_out,h,w = out.shape 170 | flops += c_in * c_out * h * w / self.conv1.groups 171 | 172 | out = self.bn1(out) 173 | out = self.relu(out) 174 | 175 | c_in = c_out 176 | out = self.conv2(out) 177 | _,c_out,h,w = out.shape 178 | flops += c_in * c_out * h * w * 9 / self.conv2.groups 179 | out = self.bn2(out) 180 | out = self.relu(out) 181 | 182 | c_in = c_out 183 | out = self.conv3(out) 184 | _,c_out,h,w = out.shape 185 | flops += c_in * c_out * h * w / self.conv3.groups 186 | out = self.bn3(out) 187 | 188 | out += residual 189 | if self.last_relu: 190 | out = self.relu(out) 191 | 192 | return out, flops 193 | 194 | class Bottleneck_focal(nn.Module): 195 | expansion = 4 196 | 197 | def __init__(self, inplanes, planes, num_segments, stride=1, downsample=None, last_relu=True, patch_groups=1, base_scale=2, is_first = True): 198 | super(Bottleneck_focal, self).__init__() 199 | self.num_segments = num_segments 200 | self.conv1 = nn.Conv2d(inplanes, planes // self.expansion, kernel_size=1, bias=False, groups=patch_groups) 201 | self.bn1 = nn.BatchNorm2d(planes // self.expansion) 202 | self.conv2 = nn.Conv2d(planes // self.expansion, planes // self.expansion, kernel_size=3, stride=stride, 203 | padding=1, bias=False, groups=patch_groups) 204 | self.bn2 = nn.BatchNorm2d(planes // self.expansion) 205 | self.conv3 = nn.Conv2d(planes // self.expansion, planes, kernel_size=1, bias=False, groups=patch_groups) 206 | self.bn3 = nn.BatchNorm2d(planes) 207 | self.relu = nn.ReLU(inplace=True) 208 | self.tsm = TSM() 209 | self.downsample = downsample 210 | 211 | self.stride = stride 212 | self.last_relu = last_relu 213 | self.patch_groups = patch_groups 214 | 215 | def forward(self, x, mask, activate_tsm=False): 216 | residual = x 217 | if self.downsample is not None: # skip connection before mask 218 | residual = self.downsample(x) 219 | 220 | if activate_tsm: 221 | out = self.tsm(x, self.num_segments) 222 | else: 223 | out = x 224 | out = out * mask 225 | 226 | out = self.conv1(out) 227 | out = self.bn1(out) 228 | out = self.relu(out) 229 | 230 | out = self.conv2(out) 231 | out = self.bn2(out) 232 | out = self.relu(out) 233 | 234 | out = self.conv3(out) 235 | out = self.bn3(out) 236 | 237 | out += residual 238 | if self.last_relu: 239 | out = self.relu(out) 240 | return out 241 | 242 | 243 | def forward_calc_flops(self, x, mask, activate_tsm=False): 244 | residual = x 245 | flops = 0 246 | if self.downsample is not None: # skip connection before mask 247 | c_in = x.shape[1] 248 | residual = self.downsample(x) 249 | flops += c_in * residual.shape[1] * residual.shape[2] * residual.shape[3] 250 | 251 | if activate_tsm: 252 | out = self.tsm(x, self.num_segments) 253 | else: 254 | out = x 255 | out = out * mask 256 | select_ratio = torch.mean(mask) 257 | 258 | c_in = out.shape[1] 259 | out = self.conv1(out) 260 | _,c_out,h,w = out.shape 261 | flops += select_ratio * c_in * c_out * h * w / self.conv1.groups 262 | 263 | out = self.bn1(out) 264 | out = self.relu(out) 265 | 266 | c_in = c_out 267 | out = self.conv2(out) 268 | _,c_out,h,w = out.shape 269 | flops += select_ratio * c_in * c_out * h * w * 9 / self.conv2.groups 270 | out = self.bn2(out) 271 | out = self.relu(out) 272 | 273 | c_in = c_out 274 | out = self.conv3(out) 275 | _,c_out,h,w = out.shape 276 | flops += select_ratio * c_in * c_out * h * w / self.conv3.groups 277 | out = self.bn3(out) 278 | 279 | out += residual 280 | if self.last_relu: 281 | out = self.relu(out) 282 | 283 | return out, flops 284 | 285 | 286 | 287 | class navigation(nn.Module): 288 | def __init__(self, inplanes=64, num_segments=8): 289 | super(navigation,self).__init__() 290 | self.num_segments = num_segments 291 | self.conv_pool = nn.Conv2d(inplanes, 2, kernel_size=1, padding=0, stride=1, bias=False) 292 | self.bn = nn.BatchNorm2d(2) 293 | self.relu = nn.ReLU(inplace=True) 294 | self.pool = nn.AdaptiveAvgPool2d((1,1)) 295 | self.conv_gs = nn.Conv2d(2*num_segments, 2*num_segments, kernel_size=1, padding=0, stride=1, bias=True, groups=num_segments) 296 | self.conv_gs.bias.data[:2*num_segments:2] = 1.0 297 | self.conv_gs.bias.data[1:2*num_segments+1:2] = 10.0 298 | self.gs = GumbleSoftmax() 299 | 300 | def forward(self, x, temperature=1.0): 301 | gates = self.pool(x) 302 | gates = self.conv_pool(gates) 303 | gates = self.bn(gates) 304 | gates = self.relu(gates) 305 | 306 | batch = x.shape[0] // self.num_segments 307 | 308 | gates = gates.view(batch, self.num_segments*2,1,1) 309 | gates = self.conv_gs(gates) 310 | 311 | gates = gates.view(batch, self.num_segments, 2, 1, 1) 312 | gates = self.gs(gates, temp=temperature, force_hard=True) 313 | mask = gates[:, :, 1, :, :] 314 | mask = mask.view(x.shape[0],1,1,1) 315 | 316 | return mask 317 | 318 | def forward_calc_flops(self, x, temperature=1.0): 319 | flops = 0 320 | 321 | flops += x.shape[1] * x.shape[2] * x.shape[3] 322 | gates = self.pool(x) 323 | 324 | c_in = gates.shape[1] 325 | gates = self.conv_pool(gates) 326 | flops += c_in * gates.shape[1] * gates.shape[2] * gates.shape[3] 327 | gates = self.bn(gates) 328 | gates = self.relu(gates) 329 | 330 | batch = x.shape[0] // self.num_segments 331 | 332 | gates = gates.view(batch, self.num_segments*2,1,1) 333 | gates = self.conv_gs(gates) 334 | flops += self.num_segments * 2 * gates.shape[1] * gates.shape[2] * gates.shape[3] / self.conv_gs.groups 335 | 336 | gates = gates.view(batch, self.num_segments, 2, 1, 1) 337 | gates = self.gs(gates, temp=temperature, force_hard=True) 338 | mask = gates[:, :, 1, :, :] 339 | mask = mask.view(x.shape[0],1,1,1) 340 | 341 | return mask, flops 342 | 343 | 344 | class AFModule(nn.Module): 345 | def __init__(self, block_ample, block_focal, in_channels, out_channels, blocks, stride, patch_groups, alpha=1, num_segments=8): 346 | super(AFModule, self).__init__() 347 | self.num_segments = num_segments 348 | self.patch_groups = patch_groups 349 | self.relu = nn.ReLU(inplace=True) 350 | 351 | frame_gen_list = [] 352 | for i in range(blocks - 1): 353 | frame_gen_list.append(navigation(inplanes=int(out_channels // alpha),num_segments=num_segments)) if i!=0 else frame_gen_list.append(navigation(inplanes=in_channels,num_segments=num_segments)) 354 | self.list_gen = nn.ModuleList(frame_gen_list) 355 | 356 | self.base_module = self._make_layer(block_ample, in_channels, int(out_channels // alpha), num_segments, blocks - 1, 2, last_relu=False) 357 | self.refine_module = self._make_layer(block_focal, in_channels, out_channels, num_segments, blocks - 1, 1, last_relu=False) 358 | 359 | self.alpha = alpha 360 | if alpha != 1: 361 | self.base_transform = nn.Sequential( 362 | nn.Conv2d(int(out_channels // alpha), out_channels, kernel_size=1, bias=False), 363 | nn.BatchNorm2d(out_channels) 364 | ) 365 | self.att_gen = dynamic_fusion(channel=out_channels, reduction=16) 366 | self.fusion = self._make_layer(block_ample, out_channels, out_channels, num_segments, 1, stride=stride) 367 | 368 | def _make_layer(self, block, inplanes, planes, num_segments, blocks, stride=1, last_relu=True, base_scale=2): 369 | downsample = [] 370 | if stride != 1: 371 | downsample.append(nn.AvgPool2d(3, stride=2, padding=1)) 372 | if inplanes != planes: 373 | downsample.append(nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False)) 374 | downsample.append(nn.BatchNorm2d(planes)) 375 | downsample = None if downsample == [] else nn.Sequential(*downsample) 376 | layers = [] 377 | if blocks == 1: # fuse, is not the first of a base branch 378 | layers.append(block(inplanes, planes, num_segments, stride=stride, downsample=downsample, 379 | patch_groups=self.patch_groups, base_scale=base_scale, is_first = False)) 380 | else: 381 | layers.append(block(inplanes, planes, num_segments, stride, downsample,patch_groups=self.patch_groups, 382 | base_scale=base_scale, is_first = True)) 383 | for i in range(1, blocks): 384 | layers.append(block(planes, planes, num_segments, 385 | last_relu=last_relu if i == blocks - 1 else True, 386 | patch_groups=self.patch_groups, base_scale=base_scale, is_first = False)) 387 | 388 | return nn.ModuleList(layers) 389 | 390 | def forward(self, x, temperature=1e-8, activate_tsm=False): 391 | b,c,h,w = x.size() 392 | x_big = x 393 | x_little = x 394 | _masks = [] 395 | 396 | for i in range(len(self.base_module)): 397 | mask = self.list_gen[i](x_little, temperature=temperature) 398 | _masks.append(mask) 399 | 400 | x_little = self.base_module[i](x_little, self.num_segments, activate_tsm) 401 | x_big = self.refine_module[i](x_big, mask, activate_tsm) 402 | 403 | if self.alpha != 1: 404 | x_little = self.base_transform(x_little) 405 | 406 | _,_,h,w = x_big.shape 407 | x_little = F.interpolate(x_little, size = (h,w)) 408 | att = self.att_gen(x_little+x_big) 409 | out = self.relu(att*x_little + (1-att)*x_big) 410 | out = self.fusion[0](out, self.num_segments, activate_tsm) 411 | return out, _masks 412 | 413 | def forward_calc_flops(self, x, temperature=1e-8, activate_tsm=False): 414 | flops = 0 415 | b,c,h,w = x.size() 416 | 417 | x_big = x 418 | x_little = x 419 | _masks = [] 420 | 421 | for i in range(len(self.base_module)): 422 | mask, _flops = self.list_gen[i].forward_calc_flops(x_little, temperature=temperature) 423 | _masks.append(mask) 424 | flops += _flops * b 425 | 426 | x_little, _flops = self.base_module[i].forward_calc_flops(x_little, self.num_segments, activate_tsm) 427 | flops += _flops * b 428 | x_big, _flops = self.refine_module[i].forward_calc_flops(x_big, mask, activate_tsm) 429 | flops += _flops * b 430 | 431 | c = x_little.shape[1] 432 | _,_, h,w = x_big.shape 433 | if self.alpha != 1: 434 | x_little = self.base_transform(x_little) 435 | flops += b * c * x_little.shape[1] * x_little.shape[2] * x_little.shape[3] 436 | 437 | x_little = F.interpolate(x_little, size = (h,w)) 438 | att, _flops = self.att_gen.forward_calc_flops(x_little+x_big) 439 | flops += _flops * b 440 | out = self.relu(att*x_little + (1-att)*x_big) 441 | out, _flops = self.fusion[0].forward_calc_flops(out, self.num_segments, activate_tsm) 442 | flops += _flops * b 443 | 444 | seg = b / self.num_segments 445 | flops = flops / seg 446 | 447 | return out, _masks, flops 448 | 449 | class AFResNet(nn.Module): 450 | def __init__(self, block_ample, block_focal, layers, width=1.0, patch_groups=1, alpha=1, shift=True, num_segments=8, num_classes=1000): 451 | num_channels = [int(64*width), int(128*width), int(256*width), 512] 452 | 453 | self.num_segments = num_segments 454 | self.activate_tsm = shift 455 | self.inplanes = 64 456 | super(AFResNet, self).__init__() 457 | self.conv1 = nn.Conv2d(3, num_channels[0], kernel_size=7, stride=2, padding=3, 458 | bias=False) 459 | self.bn1 = nn.BatchNorm2d(num_channels[0]) 460 | self.relu = nn.ReLU(inplace=True) 461 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 462 | 463 | self.layer1 = AFModule(block_ample, block_focal, num_channels[0], num_channels[0]*block_ample.expansion, 464 | layers[0], stride=2, patch_groups=patch_groups, alpha=alpha, num_segments=num_segments) 465 | self.layer2 = AFModule(block_ample, block_focal, num_channels[0]*block_ample.expansion, 466 | num_channels[1]*block_ample.expansion, layers[1], stride=2, patch_groups=patch_groups, alpha=alpha, num_segments=num_segments) 467 | self.layer3 = AFModule(block_ample, block_focal, num_channels[1]*block_ample.expansion, 468 | num_channels[2]*block_ample.expansion, layers[2], stride=1, patch_groups=patch_groups, alpha=alpha, num_segments=num_segments) 469 | self.layer4 = self._make_layer(num_segments, 470 | block_ample, num_channels[2]*block_ample.expansion, num_channels[3]*block_ample.expansion, layers[3], stride=2) 471 | self.gappool = nn.AdaptiveAvgPool2d(1) 472 | self.fc = nn.Linear(num_channels[3]*block_ample.expansion, num_classes) 473 | 474 | for k, m in self.named_modules(): 475 | if isinstance(m, nn.Conv2d): 476 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 477 | # if 'gs' in str(k): 478 | # m.weight.data.normal_(0, 0.001) 479 | elif isinstance(m, nn.BatchNorm2d): 480 | nn.init.constant_(m.weight, 1) 481 | nn.init.constant_(m.bias, 0) 482 | 483 | # Zero-initialize the last BN in each block. 484 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 485 | for m in self.modules(): 486 | if isinstance(m, Bottleneck_ample): 487 | nn.init.constant_(m.bn3.weight, 0) 488 | 489 | def _make_layer(self, num_segments, block, inplanes, planes, blocks, stride=1): 490 | downsample = [] 491 | if stride != 1: 492 | downsample.append(nn.AvgPool2d(3, stride=2, padding=1)) 493 | if inplanes != planes: 494 | downsample.append(nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False)) 495 | downsample.append(nn.BatchNorm2d(planes)) 496 | downsample = None if downsample == [] else nn.Sequential(*downsample) 497 | 498 | layers = [] 499 | layers.append(block(inplanes, planes, num_segments, stride, downsample)) 500 | for _ in range(1, blocks): 501 | layers.append(block(planes, planes, num_segments)) 502 | 503 | return nn.ModuleList(layers) 504 | 505 | def forward(self, x, temperature=1.0): 506 | x = self.conv1(x) 507 | x = self.bn1(x) 508 | x = self.relu(x) 509 | x = self.maxpool(x) 510 | 511 | _masks = [] 512 | x1, mask = self.layer1(x, temperature=temperature, activate_tsm=self.activate_tsm) 513 | _masks.extend(mask) 514 | x2, mask = self.layer2(x1, temperature=temperature, activate_tsm=self.activate_tsm) 515 | _masks.extend(mask) 516 | x3, mask = self.layer3(x2, temperature=temperature, activate_tsm=self.activate_tsm) 517 | _masks.extend(mask) 518 | x4 = x3 519 | for i in range(len(self.layer4)): 520 | x4 = self.layer4[i](x4, self.num_segments, self.activate_tsm) 521 | 522 | x = self.gappool(x4) 523 | x = x.view(x.size(0), -1) 524 | x = self.fc(x) 525 | 526 | return x, _masks 527 | 528 | def forward_calc_flops(self, x, temperature=1.0): 529 | flops = 0 530 | c_in = x.shape[1] 531 | x = self.conv1(x) 532 | flops += self.num_segments * c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv1.weight.shape[2]*self.conv1.weight.shape[3] 533 | x = self.bn1(x) 534 | x = self.relu(x) 535 | 536 | x = self.maxpool(x) 537 | flops += self.num_segments * x.numel() / x.shape[0] * 9 538 | 539 | _masks = [] 540 | x1, mask, _flops = self.layer1.forward_calc_flops(x, temperature=temperature, activate_tsm=self.activate_tsm) 541 | _masks.extend(mask) 542 | flops += _flops 543 | x2, mask, _flops = self.layer2.forward_calc_flops(x1, temperature=temperature, activate_tsm=self.activate_tsm) 544 | _masks.extend(mask) 545 | flops += _flops 546 | x3, mask, _flops = self.layer3.forward_calc_flops(x2, temperature=temperature, activate_tsm=self.activate_tsm) 547 | _masks.extend(mask) 548 | flops += _flops 549 | x4 = x3 550 | for i in range(len(self.layer4)): 551 | x4, _flops = self.layer4[i].forward_calc_flops(x4, self.num_segments, self.activate_tsm) 552 | flops += _flops * self.num_segments 553 | flops += self.num_segments * x4.shape[1] * x4.shape[2] * x4.shape[3] 554 | x = self.gappool(x4) 555 | x = x.view(x.size(0), -1) 556 | c_in = x.shape[1] 557 | x = self.fc(x) 558 | flops += self.num_segments * c_in * x.shape[1] 559 | 560 | return x, _masks, flops 561 | 562 | 563 | def AF_resnet(depth, patch_groups=1, width=1.0, alpha=1, shift=False, num_segments=8, **kwargs): 564 | layers = { 565 | 50: [3, 4, 6, 3], 566 | 101: [4, 8, 18, 3], 567 | }[depth] 568 | block = Bottleneck_ample 569 | block_focal = Bottleneck_focal 570 | model = AFResNet(block_ample=block, block_focal=block_focal, layers=layers, patch_groups=patch_groups, 571 | width=width, alpha=alpha, shift=shift, num_segments=num_segments, **kwargs) 572 | return model 573 | 574 | 575 | def AF_resnet50(pretrained=False, path_backbone = '.../.../checkpoint/ImageNet/AF-ResNet50.pth.tar', shift=False, num_segments=8, **kwargs): 576 | model = AF_resnet(depth=50, patch_groups=2, alpha=2, shift=shift, num_segments=num_segments, **kwargs) 577 | if pretrained: 578 | checkpoint = torch.load(path_backbone, map_location='cpu') 579 | pretrained_dict = checkpoint['state_dict'] 580 | new_state_dict = model.state_dict() 581 | for k, v in pretrained_dict.items(): 582 | if (k[7:] in new_state_dict): 583 | new_state_dict.update({k[7:]:v}) 584 | model.load_state_dict(new_state_dict) 585 | return model 586 | 587 | 588 | def AF_resnet101(pretrained=False, path_backbone = '.../.../checkpoint/ImageNet/AF-ResNet101.pth.tar', shift=False, num_segments=8, **kwargs): 589 | model = AF_resnet(depth=101, patch_groups=2, alpha=2, shift=shift, num_segments=num_segments, **kwargs) 590 | if pretrained: 591 | checkpoint = torch.load(path_backbone, map_location='cpu') 592 | pretrained_dict = checkpoint['state_dict'] 593 | new_state_dict = model.state_dict() 594 | for k, v in pretrained_dict.items(): 595 | if (k[7:] in new_state_dict): 596 | new_state_dict.update({k[7:]:v}) 597 | model.load_state_dict(new_state_dict) 598 | return model -------------------------------------------------------------------------------- /ops/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .AF_ResNet import * 2 | from .AF_MobileNetv3 import * -------------------------------------------------------------------------------- /ops/backbone/gumbel_softmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | 5 | """ 6 | Gumbel Softmax Sampler 7 | Requires 2D input [batchsize, number of categories] 8 | 9 | Does not support sinlge binary category. Use two dimensions with softmax instead. 10 | """ 11 | 12 | class GumbleSoftmax(torch.nn.Module): 13 | def __init__(self, hard=False): 14 | super(GumbleSoftmax, self).__init__() 15 | self.hard = hard 16 | # self.gpu = False 17 | 18 | # def cuda(self): 19 | # self.gpu = True 20 | 21 | # def cpu(self): 22 | # self.gpu = False 23 | 24 | def sample_gumbel(self, shape, eps=1e-10): 25 | """Sample from Gumbel(0, 1)""" 26 | noise = torch.rand(shape) 27 | noise.add_(eps).log_().neg_() 28 | noise.add_(eps).log_().neg_() 29 | # if self.gpu: 30 | return Variable(noise).cuda() 31 | # else: 32 | # return Variable(noise) 33 | 34 | def sample_gumbel_like(self, template_tensor, eps=1e-10): 35 | uniform_samples_tensor = template_tensor.clone().uniform_() 36 | gumble_samples_tensor = - torch.log(eps - torch.log(uniform_samples_tensor + eps)) 37 | return gumble_samples_tensor 38 | 39 | def gumbel_softmax_sample(self, logits, temperature): 40 | """ Draw a sample from the Gumbel-Softmax distribution""" 41 | dim = logits.size(-1) 42 | gumble_samples_tensor = self.sample_gumbel_like(logits.data) 43 | gumble_trick_log_prob_samples = logits + Variable(gumble_samples_tensor) 44 | soft_samples = F.softmax(gumble_trick_log_prob_samples / temperature, dim=2) 45 | return soft_samples 46 | 47 | def gumbel_softmax(self, logits, temperature, hard=False): 48 | """Sample from the Gumbel-Softmax distribution and optionally discretize. 49 | Args: 50 | logits: [batch_size, n_class] unnormalized log-probs 51 | temperature: non-negative scalar 52 | hard: if True, take argmax, but differentiate w.r.t. soft sample y 53 | Returns: 54 | [batch_size, n_class] sample from the Gumbel-Softmax distribution. 55 | If hard=True, then the returned sample will be one-hot, otherwise it will 56 | be a probabilitiy distribution that sums to 1 across classes 57 | """ 58 | y = self.gumbel_softmax_sample(logits, temperature) 59 | if hard: 60 | _, max_value_indexes = y.data.max(2, keepdim=True) 61 | y_hard = logits.data.clone().zero_().scatter_(2, max_value_indexes, 1) 62 | y = Variable(y_hard - y.data) + y 63 | return y 64 | 65 | def forward(self, logits, temp=1, force_hard=False): 66 | samplesize = logits.size() 67 | if self.training and not force_hard: 68 | return self.gumbel_softmax(logits, temperature=temp, hard=False) 69 | else: 70 | return self.gumbel_softmax(logits, temperature=temp, hard=True) 71 | -------------------------------------------------------------------------------- /ops/backbone/temporal_shift.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class TemporalShift(nn.Module): 12 | def __init__(self, net, n_segment=3, n_div=8, inplace=False): 13 | super(TemporalShift, self).__init__() 14 | self.net = net 15 | self.n_segment = n_segment 16 | self.fold_div = n_div 17 | self.inplace = inplace 18 | if inplace: 19 | print('=> Using in-place shift...') 20 | print('=> Using fold div: {}'.format(self.fold_div)) 21 | 22 | def forward(self, x): 23 | x = self.shift(x, self.n_segment, fold_div=self.fold_div, inplace=self.inplace) 24 | return self.net(x) 25 | 26 | @staticmethod 27 | def shift(x, n_segment, fold_div=3, inplace=False): 28 | nt, c, h, w = x.size() 29 | n_batch = nt // n_segment 30 | x = x.view(n_batch, n_segment, c, h, w) 31 | 32 | fold = c // fold_div 33 | if inplace: 34 | # Due to some out of order error when performing parallel computing. 35 | # May need to write a CUDA kernel. 36 | raise NotImplementedError 37 | # out = InplaceShift.apply(x, fold) 38 | else: 39 | out = torch.zeros_like(x) 40 | out[:, :-1, :fold] = x[:, 1:, :fold] # shift left 41 | out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right 42 | out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift 43 | 44 | return out.view(nt, c, h, w) 45 | 46 | 47 | class InplaceShift(torch.autograd.Function): 48 | # Special thanks to @raoyongming for the help to this function 49 | @staticmethod 50 | def forward(ctx, input, fold): 51 | # not support higher order gradient 52 | # input = input.detach_() 53 | ctx.fold_ = fold 54 | n, t, c, h, w = input.size() 55 | buffer = input.data.new(n, t, fold, h, w).zero_() 56 | buffer[:, :-1] = input.data[:, 1:, :fold] 57 | input.data[:, :, :fold] = buffer 58 | buffer.zero_() 59 | buffer[:, 1:] = input.data[:, :-1, fold: 2 * fold] 60 | input.data[:, :, fold: 2 * fold] = buffer 61 | return input 62 | 63 | @staticmethod 64 | def backward(ctx, grad_output): 65 | # grad_output = grad_output.detach_() 66 | fold = ctx.fold_ 67 | n, t, c, h, w = grad_output.size() 68 | buffer = grad_output.data.new(n, t, fold, h, w).zero_() 69 | buffer[:, 1:] = grad_output.data[:, :-1, :fold] 70 | grad_output.data[:, :, :fold] = buffer 71 | buffer.zero_() 72 | buffer[:, :-1] = grad_output.data[:, 1:, fold: 2 * fold] 73 | grad_output.data[:, :, fold: 2 * fold] = buffer 74 | return grad_output, None 75 | 76 | 77 | class TemporalPool(nn.Module): 78 | def __init__(self, net, n_segment): 79 | super(TemporalPool, self).__init__() 80 | self.net = net 81 | self.n_segment = n_segment 82 | 83 | def forward(self, x): 84 | x = self.temporal_pool(x, n_segment=self.n_segment) 85 | return self.net(x) 86 | 87 | @staticmethod 88 | def temporal_pool(x, n_segment): 89 | nt, c, h, w = x.size() 90 | n_batch = nt // n_segment 91 | x = x.view(n_batch, n_segment, c, h, w).transpose(1, 2) # n, c, t, h, w 92 | x = F.max_pool3d(x, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0)) 93 | x = x.transpose(1, 2).contiguous().view(nt // 2, c, h, w) 94 | return x 95 | 96 | 97 | def make_temporal_shift(net, n_segment, n_div=8, place='blockres', temporal_pool=False): 98 | if temporal_pool: 99 | n_segment_list = [n_segment, n_segment // 2, n_segment // 2, n_segment // 2] 100 | else: 101 | n_segment_list = [n_segment] * 4 102 | assert n_segment_list[-1] > 0 103 | print('=> n_segment per stage: {}'.format(n_segment_list)) 104 | 105 | import torchvision 106 | if isinstance(net, torchvision.models.ResNet): 107 | if place == 'block': 108 | def make_block_temporal(stage, this_segment): 109 | blocks = list(stage.children()) 110 | print('=> Processing stage with {} blocks'.format(len(blocks))) 111 | for i, b in enumerate(blocks): 112 | blocks[i] = TemporalShift(b, n_segment=this_segment, n_div=n_div) 113 | return nn.Sequential(*(blocks)) 114 | 115 | net.layer1 = make_block_temporal(net.layer1, n_segment_list[0]) 116 | net.layer2 = make_block_temporal(net.layer2, n_segment_list[1]) 117 | net.layer3 = make_block_temporal(net.layer3, n_segment_list[2]) 118 | net.layer4 = make_block_temporal(net.layer4, n_segment_list[3]) 119 | 120 | elif 'blockres' in place: 121 | n_round = 1 122 | if len(list(net.layer3.children())) >= 23: 123 | n_round = 2 124 | print('=> Using n_round {} to insert temporal shift'.format(n_round)) 125 | 126 | def make_block_temporal(stage, this_segment): 127 | blocks = list(stage.children()) 128 | print('=> Processing stage with {} blocks residual'.format(len(blocks))) 129 | for i, b in enumerate(blocks): 130 | if i % n_round == 0: 131 | blocks[i].conv1 = TemporalShift(b.conv1, n_segment=this_segment, n_div=n_div) 132 | return nn.Sequential(*blocks) 133 | 134 | net.layer1 = make_block_temporal(net.layer1, n_segment_list[0]) 135 | net.layer2 = make_block_temporal(net.layer2, n_segment_list[1]) 136 | net.layer3 = make_block_temporal(net.layer3, n_segment_list[2]) 137 | net.layer4 = make_block_temporal(net.layer4, n_segment_list[3]) 138 | else: 139 | raise NotImplementedError(place) 140 | 141 | 142 | def make_temporal_pool(net, n_segment): 143 | import torchvision 144 | if isinstance(net, torchvision.models.ResNet): 145 | print('=> Injecting nonlocal pooling') 146 | net.layer2 = TemporalPool(net.layer2, n_segment) 147 | else: 148 | raise NotImplementedError 149 | 150 | 151 | if __name__ == '__main__': 152 | # test inplace shift v.s. vanilla shift 153 | tsm1 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=False) 154 | tsm2 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=True) 155 | 156 | print('=> Testing CPU...') 157 | # test forward 158 | with torch.no_grad(): 159 | for i in range(10): 160 | x = torch.rand(2 * 8, 3, 224, 224) 161 | y1 = tsm1(x) 162 | y2 = tsm2(x) 163 | assert torch.norm(y1 - y2).item() < 1e-5 164 | 165 | # test backward 166 | with torch.enable_grad(): 167 | for i in range(10): 168 | x1 = torch.rand(2 * 8, 3, 224, 224) 169 | x1.requires_grad_() 170 | x2 = x1.clone() 171 | y1 = tsm1(x1) 172 | y2 = tsm2(x2) 173 | grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0] 174 | grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0] 175 | assert torch.norm(grad1 - grad2).item() < 1e-5 176 | 177 | print('=> Testing GPU...') 178 | tsm1.cuda() 179 | tsm2.cuda() 180 | # test forward 181 | with torch.no_grad(): 182 | for i in range(10): 183 | x = torch.rand(2 * 8, 3, 224, 224).cuda() 184 | y1 = tsm1(x) 185 | y2 = tsm2(x) 186 | assert torch.norm(y1 - y2).item() < 1e-5 187 | 188 | # test backward 189 | with torch.enable_grad(): 190 | for i in range(10): 191 | x1 = torch.rand(2 * 8, 3, 224, 224).cuda() 192 | x1.requires_grad_() 193 | x2 = x1.clone() 194 | y1 = tsm1(x1) 195 | y2 = tsm2(x2) 196 | grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0] 197 | grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0] 198 | assert torch.norm(grad1 - grad2).item() < 1e-5 199 | print('Test passed.') 200 | 201 | 202 | 203 | 204 | -------------------------------------------------------------------------------- /ops/basic_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Identity(torch.nn.Module): 5 | def forward(self, input): 6 | return input 7 | 8 | 9 | class SegmentConsensus(torch.nn.Module): 10 | 11 | def __init__(self, consensus_type, dim=1): 12 | super(SegmentConsensus, self).__init__() 13 | self.consensus_type = consensus_type 14 | self.dim = dim 15 | self.shape = None 16 | 17 | def forward(self, input_tensor): 18 | self.shape = input_tensor.size() 19 | if self.consensus_type == 'avg': 20 | output = input_tensor.mean(dim=self.dim, keepdim=True) 21 | elif self.consensus_type == 'identity': 22 | output = input_tensor 23 | else: 24 | output = None 25 | 26 | return output 27 | 28 | 29 | class ConsensusModule(torch.nn.Module): 30 | 31 | def __init__(self, consensus_type, dim=1): 32 | super(ConsensusModule, self).__init__() 33 | self.consensus_type = consensus_type if consensus_type != 'rnn' else 'identity' 34 | self.dim = dim 35 | 36 | def forward(self, input): 37 | return SegmentConsensus(self.consensus_type, self.dim)(input) 38 | -------------------------------------------------------------------------------- /ops/dataset.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | 6 | import torch.utils.data as data 7 | 8 | from PIL import Image 9 | import os 10 | import numpy as np 11 | from numpy.random import randint 12 | 13 | 14 | class VideoRecord(object): 15 | def __init__(self, row): 16 | self._data = row 17 | 18 | @property 19 | def path(self): 20 | return self._data[0] 21 | 22 | @property 23 | def num_frames(self): 24 | return int(self._data[1]) 25 | 26 | @property 27 | def label(self): 28 | return int(self._data[2]) 29 | 30 | 31 | class TSNDataSet(data.Dataset): 32 | def __init__(self, root_path, list_file, 33 | num_segments=3, new_length=1, modality='RGB', 34 | image_tmpl='img_{:05d}.jpg', transform=None, 35 | random_shift=True, test_mode=False, 36 | remove_missing=False, dense_sample=False, twice_sample=False): 37 | 38 | self.root_path = root_path 39 | self.list_file = list_file 40 | self.num_segments = num_segments 41 | self.new_length = new_length 42 | self.modality = modality 43 | self.image_tmpl = image_tmpl 44 | self.transform = transform 45 | self.random_shift = random_shift 46 | self.test_mode = test_mode 47 | self.remove_missing = remove_missing 48 | self.dense_sample = dense_sample # using dense sample as I3D 49 | self.twice_sample = twice_sample # twice sample for more validation 50 | if self.dense_sample: 51 | print('=> Using dense sample for the dataset...') 52 | if self.twice_sample: 53 | print('=> Using twice sample for the dataset...') 54 | 55 | if self.modality == 'RGBDiff': 56 | self.new_length += 1 # Diff needs one more image to calculate diff 57 | 58 | self._parse_list() 59 | 60 | def _load_image(self, directory, idx): 61 | if self.modality == 'RGB' or self.modality == 'RGBDiff': 62 | try: 63 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert('RGB')] 64 | except Exception: 65 | print('error loading image:', os.path.join(self.root_path, directory, self.image_tmpl.format(idx))) 66 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')] 67 | elif self.modality == 'Flow': 68 | if self.image_tmpl == 'flow_{}_{:05d}.jpg': # ucf 69 | x_img = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format('x', idx))).convert( 70 | 'L') 71 | y_img = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format('y', idx))).convert( 72 | 'L') 73 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg': # something v1 flow 74 | x_img = Image.open(os.path.join(self.root_path, '{:06d}'.format(int(directory)), self.image_tmpl. 75 | format(int(directory), 'x', idx))).convert('L') 76 | y_img = Image.open(os.path.join(self.root_path, '{:06d}'.format(int(directory)), self.image_tmpl. 77 | format(int(directory), 'y', idx))).convert('L') 78 | else: 79 | try: 80 | # idx_skip = 1 + (idx-1)*5 81 | flow = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert( 82 | 'RGB') 83 | except Exception: 84 | print('error loading flow file:', 85 | os.path.join(self.root_path, directory, self.image_tmpl.format(idx))) 86 | flow = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB') 87 | # the input flow file is RGB image with (flow_x, flow_y, blank) for each channel 88 | flow_x, flow_y, _ = flow.split() 89 | x_img = flow_x.convert('L') 90 | y_img = flow_y.convert('L') 91 | 92 | return [x_img, y_img] 93 | 94 | def _parse_list(self): 95 | # check the frame number is large >3: 96 | tmp = [x.strip().split(' ') for x in open(self.list_file)] 97 | 98 | if any(len(items) >= 3 for items in tmp): 99 | tmp = [[' '.join(x[:-2]), x[-2], x[-1]] for x in tmp] 100 | 101 | if not self.test_mode or self.remove_missing: 102 | tmp = [item for item in tmp if int(item[1]) >= 3] 103 | self.video_list = [VideoRecord(item) for item in tmp] 104 | 105 | if self.image_tmpl == '{:06d}-{}_{:05d}.jpg': 106 | for v in self.video_list: 107 | v._data[1] = int(v._data[1]) / 2 108 | print('video number:%d' % (len(self.video_list))) 109 | 110 | def _sample_indices(self, record): 111 | """ 112 | 113 | :param record: VideoRecord 114 | :return: list 115 | """ 116 | if self.dense_sample: # i3d dense sample 117 | sample_pos = max(1, 1 + record.num_frames - 64) 118 | t_stride = 64 // self.num_segments 119 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) 120 | offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)] 121 | return np.array(offsets) + 1 122 | else: # normal sample 123 | average_duration = (record.num_frames - self.new_length + 1) // self.num_segments 124 | if average_duration > 0: 125 | offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration, 126 | size=self.num_segments) 127 | elif record.num_frames > self.num_segments: 128 | offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments)) 129 | else: 130 | offsets = np.zeros((self.num_segments,)) 131 | return offsets + 1 132 | 133 | def _get_val_indices(self, record): 134 | if self.dense_sample: # i3d dense sample 135 | sample_pos = max(1, 1 + record.num_frames - 64) 136 | t_stride = 64 // self.num_segments 137 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) 138 | offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)] 139 | return np.array(offsets) + 1 140 | else: 141 | if record.num_frames > self.num_segments + self.new_length - 1: 142 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) 143 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) 144 | else: 145 | offsets = np.zeros((self.num_segments,)) 146 | return offsets + 1 147 | 148 | def _get_test_indices(self, record): 149 | if self.dense_sample: 150 | sample_pos = max(1, 1 + record.num_frames - 64) 151 | t_stride = 64 // self.num_segments 152 | start_list = np.linspace(0, sample_pos - 1, num=10, dtype=int) 153 | offsets = [] 154 | for start_idx in start_list.tolist(): 155 | offsets += [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)] 156 | return np.array(offsets) + 1 157 | elif self.twice_sample: 158 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) 159 | 160 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)] + 161 | [int(tick * x) for x in range(self.num_segments)]) 162 | 163 | return offsets + 1 164 | else: 165 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) 166 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) 167 | return offsets + 1 168 | 169 | def __getitem__(self, index): 170 | record = self.video_list[index] 171 | # check this is a legit video folder 172 | 173 | if self.image_tmpl == 'flow_{}_{:05d}.jpg': 174 | file_name = self.image_tmpl.format('x', 1) 175 | full_path = os.path.join(self.root_path, record.path, file_name) 176 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg': 177 | file_name = self.image_tmpl.format(int(record.path), 'x', 1) 178 | full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name) 179 | else: 180 | file_name = self.image_tmpl.format(1) 181 | full_path = os.path.join(self.root_path, record.path, file_name) 182 | 183 | while not os.path.exists(full_path): 184 | print('################## Not Found:', os.path.join(self.root_path, record.path, file_name)) 185 | index = np.random.randint(len(self.video_list)) 186 | record = self.video_list[index] 187 | if self.image_tmpl == 'flow_{}_{:05d}.jpg': 188 | file_name = self.image_tmpl.format('x', 1) 189 | full_path = os.path.join(self.root_path, record.path, file_name) 190 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg': 191 | file_name = self.image_tmpl.format(int(record.path), 'x', 1) 192 | full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name) 193 | else: 194 | file_name = self.image_tmpl.format(1) 195 | full_path = os.path.join(self.root_path, record.path, file_name) 196 | 197 | if not self.test_mode: 198 | segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record) 199 | else: 200 | segment_indices = self._get_test_indices(record) 201 | return self.get(record, segment_indices) 202 | 203 | def get(self, record, indices): 204 | 205 | images = list() 206 | for seg_ind in indices: 207 | p = int(seg_ind) 208 | for i in range(self.new_length): 209 | seg_imgs = self._load_image(record.path, p) 210 | images.extend(seg_imgs) 211 | if p < record.num_frames: 212 | p += 1 213 | 214 | process_data = self.transform(images) 215 | return process_data, record.label 216 | 217 | def __len__(self): 218 | return len(self.video_list) 219 | -------------------------------------------------------------------------------- /ops/dataset_config.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | 6 | import os 7 | 8 | def return_actnet(ROOT_DATASET, modality): 9 | filename_categories = 'actnet/classInd.txt' 10 | root_data = ROOT_DATASET + 'actnet/frames' 11 | filename_imglist_train = 'actnet/actnet_train_split_newest.txt' 12 | filename_imglist_val = 'actnet/actnet_val_split_newest.txt' 13 | prefix = 'image_{:05d}.jpg' 14 | 15 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 16 | 17 | 18 | def return_ucf101(ROOT_DATASET, modality): 19 | filename_categories = 'UCF101/labels/classInd.txt' 20 | if modality == 'RGB': 21 | root_data = ROOT_DATASET + 'UCF101/jpg' 22 | filename_imglist_train = 'UCF101/file_list/ucf101_rgb_train_split_1.txt' 23 | filename_imglist_val = 'UCF101/file_list/ucf101_rgb_val_split_1.txt' 24 | prefix = 'img_{:05d}.jpg' 25 | elif modality == 'Flow': 26 | root_data = ROOT_DATASET + 'UCF101/jpg' 27 | filename_imglist_train = 'UCF101/file_list/ucf101_flow_train_split_1.txt' 28 | filename_imglist_val = 'UCF101/file_list/ucf101_flow_val_split_1.txt' 29 | prefix = 'flow_{}_{:05d}.jpg' 30 | else: 31 | raise NotImplementedError('no such modality:' + modality) 32 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 33 | 34 | 35 | def return_hmdb51(ROOT_DATASET, modality): 36 | filename_categories = 51 37 | if modality == 'RGB': 38 | root_data = ROOT_DATASET + 'HMDB51/images' 39 | filename_imglist_train = 'HMDB51/splits/hmdb51_rgb_train_split_1.txt' 40 | filename_imglist_val = 'HMDB51/splits/hmdb51_rgb_val_split_1.txt' 41 | prefix = 'img_{:05d}.jpg' 42 | elif modality == 'Flow': 43 | root_data = ROOT_DATASET + 'HMDB51/images' 44 | filename_imglist_train = 'HMDB51/splits/hmdb51_flow_train_split_1.txt' 45 | filename_imglist_val = 'HMDB51/splits/hmdb51_flow_val_split_1.txt' 46 | prefix = 'flow_{}_{:05d}.jpg' 47 | else: 48 | raise NotImplementedError('no such modality:' + modality) 49 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 50 | 51 | 52 | def return_mini_something(ROOT_DATASET, modality): 53 | filename_categories = 'mini_something_v1/category.txt' 54 | if modality == 'RGB': 55 | root_data = ROOT_DATASET + 'mini_something_v1' 56 | filename_imglist_train = 'mini_something_v1/train_videofolder.txt' 57 | filename_imglist_val = 'mini_something_v1/val_videofolder.txt' 58 | prefix = '{:05d}.jpg' 59 | elif modality == 'Flow': 60 | root_data = ROOT_DATASET + 'mini_something_v1/20bn-something-something-v1-flow' 61 | filename_imglist_train = 'mini_something_v1/train_videofolder_flow.txt' 62 | filename_imglist_val = 'mini_something_v1/val_videofolder_flow.txt' 63 | prefix = '{:06d}-{}_{:05d}.jpg' 64 | else: 65 | print('no such modality:'+modality) 66 | raise NotImplementedError 67 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 68 | 69 | 70 | def return_something(ROOT_DATASET, modality): 71 | filename_categories = 'something_v1/category.txt' 72 | if modality == 'RGB': 73 | root_data = ROOT_DATASET + 'something_v1' 74 | filename_imglist_train = 'something_v1/train_videofolder.txt' 75 | filename_imglist_val = 'something_v1/val_videofolder.txt' 76 | prefix = '{:05d}.jpg' 77 | elif modality == 'Flow': 78 | root_data = ROOT_DATASET + 'something_v1/20bn-something-something-v1-flow' 79 | filename_imglist_train = 'something_v1/train_videofolder_flow.txt' 80 | filename_imglist_val = 'something_v1/val_videofolder_flow.txt' 81 | prefix = '{:06d}-{}_{:05d}.jpg' 82 | else: 83 | print('no such modality:'+modality) 84 | raise NotImplementedError 85 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 86 | 87 | 88 | def return_somethingv2(ROOT_DATASET, modality): 89 | filename_categories = 'something_v2/category.txt' 90 | if modality == 'RGB': 91 | root_data = ROOT_DATASET + 'something_v2' 92 | filename_imglist_train = 'something_v2/train_videofolder.txt' 93 | filename_imglist_val = 'something_v2/val_videofolder.txt' 94 | prefix = '{:06d}.jpg' 95 | elif modality == 'Flow': 96 | root_data = ROOT_DATASET + 'something_v2/20bn-something-something-v2-flow' 97 | filename_imglist_train = 'something_v2/train_videofolder_flow.txt' 98 | filename_imglist_val = 'something_v2/val_videofolder_flow.txt' 99 | prefix = '{:06d}-{}_{:05d}.jpg' 100 | else: 101 | raise NotImplementedError('no such modality:'+modality) 102 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 103 | 104 | 105 | def return_jester(ROOT_DATASET, modality): 106 | filename_categories = 'jester/category.txt' 107 | if modality == 'RGB': 108 | prefix = '{:05d}.jpg' 109 | root_data = ROOT_DATASET + 'jester' 110 | filename_imglist_train = 'jester/train_videofolder.txt' 111 | filename_imglist_val = 'jester/val_videofolder.txt' 112 | else: 113 | raise NotImplementedError('no such modality:'+modality) 114 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 115 | 116 | 117 | def return_mini_kinetics(ROOT_DATASET, modality): 118 | filename_categories = 200 119 | if modality == 'RGB': 120 | root_data = ROOT_DATASET + 'mini-kinetics' 121 | filename_imglist_train = 'mini-kinetics/mini_train_videofolder.txt' 122 | filename_imglist_val = 'mini-kinetics/mini_val_videofolder.txt' 123 | prefix = 'image_{:05d}.jpg' 124 | else: 125 | raise NotImplementedError('no such modality:' + modality) 126 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 127 | 128 | 129 | def return_kinetics(ROOT_DATASET, modality): 130 | filename_categories = 400 131 | if modality == 'RGB': 132 | root_data = ROOT_DATASET + 'kinetics/images' 133 | filename_imglist_train = 'kinetics/labels/train_videofolder.txt' 134 | filename_imglist_val = 'kinetics/labels/val_videofolder.txt' 135 | prefix = 'image_{:05d}.jpg' 136 | else: 137 | raise NotImplementedError('no such modality:' + modality) 138 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 139 | 140 | 141 | def return_dataset(root_dataset, dataset, modality): 142 | ROOT_DATASET = root_dataset 143 | 144 | dict_single = {'jester': return_jester, 'mini_something': return_mini_something, 'something': return_something, 'somethingv2': return_somethingv2, 145 | 'ucf101': return_ucf101, 'hmdb51': return_hmdb51, 'actnet': return_actnet, 'mini_kinetics': return_mini_kinetics, 146 | 'kinetics': return_kinetics } 147 | if dataset in dict_single: 148 | file_categories, file_imglist_train, file_imglist_val, root_data, prefix = dict_single[dataset](ROOT_DATASET, modality) 149 | else: 150 | raise ValueError('Unknown dataset '+dataset) 151 | 152 | file_imglist_train = os.path.join(ROOT_DATASET, file_imglist_train) 153 | file_imglist_val = os.path.join(ROOT_DATASET, file_imglist_val) 154 | if isinstance(file_categories, str): 155 | file_categories = os.path.join(ROOT_DATASET, file_categories) 156 | with open(file_categories) as f: 157 | lines = f.readlines() 158 | categories = [item.rstrip() for item in lines] 159 | else: # number of categories 160 | categories = [None] * file_categories 161 | n_class = len(categories) 162 | print('{}: {} classes'.format(dataset, n_class)) 163 | return n_class, file_imglist_train, file_imglist_val, root_data, prefix 164 | -------------------------------------------------------------------------------- /ops/models.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | 6 | from torch import nn 7 | import ops.backbone 8 | from ops.basic_ops import ConsensusModule 9 | from ops.transforms import * 10 | from torch.nn.init import normal_, constant_ 11 | 12 | 13 | class TSN(nn.Module): 14 | def __init__(self, arch_file, num_class, num_segments, modality, path_backbone, 15 | base_model='resnet101', new_length=None, 16 | consensus_type='avg', before_softmax=True, 17 | dropout=0.8, img_feature_dim=256, 18 | crop_num=1, partial_bn=True, print_spec=True, pretrain='imagenet', 19 | is_shift=False, fc_lr5=False, 20 | temporal_pool=False, non_local=False): 21 | super(TSN, self).__init__() 22 | self.modality = modality 23 | self.num_segments = num_segments 24 | self.reshape = True 25 | self.before_softmax = before_softmax 26 | self.dropout = dropout 27 | self.crop_num = crop_num 28 | self.consensus_type = consensus_type 29 | self.img_feature_dim = img_feature_dim # the dimension of the CNN feature to represent each frame 30 | self.pretrain = pretrain 31 | 32 | self.is_shift = is_shift 33 | self.base_model_name = base_model 34 | self.fc_lr5 = fc_lr5 35 | self.temporal_pool = temporal_pool 36 | self.non_local = non_local 37 | 38 | if not before_softmax and consensus_type != 'avg': 39 | raise ValueError("Only avg consensus can be used after Softmax") 40 | 41 | if new_length is None: 42 | self.new_length = 1 if modality == "RGB" else 5 43 | else: 44 | self.new_length = new_length 45 | if print_spec: 46 | print((""" 47 | Initializing TSN with base model: {}. 48 | TSN Configurations: 49 | input_modality: {} 50 | num_segments: {} 51 | new_length: {} 52 | consensus_module: {} 53 | dropout_ratio: {} 54 | img_feature_dim: {} 55 | """.format(base_model, self.modality, self.num_segments, self.new_length, consensus_type, self.dropout, self.img_feature_dim))) 56 | 57 | self._prepare_base_model(arch_file, base_model, path_backbone) 58 | 59 | feature_dim = self._prepare_tsn(num_class) 60 | 61 | if self.modality == 'Flow': 62 | print("Converting the ImageNet model to a flow init model") 63 | self.base_model = self._construct_flow_model(self.base_model) 64 | print("Done. Flow model ready...") 65 | elif self.modality == 'RGBDiff': 66 | print("Converting the ImageNet model to RGB+Diff init model") 67 | self.base_model = self._construct_diff_model(self.base_model) 68 | print("Done. RGBDiff model ready.") 69 | 70 | self.consensus = ConsensusModule(consensus_type) 71 | 72 | if not self.before_softmax: 73 | self.softmax = nn.Softmax() 74 | 75 | self._enable_pbn = partial_bn 76 | if partial_bn: 77 | self.partialBN(True) 78 | 79 | def _prepare_tsn(self, num_class): 80 | feature_dim = getattr(self.base_model, self.base_model.last_layer_name).in_features 81 | if self.dropout == 0: 82 | setattr(self.base_model, self.base_model.last_layer_name, nn.Linear(feature_dim, num_class)) 83 | self.new_fc = None 84 | else: 85 | setattr(self.base_model, self.base_model.last_layer_name, nn.Dropout(p=self.dropout)) 86 | self.new_fc = nn.Linear(feature_dim, num_class) 87 | 88 | std = 0.001 89 | if self.new_fc is None: 90 | normal_(getattr(self.base_model, self.base_model.last_layer_name).weight, 0, std) 91 | constant_(getattr(self.base_model, self.base_model.last_layer_name).bias, 0) 92 | else: 93 | if hasattr(self.new_fc, 'weight'): 94 | normal_(self.new_fc.weight, 0, std) 95 | constant_(self.new_fc.bias, 0) 96 | return feature_dim 97 | 98 | def _prepare_base_model(self, arch_file, base_model, path_backbone): 99 | print('=> base model: {}'.format(base_model)) 100 | 101 | if 'resnet' in base_model: 102 | self.base_model = eval(f'ops.backbone.{arch_file}.{base_model}')(True if self.pretrain == 'imagenet' else False, path_backbone=path_backbone, shift=self.is_shift, num_segments=self.num_segments) 103 | self.base_model.last_layer_name = 'fc' 104 | self.input_size = 224 105 | self.input_mean = [0.485, 0.456, 0.406] 106 | self.input_std = [0.229, 0.224, 0.225] 107 | 108 | if self.modality == 'Flow': 109 | self.input_mean = [0.5] 110 | self.input_std = [np.mean(self.input_std)] 111 | elif self.modality == 'RGBDiff': 112 | self.input_mean = [0.485, 0.456, 0.406] + [0] * 3 * self.new_length 113 | self.input_std = self.input_std + [np.mean(self.input_std) * 2] * 3 * self.new_length 114 | else: 115 | raise ValueError('Unknown base model: {}'.format(base_model)) 116 | 117 | def train(self, mode=True): 118 | """ 119 | Override the default train() to freeze the BN parameters 120 | :return: 121 | """ 122 | super(TSN, self).train(mode) 123 | count = 0 124 | if self._enable_pbn and mode: 125 | print("Freezing BatchNorm2D except the first one.") 126 | for m in self.base_model.modules(): 127 | if isinstance(m, nn.BatchNorm2d): 128 | count += 1 129 | if count >= (2 if self._enable_pbn else 1): 130 | m.eval() 131 | # shutdown update in frozen mode 132 | m.weight.requires_grad = False 133 | m.bias.requires_grad = False 134 | 135 | def partialBN(self, enable): 136 | self._enable_pbn = enable 137 | 138 | def get_optim_policies(self): 139 | navi_conv_weight = [] 140 | navi_conv_bias = [] 141 | first_conv_weight = [] 142 | first_conv_bias = [] 143 | normal_weight = [] 144 | normal_bias = [] 145 | lr5_weight = [] 146 | lr10_bias = [] 147 | bn = [] 148 | navi_bn = [] 149 | custom_ops = [] 150 | 151 | conv_cnt = 0 152 | bn_cnt = 0 153 | for m in self.modules(): 154 | if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv3d): 155 | ps = list(m.parameters()) 156 | conv_cnt += 1 157 | if conv_cnt == 1: 158 | first_conv_weight.append(ps[0]) 159 | if len(ps) == 2: 160 | first_conv_bias.append(ps[1]) 161 | elif ps[0].shape == torch.Size([self.num_segments*2,2,1,1]): 162 | navi_conv_weight.append(ps[0]) 163 | if len(ps) == 2: 164 | navi_conv_bias.append(ps[1]) 165 | elif ps[0].shape[0] == 2: 166 | navi_conv_weight.append(ps[0]) 167 | if len(ps) == 2: 168 | navi_conv_bias.append(ps[1]) 169 | else: 170 | normal_weight.append(ps[0]) 171 | if len(ps) == 2: 172 | normal_bias.append(ps[1]) 173 | elif isinstance(m, torch.nn.Linear): 174 | ps = list(m.parameters()) 175 | if self.fc_lr5: 176 | lr5_weight.append(ps[0]) 177 | else: 178 | normal_weight.append(ps[0]) 179 | if len(ps) == 2: 180 | if self.fc_lr5: 181 | lr10_bias.append(ps[1]) 182 | else: 183 | normal_bias.append(ps[1]) 184 | 185 | elif isinstance(m, torch.nn.BatchNorm2d): 186 | bn_cnt += 1 187 | # later BN's are frozen 188 | if list(m.parameters())[0].shape[0]==2: 189 | navi_bn.extend(list(m.parameters())) 190 | elif not self._enable_pbn or bn_cnt == 1: 191 | bn.extend(list(m.parameters())) 192 | elif isinstance(m, torch.nn.BatchNorm3d): 193 | bn_cnt += 1 194 | # later BN's are frozen 195 | if not self._enable_pbn or bn_cnt == 1: 196 | bn.extend(list(m.parameters())) 197 | elif len(m._modules) == 0: 198 | if len(list(m.parameters())) > 0: 199 | raise ValueError("New atomic module type: {}. Need to give it a learning policy".format(type(m))) 200 | 201 | return [ 202 | {'params': navi_bn, 'lr_mult': 10, 'decay_mult': 1, 203 | 'name': "navi_bn"}, 204 | {'params': navi_conv_weight, 'lr_mult': 10, 'decay_mult': 1, 205 | 'name': "navi_conv_weight"}, 206 | {'params': navi_conv_bias, 'lr_mult': 10, 'decay_mult': 1, 207 | 'name': "navi_conv_bias"}, 208 | {'params': first_conv_weight, 'lr_mult': 5 if self.modality == 'Flow' else 1, 'decay_mult': 1, 209 | 'name': "first_conv_weight"}, 210 | {'params': first_conv_bias, 'lr_mult': 10 if self.modality == 'Flow' else 2, 'decay_mult': 0, 211 | 'name': "first_conv_bias"}, 212 | {'params': normal_weight, 'lr_mult': 1, 'decay_mult': 1, 213 | 'name': "normal_weight"}, 214 | {'params': normal_bias, 'lr_mult': 2, 'decay_mult': 0, 215 | 'name': "normal_bias"}, 216 | {'params': bn, 'lr_mult': 1, 'decay_mult': 0, 217 | 'name': "BN scale/shift"}, 218 | {'params': custom_ops, 'lr_mult': 1, 'decay_mult': 1, 219 | 'name': "custom_ops"}, 220 | # for fc 221 | {'params': lr5_weight, 'lr_mult': 5, 'decay_mult': 1, 222 | 'name': "lr5_weight"}, 223 | {'params': lr10_bias, 'lr_mult': 10, 'decay_mult': 0, 224 | 'name': "lr10_bias"}, 225 | ] 226 | 227 | def forward(self, input, temperature, no_reshape=False): 228 | if not no_reshape: 229 | sample_len = (3 if self.modality == "RGB" else 2) * self.new_length 230 | 231 | if self.modality == 'RGBDiff': 232 | sample_len = 3 * self.new_length 233 | input = self._get_diff(input) 234 | 235 | base_out, temporal_mask_ls = self.base_model(input.view((-1, sample_len) + input.size()[-2:]), temperature) 236 | else: 237 | base_out, temporal_mask_ls = self.base_model(input, temperature) 238 | 239 | if self.dropout > 0: 240 | base_out = self.new_fc(base_out) 241 | 242 | if not self.before_softmax: 243 | base_out = self.softmax(base_out) 244 | 245 | if self.reshape: 246 | if self.is_shift and self.temporal_pool: 247 | base_out = base_out.view((-1, self.num_segments // 2) + base_out.size()[1:]) 248 | else: 249 | base_out = base_out.view((-1, self.num_segments) + base_out.size()[1:]) 250 | output = self.consensus(base_out) 251 | return output.squeeze(1), temporal_mask_ls 252 | 253 | def forward_calc_flops(self, input, temperature, no_reshape=False): 254 | if not no_reshape: 255 | sample_len = (3 if self.modality == "RGB" else 2) * self.new_length 256 | 257 | if self.modality == 'RGBDiff': 258 | sample_len = 3 * self.new_length 259 | input = self._get_diff(input) 260 | 261 | base_out, temporal_mask_ls, flops = self.base_model.forward_calc_flops(input.view((-1, sample_len) + input.size()[-2:]), temperature) 262 | else: 263 | base_out, temporal_mask_ls, flops = self.base_model.forward_calc_flops(input, temperature) 264 | 265 | if self.dropout > 0: 266 | base_out = self.new_fc(base_out) 267 | 268 | if not self.before_softmax: 269 | base_out = self.softmax(base_out) 270 | 271 | if self.reshape: 272 | if self.is_shift and self.temporal_pool: 273 | base_out = base_out.view((-1, self.num_segments // 2) + base_out.size()[1:]) 274 | else: 275 | base_out = base_out.view((-1, self.num_segments) + base_out.size()[1:]) 276 | output = self.consensus(base_out) 277 | return output.squeeze(1), temporal_mask_ls, flops 278 | 279 | def _get_diff(self, input, keep_rgb=False): 280 | input_c = 3 if self.modality in ["RGB", "RGBDiff"] else 2 281 | input_view = input.view((-1, self.num_segments, self.new_length + 1, input_c,) + input.size()[2:]) 282 | if keep_rgb: 283 | new_data = input_view.clone() 284 | else: 285 | new_data = input_view[:, :, 1:, :, :, :].clone() 286 | 287 | for x in reversed(list(range(1, self.new_length + 1))): 288 | if keep_rgb: 289 | new_data[:, :, x, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :] 290 | else: 291 | new_data[:, :, x - 1, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :] 292 | 293 | return new_data 294 | 295 | def _construct_flow_model(self, base_model): 296 | # modify the convolution layers 297 | # Torch models are usually defined in a hierarchical way. 298 | # nn.modules.children() return all sub modules in a DFS manner 299 | modules = list(self.base_model.modules()) 300 | first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0] 301 | conv_layer = modules[first_conv_idx] 302 | container = modules[first_conv_idx - 1] 303 | 304 | # modify parameters, assume the first blob contains the convolution kernels 305 | params = [x.clone() for x in conv_layer.parameters()] 306 | kernel_size = params[0].size() 307 | new_kernel_size = kernel_size[:1] + (2 * self.new_length, ) + kernel_size[2:] 308 | new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous() 309 | 310 | new_conv = nn.Conv2d(2 * self.new_length, conv_layer.out_channels, 311 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding, 312 | bias=True if len(params) == 2 else False) 313 | new_conv.weight.data = new_kernels 314 | if len(params) == 2: 315 | new_conv.bias.data = params[1].data # add bias if neccessary 316 | layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name 317 | 318 | # replace the first convlution layer 319 | setattr(container, layer_name, new_conv) 320 | 321 | if self.base_model_name == 'BNInception': 322 | import torch.utils.model_zoo as model_zoo 323 | sd = model_zoo.load_url('https://www.dropbox.com/s/35ftw2t4mxxgjae/BNInceptionFlow-ef652051.pth.tar?dl=1') 324 | base_model.load_state_dict(sd) 325 | print('=> Loading pretrained Flow weight done...') 326 | else: 327 | print('#' * 30, 'Warning! No Flow pretrained model is found') 328 | return base_model 329 | 330 | def _construct_diff_model(self, base_model, keep_rgb=False): 331 | # modify the convolution layers 332 | # Torch models are usually defined in a hierarchical way. 333 | # nn.modules.children() return all sub modules in a DFS manner 334 | modules = list(self.base_model.modules()) 335 | first_conv_idx = filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules))))[0] 336 | conv_layer = modules[first_conv_idx] 337 | container = modules[first_conv_idx - 1] 338 | 339 | # modify parameters, assume the first blob contains the convolution kernels 340 | params = [x.clone() for x in conv_layer.parameters()] 341 | kernel_size = params[0].size() 342 | if not keep_rgb: 343 | new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:] 344 | new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous() 345 | else: 346 | new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:] 347 | new_kernels = torch.cat((params[0].data, params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()), 348 | 1) 349 | new_kernel_size = kernel_size[:1] + (3 + 3 * self.new_length,) + kernel_size[2:] 350 | 351 | new_conv = nn.Conv2d(new_kernel_size[1], conv_layer.out_channels, 352 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding, 353 | bias=True if len(params) == 2 else False) 354 | new_conv.weight.data = new_kernels 355 | if len(params) == 2: 356 | new_conv.bias.data = params[1].data # add bias if neccessary 357 | layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name 358 | 359 | # replace the first convolution layer 360 | setattr(container, layer_name, new_conv) 361 | return base_model 362 | 363 | @property 364 | def crop_size(self): 365 | return self.input_size 366 | 367 | @property 368 | def scale_size(self): 369 | return self.input_size * 256 // 224 370 | 371 | def get_augmentation(self, flip=True): 372 | if self.modality == 'RGB': 373 | if flip: 374 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66]), 375 | GroupRandomHorizontalFlip(is_flow=False)]) 376 | else: 377 | print('#' * 20, 'NO FLIP!!!') 378 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66])]) 379 | elif self.modality == 'Flow': 380 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]), 381 | GroupRandomHorizontalFlip(is_flow=True)]) 382 | elif self.modality == 'RGBDiff': 383 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]), 384 | GroupRandomHorizontalFlip(is_flow=False)]) 385 | -------------------------------------------------------------------------------- /ops/models_mobilenet.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | 6 | from torch import nn 7 | import ops.backbone 8 | from ops.basic_ops import ConsensusModule 9 | from ops.transforms import * 10 | from torch.nn.init import normal_, constant_ 11 | 12 | 13 | class TSN(nn.Module): 14 | def __init__(self, arch_file, num_class, num_segments, modality, path_backbone, 15 | base_model='mobilenet', new_length=None, 16 | consensus_type='avg', before_softmax=True, 17 | dropout=0.8, img_feature_dim=256, 18 | crop_num=1, partial_bn=True, print_spec=True, pretrain='imagenet', 19 | is_shift=False, fc_lr5=False, 20 | temporal_pool=False, non_local=False): 21 | super(TSN, self).__init__() 22 | self.modality = modality 23 | self.num_segments = num_segments 24 | self.reshape = True 25 | self.before_softmax = before_softmax 26 | self.dropout = dropout 27 | self.crop_num = crop_num 28 | self.consensus_type = consensus_type 29 | self.img_feature_dim = img_feature_dim # the dimension of the CNN feature to represent each frame 30 | self.pretrain = pretrain 31 | 32 | self.is_shift = is_shift 33 | self.base_model_name = base_model 34 | self.fc_lr5 = fc_lr5 35 | self.temporal_pool = temporal_pool 36 | self.non_local = non_local 37 | 38 | if not before_softmax and consensus_type != 'avg': 39 | raise ValueError("Only avg consensus can be used after Softmax") 40 | 41 | if new_length is None: 42 | self.new_length = 1 if modality == "RGB" else 5 43 | else: 44 | self.new_length = new_length 45 | if print_spec: 46 | print((""" 47 | Initializing TSN with base model: {}. 48 | TSN Configurations: 49 | input_modality: {} 50 | num_segments: {} 51 | new_length: {} 52 | consensus_module: {} 53 | dropout_ratio: {} 54 | img_feature_dim: {} 55 | """.format(base_model, self.modality, self.num_segments, self.new_length, consensus_type, self.dropout, self.img_feature_dim))) 56 | 57 | self._prepare_base_model(arch_file, base_model, num_class, path_backbone) 58 | 59 | # feature_dim = self._prepare_tsn(num_class) 60 | 61 | if self.modality == 'Flow': 62 | print("Converting the ImageNet model to a flow init model") 63 | self.base_model = self._construct_flow_model(self.base_model) 64 | print("Done. Flow model ready...") 65 | elif self.modality == 'RGBDiff': 66 | print("Converting the ImageNet model to RGB+Diff init model") 67 | self.base_model = self._construct_diff_model(self.base_model) 68 | print("Done. RGBDiff model ready.") 69 | 70 | self.consensus = ConsensusModule(consensus_type) 71 | 72 | if not self.before_softmax: 73 | self.softmax = nn.Softmax() 74 | 75 | self._enable_pbn = partial_bn 76 | if partial_bn: 77 | self.partialBN(True) 78 | 79 | 80 | def _prepare_base_model(self, arch_file, base_model, num_class, path_backbone): 81 | print('=> base model: {}'.format(base_model)) 82 | 83 | if 'mobilenet' in base_model: 84 | self.base_model = eval(f'ops.backbone.{arch_file}.{base_model}')(True if self.pretrain == 'imagenet' else False, path_backbone=path_backbone, shift=self.is_shift, num_segments=self.num_segments, num_class=num_class) 85 | self.base_model.last_layer_name = 'fc' 86 | self.input_size = 224 87 | self.input_mean = [0.485, 0.456, 0.406] 88 | self.input_std = [0.229, 0.224, 0.225] 89 | 90 | if self.modality == 'Flow': 91 | self.input_mean = [0.5] 92 | self.input_std = [np.mean(self.input_std)] 93 | elif self.modality == 'RGBDiff': 94 | self.input_mean = [0.485, 0.456, 0.406] + [0] * 3 * self.new_length 95 | self.input_std = self.input_std + [np.mean(self.input_std) * 2] * 3 * self.new_length 96 | else: 97 | raise ValueError('Unknown base model: {}'.format(base_model)) 98 | 99 | def train(self, mode=True): 100 | """ 101 | Override the default train() to freeze the BN parameters 102 | :return: 103 | """ 104 | super(TSN, self).train(mode) 105 | count = 0 106 | if self._enable_pbn and mode: 107 | print("Freezing BatchNorm2D except the first one.") 108 | for m in self.base_model.modules(): 109 | if isinstance(m, nn.BatchNorm2d): 110 | count += 1 111 | if count >= (2 if self._enable_pbn else 1): 112 | m.eval() 113 | # shutdown update in frozen mode 114 | m.weight.requires_grad = False 115 | m.bias.requires_grad = False 116 | 117 | def partialBN(self, enable): 118 | self._enable_pbn = enable 119 | 120 | def get_optim_policies(self): 121 | navi_conv_weight = [] 122 | navi_conv_bias = [] 123 | first_conv_weight = [] 124 | first_conv_bias = [] 125 | normal_weight = [] 126 | normal_bias = [] 127 | lr5_weight = [] 128 | lr10_bias = [] 129 | bn = [] 130 | navi_bn = [] 131 | custom_ops = [] 132 | 133 | conv_cnt = 0 134 | bn_cnt = 0 135 | for m in self.modules(): 136 | if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv3d): 137 | ps = list(m.parameters()) 138 | conv_cnt += 1 139 | if conv_cnt == 1: 140 | first_conv_weight.append(ps[0]) 141 | if len(ps) == 2: 142 | first_conv_bias.append(ps[1]) 143 | elif ps[0].shape == torch.Size([self.num_segments*2,2,1,1]): 144 | navi_conv_weight.append(ps[0]) 145 | if len(ps) == 2: 146 | navi_conv_bias.append(ps[1]) 147 | elif ps[0].shape[0] == 2: 148 | navi_conv_weight.append(ps[0]) 149 | if len(ps) == 2: 150 | navi_conv_bias.append(ps[1]) 151 | else: 152 | normal_weight.append(ps[0]) 153 | if len(ps) == 2: 154 | normal_bias.append(ps[1]) 155 | elif isinstance(m, torch.nn.Linear): 156 | ps = list(m.parameters()) 157 | 158 | if self.fc_lr5: 159 | if ps[0].shape[0] == 1280: 160 | lr5_weight.append(ps[0]) 161 | elif ps[0].shape[1] == 1280: 162 | lr5_weight.append(ps[0]) 163 | else: 164 | normal_weight.append(ps[0]) 165 | else: 166 | normal_weight.append(ps[0]) 167 | if len(ps) == 2: 168 | if self.fc_lr5: 169 | if ps[0].shape[0] == 1280: 170 | lr10_bias.append(ps[1]) 171 | elif ps[0].shape[1] == 1280: 172 | lr10_bias.append(ps[1]) 173 | else: 174 | normal_weight.append(ps[1]) 175 | else: 176 | normal_bias.append(ps[1]) 177 | 178 | elif isinstance(m, torch.nn.BatchNorm2d): 179 | bn_cnt += 1 180 | # later BN's are frozen 181 | if list(m.parameters())[0].shape[0]==2: 182 | navi_bn.extend(list(m.parameters())) 183 | elif not self._enable_pbn or bn_cnt == 1: 184 | bn.extend(list(m.parameters())) 185 | elif isinstance(m, torch.nn.BatchNorm3d): 186 | bn_cnt += 1 187 | # later BN's are frozen 188 | if not self._enable_pbn or bn_cnt == 1: 189 | bn.extend(list(m.parameters())) 190 | elif len(m._modules) == 0: 191 | if len(list(m.parameters())) > 0: 192 | raise ValueError("New atomic module type: {}. Need to give it a learning policy".format(type(m))) 193 | 194 | return [ 195 | {'params': navi_bn, 'lr_mult': 10, 'decay_mult': 1, 196 | 'name': "navi_bn"}, 197 | {'params': navi_conv_weight, 'lr_mult': 10, 'decay_mult': 1, 198 | 'name': "navi_conv_weight"}, 199 | {'params': navi_conv_bias, 'lr_mult': 10, 'decay_mult': 1, 200 | 'name': "navi_conv_bias"}, 201 | {'params': first_conv_weight, 'lr_mult': 5 if self.modality == 'Flow' else 1, 'decay_mult': 1, 202 | 'name': "first_conv_weight"}, 203 | {'params': first_conv_bias, 'lr_mult': 10 if self.modality == 'Flow' else 2, 'decay_mult': 0, 204 | 'name': "first_conv_bias"}, 205 | {'params': normal_weight, 'lr_mult': 1, 'decay_mult': 1, 206 | 'name': "normal_weight"}, 207 | {'params': normal_bias, 'lr_mult': 2, 'decay_mult': 0, 208 | 'name': "normal_bias"}, 209 | {'params': bn, 'lr_mult': 1, 'decay_mult': 0, 210 | 'name': "BN scale/shift"}, 211 | {'params': custom_ops, 'lr_mult': 1, 'decay_mult': 1, 212 | 'name': "custom_ops"}, 213 | # for fc 214 | {'params': lr5_weight, 'lr_mult': 5, 'decay_mult': 1, 215 | 'name': "lr5_weight"}, 216 | {'params': lr10_bias, 'lr_mult': 10, 'decay_mult': 0, 217 | 'name': "lr10_bias"}, 218 | ] 219 | 220 | def forward(self, input, temperature, no_reshape=False): 221 | if not no_reshape: 222 | sample_len = (3 if self.modality == "RGB" else 2) * self.new_length 223 | 224 | if self.modality == 'RGBDiff': 225 | sample_len = 3 * self.new_length 226 | input = self._get_diff(input) 227 | 228 | base_out, temporal_mask_ls = self.base_model(input.view((-1, sample_len) + input.size()[-2:]), temperature) 229 | else: 230 | base_out, temporal_mask_ls = self.base_model(input, temperature) 231 | 232 | if not self.before_softmax: 233 | base_out = self.softmax(base_out) 234 | 235 | if self.reshape: 236 | if self.is_shift and self.temporal_pool: 237 | base_out = base_out.view((-1, self.num_segments // 2) + base_out.size()[1:]) 238 | else: 239 | base_out = base_out.view((-1, self.num_segments) + base_out.size()[1:]) 240 | output = self.consensus(base_out) 241 | return output.squeeze(1), temporal_mask_ls 242 | 243 | def forward_calc_flops(self, input, temperature, no_reshape=False): 244 | if not no_reshape: 245 | sample_len = (3 if self.modality == "RGB" else 2) * self.new_length 246 | 247 | if self.modality == 'RGBDiff': 248 | sample_len = 3 * self.new_length 249 | input = self._get_diff(input) 250 | 251 | base_out, temporal_mask_ls, flops = self.base_model.forward_calc_flops(input.view((-1, sample_len) + input.size()[-2:]), temperature) 252 | else: 253 | base_out, temporal_mask_ls, flops = self.base_model.forward_calc_flops(input, temperature) 254 | 255 | if not self.before_softmax: 256 | base_out = self.softmax(base_out) 257 | 258 | if self.reshape: 259 | if self.is_shift and self.temporal_pool: 260 | base_out = base_out.view((-1, self.num_segments // 2) + base_out.size()[1:]) 261 | else: 262 | base_out = base_out.view((-1, self.num_segments) + base_out.size()[1:]) 263 | output = self.consensus(base_out) 264 | return output.squeeze(1), temporal_mask_ls, flops 265 | 266 | def _get_diff(self, input, keep_rgb=False): 267 | input_c = 3 if self.modality in ["RGB", "RGBDiff"] else 2 268 | input_view = input.view((-1, self.num_segments, self.new_length + 1, input_c,) + input.size()[2:]) 269 | if keep_rgb: 270 | new_data = input_view.clone() 271 | else: 272 | new_data = input_view[:, :, 1:, :, :, :].clone() 273 | 274 | for x in reversed(list(range(1, self.new_length + 1))): 275 | if keep_rgb: 276 | new_data[:, :, x, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :] 277 | else: 278 | new_data[:, :, x - 1, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :] 279 | 280 | return new_data 281 | 282 | def _construct_flow_model(self, base_model): 283 | # modify the convolution layers 284 | # Torch models are usually defined in a hierarchical way. 285 | # nn.modules.children() return all sub modules in a DFS manner 286 | modules = list(self.base_model.modules()) 287 | first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0] 288 | conv_layer = modules[first_conv_idx] 289 | container = modules[first_conv_idx - 1] 290 | 291 | # modify parameters, assume the first blob contains the convolution kernels 292 | params = [x.clone() for x in conv_layer.parameters()] 293 | kernel_size = params[0].size() 294 | new_kernel_size = kernel_size[:1] + (2 * self.new_length, ) + kernel_size[2:] 295 | new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous() 296 | 297 | new_conv = nn.Conv2d(2 * self.new_length, conv_layer.out_channels, 298 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding, 299 | bias=True if len(params) == 2 else False) 300 | new_conv.weight.data = new_kernels 301 | if len(params) == 2: 302 | new_conv.bias.data = params[1].data # add bias if neccessary 303 | layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name 304 | 305 | # replace the first convlution layer 306 | setattr(container, layer_name, new_conv) 307 | 308 | if self.base_model_name == 'BNInception': 309 | import torch.utils.model_zoo as model_zoo 310 | sd = model_zoo.load_url('https://www.dropbox.com/s/35ftw2t4mxxgjae/BNInceptionFlow-ef652051.pth.tar?dl=1') 311 | base_model.load_state_dict(sd) 312 | print('=> Loading pretrained Flow weight done...') 313 | else: 314 | print('#' * 30, 'Warning! No Flow pretrained model is found') 315 | return base_model 316 | 317 | def _construct_diff_model(self, base_model, keep_rgb=False): 318 | # modify the convolution layers 319 | # Torch models are usually defined in a hierarchical way. 320 | # nn.modules.children() return all sub modules in a DFS manner 321 | modules = list(self.base_model.modules()) 322 | first_conv_idx = filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules))))[0] 323 | conv_layer = modules[first_conv_idx] 324 | container = modules[first_conv_idx - 1] 325 | 326 | # modify parameters, assume the first blob contains the convolution kernels 327 | params = [x.clone() for x in conv_layer.parameters()] 328 | kernel_size = params[0].size() 329 | if not keep_rgb: 330 | new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:] 331 | new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous() 332 | else: 333 | new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:] 334 | new_kernels = torch.cat((params[0].data, params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()), 335 | 1) 336 | new_kernel_size = kernel_size[:1] + (3 + 3 * self.new_length,) + kernel_size[2:] 337 | 338 | new_conv = nn.Conv2d(new_kernel_size[1], conv_layer.out_channels, 339 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding, 340 | bias=True if len(params) == 2 else False) 341 | new_conv.weight.data = new_kernels 342 | if len(params) == 2: 343 | new_conv.bias.data = params[1].data # add bias if neccessary 344 | layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name 345 | 346 | # replace the first convolution layer 347 | setattr(container, layer_name, new_conv) 348 | return base_model 349 | 350 | @property 351 | def crop_size(self): 352 | return self.input_size 353 | 354 | @property 355 | def scale_size(self): 356 | return self.input_size * 256 // 224 357 | 358 | def get_augmentation(self, flip=True): 359 | if self.modality == 'RGB': 360 | if flip: 361 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66]), 362 | GroupRandomHorizontalFlip(is_flow=False)]) 363 | else: 364 | print('#' * 20, 'NO FLIP!!!') 365 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66])]) 366 | elif self.modality == 'Flow': 367 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]), 368 | GroupRandomHorizontalFlip(is_flow=True)]) 369 | elif self.modality == 'RGBDiff': 370 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]), 371 | GroupRandomHorizontalFlip(is_flow=False)]) -------------------------------------------------------------------------------- /ops/transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import random 3 | from PIL import Image, ImageOps 4 | import numpy as np 5 | import numbers 6 | import math 7 | import torch 8 | 9 | 10 | class GroupRandomCrop(object): 11 | def __init__(self, size): 12 | if isinstance(size, numbers.Number): 13 | self.size = (int(size), int(size)) 14 | else: 15 | self.size = size 16 | 17 | def __call__(self, img_group): 18 | 19 | w, h = img_group[0].size 20 | th, tw = self.size 21 | 22 | out_images = list() 23 | 24 | x1 = random.randint(0, w - tw) 25 | y1 = random.randint(0, h - th) 26 | 27 | for img in img_group: 28 | assert(img.size[0] == w and img.size[1] == h) 29 | if w == tw and h == th: 30 | out_images.append(img) 31 | else: 32 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 33 | 34 | return out_images 35 | 36 | 37 | class GroupCenterCrop(object): 38 | def __init__(self, size): 39 | self.worker = torchvision.transforms.CenterCrop(size) 40 | 41 | def __call__(self, img_group): 42 | return [self.worker(img) for img in img_group] 43 | 44 | 45 | class GroupRandomHorizontalFlip(object): 46 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 47 | """ 48 | def __init__(self, is_flow=False): 49 | self.is_flow = is_flow 50 | 51 | def __call__(self, img_group, is_flow=False): 52 | v = random.random() 53 | if v < 0.5: 54 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] 55 | if self.is_flow: 56 | for i in range(0, len(ret), 2): 57 | ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping 58 | return ret 59 | else: 60 | return img_group 61 | 62 | 63 | class GroupNormalize(object): 64 | def __init__(self, mean, std): 65 | self.mean = mean 66 | self.std = std 67 | 68 | def __call__(self, tensor): 69 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean)) 70 | rep_std = self.std * (tensor.size()[0]//len(self.std)) 71 | 72 | # TODO: make efficient 73 | for t, m, s in zip(tensor, rep_mean, rep_std): 74 | t.sub_(m).div_(s) 75 | 76 | return tensor 77 | 78 | 79 | class GroupScale(object): 80 | """ Rescales the input PIL.Image to the given 'size'. 81 | 'size' will be the size of the smaller edge. 82 | For example, if height > width, then image will be 83 | rescaled to (size * height / width, size) 84 | size: size of the smaller edge 85 | interpolation: Default: PIL.Image.BILINEAR 86 | """ 87 | 88 | def __init__(self, size, interpolation=Image.BILINEAR): 89 | self.worker = torchvision.transforms.Resize(size, interpolation) 90 | 91 | def __call__(self, img_group): 92 | return [self.worker(img) for img in img_group] 93 | 94 | 95 | class GroupOverSample(object): 96 | def __init__(self, crop_size, scale_size=None, flip=True): 97 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) 98 | 99 | if scale_size is not None: 100 | self.scale_worker = GroupScale(scale_size) 101 | else: 102 | self.scale_worker = None 103 | self.flip = flip 104 | 105 | def __call__(self, img_group): 106 | 107 | if self.scale_worker is not None: 108 | img_group = self.scale_worker(img_group) 109 | 110 | image_w, image_h = img_group[0].size 111 | crop_w, crop_h = self.crop_size 112 | 113 | offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h) 114 | oversample_group = list() 115 | for o_w, o_h in offsets: 116 | normal_group = list() 117 | flip_group = list() 118 | for i, img in enumerate(img_group): 119 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 120 | normal_group.append(crop) 121 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 122 | 123 | if img.mode == 'L' and i % 2 == 0: 124 | flip_group.append(ImageOps.invert(flip_crop)) 125 | else: 126 | flip_group.append(flip_crop) 127 | 128 | oversample_group.extend(normal_group) 129 | if self.flip: 130 | oversample_group.extend(flip_group) 131 | return oversample_group 132 | 133 | 134 | class GroupFullResSample(object): 135 | def __init__(self, crop_size, scale_size=None, flip=True): 136 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) 137 | 138 | if scale_size is not None: 139 | self.scale_worker = GroupScale(scale_size) 140 | else: 141 | self.scale_worker = None 142 | self.flip = flip 143 | 144 | def __call__(self, img_group): 145 | 146 | if self.scale_worker is not None: 147 | img_group = self.scale_worker(img_group) 148 | 149 | image_w, image_h = img_group[0].size 150 | crop_w, crop_h = self.crop_size 151 | 152 | w_step = (image_w - crop_w) // 4 153 | h_step = (image_h - crop_h) // 4 154 | 155 | offsets = list() 156 | offsets.append((0 * w_step, 2 * h_step)) # left 157 | offsets.append((4 * w_step, 2 * h_step)) # right 158 | offsets.append((2 * w_step, 2 * h_step)) # center 159 | 160 | oversample_group = list() 161 | for o_w, o_h in offsets: 162 | normal_group = list() 163 | flip_group = list() 164 | for i, img in enumerate(img_group): 165 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 166 | normal_group.append(crop) 167 | if self.flip: 168 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 169 | 170 | if img.mode == 'L' and i % 2 == 0: 171 | flip_group.append(ImageOps.invert(flip_crop)) 172 | else: 173 | flip_group.append(flip_crop) 174 | 175 | oversample_group.extend(normal_group) 176 | oversample_group.extend(flip_group) 177 | return oversample_group 178 | 179 | 180 | class GroupMultiScaleCrop(object): 181 | 182 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 183 | self.scales = scales if scales is not None else [1, .875, .75, .66] 184 | self.max_distort = max_distort 185 | self.fix_crop = fix_crop 186 | self.more_fix_crop = more_fix_crop 187 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 188 | self.interpolation = Image.BILINEAR 189 | 190 | def __call__(self, img_group): 191 | 192 | im_size = img_group[0].size 193 | 194 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 195 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 196 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) 197 | for img in crop_img_group] 198 | return ret_img_group 199 | 200 | def _sample_crop_size(self, im_size): 201 | image_w, image_h = im_size[0], im_size[1] 202 | 203 | # find a crop size 204 | base_size = min(image_w, image_h) 205 | crop_sizes = [int(base_size * x) for x in self.scales] 206 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 207 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 208 | 209 | pairs = [] 210 | for i, h in enumerate(crop_h): 211 | for j, w in enumerate(crop_w): 212 | if abs(i - j) <= self.max_distort: 213 | pairs.append((w, h)) 214 | 215 | crop_pair = random.choice(pairs) 216 | if not self.fix_crop: 217 | w_offset = random.randint(0, image_w - crop_pair[0]) 218 | h_offset = random.randint(0, image_h - crop_pair[1]) 219 | else: 220 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 221 | 222 | return crop_pair[0], crop_pair[1], w_offset, h_offset 223 | 224 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 225 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 226 | return random.choice(offsets) 227 | 228 | @staticmethod 229 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 230 | w_step = (image_w - crop_w) // 4 231 | h_step = (image_h - crop_h) // 4 232 | 233 | ret = list() 234 | ret.append((0, 0)) # upper left 235 | ret.append((4 * w_step, 0)) # upper right 236 | ret.append((0, 4 * h_step)) # lower left 237 | ret.append((4 * w_step, 4 * h_step)) # lower right 238 | ret.append((2 * w_step, 2 * h_step)) # center 239 | 240 | if more_fix_crop: 241 | ret.append((0, 2 * h_step)) # center left 242 | ret.append((4 * w_step, 2 * h_step)) # center right 243 | ret.append((2 * w_step, 4 * h_step)) # lower center 244 | ret.append((2 * w_step, 0 * h_step)) # upper center 245 | 246 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 247 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 248 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 249 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 250 | 251 | return ret 252 | 253 | 254 | class GroupRandomSizedCrop(object): 255 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size 256 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 257 | This is popularly used to train the Inception networks 258 | size: size of the smaller edge 259 | interpolation: Default: PIL.Image.BILINEAR 260 | """ 261 | def __init__(self, size, interpolation=Image.BILINEAR): 262 | self.size = size 263 | self.interpolation = interpolation 264 | 265 | def __call__(self, img_group): 266 | for attempt in range(10): 267 | area = img_group[0].size[0] * img_group[0].size[1] 268 | target_area = random.uniform(0.08, 1.0) * area 269 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 270 | 271 | w = int(round(math.sqrt(target_area * aspect_ratio))) 272 | h = int(round(math.sqrt(target_area / aspect_ratio))) 273 | 274 | if random.random() < 0.5: 275 | w, h = h, w 276 | 277 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]: 278 | x1 = random.randint(0, img_group[0].size[0] - w) 279 | y1 = random.randint(0, img_group[0].size[1] - h) 280 | found = True 281 | break 282 | else: 283 | found = False 284 | x1 = 0 285 | y1 = 0 286 | 287 | if found: 288 | out_group = list() 289 | for img in img_group: 290 | img = img.crop((x1, y1, x1 + w, y1 + h)) 291 | assert(img.size == (w, h)) 292 | out_group.append(img.resize((self.size, self.size), self.interpolation)) 293 | return out_group 294 | else: 295 | # Fallback 296 | scale = GroupScale(self.size, interpolation=self.interpolation) 297 | crop = GroupRandomCrop(self.size) 298 | return crop(scale(img_group)) 299 | 300 | 301 | class Stack(object): 302 | 303 | def __init__(self, roll=False): 304 | self.roll = roll 305 | 306 | def __call__(self, img_group): 307 | if img_group[0].mode == 'L': 308 | return np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2) 309 | elif img_group[0].mode == 'RGB': 310 | if self.roll: 311 | return np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2) 312 | else: 313 | return np.concatenate(img_group, axis=2) 314 | 315 | 316 | class ToTorchFormatTensor(object): 317 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 318 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 319 | def __init__(self, div=True): 320 | self.div = div 321 | 322 | def __call__(self, pic): 323 | if isinstance(pic, np.ndarray): 324 | # handle numpy array 325 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 326 | else: 327 | # handle PIL Image 328 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 329 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 330 | # put it from HWC to CHW format 331 | # yikes, this transpose takes 80% of the loading time/CPU 332 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 333 | return img.float().div(255) if self.div else img.float() 334 | 335 | 336 | class IdentityTransform(object): 337 | 338 | def __call__(self, data): 339 | return data 340 | 341 | 342 | if __name__ == "__main__": 343 | trans = torchvision.transforms.Compose([ 344 | GroupScale(256), 345 | GroupRandomCrop(224), 346 | Stack(), 347 | ToTorchFormatTensor(), 348 | GroupNormalize( 349 | mean=[.485, .456, .406], 350 | std=[.229, .224, .225] 351 | )] 352 | ) 353 | 354 | im = Image.open('../tensorflow-model-zoo.torch/lena_299.png') 355 | 356 | color_group = [im] * 3 357 | rst = trans(color_group) 358 | 359 | gray_group = [im.convert('L')] * 9 360 | gray_rst = trans(gray_group) 361 | 362 | trans2 = torchvision.transforms.Compose([ 363 | GroupRandomSizedCrop(256), 364 | Stack(), 365 | ToTorchFormatTensor(), 366 | GroupNormalize( 367 | mean=[.485, .456, .406], 368 | std=[.229, .224, .225]) 369 | ]) 370 | print(trans2(color_group)) -------------------------------------------------------------------------------- /ops/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | def softmax(scores): 6 | es = np.exp(scores - scores.max(axis=-1)[..., None]) 7 | return es / es.sum(axis=-1)[..., None] 8 | 9 | 10 | class AverageMeter(object): 11 | """Computes and stores the average and current value""" 12 | 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | 29 | def accuracy(output, target, topk=(1,)): 30 | """Computes the precision@k for the specified values of k""" 31 | maxk = max(topk) 32 | batch_size = target.size(0) 33 | 34 | _, pred = output.topk(maxk, 1, True, True) 35 | pred = pred.t() 36 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 37 | 38 | res = [] 39 | for k in topk: 40 | correct_k = correct[:k].reshape(-1).float().sum(0) 41 | res.append(correct_k.mul_(100.0 / batch_size)) 42 | return res 43 | 44 | 45 | def cal_map(output, old_test_y): 46 | batch_size = output.size(0) 47 | num_classes = output.size(1) 48 | ap = torch.zeros(num_classes) 49 | test_y = old_test_y.clone() 50 | 51 | gt = get_multi_hot(test_y, num_classes, False) 52 | 53 | probs = F.softmax(output, dim=1) 54 | 55 | rg = torch.range(1, batch_size).float() 56 | for k in range(num_classes): 57 | scores = probs[:, k] 58 | targets = gt[:, k] 59 | _, sortind = torch.sort(scores, 0, True) 60 | truth = targets[sortind] 61 | tp = truth.float().cumsum(0) 62 | precision = tp.div(rg) 63 | ap[k] = precision[truth.byte()].sum() / max(float(truth.sum()), 1) 64 | return ap.mean()*100, ap*100 65 | 66 | 67 | 68 | def get_multi_hot(test_y, classes, assumes_starts_zero=True): 69 | bs = test_y.shape[0] 70 | label_cnt = 0 71 | 72 | if not assumes_starts_zero: 73 | for label_val in torch.unique(test_y): 74 | if label_val >= 0: 75 | test_y[test_y == label_val] = label_cnt 76 | label_cnt += 1 77 | 78 | gt = torch.zeros(bs, classes + 1) 79 | for i in range(test_y.shape[1]): 80 | gt[torch.LongTensor(range(bs)), test_y[:, i]] = 1 81 | 82 | return gt[:, :classes] -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | 6 | import argparse 7 | 8 | parser = argparse.ArgumentParser(description="PyTorch implementation of Temporal Segment Networks") 9 | parser.add_argument('dataset', type=str) 10 | parser.add_argument('modality', type=str, choices=['RGB', 'Flow']) 11 | parser.add_argument('--train_list', type=str, default="") 12 | parser.add_argument('--val_list', type=str, default="") 13 | parser.add_argument('--root_path', type=str, default="") 14 | parser.add_argument('--path_backbone', type=str) 15 | parser.add_argument('--root_dataset', type=str) 16 | # ========================= Model Configs ========================== 17 | parser.add_argument('--arch_file', type=str, default="resnet_TSM_mask") 18 | parser.add_argument('--arch', type=str, default="BNInception") 19 | parser.add_argument('--num_segments', type=int, default=3) 20 | parser.add_argument('--consensus_type', type=str, default='avg') 21 | parser.add_argument('--k', type=int, default=3) 22 | 23 | parser.add_argument('--dropout', '--do', default=0.5, type=float, 24 | metavar='DO', help='dropout ratio (default: 0.5)') 25 | parser.add_argument('--loss_type', type=str, default="nll", 26 | choices=['nll']) 27 | parser.add_argument('--img_feature_dim', default=256, type=int, help="the feature dimension for each frame") 28 | parser.add_argument('--suffix', type=str, default=None) 29 | parser.add_argument('--pretrain', type=str, default='imagenet') 30 | parser.add_argument('--tune_from', type=str, default=None, help='fine-tune from checkpoint') 31 | 32 | # ========================= Learning Configs ========================== 33 | parser.add_argument('--epochs', default=120, type=int, metavar='N', 34 | help='number of total epochs to run') 35 | parser.add_argument('-b', '--batch-size', default=128, type=int, 36 | metavar='N', help='mini-batch size (default: 256)') 37 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, 38 | metavar='LR', help='initial learning rate') 39 | parser.add_argument('--lr_type', default='step', type=str, 40 | metavar='LRtype', help='learning rate type') 41 | parser.add_argument('--lr_steps', default=[50, 100], type=float, nargs="+", 42 | metavar='LRSteps', help='epochs to decay learning rate by 10') 43 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 44 | help='momentum') 45 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 46 | metavar='W', help='weight decay (default: 5e-4)') 47 | parser.add_argument('--clip-gradient', '--gd', default=None, type=float, 48 | metavar='W', help='gradient norm clipping (default: disabled)') 49 | parser.add_argument('--no_partialbn', '--npb', default=False, action="store_true") 50 | 51 | # ========================= Monitor Configs ========================== 52 | parser.add_argument('--print-freq', '-p', default=20, type=int, 53 | metavar='N', help='print frequency (default: 10)') 54 | parser.add_argument('--eval-freq', '-ef', default=5, type=int, 55 | metavar='N', help='evaluation frequency (default: 5)') 56 | 57 | 58 | # ========================= Runtime Configs ========================== 59 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 60 | help='number of data loading workers (default: 8)') 61 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 62 | help='path to latest checkpoint (default: none)') 63 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 64 | help='evaluate model on validation set') 65 | parser.add_argument('--snapshot_pref', type=str, default="") 66 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 67 | help='manual epoch number (useful on restarts)') 68 | parser.add_argument('--gpus', nargs='+', type=int, default=None) 69 | parser.add_argument('--flow_prefix', default="", type=str) 70 | parser.add_argument('--root_log',type=str, default='log') 71 | parser.add_argument('--root_model', type=str, default='checkpoint') 72 | 73 | parser.add_argument('--shift', default=False, action="store_true", help='use shift for models') 74 | 75 | parser.add_argument('--temporal_pool', default=False, action="store_true", help='add temporal pooling') 76 | parser.add_argument('--non_local', default=False, action="store_true", help='add non local block') 77 | parser.add_argument('--dense_sample', default=False, action="store_true", help='use dense sample for video dataset') 78 | 79 | parser.add_argument('--world_size', default=1, type=int) 80 | parser.add_argument('--local_rank', default=0, type=int, help='node rank for distributed training') 81 | parser.add_argument('--distributed', default=False, action="store_true") 82 | parser.add_argument('--amp', default=False, action="store_true") 83 | 84 | parser.add_argument('--model_path', default='exp', type=str) 85 | parser.add_argument('--rt_begin', default=1, type=int) 86 | parser.add_argument('--rt_end', default=50, type=int) 87 | parser.add_argument('--rt', default=0, type=float) 88 | parser.add_argument('--t0', default=5.0, type=float) 89 | parser.add_argument('--t1', default=1e-2, type=float) 90 | parser.add_argument('--t_end', default=50, type=int) 91 | parser.add_argument('--temp', default=1, type=float) 92 | parser.add_argument('--lambda_rt', default=1, type=float) 93 | parser.add_argument('--round', default="", type=str) -------------------------------------------------------------------------------- /train_sth.sh: -------------------------------------------------------------------------------- 1 | ### train AF-ResNet(RT=0.5) 2 | CUDA_VISIBLE_DEVICES=0,1 python main.py something RGB \ 3 | --arch_file AF_ResNet \ 4 | --arch AF_resnet50 --num_segments 12 \ 5 | --root_dataset 'path_dataset' \ 6 | --path_backbone 'path_backbone' \ 7 | --batch-size 32 --lr 0.01 --lr_steps 25 45 --epochs 55 \ 8 | --gd 20 -j 12 --dropout 0.5 --consensus_type=avg --eval-freq=1 --npb \ 9 | --rt_begin 10 --rt_end 20 --t0 1 --t_end 50 --lambda_rt 0.5 \ 10 | --model_path 'models' \ 11 | --rt 0.5 --round 1; 12 | 13 | 14 | 15 | ### train AF-ResNet-TSM(RT=0.5) 16 | CUDA_VISIBLE_DEVICES=0,1 python main.py something RGB \ 17 | --arch_file AF_ResNet \ 18 | --arch AF_resnet50 --num_segments 12 \ 19 | --root_dataset 'path_dataset' \ 20 | --path_backbone 'path_backbone' \ 21 | --batch-size 32 --lr 0.01 --lr_steps 25 45 --epochs 55 \ 22 | --gd 20 -j 12 --dropout 0.5 --consensus_type=avg --eval-freq=1 --npb \ 23 | --rt_begin 10 --rt_end 20 --t0 1 --t_end 50 --lambda_rt 0.5 \ 24 | --model_path 'models' \ 25 | --shift \ 26 | --rt 0.5 --round 1; 27 | 28 | 29 | 30 | ### train AF-MobileNetv3-TSM(RT=0.5) 31 | CUDA_VISIBLE_DEVICES=0,1 python main.py something RGB \ 32 | --arch_file AF_MobileNetv3 \ 33 | --arch AF_mobilenetv3 --num_segments 12 \ 34 | --root_dataset 'path_dataset' \ 35 | --path_backbone 'path_backbone' \ 36 | --batch-size 32 --lr 0.01 --lr_steps 25 45 --epochs 55 \ 37 | --gd 20 -j 12 --dropout 0.5 --consensus_type=avg --eval-freq=1 --npb \ 38 | --rt_begin 10 --rt_end 20 --t0 1 --t_end 50 --lambda_rt 0.5 \ 39 | --model_path 'models_mobilenet' \ 40 | --shift \ 41 | --rt 0.5 --round 1; --------------------------------------------------------------------------------