├── .DS_Store
├── LICENSE
├── README.md
├── eval_sth.sh
├── fig
├── intro.jpeg
├── neu.png
└── smile.png
├── main.py
├── ops
├── __init__.py
├── backbone
│ ├── AF_MobileNetv3.py
│ ├── AF_ResNet.py
│ ├── __init__.py
│ ├── gumbel_softmax.py
│ └── temporal_shift.py
├── basic_ops.py
├── dataset.py
├── dataset_config.py
├── models.py
├── models_mobilenet.py
├── transforms.py
└── utils.py
├── opts.py
└── train_sth.sh
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BeSpontaneous/AFNet-pytorch/f20b92e6430f7978ef15e537d932381be575bdad/.DS_Store
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 BeSpontaneous
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Look More but Care Less in Video Recognition (NeurIPS 2022)
2 |
3 |
4 |

5 |

6 |
7 |
8 | [arXiv](https://arxiv.org/abs/2211.09992) | Primary contact: [Yitian Zhang](mailto:markcheung9248@gmail.com)
9 |
10 |
11 |

12 |
13 |
14 | Comparisons between existing methods and our proposed Ample and Focal Network (AFNet). Most existing works reduce the redundancy in data at the beginning of the deep networks which leads to the loss of information. We propose a two-branch design which processes frames with different computational resources within the network and preserves all input information as well.
15 |
16 |
17 | ## Requirements
18 | - python 3.7
19 | - pytorch 1.7.0
20 | - torchvision 0.9.0
21 |
22 |
23 | ## Datasets
24 | Please follow the instruction of [TSM](https://github.com/mit-han-lab/temporal-shift-module#data-preparation) to prepare the Something-Something V1/V2 dataset.
25 |
26 |
27 | ## Pretrained Models
28 | Here we provide the pretrained AF-MobileNetv3, AF-ResNet50, AF-ResNet101 on ImageNet and all the pretrained models on Something-Something V1 dataset.
29 |
30 | ### Results on ImageNet
31 | Checkpoints are available through the [link](https://drive.google.com/drive/folders/1UzSckmKnwmgwWObF2_YxpkAIZ2k2mcHL?usp=share_link).
32 | | Model | Top-1 Acc. | GFLOPs |
33 | | --------------- | ------------- | ------------- |
34 | | AF-MobileNetv3 | 72.09% | 0.2 |
35 | | AF-ResNet50 | 77.24% | 2.9 |
36 | | AF-ResNet101 | 78.36% | 5.0 |
37 |
38 | ### Results on Something-Something V1
39 | Checkpoints and logs are available through the [link](https://drive.google.com/drive/folders/1-xmE6T6OADmDkkzJr4iM1vCJbA4ofcSO?usp=share_link).
40 |
41 | **Less is More**:
42 | | Model | Frame | Top-1 Acc. | GFLOPs |
43 | | --------------- | --------------- | ------------- | ------------- |
44 | | TSN | 8 | 18.6% | 32.7 |
45 | | AFNet(RT=0.50) | 8 | 26.8% | 19.5 |
46 | | AFNet(RT=0.25) | 8 | 27.7% | 18.3 |
47 |
48 |
49 | **More is Less**:
50 | | Model | Backbone | Frame | Top-1 Acc. | GFLOPs |
51 | | --------------- | --------------- | ------------- |------------- | ------------- |
52 | | TSM | ResNet50 | 8 | 45.6% | 32.7 |
53 | | AFNet-TSM(RT=0.4) | AF-ResNet50 | 12 | 49.0% | 27.9 |
54 | | AFNet-TSM(RT=0.8) | AF-ResNet50 | 12 |49.9% | 31.7 |
55 | | AFNet-TSM(RT=0.4) | AF-MobileNetv3 | 12 | 45.3% | 2.2 |
56 | | AFNet-TSM(RT=0.8) | AF-MobileNetv3 | 12 | 45.9% | 2.3 |
57 | | AFNet-TSM(RT=0.4) | AF-ResNet101 | 12 | 49.8% | 42.1 |
58 | | AFNet-TSM(RT=0.4) | AF-ResNet101 | 12 | 50.1% | 48.9 |
59 |
60 |
61 | ## Training AFNet on Something-Something V1
62 | 1. Specify the directory of datasets with `root_dataset` in `train_sth.sh`.
63 | 2. Please download pretrained backbone on ImageNet from [Google Drive](https://drive.google.com/drive/folders/1UzSckmKnwmgwWObF2_YxpkAIZ2k2mcHL?usp=share_link).
64 | 3. Specify the directory of the downloaded backbone with `path_backbone` in `train_sth.sh`.
65 | 4. Specify the ratio of selected frames with `rt` and run `bash train_sth.sh`.
66 |
67 |
68 |
69 | ## Evaluate pretrained models on Something-Something V1
70 | **Note that there is a small variance during evaluation because of Gumbel-Softmax and the testing results may not align with the numbers in our paper. We provide the logs in Tab 2 for verification.**
71 | 1. Specify the directory of datasets with `root_dataset` in `eval_sth.sh`.
72 | 2. Please download pretrained models from [Google Drive](https://drive.google.com/drive/folders/1-xmE6T6OADmDkkzJr4iM1vCJbA4ofcSO?usp=share_link).
73 | 3. Specify the directory of the pretrained model with `resume` in `eval_sth.sh`.
74 | 4. Run `bash eval_sth.sh`.
75 |
76 |
77 |
78 | ## Reference
79 | If you find our code or paper useful for your research, please cite:
80 | ```
81 | @article{zhang2022look,
82 | title={Look More but Care Less in Video Recognition},
83 | author={Zhang, Yitian and Bai, Yue and Wang, Huan and Xu, Yi and Fu, Yun},
84 | journal={arXiv preprint arXiv:2211.09992},
85 | year={2022}
86 | }
87 | ```
--------------------------------------------------------------------------------
/eval_sth.sh:
--------------------------------------------------------------------------------
1 | ### evaluate AF-ResNet
2 | CUDA_VISIBLE_DEVICES=0,1 python main.py something RGB \
3 | --arch_file AF_ResNet \
4 | --arch AF_resnet50 --num_segments 12 \
5 | --root_dataset 'path_dataset' \
6 | --path_backbone 'path_backbone' \
7 | --batch-size 32 --lr 0.01 --lr_steps 25 45 --epochs 55 \
8 | --gd 20 -j 12 --dropout 0.5 --consensus_type=avg --eval-freq=1 --npb \
9 | --rt_begin 10 --rt_end 20 --t0 1 --t_end 50 --lambda_rt 0.5 \
10 | --model_path 'models' \
11 | --rt 0.5 --round test \
12 | --resume 'path_pretrained_model' \
13 | --evaluate;
14 |
15 |
16 |
17 | ### evaluate AF-ResNet-TSM
18 | CUDA_VISIBLE_DEVICES=0,1 python main.py something RGB \
19 | --arch_file AF_ResNet \
20 | --arch AF_resnet50 --num_segments 12 \
21 | --root_dataset 'path_dataset' \
22 | --path_backbone 'path_backbone' \
23 | --batch-size 32 --lr 0.01 --lr_steps 25 45 --epochs 55 \
24 | --gd 20 -j 12 --dropout 0.5 --consensus_type=avg --eval-freq=1 --npb \
25 | --rt_begin 10 --rt_end 20 --t0 1 --t_end 50 --lambda_rt 0.5 \
26 | --model_path 'models' \
27 | --shift \
28 | --rt 0.5 --round test \
29 | --resume 'path_pretrained_model' \
30 | --evaluate;
31 |
32 |
33 |
34 | ### evaluate AF-MobileNetv3-TSM
35 | CUDA_VISIBLE_DEVICES=0,1 python main.py something RGB \
36 | --arch_file AF_MobileNetv3 \
37 | --arch AF_mobilenetv3 --num_segments 12 \
38 | --root_dataset 'path_dataset' \
39 | --path_backbone 'path_backbone' \
40 | --batch-size 32 --lr 0.01 --lr_steps 25 45 --epochs 55 \
41 | --gd 20 -j 12 --dropout 0.5 --consensus_type=avg --eval-freq=1 --npb \
42 | --rt_begin 10 --rt_end 20 --t0 1 --t_end 50 --lambda_rt 0.5 \
43 | --model_path 'models_mobilenet' \
44 | --shift \
45 | --rt 0.5 --round test \
46 | --resume 'path_pretrained_model' \
47 | --evaluate;
--------------------------------------------------------------------------------
/fig/intro.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BeSpontaneous/AFNet-pytorch/f20b92e6430f7978ef15e537d932381be575bdad/fig/intro.jpeg
--------------------------------------------------------------------------------
/fig/neu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BeSpontaneous/AFNet-pytorch/f20b92e6430f7978ef15e537d932381be575bdad/fig/neu.png
--------------------------------------------------------------------------------
/fig/smile.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BeSpontaneous/AFNet-pytorch/f20b92e6430f7978ef15e537d932381be575bdad/fig/smile.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
2 | # arXiv:1811.08383
3 | # Ji Lin*, Chuang Gan, Song Han
4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
5 |
6 | import os
7 | import time
8 | import shutil
9 | import torch.nn.parallel
10 | import torch.backends.cudnn as cudnn
11 | import torch.optim
12 |
13 | import torch.multiprocessing as mp
14 | import torch.utils.data
15 | import torch.utils.data.distributed
16 | import torch.distributed as dist
17 | from torch.cuda.amp import autocast, GradScaler
18 |
19 | from torch.nn.utils import clip_grad_norm_
20 | import pandas as pd
21 | from ops.dataset import TSNDataSet
22 | import importlib
23 | from ops.transforms import *
24 | from opts import parser
25 | from ops import dataset_config
26 | from ops.utils import AverageMeter, accuracy
27 | from ops.backbone.temporal_shift import make_temporal_pool
28 | from tensorboardX import SummaryWriter
29 |
30 |
31 | best_prec1 = 0
32 | val_acc_top1 = []
33 | val_acc_top5 = []
34 | val_FLOPs = []
35 |
36 | tr_big_rate = []
37 | val_big_rate = []
38 | train_loss_ls = []
39 |
40 | tr_acc_top1 = []
41 | tr_acc_top5 = []
42 | train_loss = []
43 | train_loss_cls = []
44 | valid_loss = []
45 | epoch_log = []
46 |
47 |
48 | def main():
49 | global args, best_prec1
50 | global val_acc_top1
51 | global val_acc_top5
52 | global tr_acc_top1
53 | global tr_acc_top5
54 | global train_loss
55 | global train_loss_cls
56 | global valid_loss
57 | global epoch_log
58 | global tr_big_rate
59 | global val_big_rate
60 | global train_loss_ls
61 | global val_FLOPs
62 | args = parser.parse_args()
63 |
64 | if args.distributed:
65 | dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:8888',
66 | world_size=args.world_size, rank=args.local_rank)
67 | torch.cuda.set_device(args.local_rank)
68 | device = torch.device(f'cuda:{args.local_rank}')
69 |
70 | if not args.distributed or (args.distributed and torch.distributed.get_rank() == 0):
71 | num_class, args.train_list, args.val_list, args.root_path, prefix \
72 | = dataset_config.return_dataset(args.root_dataset, args.dataset, args.modality)
73 | str_round = str(args.round)
74 | args.store_name = f'{args.dataset}/{args.arch_file}/{args.arch}/frame{args.num_segments}/round{str_round}/'
75 | print('storing name: ' + args.store_name)
76 | check_rootfolders()
77 |
78 | path = str('ops.'+args.model_path)
79 | file = importlib.import_module(path)
80 | model = file.TSN(args.arch_file, num_class, args.num_segments, args.modality, args.path_backbone,
81 | base_model=args.arch,
82 | consensus_type=args.consensus_type,
83 | dropout=args.dropout,
84 | img_feature_dim=args.img_feature_dim,
85 | partial_bn=not args.no_partialbn,
86 | pretrain=args.pretrain,
87 | is_shift=args.shift,
88 | fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
89 | temporal_pool=args.temporal_pool,
90 | non_local=args.non_local)
91 |
92 | crop_size = model.crop_size
93 | scale_size = model.scale_size
94 | input_mean = model.input_mean
95 | input_std = model.input_std
96 | policies = model.get_optim_policies()
97 | train_augmentation = model.get_augmentation(flip=False if 'something' in args.dataset or 'jester' in args.dataset else True)
98 |
99 | if args.distributed:
100 | model.to(device)
101 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
102 | output_device=args.local_rank, find_unused_parameters=True)
103 | else:
104 | model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()
105 |
106 | optimizer = torch.optim.SGD(policies,
107 | args.lr,
108 | momentum=args.momentum,
109 | weight_decay=args.weight_decay)
110 |
111 | if args.resume:
112 | if args.temporal_pool: # early temporal pool so that we can load the state_dict
113 | make_temporal_pool(model.module.base_model, args.num_segments)
114 | if os.path.isfile(args.resume):
115 | print(("=> loading checkpoint '{}'".format(args.resume)))
116 | checkpoint = torch.load(args.resume)
117 | args.start_epoch = checkpoint['epoch']
118 | best_prec1 = checkpoint['best_prec1']
119 | model.load_state_dict(checkpoint['state_dict'])
120 | optimizer.load_state_dict(checkpoint['optimizer'])
121 |
122 | val_acc_top1 = checkpoint['val_acc_top1']
123 | val_acc_top5 = checkpoint['val_acc_top5']
124 | val_big_rate = checkpoint['val_big_rate']
125 | val_FLOPs = checkpoint['val_FLOPs']
126 | tr_acc_top1 = checkpoint['tr_acc_top1']
127 | tr_acc_top5 = checkpoint['tr_acc_top5']
128 | train_loss = checkpoint['train_loss']
129 | tr_big_rate = checkpoint['tr_big_rate']
130 | train_loss_cls = checkpoint['train_loss_cls']
131 | train_loss_ls = checkpoint['train_loss_ls']
132 | valid_loss = checkpoint['valid_loss']
133 | epoch_log = checkpoint['epoch_log']
134 |
135 | print(("=> loaded checkpoint '{}' (epoch {})"
136 | .format(args.evaluate, checkpoint['epoch'])))
137 | else:
138 | print(("=> no checkpoint found at '{}'".format(args.resume)))
139 |
140 | if args.tune_from:
141 | print(("=> fine-tuning from '{}'".format(args.tune_from)))
142 | sd = torch.load(args.tune_from)
143 | sd = sd['state_dict']
144 | model_dict = model.state_dict()
145 | replace_dict = []
146 | for k, v in sd.items():
147 | if k not in model_dict and k.replace('.net', '') in model_dict:
148 | print('=> Load after remove .net: ', k)
149 | replace_dict.append((k, k.replace('.net', '')))
150 | for k, v in model_dict.items():
151 | if k not in sd and k.replace('.net', '') in sd:
152 | print('=> Load after adding .net: ', k)
153 | replace_dict.append((k.replace('.net', ''), k))
154 |
155 | for k, k_new in replace_dict:
156 | sd[k_new] = sd.pop(k)
157 | keys1 = set(list(sd.keys()))
158 | keys2 = set(list(model_dict.keys()))
159 | set_diff = (keys1 - keys2) | (keys2 - keys1)
160 | print('#### Notice: keys that failed to load: {}'.format(set_diff))
161 | if args.dataset not in args.tune_from: # new dataset
162 | print('=> New dataset, do not load fc weights')
163 | sd = {k: v for k, v in sd.items() if 'fc' not in k}
164 | if args.modality == 'Flow' and 'Flow' not in args.tune_from:
165 | sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k}
166 | model_dict.update(sd)
167 | model.load_state_dict(model_dict)
168 |
169 | if args.temporal_pool and not args.resume:
170 | make_temporal_pool(model.module.base_model, args.num_segments)
171 |
172 | cudnn.benchmark = True
173 |
174 | # Data loading code
175 | if args.modality != 'RGBDiff':
176 | normalize = GroupNormalize(input_mean, input_std)
177 | else:
178 | normalize = IdentityTransform()
179 |
180 | if args.modality == 'RGB':
181 | data_length = 1
182 | elif args.modality in ['Flow', 'RGBDiff']:
183 | data_length = 5
184 |
185 | train_dataset = TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments,
186 | new_length=data_length,
187 | modality=args.modality,
188 | image_tmpl=prefix,
189 | transform=torchvision.transforms.Compose([
190 | GroupScale((240,320)),
191 | train_augmentation,
192 | Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
193 | ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
194 | normalize,
195 | ]), dense_sample=args.dense_sample)
196 |
197 | val_dataset = TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
198 | new_length=data_length,
199 | modality=args.modality,
200 | image_tmpl=prefix,
201 | random_shift=False,
202 | transform=torchvision.transforms.Compose([
203 | GroupScale((240,320)),
204 | GroupCenterCrop(crop_size),
205 | Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
206 | ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
207 | normalize,
208 | ]), dense_sample=args.dense_sample)
209 |
210 | if args.distributed:
211 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
212 | else:
213 | train_sampler = None
214 |
215 | train_loader = torch.utils.data.DataLoader(
216 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
217 | num_workers=args.workers, pin_memory=True, sampler=train_sampler,
218 | drop_last=True) # prevent something not % n_GPU
219 |
220 | val_loader = torch.utils.data.DataLoader(
221 | val_dataset, batch_size=args.batch_size, shuffle=False,
222 | num_workers=args.workers, pin_memory=True)
223 |
224 | # define loss function (criterion) and optimizer
225 | if args.loss_type == 'nll':
226 | criterion = torch.nn.CrossEntropyLoss().cuda()
227 | else:
228 | raise ValueError("Unknown loss type")
229 |
230 | for group in policies:
231 | print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
232 | group['name'], len(group['params']), group['lr_mult'], group['decay_mult'])))
233 |
234 | if args.evaluate:
235 | validate(val_loader, model, criterion, 0, args.rt)
236 | return
237 |
238 | if not args.distributed or (args.distributed and torch.distributed.get_rank() == 0):
239 | log_training = open(os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
240 | with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f:
241 | f.write(str(args))
242 | tf_writer = SummaryWriter(log_dir=os.path.join(args.root_log, args.store_name))
243 |
244 | if args.amp:
245 | scaler = GradScaler()
246 | else:
247 | scaler = None
248 |
249 | for epoch in range(args.start_epoch, args.epochs):
250 | adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)
251 | rt = adjust_ratio(epoch, args)
252 | # train for one epoch
253 | tr_acc1, tr_acc5, tr_loss, tr_loss_cls, tr_loss_rt, tr_ratios = train(train_loader, model, criterion, optimizer, epoch, rt, log_training, tf_writer, scaler)
254 |
255 | # evaluate on validation set
256 | if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
257 | val_acc1, val_acc5, val_loss, val_ratios, val_flops = validate(val_loader, model, criterion, epoch, rt, log_training, tf_writer)
258 |
259 | if not args.distributed or (args.distributed and torch.distributed.get_rank() == 0):
260 | # remember best prec@1 and save checkpoint
261 | is_best = val_acc1 > best_prec1
262 | best_prec1 = max(val_acc1, best_prec1)
263 | tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch)
264 |
265 | output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
266 | print(output_best)
267 | log_training.write(output_best + '\n')
268 | log_training.flush()
269 |
270 | val_acc_top1.append(val_acc1)
271 | val_acc_top5.append(val_acc5)
272 | val_big_rate.append(val_ratios)
273 | tr_big_rate.append(tr_ratios)
274 | val_FLOPs.append(val_flops)
275 | tr_acc_top1.append(tr_acc1)
276 | tr_acc_top5.append(tr_acc5)
277 | train_loss.append(tr_loss)
278 | train_loss_cls.append(tr_loss_cls)
279 | train_loss_ls.append(tr_loss_rt)
280 | valid_loss.append(val_loss)
281 | epoch_log.append(epoch)
282 |
283 | df = pd.DataFrame({'val_acc_top1': val_acc_top1, 'val_acc_top5': val_acc_top5,
284 | 'val_big_rate': val_big_rate, 'val_FLOPs': val_FLOPs,
285 | 'tr_big_rate': tr_big_rate, 'tr_acc_top1': tr_acc_top1, 'tr_acc_top5': tr_acc_top5,
286 | 'train_loss': train_loss, 'train_loss_cls': train_loss_cls, 'train_loss_ls': train_loss_ls,
287 | 'valid_loss': valid_loss, 'epoch_log': epoch_log})
288 |
289 | log_file = os.path.join(args.root_log, args.store_name, 'log_epoch.txt')
290 | with open(log_file, "w") as f:
291 | df.to_csv(f)
292 |
293 | save_checkpoint({
294 | 'epoch': epoch + 1,
295 | 'arch': args.arch,
296 | 'state_dict': model.state_dict(),
297 | 'optimizer': optimizer.state_dict(),
298 | 'best_prec1': best_prec1,
299 | 'val_acc_top1': val_acc_top1,
300 | 'val_acc_top5': val_acc_top5,
301 | 'val_big_rate': val_big_rate,
302 | 'val_FLOPs': val_FLOPs,
303 | 'tr_big_rate': tr_big_rate,
304 | 'tr_acc_top1': tr_acc_top1,
305 | 'tr_acc_top5': tr_acc_top5,
306 | 'train_loss': train_loss,
307 | 'train_loss_cls': train_loss_cls,
308 | 'train_loss_ls': train_loss_ls,
309 | 'valid_loss': valid_loss,
310 | 'epoch_log': epoch_log,
311 | }, is_best, epoch)
312 |
313 | if not args.distributed or (args.distributed and torch.distributed.get_rank() == 0):
314 | file1 = pd.read_csv(log_file)
315 | acc1 = np.array(file1['val_acc_top1'])
316 | flops1 = np.array(file1['val_FLOPs'])
317 | loc = np.argmax(acc1)
318 | max_acc = acc1[loc]
319 | acc_flops = flops1[loc]
320 | fout = open(os.path.join(args.root_log, args.store_name, 'log_epoch.txt'), mode='a', encoding='utf-8')
321 | fout.write("%.6f\t%.6f" % (max_acc, acc_flops))
322 |
323 |
324 | def train(train_loader, model, criterion, optimizer, epoch, rt, log, tf_writer, scaler=None):
325 | batch_time = AverageMeter()
326 | data_time = AverageMeter()
327 | losses = AverageMeter()
328 | losses_cls = AverageMeter()
329 | losses_rt = AverageMeter()
330 | top1 = AverageMeter()
331 | top5 = AverageMeter()
332 | real_ratios = AverageMeter()
333 | train_batches_num = len(train_loader)
334 |
335 | if args.no_partialbn:
336 | model.module.partialBN(False)
337 | else:
338 | model.module.partialBN(True)
339 |
340 | # switch to train mode
341 | model.train()
342 |
343 | end = time.time()
344 |
345 | if args.amp:
346 | assert scaler is not None
347 |
348 | for i, (input, target) in enumerate(train_loader):
349 | # measure data loading time
350 | data_time.update(time.time() - end)
351 |
352 | target = target.cuda()
353 | input_var = torch.autograd.Variable(input)
354 | target_var = torch.autograd.Variable(target)
355 |
356 | adjust_temperature(epoch, i, train_batches_num, args)
357 | optimizer.zero_grad()
358 |
359 | if args.amp:
360 | with autocast():
361 | # compute output
362 | output, temporal_mask_ls = model(input_var, args.temp)
363 | loss_cls = criterion(output, target_var)
364 |
365 | real_ratio = 0.0
366 | loss_real_ratio = 0.0
367 | for temporal_mask in temporal_mask_ls:
368 | real_ratio += torch.mean(temporal_mask)
369 | loss_real_ratio += torch.pow(rt-torch.mean(temporal_mask), 2)
370 | real_ratio = torch.mean(real_ratio/len(temporal_mask_ls))
371 | loss_real_ratio = torch.mean(loss_real_ratio/len(temporal_mask_ls))
372 | loss_real_ratio = args.lambda_rt * loss_real_ratio
373 | loss = loss_cls + loss_real_ratio
374 |
375 | scaler.scale(loss).backward()
376 | scaler.step(optimizer)
377 | scaler.update()
378 | else:
379 | output, temporal_mask_ls = model(input_var, args.temp)
380 | loss_cls = criterion(output, target_var)
381 |
382 | real_ratio = 0.0
383 | loss_real_ratio = 0.0
384 | for temporal_mask in temporal_mask_ls:
385 | real_ratio += torch.mean(temporal_mask)
386 | loss_real_ratio += torch.pow(rt-torch.mean(temporal_mask), 2)
387 | real_ratio = torch.mean(real_ratio/len(temporal_mask_ls))
388 | loss_real_ratio = torch.mean(loss_real_ratio/len(temporal_mask_ls))
389 | loss_real_ratio = args.lambda_rt * loss_real_ratio
390 | loss = loss_cls + loss_real_ratio
391 |
392 | loss.backward()
393 | if args.clip_gradient is not None:
394 | total_norm = clip_grad_norm_(model.parameters(), args.clip_gradient)
395 | optimizer.step()
396 |
397 |
398 | # measure accuracy and record loss
399 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
400 | real_ratios.update(real_ratio.item(), input.size(0))
401 | losses_cls.update(loss_cls.item(), input.size(0))
402 | losses_rt.update(loss_real_ratio.item(), input.size(0))
403 | losses.update(loss.item(), input.size(0))
404 | top1.update(prec1.item(), input.size(0))
405 | top5.update(prec5.item(), input.size(0))
406 |
407 | # measure elapsed time
408 | batch_time.update(time.time() - end)
409 | end = time.time()
410 |
411 | if not args.distributed or (args.distributed and torch.distributed.get_rank() == 0):
412 | if i % args.print_freq == 0:
413 | output = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
414 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
415 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
416 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
417 | 'Loss_cls {loss_cls.val:.4f} ({loss_cls.avg:.4f})\t'
418 | 'Loss_ls {loss_ls.val:.4f} ({loss_ls.avg:.4f})\t'
419 | 'Real_ratio {real_ratio.val:.4f} ({real_ratio.avg:.4f})\t'
420 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
421 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
422 | epoch, i, len(train_loader), batch_time=batch_time,
423 | data_time=data_time, loss=losses, loss_cls=losses_cls, loss_ls=losses_rt, real_ratio=real_ratios, top1=top1, top5=top5, lr=optimizer.param_groups[-1]['lr'] * 0.1)) # TODO
424 | print(output)
425 | log.write(output + '\n')
426 | log.flush()
427 |
428 | if not args.distributed or (args.distributed and torch.distributed.get_rank() == 0):
429 | tf_writer.add_scalar('loss/train', losses.avg, epoch)
430 | tf_writer.add_scalar('acc/train_top1', top1.avg, epoch)
431 | tf_writer.add_scalar('acc/train_top5', top5.avg, epoch)
432 | tf_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch)
433 |
434 | return top1.avg, top5.avg, losses.avg, losses_cls.avg, losses_rt.avg, real_ratios.avg
435 |
436 |
437 | def validate(val_loader, model, criterion, epoch, rt, log=None, tf_writer=None):
438 | batch_time = AverageMeter()
439 | losses = AverageMeter()
440 | top1 = AverageMeter()
441 | top5 = AverageMeter()
442 | real_ratios = AverageMeter()
443 | FLOPs = AverageMeter()
444 |
445 | # switch to evaluate mode
446 | model.eval()
447 |
448 | end = time.time()
449 | with torch.no_grad():
450 | for i, (input, target) in enumerate(val_loader):
451 | input = input.cuda()
452 | target = target.cuda()
453 |
454 | # compute output
455 | output, temporal_mask_ls, flops = model.module.forward_calc_flops(input, args.t1)
456 | flops /= 1e9
457 | loss_cls = criterion(output, target)
458 |
459 | real_ratio = 0.0
460 | loss_real_ratio = 0.0
461 | for temporal_mask in temporal_mask_ls:
462 | real_ratio += torch.mean(temporal_mask)
463 | loss_real_ratio += torch.pow(rt-torch.mean(temporal_mask), 2)
464 | real_ratio = torch.mean(real_ratio/len(temporal_mask_ls))
465 | loss_real_ratio = torch.mean(loss_real_ratio/len(temporal_mask_ls))
466 | loss_real_ratio = args.lambda_rt * loss_real_ratio
467 |
468 | loss = loss_cls + loss_real_ratio
469 |
470 | # measure accuracy and record loss
471 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
472 |
473 | FLOPs.update(flops.item(), input.size(0))
474 | real_ratios.update(real_ratio.item(), input.size(0))
475 | losses.update(loss.item(), input.size(0))
476 | top1.update(prec1.item(), input.size(0))
477 | top5.update(prec5.item(), input.size(0))
478 |
479 | # measure elapsed time
480 | batch_time.update(time.time() - end)
481 | end = time.time()
482 |
483 | if not args.distributed or (args.distributed and torch.distributed.get_rank() == 0):
484 | if i % args.print_freq == 0:
485 | output = ('Test: [{0}/{1}]\t'
486 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
487 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
488 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
489 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
490 | i, len(val_loader), batch_time=batch_time, loss=losses,
491 | top1=top1, top5=top5))
492 | print(output)
493 | if log is not None:
494 | log.write(output + '\n')
495 | log.flush()
496 |
497 | if not args.distributed or (args.distributed and torch.distributed.get_rank() == 0):
498 | output = ('Testing Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}'
499 | .format(top1=top1, top5=top5, loss=losses))
500 | print(output)
501 | if log is not None:
502 | log.write(output + '\n')
503 | log.flush()
504 |
505 | if tf_writer is not None:
506 | tf_writer.add_scalar('loss/test', losses.avg, epoch)
507 | tf_writer.add_scalar('acc/test_top1', top1.avg, epoch)
508 | tf_writer.add_scalar('acc/test_top5', top5.avg, epoch)
509 |
510 | return top1.avg, top5.avg, losses.avg, real_ratios.avg, FLOPs.avg
511 |
512 |
513 | def save_checkpoint(state, is_best, epoch):
514 | filename = '%s/%s/ckpt.pth.tar' % (args.root_log, args.store_name)
515 | torch.save(state, filename)
516 | if is_best:
517 | shutil.copyfile(filename, filename.replace('pth.tar', 'best.pth.tar'))
518 |
519 |
520 | def adjust_learning_rate(optimizer, epoch, lr_type, lr_steps):
521 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
522 | if lr_type == 'step':
523 | decay = 0.1 ** (sum(epoch >= np.array(lr_steps)))
524 | lr = args.lr * decay
525 | decay = args.weight_decay
526 | elif lr_type == 'cos':
527 | import math
528 | lr = 0.5 * args.lr * (1 + math.cos(math.pi * epoch / args.epochs))
529 | decay = args.weight_decay
530 | else:
531 | raise NotImplementedError
532 | for param_group in optimizer.param_groups:
533 | param_group['lr'] = lr * param_group['lr_mult']
534 | param_group['weight_decay'] = decay * param_group['decay_mult']
535 |
536 |
537 | def check_rootfolders():
538 | """Create log and model folder"""
539 | folders_util = [args.root_log, os.path.join(args.root_log, args.store_name)]
540 | for folder in folders_util:
541 | if not os.path.exists(folder):
542 | print('creating folder ' + folder)
543 | os.makedirs(folder)
544 |
545 |
546 |
547 | def adjust_temperature(epoch, step, len_epoch, args):
548 | if epoch >= args.t_end:
549 | return args.t1
550 | else:
551 | T_total = args.t_end * len_epoch
552 | T_cur = epoch * len_epoch + step
553 | alpha = math.pow(args.t1 / args.t0, 1 / T_total)
554 | args.temp = math.pow(alpha, T_cur) * args.t0
555 |
556 |
557 | def adjust_ratio(epoch, args):
558 | if epoch < args.rt_begin :
559 | rt = 1.0
560 | elif epoch < args.rt_begin + (args.rt_end-args.rt_begin)//2:
561 | rt = args.rt + (1.0 - args.rt)/3*2
562 | elif epoch < args.rt_end:
563 | rt = args.rt + (1.0 - args.rt)/3
564 | else:
565 | rt = args.rt
566 | return rt
567 |
568 |
569 | if __name__ == '__main__':
570 | main()
571 |
--------------------------------------------------------------------------------
/ops/__init__.py:
--------------------------------------------------------------------------------
1 | from ops.basic_ops import *
--------------------------------------------------------------------------------
/ops/backbone/AF_MobileNetv3.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import torch.nn as nn
3 | import math
4 | import torch
5 | from .gumbel_softmax import GumbleSoftmax
6 |
7 |
8 | __all__ = ['AF_mobilenetv3']
9 |
10 |
11 |
12 | class TSM(nn.Module):
13 | def __init__(self):
14 | super(TSM, self).__init__()
15 | self.fold_div = 8
16 |
17 | def forward(self, x, n_segment):
18 | x = self.shift(x, n_segment, fold_div=self.fold_div)
19 | return x
20 |
21 | @staticmethod
22 | def shift(x, n_segment, fold_div=3):
23 | if type(n_segment) is int:
24 | nt, c, h, w = x.size()
25 | n_batch = nt // n_segment
26 | x = x.view(n_batch, n_segment, c, h, w)
27 |
28 | fold = c // fold_div
29 | out = torch.zeros_like(x)
30 | out[:, :-1, :fold] = x[:, 1:, :fold] # shift left
31 | out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right
32 | out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift
33 | shift_out = out.view(nt, c, h, w)
34 | else:
35 | num_segment = int(n_segment.sum())
36 | ls = n_segment
37 | bool_list = ls > 0
38 | bool_list = bool_list.view(-1)
39 |
40 | shift_out = torch.zeros_like(x)
41 | x = x[bool_list]
42 | nt, c, h, w = x.size()
43 | x = x.view(-1, num_segment, c, h, w)
44 |
45 | fold = c // fold_div
46 | out = torch.zeros_like(x)
47 | out[:, :-1, :fold] = x[:, 1:, :fold] # shift left
48 | out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right
49 | out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift
50 | out = out.view(-1, c, h, w)
51 | shift_out[bool_list] = out
52 |
53 | return shift_out
54 |
55 |
56 | class dynamic_fusion(nn.Module):
57 | def __init__(self, channel, reduction=16):
58 | super(dynamic_fusion, self).__init__()
59 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
60 | self.reduction = reduction
61 | self.fc = nn.Sequential(
62 | nn.Linear(channel, int(channel // reduction)),
63 | nn.ReLU(inplace=True),
64 | nn.Linear(int(channel // reduction), channel),
65 | nn.Sigmoid()
66 | )
67 | for m in self.modules():
68 | if isinstance(m, nn.Linear):
69 | nn.init.normal_(m.weight, 0, 0.01)
70 |
71 | def forward(self, x):
72 | b, c, h, w = x.size()
73 | y = self.avg_pool(x).view(b,c)
74 | attention = self.fc(y)
75 | return attention.view(b,c,1,1)
76 |
77 | def forward_calc_flops(self, x):
78 | b, c, h, w = x.size()
79 | flops = c*h*w
80 | y = self.avg_pool(x).view(b,c)
81 | attention = self.fc(y)
82 | flops += c*c//self.reduction*2 + c
83 | return attention.view(b,c,1,1), flops
84 |
85 |
86 | def _make_divisible(v, divisor, min_value=None):
87 | """
88 | This function is taken from the original tf repo.
89 | It ensures that all layers have a channel number that is divisible by 8
90 | It can be seen here:
91 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
92 | :param v:
93 | :param divisor:
94 | :param min_value:
95 | :return:
96 | """
97 | if min_value is None:
98 | min_value = divisor
99 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
100 | # Make sure that round down does not go down by more than 10%.
101 | if new_v < 0.9 * v:
102 | new_v += divisor
103 | return new_v
104 |
105 |
106 | class h_sigmoid(nn.Module):
107 | def __init__(self, inplace=True):
108 | super(h_sigmoid, self).__init__()
109 | self.relu = nn.ReLU6(inplace=inplace)
110 |
111 | def forward(self, x):
112 | return self.relu(x + 3) / 6
113 |
114 |
115 | class h_swish(nn.Module):
116 | def __init__(self, inplace=True):
117 | super(h_swish, self).__init__()
118 | self.sigmoid = h_sigmoid(inplace=inplace)
119 |
120 | def forward(self, x):
121 | return x * self.sigmoid(x)
122 |
123 |
124 | class SELayer(nn.Module):
125 | def __init__(self, channel, reduction=4):
126 | super(SELayer, self).__init__()
127 | self.reduction = reduction
128 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
129 | self.fc = nn.Sequential(
130 | nn.Linear(channel, _make_divisible(channel // reduction, 8)),
131 | nn.ReLU(inplace=True),
132 | nn.Linear(_make_divisible(channel // reduction, 8), channel),
133 | h_sigmoid()
134 | )
135 |
136 | def forward(self, x):
137 | b, c, _, _ = x.size()
138 | y = self.avg_pool(x).view(b, c)
139 | y = self.fc(y).view(b, c, 1, 1)
140 | return x * y
141 |
142 | def forward_calc_flops(self, x):
143 | b, c, h, w = x.size()
144 | flops = c*h*w
145 | y = self.avg_pool(x).view(b,c)
146 | y = self.fc(y).view(b, c, 1, 1)
147 | flops += c*c//self.reduction*2 + c
148 | return x * y, flops
149 |
150 |
151 | def conv_3x3_bn(inp, oup, stride):
152 | return nn.Sequential(
153 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
154 | nn.BatchNorm2d(oup),
155 | h_swish()
156 | )
157 |
158 |
159 | def conv_1x1_bn(inp, oup):
160 | return nn.Sequential(
161 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
162 | nn.BatchNorm2d(oup),
163 | h_swish()
164 | )
165 |
166 |
167 | class InvertedResidual_ample(nn.Module):
168 | def __init__(self, n_segment, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs):
169 | super(InvertedResidual_ample, self).__init__()
170 | assert stride in [1, 2]
171 |
172 | self.identity = stride == 1 and inp == oup
173 | self.tsm = TSM()
174 | self.inp = inp
175 | self.hidden_dim = hidden_dim
176 | self.use_se = use_se
177 |
178 | if inp == hidden_dim:
179 | # dw
180 | self.conv1 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False)
181 | self.bn1 = nn.BatchNorm2d(hidden_dim)
182 | self.act1 = h_swish() if use_hs else nn.ReLU(inplace=True)
183 | # Squeeze-and-Excite
184 | self.se = SELayer(hidden_dim) if use_se else nn.Identity()
185 | # pw-linear
186 | self.conv2 = nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False)
187 | self.bn2 = nn.BatchNorm2d(oup)
188 | else:
189 | # pw
190 | self.conv1 = nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False)
191 | self.bn1 = nn.BatchNorm2d(hidden_dim)
192 | self.act1 = h_swish() if use_hs else nn.ReLU(inplace=True)
193 | # dw
194 | self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False)
195 | self.bn2 = nn.BatchNorm2d(hidden_dim)
196 | # Squeeze-and-Excite
197 | self.se = SELayer(hidden_dim) if use_se else nn.Identity()
198 | self.act2 = h_swish() if use_hs else nn.ReLU(inplace=True)
199 | # pw-linear
200 | self.conv3 = nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False)
201 | self.bn3 = nn.BatchNorm2d(oup)
202 |
203 | def forward(self, x, list_little):
204 | residual = x
205 | if self.inp == self.hidden_dim:
206 | x = self.tsm(x, list_little)
207 | x = self.conv1(x)
208 | x = self.bn1(x)
209 | x = self.act1(x)
210 | x = self.se(x)
211 | x = self.conv2(x)
212 | x = self.bn2(x)
213 | else:
214 | x = self.tsm(x, list_little)
215 | x = self.conv1(x)
216 | x = self.bn1(x)
217 | x = self.act1(x)
218 | x = self.conv2(x)
219 | x = self.bn2(x)
220 | x = self.se(x)
221 | x = self.act2(x)
222 | x = self.conv3(x)
223 | x = self.bn3(x)
224 |
225 | if self.identity:
226 | return x + residual
227 | else:
228 | return x
229 |
230 | def forward_calc_flops(self, x, list_little):
231 | flops = 0
232 | residual = x
233 | if self.inp == self.hidden_dim:
234 | x = self.tsm(x, list_little)
235 |
236 | c_in = x.shape[1]
237 | x = self.conv1(x)
238 | flops += c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv1.kernel_size[0] * self.conv1.kernel_size[1] / self.conv1.groups
239 |
240 | x = self.bn1(x)
241 | x = self.act1(x)
242 | if self.use_se == True:
243 | x, _flops = self.se.forward_calc_flops(x)
244 | flops += _flops
245 | else:
246 | x = self.se(x)
247 |
248 | c_in = x.shape[1]
249 | x = self.conv2(x)
250 | flops += c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv2.kernel_size[0] * self.conv2.kernel_size[1] / self.conv2.groups
251 | x = self.bn2(x)
252 | else:
253 | x = self.tsm(x, list_little)
254 |
255 | c_in = x.shape[1]
256 | x = self.conv1(x)
257 | flops += c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv1.kernel_size[0] * self.conv1.kernel_size[1] / self.conv1.groups
258 | x = self.bn1(x)
259 | x = self.act1(x)
260 |
261 | c_in = x.shape[1]
262 | x = self.conv2(x)
263 | flops += c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv2.kernel_size[0] * self.conv2.kernel_size[1] / self.conv2.groups
264 | x = self.bn2(x)
265 | if self.use_se == True:
266 | x, _flops = self.se.forward_calc_flops(x)
267 | flops += _flops
268 | else:
269 | x = self.se(x)
270 | x = self.act2(x)
271 |
272 | c_in = x.shape[1]
273 | x = self.conv3(x)
274 | flops += c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv3.kernel_size[0] * self.conv3.kernel_size[1] / self.conv3.groups
275 | x = self.bn3(x)
276 | if self.identity:
277 | return x + residual, flops
278 | else:
279 | return x, flops
280 |
281 |
282 | class InvertedResidual_focal(nn.Module):
283 | def __init__(self, n_segment, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs):
284 | super(InvertedResidual_focal, self).__init__()
285 | assert stride in [1, 2]
286 | self.n_segment = n_segment
287 | self.identity = stride == 1 and inp == oup
288 | self.tsm = TSM()
289 | if stride != 1 or inp != oup:
290 | self.res_connect = nn.Sequential(
291 | nn.Conv2d(inp, oup, kernel_size=1, stride=stride, padding=0, bias=False, groups=2),
292 | nn.BatchNorm2d(oup)
293 | )
294 |
295 | self.inp = inp
296 | self.hidden_dim = hidden_dim
297 | self.use_se = use_se
298 |
299 | if inp == hidden_dim:
300 | # dw
301 | self.conv1 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False)
302 | self.bn1 = nn.BatchNorm2d(hidden_dim)
303 | self.act1 = h_swish() if use_hs else nn.ReLU(inplace=True)
304 | # Squeeze-and-Excite
305 | self.se = SELayer(hidden_dim) if use_se else nn.Identity()
306 | # pw-linear
307 | self.conv2 = nn.Conv2d(hidden_dim, oup, 1, 1, 0, groups=2, bias=False)
308 | self.bn2 = nn.BatchNorm2d(oup)
309 | else:
310 | # pw
311 | self.conv1 = nn.Conv2d(inp, hidden_dim, 1, 1, 0, groups=2, bias=False)
312 | self.bn1 = nn.BatchNorm2d(hidden_dim)
313 | self.act1 = h_swish() if use_hs else nn.ReLU(inplace=True)
314 | # dw
315 | self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False)
316 | self.bn2 = nn.BatchNorm2d(hidden_dim)
317 | # Squeeze-and-Excite
318 | self.se = SELayer(hidden_dim) if use_se else nn.Identity()
319 | self.act2 = h_swish() if use_hs else nn.ReLU(inplace=True)
320 | # pw-linear
321 | self.conv3 = nn.Conv2d(hidden_dim, oup, 1, 1, 0, groups=2, bias=False)
322 | self.bn3 = nn.BatchNorm2d(oup)
323 |
324 | def forward(self, x, list_big):
325 | if self.identity:
326 | residual = x
327 | else:
328 | residual = self.res_connect(x)
329 |
330 | if self.inp == self.hidden_dim:
331 | x = self.tsm(x, self.n_segment)
332 | x = x * list_big
333 | x = self.conv1(x)
334 | x = self.bn1(x)
335 | x = self.act1(x)
336 | x = self.se(x)
337 | x = self.conv2(x)
338 | x = self.bn2(x)
339 | else:
340 | x = self.tsm(x, self.n_segment)
341 | x = x * list_big
342 | x = self.conv1(x)
343 | x = self.bn1(x)
344 | x = self.act1(x)
345 | x = self.conv2(x)
346 | x = self.bn2(x)
347 | x = self.se(x)
348 | x = self.act2(x)
349 | x = self.conv3(x)
350 | x = self.bn3(x)
351 |
352 | return x + residual
353 |
354 |
355 | def forward_calc_flops(self, x, list_big):
356 | flops = 0
357 |
358 | if self.identity:
359 | residual = x
360 | else:
361 | c_in = x.shape[1]
362 | residual = self.res_connect(x)
363 | flops += c_in * residual.shape[1] * residual.shape[2] * residual.shape[3] / self.res_connect[0].groups
364 |
365 | if self.inp == self.hidden_dim:
366 | x = self.tsm(x, self.n_segment)
367 | x = x * list_big
368 | select_ratio = torch.mean(list_big)
369 | # select_ratio = 1
370 |
371 | c_in = x.shape[1]
372 | x = self.conv1(x)
373 | flops += select_ratio * c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv1.kernel_size[0] * self.conv1.kernel_size[1] / self.conv1.groups
374 |
375 | x = self.bn1(x)
376 | x = self.act1(x)
377 | if self.use_se == True:
378 | x, _flops = self.se.forward_calc_flops(x)
379 | flops += select_ratio * _flops
380 | else:
381 | x = self.se(x)
382 |
383 | c_in = x.shape[1]
384 | x = self.conv2(x)
385 | flops += select_ratio * c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv2.kernel_size[0] * self.conv2.kernel_size[1] / self.conv2.groups
386 | x = self.bn2(x)
387 | else:
388 | x = self.tsm(x, self.n_segment)
389 | x = x * list_big
390 | select_ratio = torch.mean(list_big)
391 | # select_ratio = 1
392 |
393 | c_in = x.shape[1]
394 | x = self.conv1(x)
395 | flops += select_ratio * c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv1.kernel_size[0] * self.conv1.kernel_size[1] / self.conv1.groups
396 | x = self.bn1(x)
397 | x = self.act1(x)
398 |
399 | c_in = x.shape[1]
400 | x = self.conv2(x)
401 | flops += select_ratio * c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv2.kernel_size[0] * self.conv2.kernel_size[1] / self.conv2.groups
402 | x = self.bn2(x)
403 | if self.use_se == True:
404 | x, _flops = self.se.forward_calc_flops(x)
405 | flops += select_ratio * _flops
406 | else:
407 | x = self.se(x)
408 | x = self.act2(x)
409 |
410 | c_in = x.shape[1]
411 | x = self.conv3(x)
412 | flops += select_ratio * c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv3.kernel_size[0] * self.conv3.kernel_size[1] / self.conv3.groups
413 | x = self.bn3(x)
414 |
415 | return x + residual, flops
416 |
417 |
418 |
419 | class navigation(nn.Module):
420 | def __init__(self, inplanes=64, num_segments=8):
421 | super(navigation,self).__init__()
422 | self.num_segments = num_segments
423 | self.conv_pool = nn.Conv2d(inplanes, 2, kernel_size=1, padding=0, stride=1, bias=False)
424 | self.bn = nn.BatchNorm2d(2)
425 | self.relu = nn.ReLU(inplace=True)
426 | self.pool = nn.AdaptiveAvgPool2d((1,1))
427 | self.conv_gs = nn.Conv2d(2*num_segments, 2*num_segments, kernel_size=1, padding=0, stride=1, bias=True, groups=num_segments)
428 | self.conv_gs.bias.data[:2*num_segments:2] = 1.0
429 | self.conv_gs.bias.data[1:2*num_segments+1:2] = 10.0
430 | self.gs = GumbleSoftmax()
431 |
432 | def forward(self, x, temperature=1.0):
433 | gates = self.pool(x)
434 | gates = self.conv_pool(gates)
435 | gates = self.bn(gates)
436 | gates = self.relu(gates)
437 |
438 | batch = x.shape[0] // self.num_segments
439 |
440 | gates = gates.view(batch, self.num_segments*2,1,1)
441 | gates = self.conv_gs(gates)
442 |
443 | gates = gates.view(batch, self.num_segments, 2, 1, 1)
444 | gates = self.gs(gates, temp=temperature, force_hard=True)
445 | list_big = gates[:, :, 1, :, :]
446 | list_big = list_big.view(x.shape[0],1,1,1)
447 |
448 | return list_big
449 |
450 | def forward_calc_flops(self, x, temperature=1.0):
451 | flops = 0
452 |
453 | flops += x.shape[1] * x.shape[2] * x.shape[3]
454 | gates = self.pool(x)
455 |
456 | c_in = gates.shape[1]
457 | gates = self.conv_pool(gates)
458 | flops += c_in * gates.shape[1] * gates.shape[2] * gates.shape[3]
459 | gates = self.bn(gates)
460 | gates = self.relu(gates)
461 |
462 | batch = x.shape[0] // self.num_segments
463 |
464 | gates = gates.view(batch, self.num_segments*2,1,1)
465 | gates = self.conv_gs(gates)
466 | flops += self.num_segments * 2 * gates.shape[1] * gates.shape[2] * gates.shape[3] / self.conv_gs.groups
467 |
468 | gates = gates.view(batch, self.num_segments, 2, 1, 1)
469 | gates = self.gs(gates, temp=temperature, force_hard=True)
470 | list_big = gates[:, :, 1, :, :]
471 | list_big = list_big.view(x.shape[0],1,1,1)
472 |
473 | return list_big, flops
474 |
475 |
476 |
477 | class AFMobileNetV3(nn.Module):
478 | def __init__(self, num_segments, num_class, cfgs_head, cfgs_stage1, cfgs_stage2_ample,
479 | cfgs_stage2_focal, cfgs_stage2_fuse, cfgs_stage3_ample, cfgs_stage3_focal,
480 | cfgs_stage3_fuse, cfgs_stage4, cfgs_stage5, mode, width_mult=1.):
481 | super(AFMobileNetV3, self).__init__()
482 | # setting of inverted residual blocks
483 | self.num_segments = num_segments
484 | self.cfgs_head = cfgs_head
485 | self.cfgs_stage1 = cfgs_stage1
486 | self.cfgs_stage2_ample = cfgs_stage2_ample
487 | self.cfgs_stage2_focal = cfgs_stage2_focal
488 | self.cfgs_stage2_fuse = cfgs_stage2_fuse
489 | self.cfgs_stage3_ample = cfgs_stage3_ample
490 | self.cfgs_stage3_focal = cfgs_stage3_focal
491 | self.cfgs_stage3_fuse = cfgs_stage3_fuse
492 | self.cfgs_stage4 = cfgs_stage4
493 | self.cfgs_stage5 = cfgs_stage5
494 | assert mode in ['large', 'small']
495 |
496 | # building first layer
497 | input_channel = _make_divisible(16 * width_mult, 8)
498 | self.conv = nn.Conv2d(3, input_channel, 3, 2, 1, bias=False)
499 | self.bn = nn.BatchNorm2d(input_channel)
500 | self.act = h_swish()
501 | # building inverted residual blocks
502 | block_base = InvertedResidual_ample
503 | block_refine = InvertedResidual_focal
504 |
505 | layers = []
506 | for k, t, c, use_se, use_hs, s in self.cfgs_head:
507 | output_channel = _make_divisible(c * width_mult, 8)
508 | exp_size = _make_divisible(input_channel * t, 8)
509 | layers.append(block_base(num_segments, input_channel, exp_size, output_channel, k, s, use_se, use_hs))
510 | input_channel = output_channel
511 | self.features_head = nn.Sequential(*layers)
512 |
513 |
514 | ###### stage 1
515 | layers_stage1 = []
516 | for k, t, c, use_se, use_hs, s in self.cfgs_stage1:
517 | output_channel = _make_divisible(c * width_mult, 8)
518 | exp_size = _make_divisible(input_channel * t, 8)
519 | layers_stage1.append(block_base(num_segments, input_channel, exp_size, output_channel, k, s, use_se, use_hs))
520 | input_channel = output_channel
521 | self.features_stage1 = nn.Sequential(*layers_stage1)
522 |
523 |
524 | ###### stage 2
525 | input_channel_before = input_channel
526 | layers_stage2_ample = []
527 | frame_gen_list_stage2 = []
528 | for k, t, c, use_se, use_hs, s in self.cfgs_stage2_ample:
529 | output_channel = _make_divisible(c * width_mult, 8)
530 | exp_size = _make_divisible(input_channel * t, 8)
531 | layers_stage2_ample.append(block_base(num_segments, input_channel, exp_size, output_channel, k, s, use_se, use_hs))
532 | frame_gen_list_stage2.append(navigation(inplanes=input_channel,num_segments=num_segments))
533 | input_channel = output_channel
534 | self.list_gen2 = nn.ModuleList(frame_gen_list_stage2)
535 | self.features_stage2_base = nn.Sequential(*layers_stage2_ample)
536 |
537 | layers_stage2_focal = []
538 | for k, t, c, use_se, use_hs, s in self.cfgs_stage2_focal:
539 | output_channel = _make_divisible(c * width_mult, 8)
540 | exp_size = _make_divisible(input_channel_before * t, 8)
541 | layers_stage2_focal.append(block_refine(num_segments, input_channel_before, exp_size, output_channel, k, s, use_se, use_hs))
542 | input_channel_before = output_channel
543 | input_channel = input_channel_before
544 | self.features_stage2_refine = nn.Sequential(*layers_stage2_focal)
545 |
546 | layers_stage2_fuse = []
547 | for k, t, c, use_se, use_hs, s in self.cfgs_stage2_fuse:
548 | output_channel = _make_divisible(c * width_mult, 8)
549 | exp_size = _make_divisible(input_channel * t, 8)
550 | layers_stage2_fuse.append(block_base(num_segments, input_channel, exp_size, output_channel, k, s, use_se, use_hs))
551 | input_channel = output_channel
552 | self.features_stage2_fuse = nn.Sequential(*layers_stage2_fuse)
553 | self.att_gen2 = dynamic_fusion(channel=input_channel, reduction=16)
554 |
555 |
556 | ###### stage 3
557 | input_channel_before = input_channel
558 | layers_stage3_ample = []
559 | frame_gen_list_stage3 = []
560 | for k, t, c, use_se, use_hs, s in self.cfgs_stage3_ample:
561 | output_channel = _make_divisible(c * width_mult, 8)
562 | exp_size = _make_divisible(input_channel * t, 8)
563 | layers_stage3_ample.append(block_base(num_segments, input_channel, exp_size, output_channel, k, s, use_se, use_hs))
564 | frame_gen_list_stage3.append(navigation(inplanes=input_channel,num_segments=num_segments))
565 | input_channel = output_channel
566 | self.list_gen3 = nn.ModuleList(frame_gen_list_stage3)
567 | self.features_stage3_base = nn.Sequential(*layers_stage3_ample)
568 |
569 | layers_stage3_focal = []
570 | for k, t, c, use_se, use_hs, s in self.cfgs_stage3_focal:
571 | output_channel = _make_divisible(c * width_mult, 8)
572 | exp_size = _make_divisible(input_channel_before * t, 8)
573 | layers_stage3_focal.append(block_refine(num_segments, input_channel_before, exp_size, output_channel, k, s, use_se, use_hs))
574 | input_channel_before = output_channel
575 | input_channel = input_channel_before
576 | self.features_stage3_refine = nn.Sequential(*layers_stage3_focal)
577 |
578 | layers_stage3_fuse = []
579 | for k, t, c, use_se, use_hs, s in self.cfgs_stage3_fuse:
580 | output_channel = _make_divisible(c * width_mult, 8)
581 | exp_size = _make_divisible(input_channel * t, 8)
582 | layers_stage3_fuse.append(block_base(num_segments, input_channel, exp_size, output_channel, k, s, use_se, use_hs))
583 | input_channel = output_channel
584 | self.features_stage3_fuse = nn.Sequential(*layers_stage3_fuse)
585 | self.att_gen3 = dynamic_fusion(channel=input_channel, reduction=16)
586 |
587 |
588 | ###### stage 4
589 | layers_stage4 = []
590 | for k, t, c, use_se, use_hs, s in self.cfgs_stage4:
591 | output_channel = _make_divisible(c * width_mult, 8)
592 | exp_size = _make_divisible(input_channel * t, 8)
593 | layers_stage4.append(block_base(num_segments, input_channel, exp_size, output_channel, k, s, use_se, use_hs))
594 | input_channel = output_channel
595 | self.features_stage4 = nn.Sequential(*layers_stage4)
596 |
597 | ###### stage 5
598 | layers_stage5 = []
599 | for k, t, c, use_se, use_hs, s in self.cfgs_stage5:
600 | output_channel = _make_divisible(c * width_mult, 8)
601 | exp_size = _make_divisible(input_channel * t, 8)
602 | layers_stage5.append(block_base(num_segments, input_channel, exp_size, output_channel, k, s, use_se, use_hs))
603 | input_channel = output_channel
604 | self.features_stage5 = nn.Sequential(*layers_stage5)
605 |
606 | # building last several layers
607 | # self.conv_last = conv_1x1_bn(input_channel, exp_size)
608 | self.conv_last = nn.Conv2d(input_channel, exp_size, 1, 1, 0, bias=False)
609 | self.bn_last = nn.BatchNorm2d(exp_size)
610 | self.act_last = h_swish()
611 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
612 | output_channel = {'large': 1280, 'small': 1024}
613 | self.output_channel_num = output_channel[mode]
614 | output_channel = _make_divisible(output_channel[mode] * width_mult, 8) if width_mult > 1.0 else output_channel[mode]
615 | self.fc = nn.Sequential(
616 | nn.Linear(exp_size, output_channel),
617 | h_swish(),
618 | nn.Dropout(0.5),
619 | nn.Linear(output_channel, num_class),
620 | )
621 |
622 | self._initialize_weights()
623 |
624 | def forward(self, x, temperature=1e-8):
625 | _lists = []
626 | x = self.conv(x)
627 | x = self.bn(x)
628 | x = self.act(x)
629 |
630 | x = self.features_head[0](x, self.num_segments)
631 |
632 |
633 | for i in range(len(self.features_stage1)):
634 | x = self.features_stage1[i](x, self.num_segments)
635 |
636 |
637 | x_base = x
638 | x_refine = x
639 | for i in range(len(self.features_stage2_base)):
640 | list_big = self.list_gen2[i](x_base, temperature=temperature)
641 | _lists.append(list_big)
642 | x_base = self.features_stage2_base[i](x_base, self.num_segments)
643 | x_refine = self.features_stage2_refine[i](x_refine, list_big)
644 | _,_,h,w = x_refine.shape
645 | x_base = F.interpolate(x_base, size = (h,w))
646 | att = self.att_gen2(x_base+x_refine)
647 | x = self.features_stage2_fuse[0](att*x_base + (1-att)*x_refine, self.num_segments)
648 |
649 |
650 | x_base = x
651 | x_refine = x
652 | for i in range(len(self.features_stage3_base)):
653 | list_big = self.list_gen3[i](x_base, temperature=temperature)
654 | _lists.append(list_big)
655 | x_base = self.features_stage3_base[i](x_base, self.num_segments)
656 | x_refine = self.features_stage3_refine[i](x_refine, list_big)
657 | _,_,h,w = x_refine.shape
658 | x_base = F.interpolate(x_base, size = (h,w))
659 | att = self.att_gen3(x_base+x_refine)
660 | x = self.features_stage3_fuse[0](att*x_base + (1-att)*x_refine, self.num_segments)
661 |
662 |
663 | for i in range(len(self.features_stage4)):
664 | x = self.features_stage4[i](x, self.num_segments)
665 | for i in range(len(self.features_stage5)):
666 | x = self.features_stage5[i](x, self.num_segments)
667 |
668 |
669 | x = self.conv_last(x)
670 | x = self.bn_last(x)
671 | x = self.act_last(x)
672 | x = self.avgpool(x)
673 | x = x.view(x.size(0), -1)
674 |
675 |
676 | x = self.fc(x)
677 |
678 | return x, _lists
679 |
680 | def forward_calc_flops(self, x, temperature=1e-8):
681 | flops = 0
682 | _lists = []
683 |
684 | c_in = x.shape[1]
685 | x = self.conv(x)
686 | flops += c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv.kernel_size[0] * self.conv.kernel_size[1] / self.conv.groups
687 | x = self.bn(x)
688 | x = self.act(x)
689 |
690 | x, _flops = self.features_head[0].forward_calc_flops(x, self.num_segments)
691 | flops += _flops
692 |
693 |
694 | for i in range(len(self.features_stage1)):
695 | x, _flops = self.features_stage1[i].forward_calc_flops(x, self.num_segments)
696 | flops += _flops
697 |
698 |
699 | x_base = x
700 | x_refine = x
701 | for i in range(len(self.features_stage2_base)):
702 | list_big, _flops = self.list_gen2[i].forward_calc_flops(x_base, temperature=temperature)
703 | _lists.append(list_big)
704 | flops += _flops
705 | x_base, _flops = self.features_stage2_base[i].forward_calc_flops(x_base, self.num_segments)
706 | flops += _flops
707 | x_refine, _flops = self.features_stage2_refine[i].forward_calc_flops(x_refine, list_big)
708 | flops += _flops
709 | _,_,h,w = x_refine.shape
710 | x_base = F.interpolate(x_base, size = (h,w))
711 | att, _flops = self.att_gen2.forward_calc_flops(x_base+x_refine)
712 | flops += _flops
713 | x, _flops = self.features_stage2_fuse[0].forward_calc_flops(att*x_base + (1-att)*x_refine, self.num_segments)
714 | flops += _flops
715 |
716 |
717 | x_base = x
718 | x_refine = x
719 | for i in range(len(self.features_stage3_base)):
720 | list_big, _flops = self.list_gen3[i].forward_calc_flops(x_base, temperature=temperature)
721 | _lists.append(list_big)
722 | flops += _flops
723 | x_base, _flops = self.features_stage3_base[i].forward_calc_flops(x_base, self.num_segments)
724 | flops += _flops
725 | x_refine, _flops = self.features_stage3_refine[i].forward_calc_flops(x_refine, list_big)
726 | flops += _flops
727 | _,_,h,w = x_refine.shape
728 | x_base = F.interpolate(x_base, size = (h,w))
729 | att, _flops = self.att_gen3.forward_calc_flops(x_base+x_refine)
730 | flops += _flops
731 | x, _flops = self.features_stage3_fuse[0].forward_calc_flops(att*x_base + (1-att)*x_refine, self.num_segments)
732 | flops += _flops
733 |
734 |
735 | for i in range(len(self.features_stage4)):
736 | x, _flops = self.features_stage4[i].forward_calc_flops(x, self.num_segments)
737 | flops += _flops
738 |
739 | for i in range(len(self.features_stage5)):
740 | x, _flops = self.features_stage5[i].forward_calc_flops(x, self.num_segments)
741 | flops += _flops
742 |
743 | c_in = x.shape[1]
744 | x = self.conv_last(x)
745 | flops += c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv_last.kernel_size[0] * self.conv_last.kernel_size[1] / self.conv_last.groups
746 | x = self.bn_last(x)
747 | x = self.act_last(x)
748 | x = self.avgpool(x)
749 | x = x.view(x.size(0), -1)
750 |
751 |
752 | c_in = x.shape[1]
753 | x = self.fc(x)
754 | c_out = x.shape[1]
755 | flops += c_in * self.output_channel_num + self.output_channel_num * c_out
756 |
757 | return x, _lists, self.num_segments * flops
758 |
759 | def _initialize_weights(self):
760 | for m in self.modules():
761 | if isinstance(m, nn.Conv2d):
762 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
763 | m.weight.data.normal_(0, math.sqrt(2. / n))
764 | # if m.bias is not None:
765 | # m.bias.data.zero_()
766 | elif isinstance(m, nn.BatchNorm2d):
767 | m.weight.data.fill_(1)
768 | m.bias.data.zero_()
769 | elif isinstance(m, nn.Linear):
770 | m.weight.data.normal_(0, 0.01)
771 | m.bias.data.zero_()
772 |
773 |
774 | def af_mobilenetv3(num_segments, num_class, **kwargs):
775 | """
776 | Constructs a MobileNetV3-Large model
777 | """
778 | cfgs_head = [
779 | # k, t, c, SE, HS, s
780 | [3, 1, 16, 0, 0, 1]]
781 | cfgs_stage1 = [
782 | [3, 4, 24, 0, 0, 2],
783 | [3, 3, 24, 0, 0, 2]]
784 | cfgs_stage2_ample = [
785 | [5, 3, 40, 1, 0, 2],
786 | [5, 3, 40, 1, 0, 1]]
787 | cfgs_stage2_focal = [
788 | [5, 3, 40, 1, 0, 1],
789 | [5, 3, 40, 1, 0, 1]]
790 | cfgs_stage2_fuse = [
791 | [5, 3, 40, 1, 0, 2]]
792 | cfgs_stage3_ample = [
793 | [3, 6, 80, 0, 1, 2],
794 | [3, 2.5, 80, 0, 1, 1],
795 | [3, 2.3, 80, 0, 1, 1]]
796 | cfgs_stage3_focal = [
797 | [3, 6, 80, 0, 1, 1],
798 | [3, 2.5, 80, 0, 1, 1],
799 | [3, 2.3, 80, 0, 1, 1]]
800 | cfgs_stage3_fuse = [
801 | [3, 2.3, 80, 0, 1, 1]]
802 | cfgs_stage4 = [
803 | [3, 6, 112, 1, 1, 1],
804 | [3, 6, 112, 1, 1, 1]]
805 | cfgs_stage5 = [
806 | [5, 6, 160, 1, 1, 2],
807 | [5, 6, 160, 1, 1, 1],
808 | [5, 6, 160, 1, 1, 1]]
809 | return AFMobileNetV3(num_segments, num_class, cfgs_head, cfgs_stage1, cfgs_stage2_ample,
810 | cfgs_stage2_focal, cfgs_stage2_fuse, cfgs_stage3_ample, cfgs_stage3_focal,
811 | cfgs_stage3_fuse, cfgs_stage4, cfgs_stage5, mode='large', **kwargs)
812 |
813 |
814 |
815 | def AF_mobilenetv3(pretrained=False, path_backbone = '.../.../checkpoint/ImageNet/AF-MobileNetv3.pth.tar', shift=False, num_segments=8, num_class=174, **kwargs):
816 | model = af_mobilenetv3(num_segments, num_class)
817 | if pretrained:
818 | checkpoint = torch.load(path_backbone, map_location='cpu')
819 | pretrained_dict = checkpoint['state_dict']
820 | new_state_dict = model.state_dict()
821 | for k, v in pretrained_dict.items():
822 | if (k[7:] in new_state_dict):
823 | new_state_dict.update({k[7:]:v})
824 | model.load_state_dict(new_state_dict)
825 | return model
--------------------------------------------------------------------------------
/ops/backbone/AF_ResNet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | # from torch.hub import load_state_dict_from_url
3 | import torch
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import torchvision.transforms as transforms
7 | import math
8 | from .gumbel_softmax import GumbleSoftmax
9 |
10 |
11 |
12 | class dynamic_fusion(nn.Module):
13 | def __init__(self, channel, reduction=16):
14 | super(dynamic_fusion, self).__init__()
15 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
16 | self.reduction = reduction
17 | self.fc = nn.Sequential(
18 | nn.Linear(channel, int(channel // reduction), bias=False),
19 | nn.ReLU(inplace=True),
20 | nn.Linear(int(channel // reduction), channel, bias=False),
21 | nn.Sigmoid()
22 | )
23 | for m in self.modules():
24 | if isinstance(m, nn.Linear):
25 | nn.init.normal_(m.weight, 0, 0.01)
26 |
27 | def forward(self, x):
28 | b, c, h, w = x.size()
29 | y = self.avg_pool(x).view(b,c)
30 | attention = self.fc(y)
31 | return attention.view(b,c,1,1)
32 |
33 | def forward_calc_flops(self, x):
34 | b, c, h, w = x.size()
35 | flops = c*h*w
36 | y = self.avg_pool(x).view(b,c)
37 | attention = self.fc(y)
38 | flops += c*c//self.reduction*2 + c
39 | return attention.view(b,c,1,1), flops
40 |
41 |
42 | class TSM(nn.Module):
43 | def __init__(self):
44 | super(TSM, self).__init__()
45 | self.fold_div = 8
46 |
47 | def forward(self, x, n_segment):
48 | x = self.shift(x, n_segment, fold_div=self.fold_div)
49 | return x
50 |
51 | @staticmethod
52 | def shift(x, n_segment, fold_div=3):
53 | if type(n_segment) is int:
54 | nt, c, h, w = x.size()
55 | n_batch = nt // n_segment
56 | x = x.view(n_batch, n_segment, c, h, w)
57 |
58 | fold = c // fold_div
59 | out = torch.zeros_like(x)
60 | out[:, :-1, :fold] = x[:, 1:, :fold] # shift left
61 | out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right
62 | out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift
63 | shift_out = out.view(nt, c, h, w)
64 | else:
65 | num_segment = int(n_segment.sum())
66 | ls = n_segment
67 | bool_list = ls > 0
68 | bool_list = bool_list.view(-1)
69 |
70 | shift_out = torch.zeros_like(x)
71 | x = x[bool_list]
72 | nt, c, h, w = x.size()
73 | x = x.view(-1, num_segment, c, h, w)
74 |
75 | fold = c // fold_div
76 | out = torch.zeros_like(x)
77 | out[:, :-1, :fold] = x[:, 1:, :fold] # shift left
78 | out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right
79 | out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift
80 | out = out.view(-1, c, h, w)
81 | shift_out[bool_list] = out
82 |
83 | return shift_out
84 |
85 |
86 | class Bottleneck_ample(nn.Module):
87 | expansion = 4
88 |
89 | def __init__(self, inplanes, planes, num_segments, stride=1, downsample=None, last_relu=True, patch_groups=1,
90 | base_scale=2, is_first=False):
91 | super(Bottleneck_ample, self).__init__()
92 | self.num_segments = num_segments
93 | self.conv1 = nn.Conv2d(inplanes, planes // self.expansion, kernel_size=1, bias=False)
94 | self.bn1 = nn.BatchNorm2d(planes // self.expansion)
95 | self.conv2 = nn.Conv2d(planes // self.expansion, planes // self.expansion, kernel_size=3, stride=stride,
96 | padding=1, bias=False, groups=1)
97 | self.bn2 = nn.BatchNorm2d(planes // self.expansion)
98 | self.conv3 = nn.Conv2d(planes // self.expansion, planes, kernel_size=1, bias=False)
99 | self.bn3 = nn.BatchNorm2d(planes)
100 | self.relu = nn.ReLU(inplace=True)
101 | self.tsm = TSM()
102 |
103 | self.downsample = downsample
104 | self.have_pool = False
105 | self.have_1x1conv2d = False
106 |
107 | self.first_downsample = nn.AvgPool2d(3, stride=2, padding=1) if (base_scale == 4 and is_first) else None
108 |
109 | if self.downsample is not None:
110 | self.have_pool = True
111 | if len(self.downsample) > 1:
112 | self.have_1x1conv2d = True
113 |
114 | self.stride = stride
115 | self.last_relu = last_relu
116 |
117 | def forward(self, x, list_little, activate_tsm=False):
118 | if self.first_downsample is not None:
119 | x = self.first_downsample(x)
120 | residual = x
121 | if self.downsample is not None:
122 | residual = self.downsample(x)
123 |
124 | if activate_tsm:
125 | out = self.tsm(x, list_little)
126 | else:
127 | out = x
128 |
129 | out = self.conv1(out)
130 | out = self.bn1(out)
131 | out = self.relu(out)
132 |
133 | out = self.conv2(out)
134 | out = self.bn2(out)
135 | out = self.relu(out)
136 |
137 | out = self.conv3(out)
138 | out = self.bn3(out)
139 |
140 | out += residual
141 | if self.last_relu:
142 | out = self.relu(out)
143 | return out
144 |
145 | def forward_calc_flops(self, x, list_little, activate_tsm=False):
146 | flops = 0
147 | if self.first_downsample is not None:
148 | x = self.first_downsample(x)
149 | _, c, h, w = x.shape
150 | flops += 9 * c * h * w
151 |
152 | residual = x
153 | if self.downsample is not None:
154 | c_in = x.shape[1]
155 | residual = self.downsample(x)
156 | _, c, h, w = residual.shape
157 | if self.have_pool:
158 | flops += 9 * c_in * h * w
159 | if self.have_1x1conv2d:
160 | flops += c_in * c * h * w
161 |
162 | if activate_tsm:
163 | out = self.tsm(x, list_little)
164 | else:
165 | out = x
166 |
167 | c_in = out.shape[1]
168 | out = self.conv1(out)
169 | _,c_out,h,w = out.shape
170 | flops += c_in * c_out * h * w / self.conv1.groups
171 |
172 | out = self.bn1(out)
173 | out = self.relu(out)
174 |
175 | c_in = c_out
176 | out = self.conv2(out)
177 | _,c_out,h,w = out.shape
178 | flops += c_in * c_out * h * w * 9 / self.conv2.groups
179 | out = self.bn2(out)
180 | out = self.relu(out)
181 |
182 | c_in = c_out
183 | out = self.conv3(out)
184 | _,c_out,h,w = out.shape
185 | flops += c_in * c_out * h * w / self.conv3.groups
186 | out = self.bn3(out)
187 |
188 | out += residual
189 | if self.last_relu:
190 | out = self.relu(out)
191 |
192 | return out, flops
193 |
194 | class Bottleneck_focal(nn.Module):
195 | expansion = 4
196 |
197 | def __init__(self, inplanes, planes, num_segments, stride=1, downsample=None, last_relu=True, patch_groups=1, base_scale=2, is_first = True):
198 | super(Bottleneck_focal, self).__init__()
199 | self.num_segments = num_segments
200 | self.conv1 = nn.Conv2d(inplanes, planes // self.expansion, kernel_size=1, bias=False, groups=patch_groups)
201 | self.bn1 = nn.BatchNorm2d(planes // self.expansion)
202 | self.conv2 = nn.Conv2d(planes // self.expansion, planes // self.expansion, kernel_size=3, stride=stride,
203 | padding=1, bias=False, groups=patch_groups)
204 | self.bn2 = nn.BatchNorm2d(planes // self.expansion)
205 | self.conv3 = nn.Conv2d(planes // self.expansion, planes, kernel_size=1, bias=False, groups=patch_groups)
206 | self.bn3 = nn.BatchNorm2d(planes)
207 | self.relu = nn.ReLU(inplace=True)
208 | self.tsm = TSM()
209 | self.downsample = downsample
210 |
211 | self.stride = stride
212 | self.last_relu = last_relu
213 | self.patch_groups = patch_groups
214 |
215 | def forward(self, x, mask, activate_tsm=False):
216 | residual = x
217 | if self.downsample is not None: # skip connection before mask
218 | residual = self.downsample(x)
219 |
220 | if activate_tsm:
221 | out = self.tsm(x, self.num_segments)
222 | else:
223 | out = x
224 | out = out * mask
225 |
226 | out = self.conv1(out)
227 | out = self.bn1(out)
228 | out = self.relu(out)
229 |
230 | out = self.conv2(out)
231 | out = self.bn2(out)
232 | out = self.relu(out)
233 |
234 | out = self.conv3(out)
235 | out = self.bn3(out)
236 |
237 | out += residual
238 | if self.last_relu:
239 | out = self.relu(out)
240 | return out
241 |
242 |
243 | def forward_calc_flops(self, x, mask, activate_tsm=False):
244 | residual = x
245 | flops = 0
246 | if self.downsample is not None: # skip connection before mask
247 | c_in = x.shape[1]
248 | residual = self.downsample(x)
249 | flops += c_in * residual.shape[1] * residual.shape[2] * residual.shape[3]
250 |
251 | if activate_tsm:
252 | out = self.tsm(x, self.num_segments)
253 | else:
254 | out = x
255 | out = out * mask
256 | select_ratio = torch.mean(mask)
257 |
258 | c_in = out.shape[1]
259 | out = self.conv1(out)
260 | _,c_out,h,w = out.shape
261 | flops += select_ratio * c_in * c_out * h * w / self.conv1.groups
262 |
263 | out = self.bn1(out)
264 | out = self.relu(out)
265 |
266 | c_in = c_out
267 | out = self.conv2(out)
268 | _,c_out,h,w = out.shape
269 | flops += select_ratio * c_in * c_out * h * w * 9 / self.conv2.groups
270 | out = self.bn2(out)
271 | out = self.relu(out)
272 |
273 | c_in = c_out
274 | out = self.conv3(out)
275 | _,c_out,h,w = out.shape
276 | flops += select_ratio * c_in * c_out * h * w / self.conv3.groups
277 | out = self.bn3(out)
278 |
279 | out += residual
280 | if self.last_relu:
281 | out = self.relu(out)
282 |
283 | return out, flops
284 |
285 |
286 |
287 | class navigation(nn.Module):
288 | def __init__(self, inplanes=64, num_segments=8):
289 | super(navigation,self).__init__()
290 | self.num_segments = num_segments
291 | self.conv_pool = nn.Conv2d(inplanes, 2, kernel_size=1, padding=0, stride=1, bias=False)
292 | self.bn = nn.BatchNorm2d(2)
293 | self.relu = nn.ReLU(inplace=True)
294 | self.pool = nn.AdaptiveAvgPool2d((1,1))
295 | self.conv_gs = nn.Conv2d(2*num_segments, 2*num_segments, kernel_size=1, padding=0, stride=1, bias=True, groups=num_segments)
296 | self.conv_gs.bias.data[:2*num_segments:2] = 1.0
297 | self.conv_gs.bias.data[1:2*num_segments+1:2] = 10.0
298 | self.gs = GumbleSoftmax()
299 |
300 | def forward(self, x, temperature=1.0):
301 | gates = self.pool(x)
302 | gates = self.conv_pool(gates)
303 | gates = self.bn(gates)
304 | gates = self.relu(gates)
305 |
306 | batch = x.shape[0] // self.num_segments
307 |
308 | gates = gates.view(batch, self.num_segments*2,1,1)
309 | gates = self.conv_gs(gates)
310 |
311 | gates = gates.view(batch, self.num_segments, 2, 1, 1)
312 | gates = self.gs(gates, temp=temperature, force_hard=True)
313 | mask = gates[:, :, 1, :, :]
314 | mask = mask.view(x.shape[0],1,1,1)
315 |
316 | return mask
317 |
318 | def forward_calc_flops(self, x, temperature=1.0):
319 | flops = 0
320 |
321 | flops += x.shape[1] * x.shape[2] * x.shape[3]
322 | gates = self.pool(x)
323 |
324 | c_in = gates.shape[1]
325 | gates = self.conv_pool(gates)
326 | flops += c_in * gates.shape[1] * gates.shape[2] * gates.shape[3]
327 | gates = self.bn(gates)
328 | gates = self.relu(gates)
329 |
330 | batch = x.shape[0] // self.num_segments
331 |
332 | gates = gates.view(batch, self.num_segments*2,1,1)
333 | gates = self.conv_gs(gates)
334 | flops += self.num_segments * 2 * gates.shape[1] * gates.shape[2] * gates.shape[3] / self.conv_gs.groups
335 |
336 | gates = gates.view(batch, self.num_segments, 2, 1, 1)
337 | gates = self.gs(gates, temp=temperature, force_hard=True)
338 | mask = gates[:, :, 1, :, :]
339 | mask = mask.view(x.shape[0],1,1,1)
340 |
341 | return mask, flops
342 |
343 |
344 | class AFModule(nn.Module):
345 | def __init__(self, block_ample, block_focal, in_channels, out_channels, blocks, stride, patch_groups, alpha=1, num_segments=8):
346 | super(AFModule, self).__init__()
347 | self.num_segments = num_segments
348 | self.patch_groups = patch_groups
349 | self.relu = nn.ReLU(inplace=True)
350 |
351 | frame_gen_list = []
352 | for i in range(blocks - 1):
353 | frame_gen_list.append(navigation(inplanes=int(out_channels // alpha),num_segments=num_segments)) if i!=0 else frame_gen_list.append(navigation(inplanes=in_channels,num_segments=num_segments))
354 | self.list_gen = nn.ModuleList(frame_gen_list)
355 |
356 | self.base_module = self._make_layer(block_ample, in_channels, int(out_channels // alpha), num_segments, blocks - 1, 2, last_relu=False)
357 | self.refine_module = self._make_layer(block_focal, in_channels, out_channels, num_segments, blocks - 1, 1, last_relu=False)
358 |
359 | self.alpha = alpha
360 | if alpha != 1:
361 | self.base_transform = nn.Sequential(
362 | nn.Conv2d(int(out_channels // alpha), out_channels, kernel_size=1, bias=False),
363 | nn.BatchNorm2d(out_channels)
364 | )
365 | self.att_gen = dynamic_fusion(channel=out_channels, reduction=16)
366 | self.fusion = self._make_layer(block_ample, out_channels, out_channels, num_segments, 1, stride=stride)
367 |
368 | def _make_layer(self, block, inplanes, planes, num_segments, blocks, stride=1, last_relu=True, base_scale=2):
369 | downsample = []
370 | if stride != 1:
371 | downsample.append(nn.AvgPool2d(3, stride=2, padding=1))
372 | if inplanes != planes:
373 | downsample.append(nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False))
374 | downsample.append(nn.BatchNorm2d(planes))
375 | downsample = None if downsample == [] else nn.Sequential(*downsample)
376 | layers = []
377 | if blocks == 1: # fuse, is not the first of a base branch
378 | layers.append(block(inplanes, planes, num_segments, stride=stride, downsample=downsample,
379 | patch_groups=self.patch_groups, base_scale=base_scale, is_first = False))
380 | else:
381 | layers.append(block(inplanes, planes, num_segments, stride, downsample,patch_groups=self.patch_groups,
382 | base_scale=base_scale, is_first = True))
383 | for i in range(1, blocks):
384 | layers.append(block(planes, planes, num_segments,
385 | last_relu=last_relu if i == blocks - 1 else True,
386 | patch_groups=self.patch_groups, base_scale=base_scale, is_first = False))
387 |
388 | return nn.ModuleList(layers)
389 |
390 | def forward(self, x, temperature=1e-8, activate_tsm=False):
391 | b,c,h,w = x.size()
392 | x_big = x
393 | x_little = x
394 | _masks = []
395 |
396 | for i in range(len(self.base_module)):
397 | mask = self.list_gen[i](x_little, temperature=temperature)
398 | _masks.append(mask)
399 |
400 | x_little = self.base_module[i](x_little, self.num_segments, activate_tsm)
401 | x_big = self.refine_module[i](x_big, mask, activate_tsm)
402 |
403 | if self.alpha != 1:
404 | x_little = self.base_transform(x_little)
405 |
406 | _,_,h,w = x_big.shape
407 | x_little = F.interpolate(x_little, size = (h,w))
408 | att = self.att_gen(x_little+x_big)
409 | out = self.relu(att*x_little + (1-att)*x_big)
410 | out = self.fusion[0](out, self.num_segments, activate_tsm)
411 | return out, _masks
412 |
413 | def forward_calc_flops(self, x, temperature=1e-8, activate_tsm=False):
414 | flops = 0
415 | b,c,h,w = x.size()
416 |
417 | x_big = x
418 | x_little = x
419 | _masks = []
420 |
421 | for i in range(len(self.base_module)):
422 | mask, _flops = self.list_gen[i].forward_calc_flops(x_little, temperature=temperature)
423 | _masks.append(mask)
424 | flops += _flops * b
425 |
426 | x_little, _flops = self.base_module[i].forward_calc_flops(x_little, self.num_segments, activate_tsm)
427 | flops += _flops * b
428 | x_big, _flops = self.refine_module[i].forward_calc_flops(x_big, mask, activate_tsm)
429 | flops += _flops * b
430 |
431 | c = x_little.shape[1]
432 | _,_, h,w = x_big.shape
433 | if self.alpha != 1:
434 | x_little = self.base_transform(x_little)
435 | flops += b * c * x_little.shape[1] * x_little.shape[2] * x_little.shape[3]
436 |
437 | x_little = F.interpolate(x_little, size = (h,w))
438 | att, _flops = self.att_gen.forward_calc_flops(x_little+x_big)
439 | flops += _flops * b
440 | out = self.relu(att*x_little + (1-att)*x_big)
441 | out, _flops = self.fusion[0].forward_calc_flops(out, self.num_segments, activate_tsm)
442 | flops += _flops * b
443 |
444 | seg = b / self.num_segments
445 | flops = flops / seg
446 |
447 | return out, _masks, flops
448 |
449 | class AFResNet(nn.Module):
450 | def __init__(self, block_ample, block_focal, layers, width=1.0, patch_groups=1, alpha=1, shift=True, num_segments=8, num_classes=1000):
451 | num_channels = [int(64*width), int(128*width), int(256*width), 512]
452 |
453 | self.num_segments = num_segments
454 | self.activate_tsm = shift
455 | self.inplanes = 64
456 | super(AFResNet, self).__init__()
457 | self.conv1 = nn.Conv2d(3, num_channels[0], kernel_size=7, stride=2, padding=3,
458 | bias=False)
459 | self.bn1 = nn.BatchNorm2d(num_channels[0])
460 | self.relu = nn.ReLU(inplace=True)
461 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
462 |
463 | self.layer1 = AFModule(block_ample, block_focal, num_channels[0], num_channels[0]*block_ample.expansion,
464 | layers[0], stride=2, patch_groups=patch_groups, alpha=alpha, num_segments=num_segments)
465 | self.layer2 = AFModule(block_ample, block_focal, num_channels[0]*block_ample.expansion,
466 | num_channels[1]*block_ample.expansion, layers[1], stride=2, patch_groups=patch_groups, alpha=alpha, num_segments=num_segments)
467 | self.layer3 = AFModule(block_ample, block_focal, num_channels[1]*block_ample.expansion,
468 | num_channels[2]*block_ample.expansion, layers[2], stride=1, patch_groups=patch_groups, alpha=alpha, num_segments=num_segments)
469 | self.layer4 = self._make_layer(num_segments,
470 | block_ample, num_channels[2]*block_ample.expansion, num_channels[3]*block_ample.expansion, layers[3], stride=2)
471 | self.gappool = nn.AdaptiveAvgPool2d(1)
472 | self.fc = nn.Linear(num_channels[3]*block_ample.expansion, num_classes)
473 |
474 | for k, m in self.named_modules():
475 | if isinstance(m, nn.Conv2d):
476 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
477 | # if 'gs' in str(k):
478 | # m.weight.data.normal_(0, 0.001)
479 | elif isinstance(m, nn.BatchNorm2d):
480 | nn.init.constant_(m.weight, 1)
481 | nn.init.constant_(m.bias, 0)
482 |
483 | # Zero-initialize the last BN in each block.
484 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
485 | for m in self.modules():
486 | if isinstance(m, Bottleneck_ample):
487 | nn.init.constant_(m.bn3.weight, 0)
488 |
489 | def _make_layer(self, num_segments, block, inplanes, planes, blocks, stride=1):
490 | downsample = []
491 | if stride != 1:
492 | downsample.append(nn.AvgPool2d(3, stride=2, padding=1))
493 | if inplanes != planes:
494 | downsample.append(nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False))
495 | downsample.append(nn.BatchNorm2d(planes))
496 | downsample = None if downsample == [] else nn.Sequential(*downsample)
497 |
498 | layers = []
499 | layers.append(block(inplanes, planes, num_segments, stride, downsample))
500 | for _ in range(1, blocks):
501 | layers.append(block(planes, planes, num_segments))
502 |
503 | return nn.ModuleList(layers)
504 |
505 | def forward(self, x, temperature=1.0):
506 | x = self.conv1(x)
507 | x = self.bn1(x)
508 | x = self.relu(x)
509 | x = self.maxpool(x)
510 |
511 | _masks = []
512 | x1, mask = self.layer1(x, temperature=temperature, activate_tsm=self.activate_tsm)
513 | _masks.extend(mask)
514 | x2, mask = self.layer2(x1, temperature=temperature, activate_tsm=self.activate_tsm)
515 | _masks.extend(mask)
516 | x3, mask = self.layer3(x2, temperature=temperature, activate_tsm=self.activate_tsm)
517 | _masks.extend(mask)
518 | x4 = x3
519 | for i in range(len(self.layer4)):
520 | x4 = self.layer4[i](x4, self.num_segments, self.activate_tsm)
521 |
522 | x = self.gappool(x4)
523 | x = x.view(x.size(0), -1)
524 | x = self.fc(x)
525 |
526 | return x, _masks
527 |
528 | def forward_calc_flops(self, x, temperature=1.0):
529 | flops = 0
530 | c_in = x.shape[1]
531 | x = self.conv1(x)
532 | flops += self.num_segments * c_in * x.shape[1] * x.shape[2] * x.shape[3] * self.conv1.weight.shape[2]*self.conv1.weight.shape[3]
533 | x = self.bn1(x)
534 | x = self.relu(x)
535 |
536 | x = self.maxpool(x)
537 | flops += self.num_segments * x.numel() / x.shape[0] * 9
538 |
539 | _masks = []
540 | x1, mask, _flops = self.layer1.forward_calc_flops(x, temperature=temperature, activate_tsm=self.activate_tsm)
541 | _masks.extend(mask)
542 | flops += _flops
543 | x2, mask, _flops = self.layer2.forward_calc_flops(x1, temperature=temperature, activate_tsm=self.activate_tsm)
544 | _masks.extend(mask)
545 | flops += _flops
546 | x3, mask, _flops = self.layer3.forward_calc_flops(x2, temperature=temperature, activate_tsm=self.activate_tsm)
547 | _masks.extend(mask)
548 | flops += _flops
549 | x4 = x3
550 | for i in range(len(self.layer4)):
551 | x4, _flops = self.layer4[i].forward_calc_flops(x4, self.num_segments, self.activate_tsm)
552 | flops += _flops * self.num_segments
553 | flops += self.num_segments * x4.shape[1] * x4.shape[2] * x4.shape[3]
554 | x = self.gappool(x4)
555 | x = x.view(x.size(0), -1)
556 | c_in = x.shape[1]
557 | x = self.fc(x)
558 | flops += self.num_segments * c_in * x.shape[1]
559 |
560 | return x, _masks, flops
561 |
562 |
563 | def AF_resnet(depth, patch_groups=1, width=1.0, alpha=1, shift=False, num_segments=8, **kwargs):
564 | layers = {
565 | 50: [3, 4, 6, 3],
566 | 101: [4, 8, 18, 3],
567 | }[depth]
568 | block = Bottleneck_ample
569 | block_focal = Bottleneck_focal
570 | model = AFResNet(block_ample=block, block_focal=block_focal, layers=layers, patch_groups=patch_groups,
571 | width=width, alpha=alpha, shift=shift, num_segments=num_segments, **kwargs)
572 | return model
573 |
574 |
575 | def AF_resnet50(pretrained=False, path_backbone = '.../.../checkpoint/ImageNet/AF-ResNet50.pth.tar', shift=False, num_segments=8, **kwargs):
576 | model = AF_resnet(depth=50, patch_groups=2, alpha=2, shift=shift, num_segments=num_segments, **kwargs)
577 | if pretrained:
578 | checkpoint = torch.load(path_backbone, map_location='cpu')
579 | pretrained_dict = checkpoint['state_dict']
580 | new_state_dict = model.state_dict()
581 | for k, v in pretrained_dict.items():
582 | if (k[7:] in new_state_dict):
583 | new_state_dict.update({k[7:]:v})
584 | model.load_state_dict(new_state_dict)
585 | return model
586 |
587 |
588 | def AF_resnet101(pretrained=False, path_backbone = '.../.../checkpoint/ImageNet/AF-ResNet101.pth.tar', shift=False, num_segments=8, **kwargs):
589 | model = AF_resnet(depth=101, patch_groups=2, alpha=2, shift=shift, num_segments=num_segments, **kwargs)
590 | if pretrained:
591 | checkpoint = torch.load(path_backbone, map_location='cpu')
592 | pretrained_dict = checkpoint['state_dict']
593 | new_state_dict = model.state_dict()
594 | for k, v in pretrained_dict.items():
595 | if (k[7:] in new_state_dict):
596 | new_state_dict.update({k[7:]:v})
597 | model.load_state_dict(new_state_dict)
598 | return model
--------------------------------------------------------------------------------
/ops/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | from .AF_ResNet import *
2 | from .AF_MobileNetv3 import *
--------------------------------------------------------------------------------
/ops/backbone/gumbel_softmax.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 |
5 | """
6 | Gumbel Softmax Sampler
7 | Requires 2D input [batchsize, number of categories]
8 |
9 | Does not support sinlge binary category. Use two dimensions with softmax instead.
10 | """
11 |
12 | class GumbleSoftmax(torch.nn.Module):
13 | def __init__(self, hard=False):
14 | super(GumbleSoftmax, self).__init__()
15 | self.hard = hard
16 | # self.gpu = False
17 |
18 | # def cuda(self):
19 | # self.gpu = True
20 |
21 | # def cpu(self):
22 | # self.gpu = False
23 |
24 | def sample_gumbel(self, shape, eps=1e-10):
25 | """Sample from Gumbel(0, 1)"""
26 | noise = torch.rand(shape)
27 | noise.add_(eps).log_().neg_()
28 | noise.add_(eps).log_().neg_()
29 | # if self.gpu:
30 | return Variable(noise).cuda()
31 | # else:
32 | # return Variable(noise)
33 |
34 | def sample_gumbel_like(self, template_tensor, eps=1e-10):
35 | uniform_samples_tensor = template_tensor.clone().uniform_()
36 | gumble_samples_tensor = - torch.log(eps - torch.log(uniform_samples_tensor + eps))
37 | return gumble_samples_tensor
38 |
39 | def gumbel_softmax_sample(self, logits, temperature):
40 | """ Draw a sample from the Gumbel-Softmax distribution"""
41 | dim = logits.size(-1)
42 | gumble_samples_tensor = self.sample_gumbel_like(logits.data)
43 | gumble_trick_log_prob_samples = logits + Variable(gumble_samples_tensor)
44 | soft_samples = F.softmax(gumble_trick_log_prob_samples / temperature, dim=2)
45 | return soft_samples
46 |
47 | def gumbel_softmax(self, logits, temperature, hard=False):
48 | """Sample from the Gumbel-Softmax distribution and optionally discretize.
49 | Args:
50 | logits: [batch_size, n_class] unnormalized log-probs
51 | temperature: non-negative scalar
52 | hard: if True, take argmax, but differentiate w.r.t. soft sample y
53 | Returns:
54 | [batch_size, n_class] sample from the Gumbel-Softmax distribution.
55 | If hard=True, then the returned sample will be one-hot, otherwise it will
56 | be a probabilitiy distribution that sums to 1 across classes
57 | """
58 | y = self.gumbel_softmax_sample(logits, temperature)
59 | if hard:
60 | _, max_value_indexes = y.data.max(2, keepdim=True)
61 | y_hard = logits.data.clone().zero_().scatter_(2, max_value_indexes, 1)
62 | y = Variable(y_hard - y.data) + y
63 | return y
64 |
65 | def forward(self, logits, temp=1, force_hard=False):
66 | samplesize = logits.size()
67 | if self.training and not force_hard:
68 | return self.gumbel_softmax(logits, temperature=temp, hard=False)
69 | else:
70 | return self.gumbel_softmax(logits, temperature=temp, hard=True)
71 |
--------------------------------------------------------------------------------
/ops/backbone/temporal_shift.py:
--------------------------------------------------------------------------------
1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
2 | # arXiv:1811.08383
3 | # Ji Lin*, Chuang Gan, Song Han
4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | class TemporalShift(nn.Module):
12 | def __init__(self, net, n_segment=3, n_div=8, inplace=False):
13 | super(TemporalShift, self).__init__()
14 | self.net = net
15 | self.n_segment = n_segment
16 | self.fold_div = n_div
17 | self.inplace = inplace
18 | if inplace:
19 | print('=> Using in-place shift...')
20 | print('=> Using fold div: {}'.format(self.fold_div))
21 |
22 | def forward(self, x):
23 | x = self.shift(x, self.n_segment, fold_div=self.fold_div, inplace=self.inplace)
24 | return self.net(x)
25 |
26 | @staticmethod
27 | def shift(x, n_segment, fold_div=3, inplace=False):
28 | nt, c, h, w = x.size()
29 | n_batch = nt // n_segment
30 | x = x.view(n_batch, n_segment, c, h, w)
31 |
32 | fold = c // fold_div
33 | if inplace:
34 | # Due to some out of order error when performing parallel computing.
35 | # May need to write a CUDA kernel.
36 | raise NotImplementedError
37 | # out = InplaceShift.apply(x, fold)
38 | else:
39 | out = torch.zeros_like(x)
40 | out[:, :-1, :fold] = x[:, 1:, :fold] # shift left
41 | out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right
42 | out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift
43 |
44 | return out.view(nt, c, h, w)
45 |
46 |
47 | class InplaceShift(torch.autograd.Function):
48 | # Special thanks to @raoyongming for the help to this function
49 | @staticmethod
50 | def forward(ctx, input, fold):
51 | # not support higher order gradient
52 | # input = input.detach_()
53 | ctx.fold_ = fold
54 | n, t, c, h, w = input.size()
55 | buffer = input.data.new(n, t, fold, h, w).zero_()
56 | buffer[:, :-1] = input.data[:, 1:, :fold]
57 | input.data[:, :, :fold] = buffer
58 | buffer.zero_()
59 | buffer[:, 1:] = input.data[:, :-1, fold: 2 * fold]
60 | input.data[:, :, fold: 2 * fold] = buffer
61 | return input
62 |
63 | @staticmethod
64 | def backward(ctx, grad_output):
65 | # grad_output = grad_output.detach_()
66 | fold = ctx.fold_
67 | n, t, c, h, w = grad_output.size()
68 | buffer = grad_output.data.new(n, t, fold, h, w).zero_()
69 | buffer[:, 1:] = grad_output.data[:, :-1, :fold]
70 | grad_output.data[:, :, :fold] = buffer
71 | buffer.zero_()
72 | buffer[:, :-1] = grad_output.data[:, 1:, fold: 2 * fold]
73 | grad_output.data[:, :, fold: 2 * fold] = buffer
74 | return grad_output, None
75 |
76 |
77 | class TemporalPool(nn.Module):
78 | def __init__(self, net, n_segment):
79 | super(TemporalPool, self).__init__()
80 | self.net = net
81 | self.n_segment = n_segment
82 |
83 | def forward(self, x):
84 | x = self.temporal_pool(x, n_segment=self.n_segment)
85 | return self.net(x)
86 |
87 | @staticmethod
88 | def temporal_pool(x, n_segment):
89 | nt, c, h, w = x.size()
90 | n_batch = nt // n_segment
91 | x = x.view(n_batch, n_segment, c, h, w).transpose(1, 2) # n, c, t, h, w
92 | x = F.max_pool3d(x, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0))
93 | x = x.transpose(1, 2).contiguous().view(nt // 2, c, h, w)
94 | return x
95 |
96 |
97 | def make_temporal_shift(net, n_segment, n_div=8, place='blockres', temporal_pool=False):
98 | if temporal_pool:
99 | n_segment_list = [n_segment, n_segment // 2, n_segment // 2, n_segment // 2]
100 | else:
101 | n_segment_list = [n_segment] * 4
102 | assert n_segment_list[-1] > 0
103 | print('=> n_segment per stage: {}'.format(n_segment_list))
104 |
105 | import torchvision
106 | if isinstance(net, torchvision.models.ResNet):
107 | if place == 'block':
108 | def make_block_temporal(stage, this_segment):
109 | blocks = list(stage.children())
110 | print('=> Processing stage with {} blocks'.format(len(blocks)))
111 | for i, b in enumerate(blocks):
112 | blocks[i] = TemporalShift(b, n_segment=this_segment, n_div=n_div)
113 | return nn.Sequential(*(blocks))
114 |
115 | net.layer1 = make_block_temporal(net.layer1, n_segment_list[0])
116 | net.layer2 = make_block_temporal(net.layer2, n_segment_list[1])
117 | net.layer3 = make_block_temporal(net.layer3, n_segment_list[2])
118 | net.layer4 = make_block_temporal(net.layer4, n_segment_list[3])
119 |
120 | elif 'blockres' in place:
121 | n_round = 1
122 | if len(list(net.layer3.children())) >= 23:
123 | n_round = 2
124 | print('=> Using n_round {} to insert temporal shift'.format(n_round))
125 |
126 | def make_block_temporal(stage, this_segment):
127 | blocks = list(stage.children())
128 | print('=> Processing stage with {} blocks residual'.format(len(blocks)))
129 | for i, b in enumerate(blocks):
130 | if i % n_round == 0:
131 | blocks[i].conv1 = TemporalShift(b.conv1, n_segment=this_segment, n_div=n_div)
132 | return nn.Sequential(*blocks)
133 |
134 | net.layer1 = make_block_temporal(net.layer1, n_segment_list[0])
135 | net.layer2 = make_block_temporal(net.layer2, n_segment_list[1])
136 | net.layer3 = make_block_temporal(net.layer3, n_segment_list[2])
137 | net.layer4 = make_block_temporal(net.layer4, n_segment_list[3])
138 | else:
139 | raise NotImplementedError(place)
140 |
141 |
142 | def make_temporal_pool(net, n_segment):
143 | import torchvision
144 | if isinstance(net, torchvision.models.ResNet):
145 | print('=> Injecting nonlocal pooling')
146 | net.layer2 = TemporalPool(net.layer2, n_segment)
147 | else:
148 | raise NotImplementedError
149 |
150 |
151 | if __name__ == '__main__':
152 | # test inplace shift v.s. vanilla shift
153 | tsm1 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=False)
154 | tsm2 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=True)
155 |
156 | print('=> Testing CPU...')
157 | # test forward
158 | with torch.no_grad():
159 | for i in range(10):
160 | x = torch.rand(2 * 8, 3, 224, 224)
161 | y1 = tsm1(x)
162 | y2 = tsm2(x)
163 | assert torch.norm(y1 - y2).item() < 1e-5
164 |
165 | # test backward
166 | with torch.enable_grad():
167 | for i in range(10):
168 | x1 = torch.rand(2 * 8, 3, 224, 224)
169 | x1.requires_grad_()
170 | x2 = x1.clone()
171 | y1 = tsm1(x1)
172 | y2 = tsm2(x2)
173 | grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0]
174 | grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0]
175 | assert torch.norm(grad1 - grad2).item() < 1e-5
176 |
177 | print('=> Testing GPU...')
178 | tsm1.cuda()
179 | tsm2.cuda()
180 | # test forward
181 | with torch.no_grad():
182 | for i in range(10):
183 | x = torch.rand(2 * 8, 3, 224, 224).cuda()
184 | y1 = tsm1(x)
185 | y2 = tsm2(x)
186 | assert torch.norm(y1 - y2).item() < 1e-5
187 |
188 | # test backward
189 | with torch.enable_grad():
190 | for i in range(10):
191 | x1 = torch.rand(2 * 8, 3, 224, 224).cuda()
192 | x1.requires_grad_()
193 | x2 = x1.clone()
194 | y1 = tsm1(x1)
195 | y2 = tsm2(x2)
196 | grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0]
197 | grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0]
198 | assert torch.norm(grad1 - grad2).item() < 1e-5
199 | print('Test passed.')
200 |
201 |
202 |
203 |
204 |
--------------------------------------------------------------------------------
/ops/basic_ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Identity(torch.nn.Module):
5 | def forward(self, input):
6 | return input
7 |
8 |
9 | class SegmentConsensus(torch.nn.Module):
10 |
11 | def __init__(self, consensus_type, dim=1):
12 | super(SegmentConsensus, self).__init__()
13 | self.consensus_type = consensus_type
14 | self.dim = dim
15 | self.shape = None
16 |
17 | def forward(self, input_tensor):
18 | self.shape = input_tensor.size()
19 | if self.consensus_type == 'avg':
20 | output = input_tensor.mean(dim=self.dim, keepdim=True)
21 | elif self.consensus_type == 'identity':
22 | output = input_tensor
23 | else:
24 | output = None
25 |
26 | return output
27 |
28 |
29 | class ConsensusModule(torch.nn.Module):
30 |
31 | def __init__(self, consensus_type, dim=1):
32 | super(ConsensusModule, self).__init__()
33 | self.consensus_type = consensus_type if consensus_type != 'rnn' else 'identity'
34 | self.dim = dim
35 |
36 | def forward(self, input):
37 | return SegmentConsensus(self.consensus_type, self.dim)(input)
38 |
--------------------------------------------------------------------------------
/ops/dataset.py:
--------------------------------------------------------------------------------
1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
2 | # arXiv:1811.08383
3 | # Ji Lin*, Chuang Gan, Song Han
4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
5 |
6 | import torch.utils.data as data
7 |
8 | from PIL import Image
9 | import os
10 | import numpy as np
11 | from numpy.random import randint
12 |
13 |
14 | class VideoRecord(object):
15 | def __init__(self, row):
16 | self._data = row
17 |
18 | @property
19 | def path(self):
20 | return self._data[0]
21 |
22 | @property
23 | def num_frames(self):
24 | return int(self._data[1])
25 |
26 | @property
27 | def label(self):
28 | return int(self._data[2])
29 |
30 |
31 | class TSNDataSet(data.Dataset):
32 | def __init__(self, root_path, list_file,
33 | num_segments=3, new_length=1, modality='RGB',
34 | image_tmpl='img_{:05d}.jpg', transform=None,
35 | random_shift=True, test_mode=False,
36 | remove_missing=False, dense_sample=False, twice_sample=False):
37 |
38 | self.root_path = root_path
39 | self.list_file = list_file
40 | self.num_segments = num_segments
41 | self.new_length = new_length
42 | self.modality = modality
43 | self.image_tmpl = image_tmpl
44 | self.transform = transform
45 | self.random_shift = random_shift
46 | self.test_mode = test_mode
47 | self.remove_missing = remove_missing
48 | self.dense_sample = dense_sample # using dense sample as I3D
49 | self.twice_sample = twice_sample # twice sample for more validation
50 | if self.dense_sample:
51 | print('=> Using dense sample for the dataset...')
52 | if self.twice_sample:
53 | print('=> Using twice sample for the dataset...')
54 |
55 | if self.modality == 'RGBDiff':
56 | self.new_length += 1 # Diff needs one more image to calculate diff
57 |
58 | self._parse_list()
59 |
60 | def _load_image(self, directory, idx):
61 | if self.modality == 'RGB' or self.modality == 'RGBDiff':
62 | try:
63 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert('RGB')]
64 | except Exception:
65 | print('error loading image:', os.path.join(self.root_path, directory, self.image_tmpl.format(idx)))
66 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')]
67 | elif self.modality == 'Flow':
68 | if self.image_tmpl == 'flow_{}_{:05d}.jpg': # ucf
69 | x_img = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format('x', idx))).convert(
70 | 'L')
71 | y_img = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format('y', idx))).convert(
72 | 'L')
73 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg': # something v1 flow
74 | x_img = Image.open(os.path.join(self.root_path, '{:06d}'.format(int(directory)), self.image_tmpl.
75 | format(int(directory), 'x', idx))).convert('L')
76 | y_img = Image.open(os.path.join(self.root_path, '{:06d}'.format(int(directory)), self.image_tmpl.
77 | format(int(directory), 'y', idx))).convert('L')
78 | else:
79 | try:
80 | # idx_skip = 1 + (idx-1)*5
81 | flow = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert(
82 | 'RGB')
83 | except Exception:
84 | print('error loading flow file:',
85 | os.path.join(self.root_path, directory, self.image_tmpl.format(idx)))
86 | flow = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')
87 | # the input flow file is RGB image with (flow_x, flow_y, blank) for each channel
88 | flow_x, flow_y, _ = flow.split()
89 | x_img = flow_x.convert('L')
90 | y_img = flow_y.convert('L')
91 |
92 | return [x_img, y_img]
93 |
94 | def _parse_list(self):
95 | # check the frame number is large >3:
96 | tmp = [x.strip().split(' ') for x in open(self.list_file)]
97 |
98 | if any(len(items) >= 3 for items in tmp):
99 | tmp = [[' '.join(x[:-2]), x[-2], x[-1]] for x in tmp]
100 |
101 | if not self.test_mode or self.remove_missing:
102 | tmp = [item for item in tmp if int(item[1]) >= 3]
103 | self.video_list = [VideoRecord(item) for item in tmp]
104 |
105 | if self.image_tmpl == '{:06d}-{}_{:05d}.jpg':
106 | for v in self.video_list:
107 | v._data[1] = int(v._data[1]) / 2
108 | print('video number:%d' % (len(self.video_list)))
109 |
110 | def _sample_indices(self, record):
111 | """
112 |
113 | :param record: VideoRecord
114 | :return: list
115 | """
116 | if self.dense_sample: # i3d dense sample
117 | sample_pos = max(1, 1 + record.num_frames - 64)
118 | t_stride = 64 // self.num_segments
119 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1)
120 | offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)]
121 | return np.array(offsets) + 1
122 | else: # normal sample
123 | average_duration = (record.num_frames - self.new_length + 1) // self.num_segments
124 | if average_duration > 0:
125 | offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration,
126 | size=self.num_segments)
127 | elif record.num_frames > self.num_segments:
128 | offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments))
129 | else:
130 | offsets = np.zeros((self.num_segments,))
131 | return offsets + 1
132 |
133 | def _get_val_indices(self, record):
134 | if self.dense_sample: # i3d dense sample
135 | sample_pos = max(1, 1 + record.num_frames - 64)
136 | t_stride = 64 // self.num_segments
137 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1)
138 | offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)]
139 | return np.array(offsets) + 1
140 | else:
141 | if record.num_frames > self.num_segments + self.new_length - 1:
142 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments)
143 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)])
144 | else:
145 | offsets = np.zeros((self.num_segments,))
146 | return offsets + 1
147 |
148 | def _get_test_indices(self, record):
149 | if self.dense_sample:
150 | sample_pos = max(1, 1 + record.num_frames - 64)
151 | t_stride = 64 // self.num_segments
152 | start_list = np.linspace(0, sample_pos - 1, num=10, dtype=int)
153 | offsets = []
154 | for start_idx in start_list.tolist():
155 | offsets += [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)]
156 | return np.array(offsets) + 1
157 | elif self.twice_sample:
158 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments)
159 |
160 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)] +
161 | [int(tick * x) for x in range(self.num_segments)])
162 |
163 | return offsets + 1
164 | else:
165 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments)
166 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)])
167 | return offsets + 1
168 |
169 | def __getitem__(self, index):
170 | record = self.video_list[index]
171 | # check this is a legit video folder
172 |
173 | if self.image_tmpl == 'flow_{}_{:05d}.jpg':
174 | file_name = self.image_tmpl.format('x', 1)
175 | full_path = os.path.join(self.root_path, record.path, file_name)
176 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg':
177 | file_name = self.image_tmpl.format(int(record.path), 'x', 1)
178 | full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name)
179 | else:
180 | file_name = self.image_tmpl.format(1)
181 | full_path = os.path.join(self.root_path, record.path, file_name)
182 |
183 | while not os.path.exists(full_path):
184 | print('################## Not Found:', os.path.join(self.root_path, record.path, file_name))
185 | index = np.random.randint(len(self.video_list))
186 | record = self.video_list[index]
187 | if self.image_tmpl == 'flow_{}_{:05d}.jpg':
188 | file_name = self.image_tmpl.format('x', 1)
189 | full_path = os.path.join(self.root_path, record.path, file_name)
190 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg':
191 | file_name = self.image_tmpl.format(int(record.path), 'x', 1)
192 | full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name)
193 | else:
194 | file_name = self.image_tmpl.format(1)
195 | full_path = os.path.join(self.root_path, record.path, file_name)
196 |
197 | if not self.test_mode:
198 | segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record)
199 | else:
200 | segment_indices = self._get_test_indices(record)
201 | return self.get(record, segment_indices)
202 |
203 | def get(self, record, indices):
204 |
205 | images = list()
206 | for seg_ind in indices:
207 | p = int(seg_ind)
208 | for i in range(self.new_length):
209 | seg_imgs = self._load_image(record.path, p)
210 | images.extend(seg_imgs)
211 | if p < record.num_frames:
212 | p += 1
213 |
214 | process_data = self.transform(images)
215 | return process_data, record.label
216 |
217 | def __len__(self):
218 | return len(self.video_list)
219 |
--------------------------------------------------------------------------------
/ops/dataset_config.py:
--------------------------------------------------------------------------------
1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
2 | # arXiv:1811.08383
3 | # Ji Lin*, Chuang Gan, Song Han
4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
5 |
6 | import os
7 |
8 | def return_actnet(ROOT_DATASET, modality):
9 | filename_categories = 'actnet/classInd.txt'
10 | root_data = ROOT_DATASET + 'actnet/frames'
11 | filename_imglist_train = 'actnet/actnet_train_split_newest.txt'
12 | filename_imglist_val = 'actnet/actnet_val_split_newest.txt'
13 | prefix = 'image_{:05d}.jpg'
14 |
15 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
16 |
17 |
18 | def return_ucf101(ROOT_DATASET, modality):
19 | filename_categories = 'UCF101/labels/classInd.txt'
20 | if modality == 'RGB':
21 | root_data = ROOT_DATASET + 'UCF101/jpg'
22 | filename_imglist_train = 'UCF101/file_list/ucf101_rgb_train_split_1.txt'
23 | filename_imglist_val = 'UCF101/file_list/ucf101_rgb_val_split_1.txt'
24 | prefix = 'img_{:05d}.jpg'
25 | elif modality == 'Flow':
26 | root_data = ROOT_DATASET + 'UCF101/jpg'
27 | filename_imglist_train = 'UCF101/file_list/ucf101_flow_train_split_1.txt'
28 | filename_imglist_val = 'UCF101/file_list/ucf101_flow_val_split_1.txt'
29 | prefix = 'flow_{}_{:05d}.jpg'
30 | else:
31 | raise NotImplementedError('no such modality:' + modality)
32 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
33 |
34 |
35 | def return_hmdb51(ROOT_DATASET, modality):
36 | filename_categories = 51
37 | if modality == 'RGB':
38 | root_data = ROOT_DATASET + 'HMDB51/images'
39 | filename_imglist_train = 'HMDB51/splits/hmdb51_rgb_train_split_1.txt'
40 | filename_imglist_val = 'HMDB51/splits/hmdb51_rgb_val_split_1.txt'
41 | prefix = 'img_{:05d}.jpg'
42 | elif modality == 'Flow':
43 | root_data = ROOT_DATASET + 'HMDB51/images'
44 | filename_imglist_train = 'HMDB51/splits/hmdb51_flow_train_split_1.txt'
45 | filename_imglist_val = 'HMDB51/splits/hmdb51_flow_val_split_1.txt'
46 | prefix = 'flow_{}_{:05d}.jpg'
47 | else:
48 | raise NotImplementedError('no such modality:' + modality)
49 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
50 |
51 |
52 | def return_mini_something(ROOT_DATASET, modality):
53 | filename_categories = 'mini_something_v1/category.txt'
54 | if modality == 'RGB':
55 | root_data = ROOT_DATASET + 'mini_something_v1'
56 | filename_imglist_train = 'mini_something_v1/train_videofolder.txt'
57 | filename_imglist_val = 'mini_something_v1/val_videofolder.txt'
58 | prefix = '{:05d}.jpg'
59 | elif modality == 'Flow':
60 | root_data = ROOT_DATASET + 'mini_something_v1/20bn-something-something-v1-flow'
61 | filename_imglist_train = 'mini_something_v1/train_videofolder_flow.txt'
62 | filename_imglist_val = 'mini_something_v1/val_videofolder_flow.txt'
63 | prefix = '{:06d}-{}_{:05d}.jpg'
64 | else:
65 | print('no such modality:'+modality)
66 | raise NotImplementedError
67 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
68 |
69 |
70 | def return_something(ROOT_DATASET, modality):
71 | filename_categories = 'something_v1/category.txt'
72 | if modality == 'RGB':
73 | root_data = ROOT_DATASET + 'something_v1'
74 | filename_imglist_train = 'something_v1/train_videofolder.txt'
75 | filename_imglist_val = 'something_v1/val_videofolder.txt'
76 | prefix = '{:05d}.jpg'
77 | elif modality == 'Flow':
78 | root_data = ROOT_DATASET + 'something_v1/20bn-something-something-v1-flow'
79 | filename_imglist_train = 'something_v1/train_videofolder_flow.txt'
80 | filename_imglist_val = 'something_v1/val_videofolder_flow.txt'
81 | prefix = '{:06d}-{}_{:05d}.jpg'
82 | else:
83 | print('no such modality:'+modality)
84 | raise NotImplementedError
85 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
86 |
87 |
88 | def return_somethingv2(ROOT_DATASET, modality):
89 | filename_categories = 'something_v2/category.txt'
90 | if modality == 'RGB':
91 | root_data = ROOT_DATASET + 'something_v2'
92 | filename_imglist_train = 'something_v2/train_videofolder.txt'
93 | filename_imglist_val = 'something_v2/val_videofolder.txt'
94 | prefix = '{:06d}.jpg'
95 | elif modality == 'Flow':
96 | root_data = ROOT_DATASET + 'something_v2/20bn-something-something-v2-flow'
97 | filename_imglist_train = 'something_v2/train_videofolder_flow.txt'
98 | filename_imglist_val = 'something_v2/val_videofolder_flow.txt'
99 | prefix = '{:06d}-{}_{:05d}.jpg'
100 | else:
101 | raise NotImplementedError('no such modality:'+modality)
102 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
103 |
104 |
105 | def return_jester(ROOT_DATASET, modality):
106 | filename_categories = 'jester/category.txt'
107 | if modality == 'RGB':
108 | prefix = '{:05d}.jpg'
109 | root_data = ROOT_DATASET + 'jester'
110 | filename_imglist_train = 'jester/train_videofolder.txt'
111 | filename_imglist_val = 'jester/val_videofolder.txt'
112 | else:
113 | raise NotImplementedError('no such modality:'+modality)
114 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
115 |
116 |
117 | def return_mini_kinetics(ROOT_DATASET, modality):
118 | filename_categories = 200
119 | if modality == 'RGB':
120 | root_data = ROOT_DATASET + 'mini-kinetics'
121 | filename_imglist_train = 'mini-kinetics/mini_train_videofolder.txt'
122 | filename_imglist_val = 'mini-kinetics/mini_val_videofolder.txt'
123 | prefix = 'image_{:05d}.jpg'
124 | else:
125 | raise NotImplementedError('no such modality:' + modality)
126 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
127 |
128 |
129 | def return_kinetics(ROOT_DATASET, modality):
130 | filename_categories = 400
131 | if modality == 'RGB':
132 | root_data = ROOT_DATASET + 'kinetics/images'
133 | filename_imglist_train = 'kinetics/labels/train_videofolder.txt'
134 | filename_imglist_val = 'kinetics/labels/val_videofolder.txt'
135 | prefix = 'image_{:05d}.jpg'
136 | else:
137 | raise NotImplementedError('no such modality:' + modality)
138 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
139 |
140 |
141 | def return_dataset(root_dataset, dataset, modality):
142 | ROOT_DATASET = root_dataset
143 |
144 | dict_single = {'jester': return_jester, 'mini_something': return_mini_something, 'something': return_something, 'somethingv2': return_somethingv2,
145 | 'ucf101': return_ucf101, 'hmdb51': return_hmdb51, 'actnet': return_actnet, 'mini_kinetics': return_mini_kinetics,
146 | 'kinetics': return_kinetics }
147 | if dataset in dict_single:
148 | file_categories, file_imglist_train, file_imglist_val, root_data, prefix = dict_single[dataset](ROOT_DATASET, modality)
149 | else:
150 | raise ValueError('Unknown dataset '+dataset)
151 |
152 | file_imglist_train = os.path.join(ROOT_DATASET, file_imglist_train)
153 | file_imglist_val = os.path.join(ROOT_DATASET, file_imglist_val)
154 | if isinstance(file_categories, str):
155 | file_categories = os.path.join(ROOT_DATASET, file_categories)
156 | with open(file_categories) as f:
157 | lines = f.readlines()
158 | categories = [item.rstrip() for item in lines]
159 | else: # number of categories
160 | categories = [None] * file_categories
161 | n_class = len(categories)
162 | print('{}: {} classes'.format(dataset, n_class))
163 | return n_class, file_imglist_train, file_imglist_val, root_data, prefix
164 |
--------------------------------------------------------------------------------
/ops/models.py:
--------------------------------------------------------------------------------
1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
2 | # arXiv:1811.08383
3 | # Ji Lin*, Chuang Gan, Song Han
4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
5 |
6 | from torch import nn
7 | import ops.backbone
8 | from ops.basic_ops import ConsensusModule
9 | from ops.transforms import *
10 | from torch.nn.init import normal_, constant_
11 |
12 |
13 | class TSN(nn.Module):
14 | def __init__(self, arch_file, num_class, num_segments, modality, path_backbone,
15 | base_model='resnet101', new_length=None,
16 | consensus_type='avg', before_softmax=True,
17 | dropout=0.8, img_feature_dim=256,
18 | crop_num=1, partial_bn=True, print_spec=True, pretrain='imagenet',
19 | is_shift=False, fc_lr5=False,
20 | temporal_pool=False, non_local=False):
21 | super(TSN, self).__init__()
22 | self.modality = modality
23 | self.num_segments = num_segments
24 | self.reshape = True
25 | self.before_softmax = before_softmax
26 | self.dropout = dropout
27 | self.crop_num = crop_num
28 | self.consensus_type = consensus_type
29 | self.img_feature_dim = img_feature_dim # the dimension of the CNN feature to represent each frame
30 | self.pretrain = pretrain
31 |
32 | self.is_shift = is_shift
33 | self.base_model_name = base_model
34 | self.fc_lr5 = fc_lr5
35 | self.temporal_pool = temporal_pool
36 | self.non_local = non_local
37 |
38 | if not before_softmax and consensus_type != 'avg':
39 | raise ValueError("Only avg consensus can be used after Softmax")
40 |
41 | if new_length is None:
42 | self.new_length = 1 if modality == "RGB" else 5
43 | else:
44 | self.new_length = new_length
45 | if print_spec:
46 | print(("""
47 | Initializing TSN with base model: {}.
48 | TSN Configurations:
49 | input_modality: {}
50 | num_segments: {}
51 | new_length: {}
52 | consensus_module: {}
53 | dropout_ratio: {}
54 | img_feature_dim: {}
55 | """.format(base_model, self.modality, self.num_segments, self.new_length, consensus_type, self.dropout, self.img_feature_dim)))
56 |
57 | self._prepare_base_model(arch_file, base_model, path_backbone)
58 |
59 | feature_dim = self._prepare_tsn(num_class)
60 |
61 | if self.modality == 'Flow':
62 | print("Converting the ImageNet model to a flow init model")
63 | self.base_model = self._construct_flow_model(self.base_model)
64 | print("Done. Flow model ready...")
65 | elif self.modality == 'RGBDiff':
66 | print("Converting the ImageNet model to RGB+Diff init model")
67 | self.base_model = self._construct_diff_model(self.base_model)
68 | print("Done. RGBDiff model ready.")
69 |
70 | self.consensus = ConsensusModule(consensus_type)
71 |
72 | if not self.before_softmax:
73 | self.softmax = nn.Softmax()
74 |
75 | self._enable_pbn = partial_bn
76 | if partial_bn:
77 | self.partialBN(True)
78 |
79 | def _prepare_tsn(self, num_class):
80 | feature_dim = getattr(self.base_model, self.base_model.last_layer_name).in_features
81 | if self.dropout == 0:
82 | setattr(self.base_model, self.base_model.last_layer_name, nn.Linear(feature_dim, num_class))
83 | self.new_fc = None
84 | else:
85 | setattr(self.base_model, self.base_model.last_layer_name, nn.Dropout(p=self.dropout))
86 | self.new_fc = nn.Linear(feature_dim, num_class)
87 |
88 | std = 0.001
89 | if self.new_fc is None:
90 | normal_(getattr(self.base_model, self.base_model.last_layer_name).weight, 0, std)
91 | constant_(getattr(self.base_model, self.base_model.last_layer_name).bias, 0)
92 | else:
93 | if hasattr(self.new_fc, 'weight'):
94 | normal_(self.new_fc.weight, 0, std)
95 | constant_(self.new_fc.bias, 0)
96 | return feature_dim
97 |
98 | def _prepare_base_model(self, arch_file, base_model, path_backbone):
99 | print('=> base model: {}'.format(base_model))
100 |
101 | if 'resnet' in base_model:
102 | self.base_model = eval(f'ops.backbone.{arch_file}.{base_model}')(True if self.pretrain == 'imagenet' else False, path_backbone=path_backbone, shift=self.is_shift, num_segments=self.num_segments)
103 | self.base_model.last_layer_name = 'fc'
104 | self.input_size = 224
105 | self.input_mean = [0.485, 0.456, 0.406]
106 | self.input_std = [0.229, 0.224, 0.225]
107 |
108 | if self.modality == 'Flow':
109 | self.input_mean = [0.5]
110 | self.input_std = [np.mean(self.input_std)]
111 | elif self.modality == 'RGBDiff':
112 | self.input_mean = [0.485, 0.456, 0.406] + [0] * 3 * self.new_length
113 | self.input_std = self.input_std + [np.mean(self.input_std) * 2] * 3 * self.new_length
114 | else:
115 | raise ValueError('Unknown base model: {}'.format(base_model))
116 |
117 | def train(self, mode=True):
118 | """
119 | Override the default train() to freeze the BN parameters
120 | :return:
121 | """
122 | super(TSN, self).train(mode)
123 | count = 0
124 | if self._enable_pbn and mode:
125 | print("Freezing BatchNorm2D except the first one.")
126 | for m in self.base_model.modules():
127 | if isinstance(m, nn.BatchNorm2d):
128 | count += 1
129 | if count >= (2 if self._enable_pbn else 1):
130 | m.eval()
131 | # shutdown update in frozen mode
132 | m.weight.requires_grad = False
133 | m.bias.requires_grad = False
134 |
135 | def partialBN(self, enable):
136 | self._enable_pbn = enable
137 |
138 | def get_optim_policies(self):
139 | navi_conv_weight = []
140 | navi_conv_bias = []
141 | first_conv_weight = []
142 | first_conv_bias = []
143 | normal_weight = []
144 | normal_bias = []
145 | lr5_weight = []
146 | lr10_bias = []
147 | bn = []
148 | navi_bn = []
149 | custom_ops = []
150 |
151 | conv_cnt = 0
152 | bn_cnt = 0
153 | for m in self.modules():
154 | if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv3d):
155 | ps = list(m.parameters())
156 | conv_cnt += 1
157 | if conv_cnt == 1:
158 | first_conv_weight.append(ps[0])
159 | if len(ps) == 2:
160 | first_conv_bias.append(ps[1])
161 | elif ps[0].shape == torch.Size([self.num_segments*2,2,1,1]):
162 | navi_conv_weight.append(ps[0])
163 | if len(ps) == 2:
164 | navi_conv_bias.append(ps[1])
165 | elif ps[0].shape[0] == 2:
166 | navi_conv_weight.append(ps[0])
167 | if len(ps) == 2:
168 | navi_conv_bias.append(ps[1])
169 | else:
170 | normal_weight.append(ps[0])
171 | if len(ps) == 2:
172 | normal_bias.append(ps[1])
173 | elif isinstance(m, torch.nn.Linear):
174 | ps = list(m.parameters())
175 | if self.fc_lr5:
176 | lr5_weight.append(ps[0])
177 | else:
178 | normal_weight.append(ps[0])
179 | if len(ps) == 2:
180 | if self.fc_lr5:
181 | lr10_bias.append(ps[1])
182 | else:
183 | normal_bias.append(ps[1])
184 |
185 | elif isinstance(m, torch.nn.BatchNorm2d):
186 | bn_cnt += 1
187 | # later BN's are frozen
188 | if list(m.parameters())[0].shape[0]==2:
189 | navi_bn.extend(list(m.parameters()))
190 | elif not self._enable_pbn or bn_cnt == 1:
191 | bn.extend(list(m.parameters()))
192 | elif isinstance(m, torch.nn.BatchNorm3d):
193 | bn_cnt += 1
194 | # later BN's are frozen
195 | if not self._enable_pbn or bn_cnt == 1:
196 | bn.extend(list(m.parameters()))
197 | elif len(m._modules) == 0:
198 | if len(list(m.parameters())) > 0:
199 | raise ValueError("New atomic module type: {}. Need to give it a learning policy".format(type(m)))
200 |
201 | return [
202 | {'params': navi_bn, 'lr_mult': 10, 'decay_mult': 1,
203 | 'name': "navi_bn"},
204 | {'params': navi_conv_weight, 'lr_mult': 10, 'decay_mult': 1,
205 | 'name': "navi_conv_weight"},
206 | {'params': navi_conv_bias, 'lr_mult': 10, 'decay_mult': 1,
207 | 'name': "navi_conv_bias"},
208 | {'params': first_conv_weight, 'lr_mult': 5 if self.modality == 'Flow' else 1, 'decay_mult': 1,
209 | 'name': "first_conv_weight"},
210 | {'params': first_conv_bias, 'lr_mult': 10 if self.modality == 'Flow' else 2, 'decay_mult': 0,
211 | 'name': "first_conv_bias"},
212 | {'params': normal_weight, 'lr_mult': 1, 'decay_mult': 1,
213 | 'name': "normal_weight"},
214 | {'params': normal_bias, 'lr_mult': 2, 'decay_mult': 0,
215 | 'name': "normal_bias"},
216 | {'params': bn, 'lr_mult': 1, 'decay_mult': 0,
217 | 'name': "BN scale/shift"},
218 | {'params': custom_ops, 'lr_mult': 1, 'decay_mult': 1,
219 | 'name': "custom_ops"},
220 | # for fc
221 | {'params': lr5_weight, 'lr_mult': 5, 'decay_mult': 1,
222 | 'name': "lr5_weight"},
223 | {'params': lr10_bias, 'lr_mult': 10, 'decay_mult': 0,
224 | 'name': "lr10_bias"},
225 | ]
226 |
227 | def forward(self, input, temperature, no_reshape=False):
228 | if not no_reshape:
229 | sample_len = (3 if self.modality == "RGB" else 2) * self.new_length
230 |
231 | if self.modality == 'RGBDiff':
232 | sample_len = 3 * self.new_length
233 | input = self._get_diff(input)
234 |
235 | base_out, temporal_mask_ls = self.base_model(input.view((-1, sample_len) + input.size()[-2:]), temperature)
236 | else:
237 | base_out, temporal_mask_ls = self.base_model(input, temperature)
238 |
239 | if self.dropout > 0:
240 | base_out = self.new_fc(base_out)
241 |
242 | if not self.before_softmax:
243 | base_out = self.softmax(base_out)
244 |
245 | if self.reshape:
246 | if self.is_shift and self.temporal_pool:
247 | base_out = base_out.view((-1, self.num_segments // 2) + base_out.size()[1:])
248 | else:
249 | base_out = base_out.view((-1, self.num_segments) + base_out.size()[1:])
250 | output = self.consensus(base_out)
251 | return output.squeeze(1), temporal_mask_ls
252 |
253 | def forward_calc_flops(self, input, temperature, no_reshape=False):
254 | if not no_reshape:
255 | sample_len = (3 if self.modality == "RGB" else 2) * self.new_length
256 |
257 | if self.modality == 'RGBDiff':
258 | sample_len = 3 * self.new_length
259 | input = self._get_diff(input)
260 |
261 | base_out, temporal_mask_ls, flops = self.base_model.forward_calc_flops(input.view((-1, sample_len) + input.size()[-2:]), temperature)
262 | else:
263 | base_out, temporal_mask_ls, flops = self.base_model.forward_calc_flops(input, temperature)
264 |
265 | if self.dropout > 0:
266 | base_out = self.new_fc(base_out)
267 |
268 | if not self.before_softmax:
269 | base_out = self.softmax(base_out)
270 |
271 | if self.reshape:
272 | if self.is_shift and self.temporal_pool:
273 | base_out = base_out.view((-1, self.num_segments // 2) + base_out.size()[1:])
274 | else:
275 | base_out = base_out.view((-1, self.num_segments) + base_out.size()[1:])
276 | output = self.consensus(base_out)
277 | return output.squeeze(1), temporal_mask_ls, flops
278 |
279 | def _get_diff(self, input, keep_rgb=False):
280 | input_c = 3 if self.modality in ["RGB", "RGBDiff"] else 2
281 | input_view = input.view((-1, self.num_segments, self.new_length + 1, input_c,) + input.size()[2:])
282 | if keep_rgb:
283 | new_data = input_view.clone()
284 | else:
285 | new_data = input_view[:, :, 1:, :, :, :].clone()
286 |
287 | for x in reversed(list(range(1, self.new_length + 1))):
288 | if keep_rgb:
289 | new_data[:, :, x, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :]
290 | else:
291 | new_data[:, :, x - 1, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :]
292 |
293 | return new_data
294 |
295 | def _construct_flow_model(self, base_model):
296 | # modify the convolution layers
297 | # Torch models are usually defined in a hierarchical way.
298 | # nn.modules.children() return all sub modules in a DFS manner
299 | modules = list(self.base_model.modules())
300 | first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0]
301 | conv_layer = modules[first_conv_idx]
302 | container = modules[first_conv_idx - 1]
303 |
304 | # modify parameters, assume the first blob contains the convolution kernels
305 | params = [x.clone() for x in conv_layer.parameters()]
306 | kernel_size = params[0].size()
307 | new_kernel_size = kernel_size[:1] + (2 * self.new_length, ) + kernel_size[2:]
308 | new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()
309 |
310 | new_conv = nn.Conv2d(2 * self.new_length, conv_layer.out_channels,
311 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding,
312 | bias=True if len(params) == 2 else False)
313 | new_conv.weight.data = new_kernels
314 | if len(params) == 2:
315 | new_conv.bias.data = params[1].data # add bias if neccessary
316 | layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name
317 |
318 | # replace the first convlution layer
319 | setattr(container, layer_name, new_conv)
320 |
321 | if self.base_model_name == 'BNInception':
322 | import torch.utils.model_zoo as model_zoo
323 | sd = model_zoo.load_url('https://www.dropbox.com/s/35ftw2t4mxxgjae/BNInceptionFlow-ef652051.pth.tar?dl=1')
324 | base_model.load_state_dict(sd)
325 | print('=> Loading pretrained Flow weight done...')
326 | else:
327 | print('#' * 30, 'Warning! No Flow pretrained model is found')
328 | return base_model
329 |
330 | def _construct_diff_model(self, base_model, keep_rgb=False):
331 | # modify the convolution layers
332 | # Torch models are usually defined in a hierarchical way.
333 | # nn.modules.children() return all sub modules in a DFS manner
334 | modules = list(self.base_model.modules())
335 | first_conv_idx = filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules))))[0]
336 | conv_layer = modules[first_conv_idx]
337 | container = modules[first_conv_idx - 1]
338 |
339 | # modify parameters, assume the first blob contains the convolution kernels
340 | params = [x.clone() for x in conv_layer.parameters()]
341 | kernel_size = params[0].size()
342 | if not keep_rgb:
343 | new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:]
344 | new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()
345 | else:
346 | new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:]
347 | new_kernels = torch.cat((params[0].data, params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()),
348 | 1)
349 | new_kernel_size = kernel_size[:1] + (3 + 3 * self.new_length,) + kernel_size[2:]
350 |
351 | new_conv = nn.Conv2d(new_kernel_size[1], conv_layer.out_channels,
352 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding,
353 | bias=True if len(params) == 2 else False)
354 | new_conv.weight.data = new_kernels
355 | if len(params) == 2:
356 | new_conv.bias.data = params[1].data # add bias if neccessary
357 | layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name
358 |
359 | # replace the first convolution layer
360 | setattr(container, layer_name, new_conv)
361 | return base_model
362 |
363 | @property
364 | def crop_size(self):
365 | return self.input_size
366 |
367 | @property
368 | def scale_size(self):
369 | return self.input_size * 256 // 224
370 |
371 | def get_augmentation(self, flip=True):
372 | if self.modality == 'RGB':
373 | if flip:
374 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66]),
375 | GroupRandomHorizontalFlip(is_flow=False)])
376 | else:
377 | print('#' * 20, 'NO FLIP!!!')
378 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66])])
379 | elif self.modality == 'Flow':
380 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]),
381 | GroupRandomHorizontalFlip(is_flow=True)])
382 | elif self.modality == 'RGBDiff':
383 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]),
384 | GroupRandomHorizontalFlip(is_flow=False)])
385 |
--------------------------------------------------------------------------------
/ops/models_mobilenet.py:
--------------------------------------------------------------------------------
1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
2 | # arXiv:1811.08383
3 | # Ji Lin*, Chuang Gan, Song Han
4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
5 |
6 | from torch import nn
7 | import ops.backbone
8 | from ops.basic_ops import ConsensusModule
9 | from ops.transforms import *
10 | from torch.nn.init import normal_, constant_
11 |
12 |
13 | class TSN(nn.Module):
14 | def __init__(self, arch_file, num_class, num_segments, modality, path_backbone,
15 | base_model='mobilenet', new_length=None,
16 | consensus_type='avg', before_softmax=True,
17 | dropout=0.8, img_feature_dim=256,
18 | crop_num=1, partial_bn=True, print_spec=True, pretrain='imagenet',
19 | is_shift=False, fc_lr5=False,
20 | temporal_pool=False, non_local=False):
21 | super(TSN, self).__init__()
22 | self.modality = modality
23 | self.num_segments = num_segments
24 | self.reshape = True
25 | self.before_softmax = before_softmax
26 | self.dropout = dropout
27 | self.crop_num = crop_num
28 | self.consensus_type = consensus_type
29 | self.img_feature_dim = img_feature_dim # the dimension of the CNN feature to represent each frame
30 | self.pretrain = pretrain
31 |
32 | self.is_shift = is_shift
33 | self.base_model_name = base_model
34 | self.fc_lr5 = fc_lr5
35 | self.temporal_pool = temporal_pool
36 | self.non_local = non_local
37 |
38 | if not before_softmax and consensus_type != 'avg':
39 | raise ValueError("Only avg consensus can be used after Softmax")
40 |
41 | if new_length is None:
42 | self.new_length = 1 if modality == "RGB" else 5
43 | else:
44 | self.new_length = new_length
45 | if print_spec:
46 | print(("""
47 | Initializing TSN with base model: {}.
48 | TSN Configurations:
49 | input_modality: {}
50 | num_segments: {}
51 | new_length: {}
52 | consensus_module: {}
53 | dropout_ratio: {}
54 | img_feature_dim: {}
55 | """.format(base_model, self.modality, self.num_segments, self.new_length, consensus_type, self.dropout, self.img_feature_dim)))
56 |
57 | self._prepare_base_model(arch_file, base_model, num_class, path_backbone)
58 |
59 | # feature_dim = self._prepare_tsn(num_class)
60 |
61 | if self.modality == 'Flow':
62 | print("Converting the ImageNet model to a flow init model")
63 | self.base_model = self._construct_flow_model(self.base_model)
64 | print("Done. Flow model ready...")
65 | elif self.modality == 'RGBDiff':
66 | print("Converting the ImageNet model to RGB+Diff init model")
67 | self.base_model = self._construct_diff_model(self.base_model)
68 | print("Done. RGBDiff model ready.")
69 |
70 | self.consensus = ConsensusModule(consensus_type)
71 |
72 | if not self.before_softmax:
73 | self.softmax = nn.Softmax()
74 |
75 | self._enable_pbn = partial_bn
76 | if partial_bn:
77 | self.partialBN(True)
78 |
79 |
80 | def _prepare_base_model(self, arch_file, base_model, num_class, path_backbone):
81 | print('=> base model: {}'.format(base_model))
82 |
83 | if 'mobilenet' in base_model:
84 | self.base_model = eval(f'ops.backbone.{arch_file}.{base_model}')(True if self.pretrain == 'imagenet' else False, path_backbone=path_backbone, shift=self.is_shift, num_segments=self.num_segments, num_class=num_class)
85 | self.base_model.last_layer_name = 'fc'
86 | self.input_size = 224
87 | self.input_mean = [0.485, 0.456, 0.406]
88 | self.input_std = [0.229, 0.224, 0.225]
89 |
90 | if self.modality == 'Flow':
91 | self.input_mean = [0.5]
92 | self.input_std = [np.mean(self.input_std)]
93 | elif self.modality == 'RGBDiff':
94 | self.input_mean = [0.485, 0.456, 0.406] + [0] * 3 * self.new_length
95 | self.input_std = self.input_std + [np.mean(self.input_std) * 2] * 3 * self.new_length
96 | else:
97 | raise ValueError('Unknown base model: {}'.format(base_model))
98 |
99 | def train(self, mode=True):
100 | """
101 | Override the default train() to freeze the BN parameters
102 | :return:
103 | """
104 | super(TSN, self).train(mode)
105 | count = 0
106 | if self._enable_pbn and mode:
107 | print("Freezing BatchNorm2D except the first one.")
108 | for m in self.base_model.modules():
109 | if isinstance(m, nn.BatchNorm2d):
110 | count += 1
111 | if count >= (2 if self._enable_pbn else 1):
112 | m.eval()
113 | # shutdown update in frozen mode
114 | m.weight.requires_grad = False
115 | m.bias.requires_grad = False
116 |
117 | def partialBN(self, enable):
118 | self._enable_pbn = enable
119 |
120 | def get_optim_policies(self):
121 | navi_conv_weight = []
122 | navi_conv_bias = []
123 | first_conv_weight = []
124 | first_conv_bias = []
125 | normal_weight = []
126 | normal_bias = []
127 | lr5_weight = []
128 | lr10_bias = []
129 | bn = []
130 | navi_bn = []
131 | custom_ops = []
132 |
133 | conv_cnt = 0
134 | bn_cnt = 0
135 | for m in self.modules():
136 | if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv3d):
137 | ps = list(m.parameters())
138 | conv_cnt += 1
139 | if conv_cnt == 1:
140 | first_conv_weight.append(ps[0])
141 | if len(ps) == 2:
142 | first_conv_bias.append(ps[1])
143 | elif ps[0].shape == torch.Size([self.num_segments*2,2,1,1]):
144 | navi_conv_weight.append(ps[0])
145 | if len(ps) == 2:
146 | navi_conv_bias.append(ps[1])
147 | elif ps[0].shape[0] == 2:
148 | navi_conv_weight.append(ps[0])
149 | if len(ps) == 2:
150 | navi_conv_bias.append(ps[1])
151 | else:
152 | normal_weight.append(ps[0])
153 | if len(ps) == 2:
154 | normal_bias.append(ps[1])
155 | elif isinstance(m, torch.nn.Linear):
156 | ps = list(m.parameters())
157 |
158 | if self.fc_lr5:
159 | if ps[0].shape[0] == 1280:
160 | lr5_weight.append(ps[0])
161 | elif ps[0].shape[1] == 1280:
162 | lr5_weight.append(ps[0])
163 | else:
164 | normal_weight.append(ps[0])
165 | else:
166 | normal_weight.append(ps[0])
167 | if len(ps) == 2:
168 | if self.fc_lr5:
169 | if ps[0].shape[0] == 1280:
170 | lr10_bias.append(ps[1])
171 | elif ps[0].shape[1] == 1280:
172 | lr10_bias.append(ps[1])
173 | else:
174 | normal_weight.append(ps[1])
175 | else:
176 | normal_bias.append(ps[1])
177 |
178 | elif isinstance(m, torch.nn.BatchNorm2d):
179 | bn_cnt += 1
180 | # later BN's are frozen
181 | if list(m.parameters())[0].shape[0]==2:
182 | navi_bn.extend(list(m.parameters()))
183 | elif not self._enable_pbn or bn_cnt == 1:
184 | bn.extend(list(m.parameters()))
185 | elif isinstance(m, torch.nn.BatchNorm3d):
186 | bn_cnt += 1
187 | # later BN's are frozen
188 | if not self._enable_pbn or bn_cnt == 1:
189 | bn.extend(list(m.parameters()))
190 | elif len(m._modules) == 0:
191 | if len(list(m.parameters())) > 0:
192 | raise ValueError("New atomic module type: {}. Need to give it a learning policy".format(type(m)))
193 |
194 | return [
195 | {'params': navi_bn, 'lr_mult': 10, 'decay_mult': 1,
196 | 'name': "navi_bn"},
197 | {'params': navi_conv_weight, 'lr_mult': 10, 'decay_mult': 1,
198 | 'name': "navi_conv_weight"},
199 | {'params': navi_conv_bias, 'lr_mult': 10, 'decay_mult': 1,
200 | 'name': "navi_conv_bias"},
201 | {'params': first_conv_weight, 'lr_mult': 5 if self.modality == 'Flow' else 1, 'decay_mult': 1,
202 | 'name': "first_conv_weight"},
203 | {'params': first_conv_bias, 'lr_mult': 10 if self.modality == 'Flow' else 2, 'decay_mult': 0,
204 | 'name': "first_conv_bias"},
205 | {'params': normal_weight, 'lr_mult': 1, 'decay_mult': 1,
206 | 'name': "normal_weight"},
207 | {'params': normal_bias, 'lr_mult': 2, 'decay_mult': 0,
208 | 'name': "normal_bias"},
209 | {'params': bn, 'lr_mult': 1, 'decay_mult': 0,
210 | 'name': "BN scale/shift"},
211 | {'params': custom_ops, 'lr_mult': 1, 'decay_mult': 1,
212 | 'name': "custom_ops"},
213 | # for fc
214 | {'params': lr5_weight, 'lr_mult': 5, 'decay_mult': 1,
215 | 'name': "lr5_weight"},
216 | {'params': lr10_bias, 'lr_mult': 10, 'decay_mult': 0,
217 | 'name': "lr10_bias"},
218 | ]
219 |
220 | def forward(self, input, temperature, no_reshape=False):
221 | if not no_reshape:
222 | sample_len = (3 if self.modality == "RGB" else 2) * self.new_length
223 |
224 | if self.modality == 'RGBDiff':
225 | sample_len = 3 * self.new_length
226 | input = self._get_diff(input)
227 |
228 | base_out, temporal_mask_ls = self.base_model(input.view((-1, sample_len) + input.size()[-2:]), temperature)
229 | else:
230 | base_out, temporal_mask_ls = self.base_model(input, temperature)
231 |
232 | if not self.before_softmax:
233 | base_out = self.softmax(base_out)
234 |
235 | if self.reshape:
236 | if self.is_shift and self.temporal_pool:
237 | base_out = base_out.view((-1, self.num_segments // 2) + base_out.size()[1:])
238 | else:
239 | base_out = base_out.view((-1, self.num_segments) + base_out.size()[1:])
240 | output = self.consensus(base_out)
241 | return output.squeeze(1), temporal_mask_ls
242 |
243 | def forward_calc_flops(self, input, temperature, no_reshape=False):
244 | if not no_reshape:
245 | sample_len = (3 if self.modality == "RGB" else 2) * self.new_length
246 |
247 | if self.modality == 'RGBDiff':
248 | sample_len = 3 * self.new_length
249 | input = self._get_diff(input)
250 |
251 | base_out, temporal_mask_ls, flops = self.base_model.forward_calc_flops(input.view((-1, sample_len) + input.size()[-2:]), temperature)
252 | else:
253 | base_out, temporal_mask_ls, flops = self.base_model.forward_calc_flops(input, temperature)
254 |
255 | if not self.before_softmax:
256 | base_out = self.softmax(base_out)
257 |
258 | if self.reshape:
259 | if self.is_shift and self.temporal_pool:
260 | base_out = base_out.view((-1, self.num_segments // 2) + base_out.size()[1:])
261 | else:
262 | base_out = base_out.view((-1, self.num_segments) + base_out.size()[1:])
263 | output = self.consensus(base_out)
264 | return output.squeeze(1), temporal_mask_ls, flops
265 |
266 | def _get_diff(self, input, keep_rgb=False):
267 | input_c = 3 if self.modality in ["RGB", "RGBDiff"] else 2
268 | input_view = input.view((-1, self.num_segments, self.new_length + 1, input_c,) + input.size()[2:])
269 | if keep_rgb:
270 | new_data = input_view.clone()
271 | else:
272 | new_data = input_view[:, :, 1:, :, :, :].clone()
273 |
274 | for x in reversed(list(range(1, self.new_length + 1))):
275 | if keep_rgb:
276 | new_data[:, :, x, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :]
277 | else:
278 | new_data[:, :, x - 1, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :]
279 |
280 | return new_data
281 |
282 | def _construct_flow_model(self, base_model):
283 | # modify the convolution layers
284 | # Torch models are usually defined in a hierarchical way.
285 | # nn.modules.children() return all sub modules in a DFS manner
286 | modules = list(self.base_model.modules())
287 | first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0]
288 | conv_layer = modules[first_conv_idx]
289 | container = modules[first_conv_idx - 1]
290 |
291 | # modify parameters, assume the first blob contains the convolution kernels
292 | params = [x.clone() for x in conv_layer.parameters()]
293 | kernel_size = params[0].size()
294 | new_kernel_size = kernel_size[:1] + (2 * self.new_length, ) + kernel_size[2:]
295 | new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()
296 |
297 | new_conv = nn.Conv2d(2 * self.new_length, conv_layer.out_channels,
298 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding,
299 | bias=True if len(params) == 2 else False)
300 | new_conv.weight.data = new_kernels
301 | if len(params) == 2:
302 | new_conv.bias.data = params[1].data # add bias if neccessary
303 | layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name
304 |
305 | # replace the first convlution layer
306 | setattr(container, layer_name, new_conv)
307 |
308 | if self.base_model_name == 'BNInception':
309 | import torch.utils.model_zoo as model_zoo
310 | sd = model_zoo.load_url('https://www.dropbox.com/s/35ftw2t4mxxgjae/BNInceptionFlow-ef652051.pth.tar?dl=1')
311 | base_model.load_state_dict(sd)
312 | print('=> Loading pretrained Flow weight done...')
313 | else:
314 | print('#' * 30, 'Warning! No Flow pretrained model is found')
315 | return base_model
316 |
317 | def _construct_diff_model(self, base_model, keep_rgb=False):
318 | # modify the convolution layers
319 | # Torch models are usually defined in a hierarchical way.
320 | # nn.modules.children() return all sub modules in a DFS manner
321 | modules = list(self.base_model.modules())
322 | first_conv_idx = filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules))))[0]
323 | conv_layer = modules[first_conv_idx]
324 | container = modules[first_conv_idx - 1]
325 |
326 | # modify parameters, assume the first blob contains the convolution kernels
327 | params = [x.clone() for x in conv_layer.parameters()]
328 | kernel_size = params[0].size()
329 | if not keep_rgb:
330 | new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:]
331 | new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()
332 | else:
333 | new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:]
334 | new_kernels = torch.cat((params[0].data, params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()),
335 | 1)
336 | new_kernel_size = kernel_size[:1] + (3 + 3 * self.new_length,) + kernel_size[2:]
337 |
338 | new_conv = nn.Conv2d(new_kernel_size[1], conv_layer.out_channels,
339 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding,
340 | bias=True if len(params) == 2 else False)
341 | new_conv.weight.data = new_kernels
342 | if len(params) == 2:
343 | new_conv.bias.data = params[1].data # add bias if neccessary
344 | layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name
345 |
346 | # replace the first convolution layer
347 | setattr(container, layer_name, new_conv)
348 | return base_model
349 |
350 | @property
351 | def crop_size(self):
352 | return self.input_size
353 |
354 | @property
355 | def scale_size(self):
356 | return self.input_size * 256 // 224
357 |
358 | def get_augmentation(self, flip=True):
359 | if self.modality == 'RGB':
360 | if flip:
361 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66]),
362 | GroupRandomHorizontalFlip(is_flow=False)])
363 | else:
364 | print('#' * 20, 'NO FLIP!!!')
365 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66])])
366 | elif self.modality == 'Flow':
367 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]),
368 | GroupRandomHorizontalFlip(is_flow=True)])
369 | elif self.modality == 'RGBDiff':
370 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]),
371 | GroupRandomHorizontalFlip(is_flow=False)])
--------------------------------------------------------------------------------
/ops/transforms.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | import random
3 | from PIL import Image, ImageOps
4 | import numpy as np
5 | import numbers
6 | import math
7 | import torch
8 |
9 |
10 | class GroupRandomCrop(object):
11 | def __init__(self, size):
12 | if isinstance(size, numbers.Number):
13 | self.size = (int(size), int(size))
14 | else:
15 | self.size = size
16 |
17 | def __call__(self, img_group):
18 |
19 | w, h = img_group[0].size
20 | th, tw = self.size
21 |
22 | out_images = list()
23 |
24 | x1 = random.randint(0, w - tw)
25 | y1 = random.randint(0, h - th)
26 |
27 | for img in img_group:
28 | assert(img.size[0] == w and img.size[1] == h)
29 | if w == tw and h == th:
30 | out_images.append(img)
31 | else:
32 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
33 |
34 | return out_images
35 |
36 |
37 | class GroupCenterCrop(object):
38 | def __init__(self, size):
39 | self.worker = torchvision.transforms.CenterCrop(size)
40 |
41 | def __call__(self, img_group):
42 | return [self.worker(img) for img in img_group]
43 |
44 |
45 | class GroupRandomHorizontalFlip(object):
46 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5
47 | """
48 | def __init__(self, is_flow=False):
49 | self.is_flow = is_flow
50 |
51 | def __call__(self, img_group, is_flow=False):
52 | v = random.random()
53 | if v < 0.5:
54 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
55 | if self.is_flow:
56 | for i in range(0, len(ret), 2):
57 | ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping
58 | return ret
59 | else:
60 | return img_group
61 |
62 |
63 | class GroupNormalize(object):
64 | def __init__(self, mean, std):
65 | self.mean = mean
66 | self.std = std
67 |
68 | def __call__(self, tensor):
69 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean))
70 | rep_std = self.std * (tensor.size()[0]//len(self.std))
71 |
72 | # TODO: make efficient
73 | for t, m, s in zip(tensor, rep_mean, rep_std):
74 | t.sub_(m).div_(s)
75 |
76 | return tensor
77 |
78 |
79 | class GroupScale(object):
80 | """ Rescales the input PIL.Image to the given 'size'.
81 | 'size' will be the size of the smaller edge.
82 | For example, if height > width, then image will be
83 | rescaled to (size * height / width, size)
84 | size: size of the smaller edge
85 | interpolation: Default: PIL.Image.BILINEAR
86 | """
87 |
88 | def __init__(self, size, interpolation=Image.BILINEAR):
89 | self.worker = torchvision.transforms.Resize(size, interpolation)
90 |
91 | def __call__(self, img_group):
92 | return [self.worker(img) for img in img_group]
93 |
94 |
95 | class GroupOverSample(object):
96 | def __init__(self, crop_size, scale_size=None, flip=True):
97 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size)
98 |
99 | if scale_size is not None:
100 | self.scale_worker = GroupScale(scale_size)
101 | else:
102 | self.scale_worker = None
103 | self.flip = flip
104 |
105 | def __call__(self, img_group):
106 |
107 | if self.scale_worker is not None:
108 | img_group = self.scale_worker(img_group)
109 |
110 | image_w, image_h = img_group[0].size
111 | crop_w, crop_h = self.crop_size
112 |
113 | offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h)
114 | oversample_group = list()
115 | for o_w, o_h in offsets:
116 | normal_group = list()
117 | flip_group = list()
118 | for i, img in enumerate(img_group):
119 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
120 | normal_group.append(crop)
121 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
122 |
123 | if img.mode == 'L' and i % 2 == 0:
124 | flip_group.append(ImageOps.invert(flip_crop))
125 | else:
126 | flip_group.append(flip_crop)
127 |
128 | oversample_group.extend(normal_group)
129 | if self.flip:
130 | oversample_group.extend(flip_group)
131 | return oversample_group
132 |
133 |
134 | class GroupFullResSample(object):
135 | def __init__(self, crop_size, scale_size=None, flip=True):
136 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size)
137 |
138 | if scale_size is not None:
139 | self.scale_worker = GroupScale(scale_size)
140 | else:
141 | self.scale_worker = None
142 | self.flip = flip
143 |
144 | def __call__(self, img_group):
145 |
146 | if self.scale_worker is not None:
147 | img_group = self.scale_worker(img_group)
148 |
149 | image_w, image_h = img_group[0].size
150 | crop_w, crop_h = self.crop_size
151 |
152 | w_step = (image_w - crop_w) // 4
153 | h_step = (image_h - crop_h) // 4
154 |
155 | offsets = list()
156 | offsets.append((0 * w_step, 2 * h_step)) # left
157 | offsets.append((4 * w_step, 2 * h_step)) # right
158 | offsets.append((2 * w_step, 2 * h_step)) # center
159 |
160 | oversample_group = list()
161 | for o_w, o_h in offsets:
162 | normal_group = list()
163 | flip_group = list()
164 | for i, img in enumerate(img_group):
165 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
166 | normal_group.append(crop)
167 | if self.flip:
168 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
169 |
170 | if img.mode == 'L' and i % 2 == 0:
171 | flip_group.append(ImageOps.invert(flip_crop))
172 | else:
173 | flip_group.append(flip_crop)
174 |
175 | oversample_group.extend(normal_group)
176 | oversample_group.extend(flip_group)
177 | return oversample_group
178 |
179 |
180 | class GroupMultiScaleCrop(object):
181 |
182 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):
183 | self.scales = scales if scales is not None else [1, .875, .75, .66]
184 | self.max_distort = max_distort
185 | self.fix_crop = fix_crop
186 | self.more_fix_crop = more_fix_crop
187 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
188 | self.interpolation = Image.BILINEAR
189 |
190 | def __call__(self, img_group):
191 |
192 | im_size = img_group[0].size
193 |
194 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
195 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
196 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)
197 | for img in crop_img_group]
198 | return ret_img_group
199 |
200 | def _sample_crop_size(self, im_size):
201 | image_w, image_h = im_size[0], im_size[1]
202 |
203 | # find a crop size
204 | base_size = min(image_w, image_h)
205 | crop_sizes = [int(base_size * x) for x in self.scales]
206 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]
207 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]
208 |
209 | pairs = []
210 | for i, h in enumerate(crop_h):
211 | for j, w in enumerate(crop_w):
212 | if abs(i - j) <= self.max_distort:
213 | pairs.append((w, h))
214 |
215 | crop_pair = random.choice(pairs)
216 | if not self.fix_crop:
217 | w_offset = random.randint(0, image_w - crop_pair[0])
218 | h_offset = random.randint(0, image_h - crop_pair[1])
219 | else:
220 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
221 |
222 | return crop_pair[0], crop_pair[1], w_offset, h_offset
223 |
224 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
225 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)
226 | return random.choice(offsets)
227 |
228 | @staticmethod
229 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
230 | w_step = (image_w - crop_w) // 4
231 | h_step = (image_h - crop_h) // 4
232 |
233 | ret = list()
234 | ret.append((0, 0)) # upper left
235 | ret.append((4 * w_step, 0)) # upper right
236 | ret.append((0, 4 * h_step)) # lower left
237 | ret.append((4 * w_step, 4 * h_step)) # lower right
238 | ret.append((2 * w_step, 2 * h_step)) # center
239 |
240 | if more_fix_crop:
241 | ret.append((0, 2 * h_step)) # center left
242 | ret.append((4 * w_step, 2 * h_step)) # center right
243 | ret.append((2 * w_step, 4 * h_step)) # lower center
244 | ret.append((2 * w_step, 0 * h_step)) # upper center
245 |
246 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter
247 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter
248 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter
249 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
250 |
251 | return ret
252 |
253 |
254 | class GroupRandomSizedCrop(object):
255 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
256 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
257 | This is popularly used to train the Inception networks
258 | size: size of the smaller edge
259 | interpolation: Default: PIL.Image.BILINEAR
260 | """
261 | def __init__(self, size, interpolation=Image.BILINEAR):
262 | self.size = size
263 | self.interpolation = interpolation
264 |
265 | def __call__(self, img_group):
266 | for attempt in range(10):
267 | area = img_group[0].size[0] * img_group[0].size[1]
268 | target_area = random.uniform(0.08, 1.0) * area
269 | aspect_ratio = random.uniform(3. / 4, 4. / 3)
270 |
271 | w = int(round(math.sqrt(target_area * aspect_ratio)))
272 | h = int(round(math.sqrt(target_area / aspect_ratio)))
273 |
274 | if random.random() < 0.5:
275 | w, h = h, w
276 |
277 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]:
278 | x1 = random.randint(0, img_group[0].size[0] - w)
279 | y1 = random.randint(0, img_group[0].size[1] - h)
280 | found = True
281 | break
282 | else:
283 | found = False
284 | x1 = 0
285 | y1 = 0
286 |
287 | if found:
288 | out_group = list()
289 | for img in img_group:
290 | img = img.crop((x1, y1, x1 + w, y1 + h))
291 | assert(img.size == (w, h))
292 | out_group.append(img.resize((self.size, self.size), self.interpolation))
293 | return out_group
294 | else:
295 | # Fallback
296 | scale = GroupScale(self.size, interpolation=self.interpolation)
297 | crop = GroupRandomCrop(self.size)
298 | return crop(scale(img_group))
299 |
300 |
301 | class Stack(object):
302 |
303 | def __init__(self, roll=False):
304 | self.roll = roll
305 |
306 | def __call__(self, img_group):
307 | if img_group[0].mode == 'L':
308 | return np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2)
309 | elif img_group[0].mode == 'RGB':
310 | if self.roll:
311 | return np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2)
312 | else:
313 | return np.concatenate(img_group, axis=2)
314 |
315 |
316 | class ToTorchFormatTensor(object):
317 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
318 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
319 | def __init__(self, div=True):
320 | self.div = div
321 |
322 | def __call__(self, pic):
323 | if isinstance(pic, np.ndarray):
324 | # handle numpy array
325 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
326 | else:
327 | # handle PIL Image
328 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
329 | img = img.view(pic.size[1], pic.size[0], len(pic.mode))
330 | # put it from HWC to CHW format
331 | # yikes, this transpose takes 80% of the loading time/CPU
332 | img = img.transpose(0, 1).transpose(0, 2).contiguous()
333 | return img.float().div(255) if self.div else img.float()
334 |
335 |
336 | class IdentityTransform(object):
337 |
338 | def __call__(self, data):
339 | return data
340 |
341 |
342 | if __name__ == "__main__":
343 | trans = torchvision.transforms.Compose([
344 | GroupScale(256),
345 | GroupRandomCrop(224),
346 | Stack(),
347 | ToTorchFormatTensor(),
348 | GroupNormalize(
349 | mean=[.485, .456, .406],
350 | std=[.229, .224, .225]
351 | )]
352 | )
353 |
354 | im = Image.open('../tensorflow-model-zoo.torch/lena_299.png')
355 |
356 | color_group = [im] * 3
357 | rst = trans(color_group)
358 |
359 | gray_group = [im.convert('L')] * 9
360 | gray_rst = trans(gray_group)
361 |
362 | trans2 = torchvision.transforms.Compose([
363 | GroupRandomSizedCrop(256),
364 | Stack(),
365 | ToTorchFormatTensor(),
366 | GroupNormalize(
367 | mean=[.485, .456, .406],
368 | std=[.229, .224, .225])
369 | ])
370 | print(trans2(color_group))
--------------------------------------------------------------------------------
/ops/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | def softmax(scores):
6 | es = np.exp(scores - scores.max(axis=-1)[..., None])
7 | return es / es.sum(axis=-1)[..., None]
8 |
9 |
10 | class AverageMeter(object):
11 | """Computes and stores the average and current value"""
12 |
13 | def __init__(self):
14 | self.reset()
15 |
16 | def reset(self):
17 | self.val = 0
18 | self.avg = 0
19 | self.sum = 0
20 | self.count = 0
21 |
22 | def update(self, val, n=1):
23 | self.val = val
24 | self.sum += val * n
25 | self.count += n
26 | self.avg = self.sum / self.count
27 |
28 |
29 | def accuracy(output, target, topk=(1,)):
30 | """Computes the precision@k for the specified values of k"""
31 | maxk = max(topk)
32 | batch_size = target.size(0)
33 |
34 | _, pred = output.topk(maxk, 1, True, True)
35 | pred = pred.t()
36 | correct = pred.eq(target.view(1, -1).expand_as(pred))
37 |
38 | res = []
39 | for k in topk:
40 | correct_k = correct[:k].reshape(-1).float().sum(0)
41 | res.append(correct_k.mul_(100.0 / batch_size))
42 | return res
43 |
44 |
45 | def cal_map(output, old_test_y):
46 | batch_size = output.size(0)
47 | num_classes = output.size(1)
48 | ap = torch.zeros(num_classes)
49 | test_y = old_test_y.clone()
50 |
51 | gt = get_multi_hot(test_y, num_classes, False)
52 |
53 | probs = F.softmax(output, dim=1)
54 |
55 | rg = torch.range(1, batch_size).float()
56 | for k in range(num_classes):
57 | scores = probs[:, k]
58 | targets = gt[:, k]
59 | _, sortind = torch.sort(scores, 0, True)
60 | truth = targets[sortind]
61 | tp = truth.float().cumsum(0)
62 | precision = tp.div(rg)
63 | ap[k] = precision[truth.byte()].sum() / max(float(truth.sum()), 1)
64 | return ap.mean()*100, ap*100
65 |
66 |
67 |
68 | def get_multi_hot(test_y, classes, assumes_starts_zero=True):
69 | bs = test_y.shape[0]
70 | label_cnt = 0
71 |
72 | if not assumes_starts_zero:
73 | for label_val in torch.unique(test_y):
74 | if label_val >= 0:
75 | test_y[test_y == label_val] = label_cnt
76 | label_cnt += 1
77 |
78 | gt = torch.zeros(bs, classes + 1)
79 | for i in range(test_y.shape[1]):
80 | gt[torch.LongTensor(range(bs)), test_y[:, i]] = 1
81 |
82 | return gt[:, :classes]
--------------------------------------------------------------------------------
/opts.py:
--------------------------------------------------------------------------------
1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
2 | # arXiv:1811.08383
3 | # Ji Lin*, Chuang Gan, Song Han
4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
5 |
6 | import argparse
7 |
8 | parser = argparse.ArgumentParser(description="PyTorch implementation of Temporal Segment Networks")
9 | parser.add_argument('dataset', type=str)
10 | parser.add_argument('modality', type=str, choices=['RGB', 'Flow'])
11 | parser.add_argument('--train_list', type=str, default="")
12 | parser.add_argument('--val_list', type=str, default="")
13 | parser.add_argument('--root_path', type=str, default="")
14 | parser.add_argument('--path_backbone', type=str)
15 | parser.add_argument('--root_dataset', type=str)
16 | # ========================= Model Configs ==========================
17 | parser.add_argument('--arch_file', type=str, default="resnet_TSM_mask")
18 | parser.add_argument('--arch', type=str, default="BNInception")
19 | parser.add_argument('--num_segments', type=int, default=3)
20 | parser.add_argument('--consensus_type', type=str, default='avg')
21 | parser.add_argument('--k', type=int, default=3)
22 |
23 | parser.add_argument('--dropout', '--do', default=0.5, type=float,
24 | metavar='DO', help='dropout ratio (default: 0.5)')
25 | parser.add_argument('--loss_type', type=str, default="nll",
26 | choices=['nll'])
27 | parser.add_argument('--img_feature_dim', default=256, type=int, help="the feature dimension for each frame")
28 | parser.add_argument('--suffix', type=str, default=None)
29 | parser.add_argument('--pretrain', type=str, default='imagenet')
30 | parser.add_argument('--tune_from', type=str, default=None, help='fine-tune from checkpoint')
31 |
32 | # ========================= Learning Configs ==========================
33 | parser.add_argument('--epochs', default=120, type=int, metavar='N',
34 | help='number of total epochs to run')
35 | parser.add_argument('-b', '--batch-size', default=128, type=int,
36 | metavar='N', help='mini-batch size (default: 256)')
37 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
38 | metavar='LR', help='initial learning rate')
39 | parser.add_argument('--lr_type', default='step', type=str,
40 | metavar='LRtype', help='learning rate type')
41 | parser.add_argument('--lr_steps', default=[50, 100], type=float, nargs="+",
42 | metavar='LRSteps', help='epochs to decay learning rate by 10')
43 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
44 | help='momentum')
45 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
46 | metavar='W', help='weight decay (default: 5e-4)')
47 | parser.add_argument('--clip-gradient', '--gd', default=None, type=float,
48 | metavar='W', help='gradient norm clipping (default: disabled)')
49 | parser.add_argument('--no_partialbn', '--npb', default=False, action="store_true")
50 |
51 | # ========================= Monitor Configs ==========================
52 | parser.add_argument('--print-freq', '-p', default=20, type=int,
53 | metavar='N', help='print frequency (default: 10)')
54 | parser.add_argument('--eval-freq', '-ef', default=5, type=int,
55 | metavar='N', help='evaluation frequency (default: 5)')
56 |
57 |
58 | # ========================= Runtime Configs ==========================
59 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
60 | help='number of data loading workers (default: 8)')
61 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
62 | help='path to latest checkpoint (default: none)')
63 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
64 | help='evaluate model on validation set')
65 | parser.add_argument('--snapshot_pref', type=str, default="")
66 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
67 | help='manual epoch number (useful on restarts)')
68 | parser.add_argument('--gpus', nargs='+', type=int, default=None)
69 | parser.add_argument('--flow_prefix', default="", type=str)
70 | parser.add_argument('--root_log',type=str, default='log')
71 | parser.add_argument('--root_model', type=str, default='checkpoint')
72 |
73 | parser.add_argument('--shift', default=False, action="store_true", help='use shift for models')
74 |
75 | parser.add_argument('--temporal_pool', default=False, action="store_true", help='add temporal pooling')
76 | parser.add_argument('--non_local', default=False, action="store_true", help='add non local block')
77 | parser.add_argument('--dense_sample', default=False, action="store_true", help='use dense sample for video dataset')
78 |
79 | parser.add_argument('--world_size', default=1, type=int)
80 | parser.add_argument('--local_rank', default=0, type=int, help='node rank for distributed training')
81 | parser.add_argument('--distributed', default=False, action="store_true")
82 | parser.add_argument('--amp', default=False, action="store_true")
83 |
84 | parser.add_argument('--model_path', default='exp', type=str)
85 | parser.add_argument('--rt_begin', default=1, type=int)
86 | parser.add_argument('--rt_end', default=50, type=int)
87 | parser.add_argument('--rt', default=0, type=float)
88 | parser.add_argument('--t0', default=5.0, type=float)
89 | parser.add_argument('--t1', default=1e-2, type=float)
90 | parser.add_argument('--t_end', default=50, type=int)
91 | parser.add_argument('--temp', default=1, type=float)
92 | parser.add_argument('--lambda_rt', default=1, type=float)
93 | parser.add_argument('--round', default="", type=str)
--------------------------------------------------------------------------------
/train_sth.sh:
--------------------------------------------------------------------------------
1 | ### train AF-ResNet(RT=0.5)
2 | CUDA_VISIBLE_DEVICES=0,1 python main.py something RGB \
3 | --arch_file AF_ResNet \
4 | --arch AF_resnet50 --num_segments 12 \
5 | --root_dataset 'path_dataset' \
6 | --path_backbone 'path_backbone' \
7 | --batch-size 32 --lr 0.01 --lr_steps 25 45 --epochs 55 \
8 | --gd 20 -j 12 --dropout 0.5 --consensus_type=avg --eval-freq=1 --npb \
9 | --rt_begin 10 --rt_end 20 --t0 1 --t_end 50 --lambda_rt 0.5 \
10 | --model_path 'models' \
11 | --rt 0.5 --round 1;
12 |
13 |
14 |
15 | ### train AF-ResNet-TSM(RT=0.5)
16 | CUDA_VISIBLE_DEVICES=0,1 python main.py something RGB \
17 | --arch_file AF_ResNet \
18 | --arch AF_resnet50 --num_segments 12 \
19 | --root_dataset 'path_dataset' \
20 | --path_backbone 'path_backbone' \
21 | --batch-size 32 --lr 0.01 --lr_steps 25 45 --epochs 55 \
22 | --gd 20 -j 12 --dropout 0.5 --consensus_type=avg --eval-freq=1 --npb \
23 | --rt_begin 10 --rt_end 20 --t0 1 --t_end 50 --lambda_rt 0.5 \
24 | --model_path 'models' \
25 | --shift \
26 | --rt 0.5 --round 1;
27 |
28 |
29 |
30 | ### train AF-MobileNetv3-TSM(RT=0.5)
31 | CUDA_VISIBLE_DEVICES=0,1 python main.py something RGB \
32 | --arch_file AF_MobileNetv3 \
33 | --arch AF_mobilenetv3 --num_segments 12 \
34 | --root_dataset 'path_dataset' \
35 | --path_backbone 'path_backbone' \
36 | --batch-size 32 --lr 0.01 --lr_steps 25 45 --epochs 55 \
37 | --gd 20 -j 12 --dropout 0.5 --consensus_type=avg --eval-freq=1 --npb \
38 | --rt_begin 10 --rt_end 20 --t0 1 --t_end 50 --lambda_rt 0.5 \
39 | --model_path 'models_mobilenet' \
40 | --shift \
41 | --rt 0.5 --round 1;
--------------------------------------------------------------------------------