├── README.md ├── archs ├── __init__.py ├── bn_inception.py └── mobilenet_v2.py ├── main.py ├── online_demo ├── README.md ├── main.py ├── mobilenet_tsm_tvm_cuda.json └── mobilenet_v2_tsm.py ├── ops ├── __init__.py ├── basic_ops.py ├── dataset.py ├── dataset_config.py ├── models.py ├── non_local.py ├── temporal_shift.py ├── transforms.py └── utils.py ├── opts.py ├── scripts ├── finetune_tsm_ucf101_rgb_8f.sh ├── test_tsm_kinetics_rgb_8f.sh ├── train_tsm_kinetics_rgb_16f.sh ├── train_tsm_kinetics_rgb_8f.sh └── train_tsn_kinetics_rgb_5f.sh ├── test_models.py └── tools ├── gen_label_kinetics.py ├── gen_label_sthv1.py ├── gen_label_sthv2.py ├── kinetics_label_map.txt ├── vid2img_kinetics.py └── vid2img_sthv2.py /README.md: -------------------------------------------------------------------------------- 1 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/mutual-modality-learning-for-video-action/action-recognition-in-videos-on-something)](https://paperswithcode.com/sota/action-recognition-in-videos-on-something?p=mutual-modality-learning-for-video-action) 2 | 3 | # Mutual Modality Learning for Video Action Classification 4 | 5 | By Stepan Komkov, Maksim Dzabraev and Aleksandr Petiushko 6 | 7 | This is the code for the [Mutual Modality Learning article](https://arxiv.org/abs/2011.02543). 8 | 9 | ## Abstract 10 | 11 | The construction of models for video action classification progresses rapidly. 12 | However, the performance of those models can still be easily improved by ensembling 13 | with the same models trained on different modalities (e.g. Optical flow). Unfortunately, 14 | it is computationally expensive to use several modalities during inference. Recent works 15 | examine the ways to integrate advantages of multi-modality into a single RGB-model. Yet, 16 | there is still a room for improvement. In this paper, we explore the various methods to 17 | embed the ensemble power into a single model. We show that proper initialization, as well 18 | as mutual modality learning, enhances single-modality models. As a result, we achieve 19 | state-of-the-art results in the Something-Something-v2 benchmark. 20 | 21 | ## Code changes 22 | 23 | This is a forked repository 24 | from the [code](https://github.com/mit-han-lab/temporal-shift-module) presented 25 | in the [TSM article](https://openaccess.thecvf.com/content_ICCV_2019/papers/Lin_TSM_Temporal_Shift_Module_for_Efficient_Video_Understanding_ICCV_2019_paper.pdf). 26 | 27 | Since this is a forked repository, we describe only differences in regard to the 28 | [original code](https://github.com/mit-han-lab/temporal-shift-module). 29 | 30 | ### Mutual Learning and Mutual Modality Learning launch 31 | 32 | In order to launch ordinary Mutual Learning with RGB inputs use commands 33 | 34 | ``` 35 | python main.py somethingv2 RGB,RGB --rank 0 --world_size 2 [other training parameters] 36 | python main.py somethingv2 RGB,RGB --rank 1 --world_size 2 [other training parameters] 37 | ``` 38 | 39 | In order to launch Mutual Mutual Learning with RGB-, Flow- and Diff-based models use commands 40 | 41 | ``` 42 | python main.py somethingv2 RGBDiff,Flow,RGB --rank 0 --world_size 3 [other training parameters] 43 | python main.py somethingv2 RGBDiff,Flow,RGB --rank 1 --world_size 3 [other training parameters] 44 | python main.py somethingv2 RGBDiff,Flow,RGB --rank 2 --world_size 3 [other training parameters] 45 | ``` 46 | 47 | Use `--gpus` and `--init_method` arguments to specify devices for each model and/or launch multi-node training. 48 | 49 | Use `--tune_from` argument to specify the initialization model (the same way as before). 50 | 51 | Use `--random_sample` argument to turn on random sampling strategy during training. 52 | 53 | Use `--dense_length` argument to specify the number of frames for the dense sampling (it also affects random sampling). 54 | 55 | Thus, these are the minimum commands to reproduce MML results: 56 | 57 | ``` 58 | python main.py somethingv2 RGB [other training parameters] 59 | 60 | python main.py somethingv2 RGB,Flow --rank 0 --world_size 2 --tune_from $PATH_TO_MODEL_FROM_THE_FIRST_STEP$ [other training parameters] 61 | python main.py somethingv2 RGB,Flow --rank 1 --world_size 2 --tune_from $PATH_TO_MODEL_FROM_THE_FIRST_STEP$ [other training parameters] 62 | ``` 63 | 64 | ### Testing part 65 | 66 | The testing script can be launched as before. There are several new functions available during testing. 67 | 68 | Use `--random_sample` argument to use both uniform and dense sampling during testing. 69 | 70 | Use `--dense_length` argument to specify the number of frames for the dense sampling (it also affects random sampling). 71 | 72 | Use `--dense_number` argument to specify the number of dense samplings (it also affects random sampling). 73 | 74 | Use `--twice_sample` argument to use two uniform samplings during testings that are shifted by the half-period (it also affects random sampling). 75 | 76 | ## Citation 77 | 78 | ``` 79 | @article{komkov2020mml, 80 | title={Mutual Modality Learning for Video Action Classification}, 81 | author={Komkov, Stepan and Dzabraev, Maksim and Petiushko, Aleksandr}, 82 | journal={arXiv preprint arXiv:2011.02543}, 83 | year={2020} 84 | } 85 | ``` 86 | -------------------------------------------------------------------------------- /archs/__init__.py: -------------------------------------------------------------------------------- 1 | from .bn_inception import * 2 | -------------------------------------------------------------------------------- /archs/bn_inception.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division, absolute_import 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | import torch.nn.functional as F 6 | 7 | 8 | __all__ = ['BNInception', 'bninception'] 9 | 10 | pretrained_settings = { 11 | 'bninception': { 12 | 'imagenet': { 13 | 'url': 'https://www.dropbox.com/s/3cvod6kzwluijcw/BNInception-9baff57459f5a1744.pth?dl=1', 14 | 'input_space': 'BGR', 15 | 'input_size': 224, 16 | 'input_range': [0, 255], 17 | 'mean': [104, 117, 128], 18 | 'std': [1, 1, 1], 19 | 'num_classes': 1000 20 | }, 21 | 'kinetics': { 22 | 'url': 'https://www.dropbox.com/s/gx4u7itoyygix0c/BNInceptionKinetics-47f0695e.pth?dl=1', 23 | 'input_space': 'BGR', 24 | 'input_size': 224, 25 | 'input_range': [0, 255], 26 | 'mean': [104, 117, 128], # [96.29023126, 103.16065604, 110.63666788] 27 | 'std': [1, 1, 1], # [40.02898126, 37.88248729, 38.7568578], 28 | 'num_classes': 400 29 | } 30 | }, 31 | } 32 | 33 | 34 | class BNInception(nn.Module): 35 | def __init__(self, num_classes=1000): 36 | super(BNInception, self).__init__() 37 | inplace = True 38 | self._build_features(inplace, num_classes) 39 | 40 | def forward(self, x): 41 | # if self.input_space == 'BGR': 42 | # assert len(x.size()) == 4 43 | # x = x[:, (2, 1, 0)] 44 | x = self.features(x) 45 | x = self.logits(x) 46 | return x 47 | 48 | def features(self, x): 49 | # stage1 50 | pool1_3x3_s2_out = self._temporal_forward_wrap(self._block_1, 0)(x) 51 | # stage2 52 | pool2_3x3_s2_out = self._temporal_forward_wrap(self._block_2, 1)(pool1_3x3_s2_out) 53 | 54 | # stage3 55 | inception_3a_output_out = self._temporal_forward_wrap(self._block_3a, 2)(pool2_3x3_s2_out) 56 | inception_3b_output_out = self._temporal_forward_wrap(self._block_3b, 3)(inception_3a_output_out) 57 | inception_3c_output_out = self._temporal_forward_wrap(self._block_3c, 4)(inception_3b_output_out) 58 | 59 | inception_4a_output_out = self._temporal_forward_wrap(self._block_4a, 5)(inception_3c_output_out) 60 | inception_4b_output_out = self._temporal_forward_wrap(self._block_4b, 6)(inception_4a_output_out) 61 | inception_4c_output_out = self._temporal_forward_wrap(self._block_4c, 7)(inception_4b_output_out) 62 | inception_4d_output_out = self._temporal_forward_wrap(self._block_4d, 8)(inception_4c_output_out) 63 | inception_4e_output_out = self._temporal_forward_wrap(self._block_4e, 9)(inception_4d_output_out) 64 | 65 | inception_5a_output_out = self._temporal_forward_wrap(self._block_5a, 10)(inception_4e_output_out) 66 | inception_5b_output_out = self._temporal_forward_wrap(self._block_5b, 11)(inception_5a_output_out) 67 | 68 | return inception_5b_output_out 69 | 70 | def logits(self, features): 71 | x = self.global_pool(features) 72 | x = x.view(x.size(0), -1) 73 | x = self.fc(x) 74 | return x 75 | 76 | def build_temporal_ops(self, n_segment, is_temporal_shift='0' * 12, shift_div=8): 77 | # must call after loading weights 78 | self.n_segment = n_segment 79 | self.residual = 'res' in is_temporal_shift 80 | if self.residual: 81 | print('=> Using residual shift functions...') 82 | if is_temporal_shift in ['block', 'blockres']: 83 | self.is_temporal_shift = '1' * 12 84 | else: 85 | self.is_temporal_shift = is_temporal_shift 86 | self.is_temporal_shift = '0' + self.is_temporal_shift[1:] # image input does not shift 87 | 88 | assert len(self.is_temporal_shift) == 12 89 | 90 | print('=> Injecting temporal shift with mask {}'.format(self.is_temporal_shift)) 91 | self.fold_div = shift_div 92 | print('=> Using fold div: {}'.format(self.fold_div)) 93 | 94 | def _temporal_forward_wrap(self, layer_func, index): 95 | if hasattr(self, 'is_temporal_shift') and self.is_temporal_shift[index] == '1': # run temporal shuffling 96 | from ops.temporal_shift import TemporalShift 97 | def wrapped_func(x, is_residual, n_segment, fold_div): 98 | if is_residual: 99 | x_shift = TemporalShift.shift(x, n_segment, fold_div=fold_div) 100 | return F.relu(x + layer_func(x_shift)) 101 | else: 102 | x = TemporalShift.shift(x, n_segment, fold_div=fold_div) 103 | return layer_func(x) 104 | from functools import partial 105 | return partial(wrapped_func, is_residual=self.residual, n_segment=self.n_segment, 106 | fold_div=self.fold_div) 107 | else: 108 | return layer_func 109 | 110 | def _block_1(self, x): 111 | conv1_7x7_s2_out = self.conv1_7x7_s2(x) 112 | conv1_7x7_s2_bn_out = self.conv1_7x7_s2_bn(conv1_7x7_s2_out) 113 | conv1_relu_7x7_out = self.conv1_relu_7x7(conv1_7x7_s2_bn_out) 114 | pool1_3x3_s2_out = self.pool1_3x3_s2(conv1_7x7_s2_bn_out) 115 | return pool1_3x3_s2_out 116 | 117 | def _block_2(self, x): 118 | conv2_3x3_reduce_out = self.conv2_3x3_reduce(x) 119 | conv2_3x3_reduce_bn_out = self.conv2_3x3_reduce_bn(conv2_3x3_reduce_out) 120 | conv2_relu_3x3_reduce_out = self.conv2_relu_3x3_reduce(conv2_3x3_reduce_bn_out) 121 | conv2_3x3_out = self.conv2_3x3(conv2_3x3_reduce_bn_out) 122 | conv2_3x3_bn_out = self.conv2_3x3_bn(conv2_3x3_out) 123 | conv2_relu_3x3_out = self.conv2_relu_3x3(conv2_3x3_bn_out) 124 | pool2_3x3_s2_out = self.pool2_3x3_s2(conv2_3x3_bn_out) 125 | return pool2_3x3_s2_out 126 | 127 | def _block_3a(self, pool2_3x3_s2_out): 128 | inception_3a_1x1_out = self.inception_3a_1x1(pool2_3x3_s2_out) 129 | inception_3a_1x1_bn_out = self.inception_3a_1x1_bn(inception_3a_1x1_out) 130 | inception_3a_relu_1x1_out = self.inception_3a_relu_1x1(inception_3a_1x1_bn_out) 131 | inception_3a_3x3_reduce_out = self.inception_3a_3x3_reduce(pool2_3x3_s2_out) 132 | inception_3a_3x3_reduce_bn_out = self.inception_3a_3x3_reduce_bn(inception_3a_3x3_reduce_out) 133 | inception_3a_relu_3x3_reduce_out = self.inception_3a_relu_3x3_reduce(inception_3a_3x3_reduce_bn_out) 134 | inception_3a_3x3_out = self.inception_3a_3x3(inception_3a_3x3_reduce_bn_out) 135 | inception_3a_3x3_bn_out = self.inception_3a_3x3_bn(inception_3a_3x3_out) 136 | inception_3a_relu_3x3_out = self.inception_3a_relu_3x3(inception_3a_3x3_bn_out) 137 | inception_3a_double_3x3_reduce_out = self.inception_3a_double_3x3_reduce(pool2_3x3_s2_out) 138 | inception_3a_double_3x3_reduce_bn_out = self.inception_3a_double_3x3_reduce_bn( 139 | inception_3a_double_3x3_reduce_out) 140 | inception_3a_relu_double_3x3_reduce_out = self.inception_3a_relu_double_3x3_reduce( 141 | inception_3a_double_3x3_reduce_bn_out) 142 | inception_3a_double_3x3_1_out = self.inception_3a_double_3x3_1(inception_3a_double_3x3_reduce_bn_out) 143 | inception_3a_double_3x3_1_bn_out = self.inception_3a_double_3x3_1_bn(inception_3a_double_3x3_1_out) 144 | inception_3a_relu_double_3x3_1_out = self.inception_3a_relu_double_3x3_1(inception_3a_double_3x3_1_bn_out) 145 | inception_3a_double_3x3_2_out = self.inception_3a_double_3x3_2(inception_3a_double_3x3_1_bn_out) 146 | inception_3a_double_3x3_2_bn_out = self.inception_3a_double_3x3_2_bn(inception_3a_double_3x3_2_out) 147 | inception_3a_relu_double_3x3_2_out = self.inception_3a_relu_double_3x3_2(inception_3a_double_3x3_2_bn_out) 148 | inception_3a_pool_out = self.inception_3a_pool(pool2_3x3_s2_out) 149 | inception_3a_pool_proj_out = self.inception_3a_pool_proj(inception_3a_pool_out) 150 | inception_3a_pool_proj_bn_out = self.inception_3a_pool_proj_bn(inception_3a_pool_proj_out) 151 | inception_3a_relu_pool_proj_out = self.inception_3a_relu_pool_proj(inception_3a_pool_proj_bn_out) 152 | inception_3a_output_out = torch.cat( 153 | [inception_3a_1x1_bn_out, inception_3a_3x3_bn_out, inception_3a_double_3x3_2_bn_out, 154 | inception_3a_pool_proj_bn_out], 1) 155 | return inception_3a_output_out 156 | 157 | def _block_3b(self, inception_3a_output_out): 158 | inception_3b_1x1_out = self.inception_3b_1x1(inception_3a_output_out) 159 | inception_3b_1x1_bn_out = self.inception_3b_1x1_bn(inception_3b_1x1_out) 160 | inception_3b_relu_1x1_out = self.inception_3b_relu_1x1(inception_3b_1x1_bn_out) 161 | inception_3b_3x3_reduce_out = self.inception_3b_3x3_reduce(inception_3a_output_out) 162 | inception_3b_3x3_reduce_bn_out = self.inception_3b_3x3_reduce_bn(inception_3b_3x3_reduce_out) 163 | inception_3b_relu_3x3_reduce_out = self.inception_3b_relu_3x3_reduce(inception_3b_3x3_reduce_bn_out) 164 | inception_3b_3x3_out = self.inception_3b_3x3(inception_3b_3x3_reduce_bn_out) 165 | inception_3b_3x3_bn_out = self.inception_3b_3x3_bn(inception_3b_3x3_out) 166 | inception_3b_relu_3x3_out = self.inception_3b_relu_3x3(inception_3b_3x3_bn_out) 167 | inception_3b_double_3x3_reduce_out = self.inception_3b_double_3x3_reduce(inception_3a_output_out) 168 | inception_3b_double_3x3_reduce_bn_out = self.inception_3b_double_3x3_reduce_bn( 169 | inception_3b_double_3x3_reduce_out) 170 | inception_3b_relu_double_3x3_reduce_out = self.inception_3b_relu_double_3x3_reduce( 171 | inception_3b_double_3x3_reduce_bn_out) 172 | inception_3b_double_3x3_1_out = self.inception_3b_double_3x3_1(inception_3b_double_3x3_reduce_bn_out) 173 | inception_3b_double_3x3_1_bn_out = self.inception_3b_double_3x3_1_bn(inception_3b_double_3x3_1_out) 174 | inception_3b_relu_double_3x3_1_out = self.inception_3b_relu_double_3x3_1(inception_3b_double_3x3_1_bn_out) 175 | inception_3b_double_3x3_2_out = self.inception_3b_double_3x3_2(inception_3b_double_3x3_1_bn_out) 176 | inception_3b_double_3x3_2_bn_out = self.inception_3b_double_3x3_2_bn(inception_3b_double_3x3_2_out) 177 | inception_3b_relu_double_3x3_2_out = self.inception_3b_relu_double_3x3_2(inception_3b_double_3x3_2_bn_out) 178 | inception_3b_pool_out = self.inception_3b_pool(inception_3a_output_out) 179 | inception_3b_pool_proj_out = self.inception_3b_pool_proj(inception_3b_pool_out) 180 | inception_3b_pool_proj_bn_out = self.inception_3b_pool_proj_bn(inception_3b_pool_proj_out) 181 | inception_3b_relu_pool_proj_out = self.inception_3b_relu_pool_proj(inception_3b_pool_proj_bn_out) 182 | inception_3b_output_out = torch.cat( 183 | [inception_3b_1x1_bn_out, inception_3b_3x3_bn_out, inception_3b_double_3x3_2_bn_out, 184 | inception_3b_pool_proj_bn_out], 1) 185 | return inception_3b_output_out 186 | 187 | def _block_3c(self, inception_3b_output_out): 188 | inception_3c_3x3_reduce_out = self.inception_3c_3x3_reduce(inception_3b_output_out) 189 | inception_3c_3x3_reduce_bn_out = self.inception_3c_3x3_reduce_bn(inception_3c_3x3_reduce_out) 190 | inception_3c_relu_3x3_reduce_out = self.inception_3c_relu_3x3_reduce(inception_3c_3x3_reduce_bn_out) 191 | inception_3c_3x3_out = self.inception_3c_3x3(inception_3c_3x3_reduce_bn_out) 192 | inception_3c_3x3_bn_out = self.inception_3c_3x3_bn(inception_3c_3x3_out) 193 | inception_3c_relu_3x3_out = self.inception_3c_relu_3x3(inception_3c_3x3_bn_out) 194 | inception_3c_double_3x3_reduce_out = self.inception_3c_double_3x3_reduce(inception_3b_output_out) 195 | inception_3c_double_3x3_reduce_bn_out = self.inception_3c_double_3x3_reduce_bn( 196 | inception_3c_double_3x3_reduce_out) 197 | inception_3c_relu_double_3x3_reduce_out = self.inception_3c_relu_double_3x3_reduce( 198 | inception_3c_double_3x3_reduce_bn_out) 199 | inception_3c_double_3x3_1_out = self.inception_3c_double_3x3_1(inception_3c_double_3x3_reduce_bn_out) 200 | inception_3c_double_3x3_1_bn_out = self.inception_3c_double_3x3_1_bn(inception_3c_double_3x3_1_out) 201 | inception_3c_relu_double_3x3_1_out = self.inception_3c_relu_double_3x3_1(inception_3c_double_3x3_1_bn_out) 202 | inception_3c_double_3x3_2_out = self.inception_3c_double_3x3_2(inception_3c_double_3x3_1_bn_out) 203 | inception_3c_double_3x3_2_bn_out = self.inception_3c_double_3x3_2_bn(inception_3c_double_3x3_2_out) 204 | inception_3c_relu_double_3x3_2_out = self.inception_3c_relu_double_3x3_2(inception_3c_double_3x3_2_bn_out) 205 | inception_3c_pool_out = self.inception_3c_pool(inception_3b_output_out) 206 | inception_3c_output_out = torch.cat( 207 | [inception_3c_3x3_bn_out, inception_3c_double_3x3_2_bn_out, inception_3c_pool_out], 1) 208 | return inception_3c_output_out 209 | 210 | def _block_4a(self, inception_3c_output_out): 211 | inception_4a_1x1_out = self.inception_4a_1x1(inception_3c_output_out) 212 | inception_4a_1x1_bn_out = self.inception_4a_1x1_bn(inception_4a_1x1_out) 213 | inception_4a_relu_1x1_out = self.inception_4a_relu_1x1(inception_4a_1x1_bn_out) 214 | inception_4a_3x3_reduce_out = self.inception_4a_3x3_reduce(inception_3c_output_out) 215 | inception_4a_3x3_reduce_bn_out = self.inception_4a_3x3_reduce_bn(inception_4a_3x3_reduce_out) 216 | inception_4a_relu_3x3_reduce_out = self.inception_4a_relu_3x3_reduce(inception_4a_3x3_reduce_bn_out) 217 | inception_4a_3x3_out = self.inception_4a_3x3(inception_4a_3x3_reduce_bn_out) 218 | inception_4a_3x3_bn_out = self.inception_4a_3x3_bn(inception_4a_3x3_out) 219 | inception_4a_relu_3x3_out = self.inception_4a_relu_3x3(inception_4a_3x3_bn_out) 220 | inception_4a_double_3x3_reduce_out = self.inception_4a_double_3x3_reduce(inception_3c_output_out) 221 | inception_4a_double_3x3_reduce_bn_out = self.inception_4a_double_3x3_reduce_bn( 222 | inception_4a_double_3x3_reduce_out) 223 | inception_4a_relu_double_3x3_reduce_out = self.inception_4a_relu_double_3x3_reduce( 224 | inception_4a_double_3x3_reduce_bn_out) 225 | inception_4a_double_3x3_1_out = self.inception_4a_double_3x3_1(inception_4a_double_3x3_reduce_bn_out) 226 | inception_4a_double_3x3_1_bn_out = self.inception_4a_double_3x3_1_bn(inception_4a_double_3x3_1_out) 227 | inception_4a_relu_double_3x3_1_out = self.inception_4a_relu_double_3x3_1(inception_4a_double_3x3_1_bn_out) 228 | inception_4a_double_3x3_2_out = self.inception_4a_double_3x3_2(inception_4a_double_3x3_1_bn_out) 229 | inception_4a_double_3x3_2_bn_out = self.inception_4a_double_3x3_2_bn(inception_4a_double_3x3_2_out) 230 | inception_4a_relu_double_3x3_2_out = self.inception_4a_relu_double_3x3_2(inception_4a_double_3x3_2_bn_out) 231 | inception_4a_pool_out = self.inception_4a_pool(inception_3c_output_out) 232 | inception_4a_pool_proj_out = self.inception_4a_pool_proj(inception_4a_pool_out) 233 | inception_4a_pool_proj_bn_out = self.inception_4a_pool_proj_bn(inception_4a_pool_proj_out) 234 | inception_4a_relu_pool_proj_out = self.inception_4a_relu_pool_proj(inception_4a_pool_proj_bn_out) 235 | inception_4a_output_out = torch.cat( 236 | [inception_4a_1x1_bn_out, inception_4a_3x3_bn_out, inception_4a_double_3x3_2_bn_out, 237 | inception_4a_pool_proj_bn_out], 1) 238 | return inception_4a_output_out 239 | 240 | def _block_4b(self, inception_4a_output_out): 241 | inception_4b_1x1_out = self.inception_4b_1x1(inception_4a_output_out) 242 | inception_4b_1x1_bn_out = self.inception_4b_1x1_bn(inception_4b_1x1_out) 243 | inception_4b_relu_1x1_out = self.inception_4b_relu_1x1(inception_4b_1x1_bn_out) 244 | inception_4b_3x3_reduce_out = self.inception_4b_3x3_reduce(inception_4a_output_out) 245 | inception_4b_3x3_reduce_bn_out = self.inception_4b_3x3_reduce_bn(inception_4b_3x3_reduce_out) 246 | inception_4b_relu_3x3_reduce_out = self.inception_4b_relu_3x3_reduce(inception_4b_3x3_reduce_bn_out) 247 | inception_4b_3x3_out = self.inception_4b_3x3(inception_4b_3x3_reduce_bn_out) 248 | inception_4b_3x3_bn_out = self.inception_4b_3x3_bn(inception_4b_3x3_out) 249 | inception_4b_relu_3x3_out = self.inception_4b_relu_3x3(inception_4b_3x3_bn_out) 250 | inception_4b_double_3x3_reduce_out = self.inception_4b_double_3x3_reduce(inception_4a_output_out) 251 | inception_4b_double_3x3_reduce_bn_out = self.inception_4b_double_3x3_reduce_bn( 252 | inception_4b_double_3x3_reduce_out) 253 | inception_4b_relu_double_3x3_reduce_out = self.inception_4b_relu_double_3x3_reduce( 254 | inception_4b_double_3x3_reduce_bn_out) 255 | inception_4b_double_3x3_1_out = self.inception_4b_double_3x3_1(inception_4b_double_3x3_reduce_bn_out) 256 | inception_4b_double_3x3_1_bn_out = self.inception_4b_double_3x3_1_bn(inception_4b_double_3x3_1_out) 257 | inception_4b_relu_double_3x3_1_out = self.inception_4b_relu_double_3x3_1(inception_4b_double_3x3_1_bn_out) 258 | inception_4b_double_3x3_2_out = self.inception_4b_double_3x3_2(inception_4b_double_3x3_1_bn_out) 259 | inception_4b_double_3x3_2_bn_out = self.inception_4b_double_3x3_2_bn(inception_4b_double_3x3_2_out) 260 | inception_4b_relu_double_3x3_2_out = self.inception_4b_relu_double_3x3_2(inception_4b_double_3x3_2_bn_out) 261 | inception_4b_pool_out = self.inception_4b_pool(inception_4a_output_out) 262 | inception_4b_pool_proj_out = self.inception_4b_pool_proj(inception_4b_pool_out) 263 | inception_4b_pool_proj_bn_out = self.inception_4b_pool_proj_bn(inception_4b_pool_proj_out) 264 | inception_4b_relu_pool_proj_out = self.inception_4b_relu_pool_proj(inception_4b_pool_proj_bn_out) 265 | inception_4b_output_out = torch.cat( 266 | [inception_4b_1x1_bn_out, inception_4b_3x3_bn_out, inception_4b_double_3x3_2_bn_out, 267 | inception_4b_pool_proj_bn_out], 1) 268 | return inception_4b_output_out 269 | 270 | def _block_4c(self, inception_4b_output_out): 271 | inception_4c_1x1_out = self.inception_4c_1x1(inception_4b_output_out) 272 | inception_4c_1x1_bn_out = self.inception_4c_1x1_bn(inception_4c_1x1_out) 273 | inception_4c_relu_1x1_out = self.inception_4c_relu_1x1(inception_4c_1x1_bn_out) 274 | inception_4c_3x3_reduce_out = self.inception_4c_3x3_reduce(inception_4b_output_out) 275 | inception_4c_3x3_reduce_bn_out = self.inception_4c_3x3_reduce_bn(inception_4c_3x3_reduce_out) 276 | inception_4c_relu_3x3_reduce_out = self.inception_4c_relu_3x3_reduce(inception_4c_3x3_reduce_bn_out) 277 | inception_4c_3x3_out = self.inception_4c_3x3(inception_4c_3x3_reduce_bn_out) 278 | inception_4c_3x3_bn_out = self.inception_4c_3x3_bn(inception_4c_3x3_out) 279 | inception_4c_relu_3x3_out = self.inception_4c_relu_3x3(inception_4c_3x3_bn_out) 280 | inception_4c_double_3x3_reduce_out = self.inception_4c_double_3x3_reduce(inception_4b_output_out) 281 | inception_4c_double_3x3_reduce_bn_out = self.inception_4c_double_3x3_reduce_bn( 282 | inception_4c_double_3x3_reduce_out) 283 | inception_4c_relu_double_3x3_reduce_out = self.inception_4c_relu_double_3x3_reduce( 284 | inception_4c_double_3x3_reduce_bn_out) 285 | inception_4c_double_3x3_1_out = self.inception_4c_double_3x3_1(inception_4c_double_3x3_reduce_bn_out) 286 | inception_4c_double_3x3_1_bn_out = self.inception_4c_double_3x3_1_bn(inception_4c_double_3x3_1_out) 287 | inception_4c_relu_double_3x3_1_out = self.inception_4c_relu_double_3x3_1(inception_4c_double_3x3_1_bn_out) 288 | inception_4c_double_3x3_2_out = self.inception_4c_double_3x3_2(inception_4c_double_3x3_1_bn_out) 289 | inception_4c_double_3x3_2_bn_out = self.inception_4c_double_3x3_2_bn(inception_4c_double_3x3_2_out) 290 | inception_4c_relu_double_3x3_2_out = self.inception_4c_relu_double_3x3_2(inception_4c_double_3x3_2_bn_out) 291 | inception_4c_pool_out = self.inception_4c_pool(inception_4b_output_out) 292 | inception_4c_pool_proj_out = self.inception_4c_pool_proj(inception_4c_pool_out) 293 | inception_4c_pool_proj_bn_out = self.inception_4c_pool_proj_bn(inception_4c_pool_proj_out) 294 | inception_4c_relu_pool_proj_out = self.inception_4c_relu_pool_proj(inception_4c_pool_proj_bn_out) 295 | inception_4c_output_out = torch.cat( 296 | [inception_4c_1x1_bn_out, inception_4c_3x3_bn_out, inception_4c_double_3x3_2_bn_out, 297 | inception_4c_pool_proj_bn_out], 1) 298 | return inception_4c_output_out 299 | 300 | def _block_4d(self, inception_4c_output_out): 301 | inception_4d_1x1_out = self.inception_4d_1x1(inception_4c_output_out) 302 | inception_4d_1x1_bn_out = self.inception_4d_1x1_bn(inception_4d_1x1_out) 303 | inception_4d_relu_1x1_out = self.inception_4d_relu_1x1(inception_4d_1x1_bn_out) 304 | inception_4d_3x3_reduce_out = self.inception_4d_3x3_reduce(inception_4c_output_out) 305 | inception_4d_3x3_reduce_bn_out = self.inception_4d_3x3_reduce_bn(inception_4d_3x3_reduce_out) 306 | inception_4d_relu_3x3_reduce_out = self.inception_4d_relu_3x3_reduce(inception_4d_3x3_reduce_bn_out) 307 | inception_4d_3x3_out = self.inception_4d_3x3(inception_4d_3x3_reduce_bn_out) 308 | inception_4d_3x3_bn_out = self.inception_4d_3x3_bn(inception_4d_3x3_out) 309 | inception_4d_relu_3x3_out = self.inception_4d_relu_3x3(inception_4d_3x3_bn_out) 310 | inception_4d_double_3x3_reduce_out = self.inception_4d_double_3x3_reduce(inception_4c_output_out) 311 | inception_4d_double_3x3_reduce_bn_out = self.inception_4d_double_3x3_reduce_bn( 312 | inception_4d_double_3x3_reduce_out) 313 | inception_4d_relu_double_3x3_reduce_out = self.inception_4d_relu_double_3x3_reduce( 314 | inception_4d_double_3x3_reduce_bn_out) 315 | inception_4d_double_3x3_1_out = self.inception_4d_double_3x3_1(inception_4d_double_3x3_reduce_bn_out) 316 | inception_4d_double_3x3_1_bn_out = self.inception_4d_double_3x3_1_bn(inception_4d_double_3x3_1_out) 317 | inception_4d_relu_double_3x3_1_out = self.inception_4d_relu_double_3x3_1(inception_4d_double_3x3_1_bn_out) 318 | inception_4d_double_3x3_2_out = self.inception_4d_double_3x3_2(inception_4d_double_3x3_1_bn_out) 319 | inception_4d_double_3x3_2_bn_out = self.inception_4d_double_3x3_2_bn(inception_4d_double_3x3_2_out) 320 | inception_4d_relu_double_3x3_2_out = self.inception_4d_relu_double_3x3_2(inception_4d_double_3x3_2_bn_out) 321 | inception_4d_pool_out = self.inception_4d_pool(inception_4c_output_out) 322 | inception_4d_pool_proj_out = self.inception_4d_pool_proj(inception_4d_pool_out) 323 | inception_4d_pool_proj_bn_out = self.inception_4d_pool_proj_bn(inception_4d_pool_proj_out) 324 | inception_4d_relu_pool_proj_out = self.inception_4d_relu_pool_proj(inception_4d_pool_proj_bn_out) 325 | inception_4d_output_out = torch.cat( 326 | [inception_4d_1x1_bn_out, inception_4d_3x3_bn_out, inception_4d_double_3x3_2_bn_out, 327 | inception_4d_pool_proj_bn_out], 1) 328 | return inception_4d_output_out 329 | 330 | def _block_4e(self, inception_4d_output_out): 331 | inception_4e_3x3_reduce_out = self.inception_4e_3x3_reduce(inception_4d_output_out) 332 | inception_4e_3x3_reduce_bn_out = self.inception_4e_3x3_reduce_bn(inception_4e_3x3_reduce_out) 333 | inception_4e_relu_3x3_reduce_out = self.inception_4e_relu_3x3_reduce(inception_4e_3x3_reduce_bn_out) 334 | inception_4e_3x3_out = self.inception_4e_3x3(inception_4e_3x3_reduce_bn_out) 335 | inception_4e_3x3_bn_out = self.inception_4e_3x3_bn(inception_4e_3x3_out) 336 | inception_4e_relu_3x3_out = self.inception_4e_relu_3x3(inception_4e_3x3_bn_out) 337 | inception_4e_double_3x3_reduce_out = self.inception_4e_double_3x3_reduce(inception_4d_output_out) 338 | inception_4e_double_3x3_reduce_bn_out = self.inception_4e_double_3x3_reduce_bn( 339 | inception_4e_double_3x3_reduce_out) 340 | inception_4e_relu_double_3x3_reduce_out = self.inception_4e_relu_double_3x3_reduce( 341 | inception_4e_double_3x3_reduce_bn_out) 342 | inception_4e_double_3x3_1_out = self.inception_4e_double_3x3_1(inception_4e_double_3x3_reduce_bn_out) 343 | inception_4e_double_3x3_1_bn_out = self.inception_4e_double_3x3_1_bn(inception_4e_double_3x3_1_out) 344 | inception_4e_relu_double_3x3_1_out = self.inception_4e_relu_double_3x3_1(inception_4e_double_3x3_1_bn_out) 345 | inception_4e_double_3x3_2_out = self.inception_4e_double_3x3_2(inception_4e_double_3x3_1_bn_out) 346 | inception_4e_double_3x3_2_bn_out = self.inception_4e_double_3x3_2_bn(inception_4e_double_3x3_2_out) 347 | inception_4e_relu_double_3x3_2_out = self.inception_4e_relu_double_3x3_2(inception_4e_double_3x3_2_bn_out) 348 | inception_4e_pool_out = self.inception_4e_pool(inception_4d_output_out) 349 | inception_4e_output_out = torch.cat( 350 | [inception_4e_3x3_bn_out, inception_4e_double_3x3_2_bn_out, inception_4e_pool_out], 1) 351 | return inception_4e_output_out 352 | 353 | def _block_5a(self, inception_4e_output_out): 354 | inception_5a_1x1_out = self.inception_5a_1x1(inception_4e_output_out) 355 | inception_5a_1x1_bn_out = self.inception_5a_1x1_bn(inception_5a_1x1_out) 356 | inception_5a_relu_1x1_out = self.inception_5a_relu_1x1(inception_5a_1x1_bn_out) 357 | inception_5a_3x3_reduce_out = self.inception_5a_3x3_reduce(inception_4e_output_out) 358 | inception_5a_3x3_reduce_bn_out = self.inception_5a_3x3_reduce_bn(inception_5a_3x3_reduce_out) 359 | inception_5a_relu_3x3_reduce_out = self.inception_5a_relu_3x3_reduce(inception_5a_3x3_reduce_bn_out) 360 | inception_5a_3x3_out = self.inception_5a_3x3(inception_5a_3x3_reduce_bn_out) 361 | inception_5a_3x3_bn_out = self.inception_5a_3x3_bn(inception_5a_3x3_out) 362 | inception_5a_relu_3x3_out = self.inception_5a_relu_3x3(inception_5a_3x3_bn_out) 363 | inception_5a_double_3x3_reduce_out = self.inception_5a_double_3x3_reduce(inception_4e_output_out) 364 | inception_5a_double_3x3_reduce_bn_out = self.inception_5a_double_3x3_reduce_bn( 365 | inception_5a_double_3x3_reduce_out) 366 | inception_5a_relu_double_3x3_reduce_out = self.inception_5a_relu_double_3x3_reduce( 367 | inception_5a_double_3x3_reduce_bn_out) 368 | inception_5a_double_3x3_1_out = self.inception_5a_double_3x3_1(inception_5a_double_3x3_reduce_bn_out) 369 | inception_5a_double_3x3_1_bn_out = self.inception_5a_double_3x3_1_bn(inception_5a_double_3x3_1_out) 370 | inception_5a_relu_double_3x3_1_out = self.inception_5a_relu_double_3x3_1(inception_5a_double_3x3_1_bn_out) 371 | inception_5a_double_3x3_2_out = self.inception_5a_double_3x3_2(inception_5a_double_3x3_1_bn_out) 372 | inception_5a_double_3x3_2_bn_out = self.inception_5a_double_3x3_2_bn(inception_5a_double_3x3_2_out) 373 | inception_5a_relu_double_3x3_2_out = self.inception_5a_relu_double_3x3_2(inception_5a_double_3x3_2_bn_out) 374 | inception_5a_pool_out = self.inception_5a_pool(inception_4e_output_out) 375 | inception_5a_pool_proj_out = self.inception_5a_pool_proj(inception_5a_pool_out) 376 | inception_5a_pool_proj_bn_out = self.inception_5a_pool_proj_bn(inception_5a_pool_proj_out) 377 | inception_5a_relu_pool_proj_out = self.inception_5a_relu_pool_proj(inception_5a_pool_proj_bn_out) 378 | inception_5a_output_out = torch.cat( 379 | [inception_5a_1x1_bn_out, inception_5a_3x3_bn_out, inception_5a_double_3x3_2_bn_out, 380 | inception_5a_pool_proj_bn_out], 1) 381 | return inception_5a_output_out 382 | 383 | def _block_5b(self, inception_5a_output_out): 384 | inception_5b_1x1_out = self.inception_5b_1x1(inception_5a_output_out) 385 | inception_5b_1x1_bn_out = self.inception_5b_1x1_bn(inception_5b_1x1_out) 386 | inception_5b_relu_1x1_out = self.inception_5b_relu_1x1(inception_5b_1x1_bn_out) 387 | inception_5b_3x3_reduce_out = self.inception_5b_3x3_reduce(inception_5a_output_out) 388 | inception_5b_3x3_reduce_bn_out = self.inception_5b_3x3_reduce_bn(inception_5b_3x3_reduce_out) 389 | inception_5b_relu_3x3_reduce_out = self.inception_5b_relu_3x3_reduce(inception_5b_3x3_reduce_bn_out) 390 | inception_5b_3x3_out = self.inception_5b_3x3(inception_5b_3x3_reduce_bn_out) 391 | inception_5b_3x3_bn_out = self.inception_5b_3x3_bn(inception_5b_3x3_out) 392 | inception_5b_relu_3x3_out = self.inception_5b_relu_3x3(inception_5b_3x3_bn_out) 393 | inception_5b_double_3x3_reduce_out = self.inception_5b_double_3x3_reduce(inception_5a_output_out) 394 | inception_5b_double_3x3_reduce_bn_out = self.inception_5b_double_3x3_reduce_bn( 395 | inception_5b_double_3x3_reduce_out) 396 | inception_5b_relu_double_3x3_reduce_out = self.inception_5b_relu_double_3x3_reduce( 397 | inception_5b_double_3x3_reduce_bn_out) 398 | inception_5b_double_3x3_1_out = self.inception_5b_double_3x3_1(inception_5b_double_3x3_reduce_bn_out) 399 | inception_5b_double_3x3_1_bn_out = self.inception_5b_double_3x3_1_bn(inception_5b_double_3x3_1_out) 400 | inception_5b_relu_double_3x3_1_out = self.inception_5b_relu_double_3x3_1(inception_5b_double_3x3_1_bn_out) 401 | inception_5b_double_3x3_2_out = self.inception_5b_double_3x3_2(inception_5b_double_3x3_1_bn_out) 402 | inception_5b_double_3x3_2_bn_out = self.inception_5b_double_3x3_2_bn(inception_5b_double_3x3_2_out) 403 | inception_5b_relu_double_3x3_2_out = self.inception_5b_relu_double_3x3_2(inception_5b_double_3x3_2_bn_out) 404 | inception_5b_pool_out = self.inception_5b_pool(inception_5a_output_out) 405 | inception_5b_pool_proj_out = self.inception_5b_pool_proj(inception_5b_pool_out) 406 | inception_5b_pool_proj_bn_out = self.inception_5b_pool_proj_bn(inception_5b_pool_proj_out) 407 | inception_5b_relu_pool_proj_out = self.inception_5b_relu_pool_proj(inception_5b_pool_proj_bn_out) 408 | inception_5b_output_out = torch.cat( 409 | [inception_5b_1x1_bn_out, inception_5b_3x3_bn_out, inception_5b_double_3x3_2_bn_out, 410 | inception_5b_pool_proj_bn_out], 1) 411 | return inception_5b_output_out 412 | 413 | def _build_features(self, inplace, num_classes): 414 | self.conv1_7x7_s2 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)) 415 | self.conv1_7x7_s2_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) 416 | self.conv1_relu_7x7 = nn.ReLU(inplace) 417 | self.pool1_3x3_s2 = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) 418 | self.conv2_3x3_reduce = nn.Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1)) 419 | self.conv2_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) 420 | self.conv2_relu_3x3_reduce = nn.ReLU(inplace) 421 | self.conv2_3x3 = nn.Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 422 | self.conv2_3x3_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) 423 | self.conv2_relu_3x3 = nn.ReLU(inplace) 424 | self.pool2_3x3_s2 = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) 425 | self.inception_3a_1x1 = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) 426 | self.inception_3a_1x1_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) 427 | self.inception_3a_relu_1x1 = nn.ReLU(inplace) 428 | self.inception_3a_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) 429 | self.inception_3a_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) 430 | self.inception_3a_relu_3x3_reduce = nn.ReLU(inplace) 431 | self.inception_3a_3x3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 432 | self.inception_3a_3x3_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) 433 | self.inception_3a_relu_3x3 = nn.ReLU(inplace) 434 | self.inception_3a_double_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) 435 | self.inception_3a_double_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) 436 | self.inception_3a_relu_double_3x3_reduce = nn.ReLU(inplace) 437 | self.inception_3a_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 438 | self.inception_3a_double_3x3_1_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 439 | self.inception_3a_relu_double_3x3_1 = nn.ReLU(inplace) 440 | self.inception_3a_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 441 | self.inception_3a_double_3x3_2_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 442 | self.inception_3a_relu_double_3x3_2 = nn.ReLU(inplace) 443 | self.inception_3a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 444 | self.inception_3a_pool_proj = nn.Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1)) 445 | self.inception_3a_pool_proj_bn = nn.BatchNorm2d(32, eps=1e-05, momentum=0.9, affine=True) 446 | self.inception_3a_relu_pool_proj = nn.ReLU(inplace) 447 | self.inception_3b_1x1 = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) 448 | self.inception_3b_1x1_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) 449 | self.inception_3b_relu_1x1 = nn.ReLU(inplace) 450 | self.inception_3b_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) 451 | self.inception_3b_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) 452 | self.inception_3b_relu_3x3_reduce = nn.ReLU(inplace) 453 | self.inception_3b_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 454 | self.inception_3b_3x3_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 455 | self.inception_3b_relu_3x3 = nn.ReLU(inplace) 456 | self.inception_3b_double_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) 457 | self.inception_3b_double_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) 458 | self.inception_3b_relu_double_3x3_reduce = nn.ReLU(inplace) 459 | self.inception_3b_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 460 | self.inception_3b_double_3x3_1_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 461 | self.inception_3b_relu_double_3x3_1 = nn.ReLU(inplace) 462 | self.inception_3b_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 463 | self.inception_3b_double_3x3_2_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 464 | self.inception_3b_relu_double_3x3_2 = nn.ReLU(inplace) 465 | self.inception_3b_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 466 | self.inception_3b_pool_proj = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) 467 | self.inception_3b_pool_proj_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) 468 | self.inception_3b_relu_pool_proj = nn.ReLU(inplace) 469 | self.inception_3c_3x3_reduce = nn.Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1)) 470 | self.inception_3c_3x3_reduce_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 471 | self.inception_3c_relu_3x3_reduce = nn.ReLU(inplace) 472 | self.inception_3c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 473 | self.inception_3c_3x3_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) 474 | self.inception_3c_relu_3x3 = nn.ReLU(inplace) 475 | self.inception_3c_double_3x3_reduce = nn.Conv2d(320, 64, kernel_size=(1, 1), stride=(1, 1)) 476 | self.inception_3c_double_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) 477 | self.inception_3c_relu_double_3x3_reduce = nn.ReLU(inplace) 478 | self.inception_3c_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 479 | self.inception_3c_double_3x3_1_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 480 | self.inception_3c_relu_double_3x3_1 = nn.ReLU(inplace) 481 | self.inception_3c_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 482 | self.inception_3c_double_3x3_2_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 483 | self.inception_3c_relu_double_3x3_2 = nn.ReLU(inplace) 484 | self.inception_3c_pool = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) 485 | self.inception_4a_1x1 = nn.Conv2d(576, 224, kernel_size=(1, 1), stride=(1, 1)) 486 | self.inception_4a_1x1_bn = nn.BatchNorm2d(224, eps=1e-05, momentum=0.9, affine=True) 487 | self.inception_4a_relu_1x1 = nn.ReLU(inplace) 488 | self.inception_4a_3x3_reduce = nn.Conv2d(576, 64, kernel_size=(1, 1), stride=(1, 1)) 489 | self.inception_4a_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) 490 | self.inception_4a_relu_3x3_reduce = nn.ReLU(inplace) 491 | self.inception_4a_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 492 | self.inception_4a_3x3_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 493 | self.inception_4a_relu_3x3 = nn.ReLU(inplace) 494 | self.inception_4a_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) 495 | self.inception_4a_double_3x3_reduce_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 496 | self.inception_4a_relu_double_3x3_reduce = nn.ReLU(inplace) 497 | self.inception_4a_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 498 | self.inception_4a_double_3x3_1_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 499 | self.inception_4a_relu_double_3x3_1 = nn.ReLU(inplace) 500 | self.inception_4a_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 501 | self.inception_4a_double_3x3_2_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 502 | self.inception_4a_relu_double_3x3_2 = nn.ReLU(inplace) 503 | self.inception_4a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 504 | self.inception_4a_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 505 | self.inception_4a_pool_proj_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 506 | self.inception_4a_relu_pool_proj = nn.ReLU(inplace) 507 | self.inception_4b_1x1 = nn.Conv2d(576, 192, kernel_size=(1, 1), stride=(1, 1)) 508 | self.inception_4b_1x1_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) 509 | self.inception_4b_relu_1x1 = nn.ReLU(inplace) 510 | self.inception_4b_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) 511 | self.inception_4b_3x3_reduce_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 512 | self.inception_4b_relu_3x3_reduce = nn.ReLU(inplace) 513 | self.inception_4b_3x3 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 514 | self.inception_4b_3x3_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 515 | self.inception_4b_relu_3x3 = nn.ReLU(inplace) 516 | self.inception_4b_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) 517 | self.inception_4b_double_3x3_reduce_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 518 | self.inception_4b_relu_double_3x3_reduce = nn.ReLU(inplace) 519 | self.inception_4b_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 520 | self.inception_4b_double_3x3_1_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 521 | self.inception_4b_relu_double_3x3_1 = nn.ReLU(inplace) 522 | self.inception_4b_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 523 | self.inception_4b_double_3x3_2_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 524 | self.inception_4b_relu_double_3x3_2 = nn.ReLU(inplace) 525 | self.inception_4b_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 526 | self.inception_4b_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 527 | self.inception_4b_pool_proj_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 528 | self.inception_4b_relu_pool_proj = nn.ReLU(inplace) 529 | self.inception_4c_1x1 = nn.Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1)) 530 | self.inception_4c_1x1_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) 531 | self.inception_4c_relu_1x1 = nn.ReLU(inplace) 532 | self.inception_4c_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 533 | self.inception_4c_3x3_reduce_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 534 | self.inception_4c_relu_3x3_reduce = nn.ReLU(inplace) 535 | self.inception_4c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 536 | self.inception_4c_3x3_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) 537 | self.inception_4c_relu_3x3 = nn.ReLU(inplace) 538 | self.inception_4c_double_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 539 | self.inception_4c_double_3x3_reduce_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 540 | self.inception_4c_relu_double_3x3_reduce = nn.ReLU(inplace) 541 | self.inception_4c_double_3x3_1 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 542 | self.inception_4c_double_3x3_1_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) 543 | self.inception_4c_relu_double_3x3_1 = nn.ReLU(inplace) 544 | self.inception_4c_double_3x3_2 = nn.Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 545 | self.inception_4c_double_3x3_2_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) 546 | self.inception_4c_relu_double_3x3_2 = nn.ReLU(inplace) 547 | self.inception_4c_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 548 | self.inception_4c_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 549 | self.inception_4c_pool_proj_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 550 | self.inception_4c_relu_pool_proj = nn.ReLU(inplace) 551 | self.inception_4d_1x1 = nn.Conv2d(608, 96, kernel_size=(1, 1), stride=(1, 1)) 552 | self.inception_4d_1x1_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 553 | self.inception_4d_relu_1x1 = nn.ReLU(inplace) 554 | self.inception_4d_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) 555 | self.inception_4d_3x3_reduce_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 556 | self.inception_4d_relu_3x3_reduce = nn.ReLU(inplace) 557 | self.inception_4d_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 558 | self.inception_4d_3x3_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) 559 | self.inception_4d_relu_3x3 = nn.ReLU(inplace) 560 | self.inception_4d_double_3x3_reduce = nn.Conv2d(608, 160, kernel_size=(1, 1), stride=(1, 1)) 561 | self.inception_4d_double_3x3_reduce_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) 562 | self.inception_4d_relu_double_3x3_reduce = nn.ReLU(inplace) 563 | self.inception_4d_double_3x3_1 = nn.Conv2d(160, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 564 | self.inception_4d_double_3x3_1_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) 565 | self.inception_4d_relu_double_3x3_1 = nn.ReLU(inplace) 566 | self.inception_4d_double_3x3_2 = nn.Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 567 | self.inception_4d_double_3x3_2_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) 568 | self.inception_4d_relu_double_3x3_2 = nn.ReLU(inplace) 569 | self.inception_4d_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 570 | self.inception_4d_pool_proj = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) 571 | self.inception_4d_pool_proj_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 572 | self.inception_4d_relu_pool_proj = nn.ReLU(inplace) 573 | self.inception_4e_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) 574 | self.inception_4e_3x3_reduce_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 575 | self.inception_4e_relu_3x3_reduce = nn.ReLU(inplace) 576 | self.inception_4e_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 577 | self.inception_4e_3x3_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) 578 | self.inception_4e_relu_3x3 = nn.ReLU(inplace) 579 | self.inception_4e_double_3x3_reduce = nn.Conv2d(608, 192, kernel_size=(1, 1), stride=(1, 1)) 580 | self.inception_4e_double_3x3_reduce_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) 581 | self.inception_4e_relu_double_3x3_reduce = nn.ReLU(inplace) 582 | self.inception_4e_double_3x3_1 = nn.Conv2d(192, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 583 | self.inception_4e_double_3x3_1_bn = nn.BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True) 584 | self.inception_4e_relu_double_3x3_1 = nn.ReLU(inplace) 585 | self.inception_4e_double_3x3_2 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 586 | self.inception_4e_double_3x3_2_bn = nn.BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True) 587 | self.inception_4e_relu_double_3x3_2 = nn.ReLU(inplace) 588 | self.inception_4e_pool = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) 589 | self.inception_5a_1x1 = nn.Conv2d(1056, 352, kernel_size=(1, 1), stride=(1, 1)) 590 | self.inception_5a_1x1_bn = nn.BatchNorm2d(352, eps=1e-05, momentum=0.9, affine=True) 591 | self.inception_5a_relu_1x1 = nn.ReLU(inplace) 592 | self.inception_5a_3x3_reduce = nn.Conv2d(1056, 192, kernel_size=(1, 1), stride=(1, 1)) 593 | self.inception_5a_3x3_reduce_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) 594 | self.inception_5a_relu_3x3_reduce = nn.ReLU(inplace) 595 | self.inception_5a_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 596 | self.inception_5a_3x3_bn = nn.BatchNorm2d(320, eps=1e-05, momentum=0.9, affine=True) 597 | self.inception_5a_relu_3x3 = nn.ReLU(inplace) 598 | self.inception_5a_double_3x3_reduce = nn.Conv2d(1056, 160, kernel_size=(1, 1), stride=(1, 1)) 599 | self.inception_5a_double_3x3_reduce_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) 600 | self.inception_5a_relu_double_3x3_reduce = nn.ReLU(inplace) 601 | self.inception_5a_double_3x3_1 = nn.Conv2d(160, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 602 | self.inception_5a_double_3x3_1_bn = nn.BatchNorm2d(224, eps=1e-05, momentum=0.9, affine=True) 603 | self.inception_5a_relu_double_3x3_1 = nn.ReLU(inplace) 604 | self.inception_5a_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 605 | self.inception_5a_double_3x3_2_bn = nn.BatchNorm2d(224, eps=1e-05, momentum=0.9, affine=True) 606 | self.inception_5a_relu_double_3x3_2 = nn.ReLU(inplace) 607 | self.inception_5a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 608 | self.inception_5a_pool_proj = nn.Conv2d(1056, 128, kernel_size=(1, 1), stride=(1, 1)) 609 | self.inception_5a_pool_proj_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 610 | self.inception_5a_relu_pool_proj = nn.ReLU(inplace) 611 | self.inception_5b_1x1 = nn.Conv2d(1024, 352, kernel_size=(1, 1), stride=(1, 1)) 612 | self.inception_5b_1x1_bn = nn.BatchNorm2d(352, eps=1e-05, momentum=0.9, affine=True) 613 | self.inception_5b_relu_1x1 = nn.ReLU(inplace) 614 | self.inception_5b_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1)) 615 | self.inception_5b_3x3_reduce_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) 616 | self.inception_5b_relu_3x3_reduce = nn.ReLU(inplace) 617 | self.inception_5b_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 618 | self.inception_5b_3x3_bn = nn.BatchNorm2d(320, eps=1e-05, momentum=0.9, affine=True) 619 | self.inception_5b_relu_3x3 = nn.ReLU(inplace) 620 | self.inception_5b_double_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1)) 621 | self.inception_5b_double_3x3_reduce_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) 622 | self.inception_5b_relu_double_3x3_reduce = nn.ReLU(inplace) 623 | self.inception_5b_double_3x3_1 = nn.Conv2d(192, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 624 | self.inception_5b_double_3x3_1_bn = nn.BatchNorm2d(224, eps=1e-05, momentum=0.9, affine=True) 625 | self.inception_5b_relu_double_3x3_1 = nn.ReLU(inplace) 626 | self.inception_5b_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 627 | self.inception_5b_double_3x3_2_bn = nn.BatchNorm2d(224, eps=1e-05, momentum=0.9, affine=True) 628 | self.inception_5b_relu_double_3x3_2 = nn.ReLU(inplace) 629 | self.inception_5b_pool = nn.MaxPool2d((3, 3), stride=(1, 1), padding=(1, 1), dilation=(1, 1), ceil_mode=True) 630 | self.inception_5b_pool_proj = nn.Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1)) 631 | self.inception_5b_pool_proj_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 632 | self.inception_5b_relu_pool_proj = nn.ReLU(inplace) 633 | self.global_pool = nn.AvgPool2d(7, stride=1, padding=0, ceil_mode=True, count_include_pad=True) 634 | self.fc = nn.Linear(1024, num_classes) 635 | 636 | 637 | def bninception(pretrained='imagenet'): 638 | r"""BNInception model architecture from `_ paper. 639 | """ 640 | if pretrained is not None: 641 | print('=> Loading from pretrained model: {}'.format(pretrained)) 642 | settings = pretrained_settings['bninception'][pretrained] 643 | num_classes = settings['num_classes'] 644 | model = BNInception(num_classes=num_classes) 645 | model.load_state_dict(model_zoo.load_url(settings['url'])) 646 | model.input_space = settings['input_space'] 647 | model.input_size = settings['input_size'] 648 | model.input_range = settings['input_range'] 649 | model.mean = settings['mean'] 650 | model.std = settings['std'] 651 | else: 652 | raise NotImplementedError 653 | return model 654 | 655 | 656 | if __name__ == '__main__': 657 | model = bninception() 658 | -------------------------------------------------------------------------------- /archs/mobilenet_v2.py: -------------------------------------------------------------------------------- 1 | # Code adapted from https://github.com/tonylins/pytorch-mobilenet-v2 2 | 3 | import torch.nn as nn 4 | import math 5 | 6 | 7 | def conv_bn(inp, oup, stride): 8 | return nn.Sequential( 9 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 10 | nn.BatchNorm2d(oup), 11 | nn.ReLU6(inplace=True) 12 | ) 13 | 14 | 15 | def conv_1x1_bn(inp, oup): 16 | return nn.Sequential( 17 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 18 | nn.BatchNorm2d(oup), 19 | nn.ReLU6(inplace=True) 20 | ) 21 | 22 | 23 | def make_divisible(x, divisible_by=8): 24 | import numpy as np 25 | return int(np.ceil(x * 1. / divisible_by) * divisible_by) 26 | 27 | 28 | class InvertedResidual(nn.Module): 29 | def __init__(self, inp, oup, stride, expand_ratio): 30 | super(InvertedResidual, self).__init__() 31 | self.stride = stride 32 | assert stride in [1, 2] 33 | 34 | hidden_dim = int(inp * expand_ratio) 35 | self.use_res_connect = self.stride == 1 and inp == oup 36 | 37 | if expand_ratio == 1: 38 | self.conv = nn.Sequential( 39 | # dw 40 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 41 | nn.BatchNorm2d(hidden_dim), 42 | nn.ReLU6(inplace=True), 43 | # pw-linear 44 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 45 | nn.BatchNorm2d(oup), 46 | ) 47 | else: 48 | self.conv = nn.Sequential( 49 | # pw 50 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 51 | nn.BatchNorm2d(hidden_dim), 52 | nn.ReLU6(inplace=True), 53 | # dw 54 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 55 | nn.BatchNorm2d(hidden_dim), 56 | nn.ReLU6(inplace=True), 57 | # pw-linear 58 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 59 | nn.BatchNorm2d(oup), 60 | ) 61 | 62 | def forward(self, x): 63 | if self.use_res_connect: 64 | return x + self.conv(x) 65 | else: 66 | return self.conv(x) 67 | 68 | 69 | class MobileNetV2(nn.Module): 70 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 71 | super(MobileNetV2, self).__init__() 72 | block = InvertedResidual 73 | input_channel = 32 74 | last_channel = 1280 75 | interverted_residual_setting = [ 76 | # t, c, n, s 77 | [1, 16, 1, 1], 78 | [6, 24, 2, 2], 79 | [6, 32, 3, 2], 80 | [6, 64, 4, 2], 81 | [6, 96, 3, 1], 82 | [6, 160, 3, 2], 83 | [6, 320, 1, 1], 84 | ] 85 | 86 | # building first layer 87 | assert input_size % 32 == 0 88 | # input_channel = make_divisible(input_channel * width_mult) # first channel is always 32! 89 | self.last_channel = make_divisible(last_channel * width_mult) if width_mult > 1.0 else last_channel 90 | self.features = [conv_bn(3, input_channel, 2)] 91 | # building inverted residual blocks 92 | for t, c, n, s in interverted_residual_setting: 93 | output_channel = make_divisible(c * width_mult) if t > 1 else c 94 | for i in range(n): 95 | if i == 0: 96 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 97 | else: 98 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 99 | input_channel = output_channel 100 | # building last several layers 101 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 102 | # make it nn.Sequential 103 | self.features = nn.Sequential(*self.features) 104 | 105 | # building classifier 106 | self.classifier = nn.Linear(self.last_channel, n_class) 107 | 108 | self._initialize_weights() 109 | 110 | def forward(self, x): 111 | x = self.features(x) 112 | x = x.mean(3).mean(2) 113 | x = self.classifier(x) 114 | return x 115 | 116 | def _initialize_weights(self): 117 | for m in self.modules(): 118 | if isinstance(m, nn.Conv2d): 119 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 120 | m.weight.data.normal_(0, math.sqrt(2. / n)) 121 | if m.bias is not None: 122 | m.bias.data.zero_() 123 | elif isinstance(m, nn.BatchNorm2d): 124 | m.weight.data.fill_(1) 125 | m.bias.data.zero_() 126 | elif isinstance(m, nn.Linear): 127 | n = m.weight.size(1) 128 | m.weight.data.normal_(0, 0.01) 129 | m.bias.data.zero_() 130 | 131 | 132 | def mobilenet_v2(pretrained=True): 133 | model = MobileNetV2(width_mult=1) 134 | 135 | if pretrained: 136 | try: 137 | from torch.hub import load_state_dict_from_url 138 | except ImportError: 139 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 140 | state_dict = load_state_dict_from_url( 141 | 'https://www.dropbox.com/s/47tyzpofuuyyv1b/mobilenetv2_1.0-f2a8633.pth.tar?dl=1', progress=True) 142 | model.load_state_dict(state_dict) 143 | return model 144 | 145 | 146 | if __name__ == '__main__': 147 | net = mobilenet_v2(True) 148 | 149 | 150 | 151 | 152 | 153 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | 6 | import os 7 | import time 8 | import shutil 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim 12 | from torch.nn.utils import clip_grad_norm_ 13 | import torch.distributed as distr 14 | 15 | from ops.dataset import TSNDataSet 16 | from ops.models import TSN 17 | from ops.transforms import * 18 | from opts import parser 19 | from ops import dataset_config 20 | from ops.utils import AverageMeter, accuracy 21 | from ops.temporal_shift import make_temporal_pool 22 | 23 | from tensorboardX import SummaryWriter 24 | import datetime 25 | 26 | best_prec1 = 0 27 | 28 | def main(): 29 | global args, best_prec1 30 | args = parser.parse_args() 31 | 32 | distr.init_process_group(backend='nccl',init_method=args.init_method, 33 | rank=args.rank, world_size=args.world_size, timeout=datetime.timedelta(hours=1.)) 34 | 35 | num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset, 36 | args.modality) 37 | 38 | args.modality = args.modality.split(',') 39 | full_arch_name = args.arch 40 | if args.shift: 41 | full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place) 42 | if args.temporal_pool: 43 | full_arch_name += '_tpool' 44 | args.store_name = '_'.join( 45 | ['TSM', args.dataset, args.modality[args.rank], full_arch_name, args.consensus_type, 'segment%d' % args.num_segments, 46 | 'e{}'.format(args.epochs)]) 47 | if args.pretrain != 'imagenet': 48 | args.store_name += '_{}'.format(args.pretrain) 49 | if args.lr_type != 'step': 50 | args.store_name += '_{}'.format(args.lr_type) 51 | if args.dense_sample: 52 | args.store_name += '_dense{}'.format(args.dense_length) 53 | elif args.random_sample: 54 | args.store_name += '_random{}'.format(args.dense_length) 55 | if args.non_local > 0: 56 | args.store_name += '_nl' 57 | if args.suffix is not None: 58 | args.store_name += '_{}'.format(args.suffix) 59 | if len(args.modality)>1: 60 | args.store_name += '_ML{}'.format(args.rank) 61 | print('storing name: ' + args.store_name) 62 | 63 | check_rootfolders() 64 | 65 | model = TSN(num_class, args.num_segments, args.modality[args.rank], 66 | base_model=args.arch, 67 | consensus_type=args.consensus_type, 68 | dropout=args.dropout, 69 | img_feature_dim=args.img_feature_dim, 70 | partial_bn=not args.no_partialbn, 71 | pretrain=args.pretrain, 72 | is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, 73 | fc_lr5=not (args.tune_from and args.dataset in args.tune_from), 74 | temporal_pool=args.temporal_pool, 75 | non_local=args.non_local) 76 | 77 | crop_size = model.crop_size 78 | scale_size = model.scale_size 79 | policies = model.get_optim_policies() 80 | train_augmentation = model.get_augmentation(flip=False if 'something' in args.dataset or 'jester' in args.dataset else True) 81 | 82 | model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda(args.gpus[0]) 83 | 84 | optimizer = torch.optim.SGD(policies, 85 | args.lr, 86 | momentum=args.momentum, 87 | weight_decay=args.weight_decay) 88 | 89 | if args.resume: 90 | if args.temporal_pool: # early temporal pool so that we can load the state_dict 91 | make_temporal_pool(model.module.base_model, args.num_segments) 92 | if os.path.isfile(args.resume): 93 | print(("=> loading checkpoint '{}'".format(args.resume))) 94 | checkpoint = torch.load(args.resume) 95 | args.start_epoch = checkpoint['epoch'] 96 | best_prec1 = checkpoint['best_prec1'] 97 | model.load_state_dict(checkpoint['state_dict']) 98 | optimizer.load_state_dict(checkpoint['optimizer']) 99 | print(("=> loaded checkpoint '{}' (epoch {})" 100 | .format(args.evaluate, checkpoint['epoch']))) 101 | else: 102 | print(("=> no checkpoint found at '{}'".format(args.resume))) 103 | 104 | if args.tune_from: 105 | print(("=> fine-tuning from '{}'".format(args.tune_from))) 106 | sd = torch.load(args.tune_from) 107 | sd = sd['state_dict'] 108 | model_dict = model.state_dict() 109 | replace_dict = [] 110 | for k, v in sd.items(): 111 | if k not in model_dict and k.replace('.net', '') in model_dict: 112 | print('=> Load after remove .net: ', k) 113 | replace_dict.append((k, k.replace('.net', ''))) 114 | for k, v in model_dict.items(): 115 | if k not in sd and k.replace('.net', '') in sd: 116 | print('=> Load after adding .net: ', k) 117 | replace_dict.append((k.replace('.net', ''), k)) 118 | 119 | for k, k_new in replace_dict: 120 | sd[k_new] = sd.pop(k) 121 | keys1 = set(list(sd.keys())) 122 | keys2 = set(list(model_dict.keys())) 123 | set_diff = (keys1 - keys2) | (keys2 - keys1) 124 | print('#### Notice: keys that failed to load: {}'.format(set_diff)) 125 | if args.dataset not in args.tune_from: # new dataset 126 | print('=> New dataset, do not load fc weights') 127 | sd = {k: v for k, v in sd.items() if 'fc' not in k} 128 | model_dict.update(sd) 129 | 130 | if args.modality[args.rank] not in args.tune_from or (args.modality[args.rank]=='RGB' and 'RGBDiff' in args.tune_from): 131 | if 'Flow' in args.tune_from: 132 | model._construct_flow_model(model.base_model) 133 | elif 'RGBDiff' in args.tune_from: 134 | model._construct_diff_model(model.base_model) 135 | else: 136 | model._construct_rgb_model(model.base_model) 137 | model.load_state_dict(model_dict) 138 | if args.modality[args.rank]=='Flow': 139 | model._construct_flow_model(model.base_model) 140 | elif args.modality[args.rank]=='RGBDiff': 141 | model._construct_diff_model(model.base_model) 142 | else: 143 | model._construct_rgb_model(model.base_model) 144 | else: 145 | model.load_state_dict(model_dict) 146 | 147 | if args.temporal_pool and not args.resume: 148 | make_temporal_pool(model.module.base_model, args.num_segments) 149 | 150 | cudnn.benchmark = True 151 | 152 | # Data loading code 153 | train_loader = None 154 | if args.rank==0: 155 | input_mean = [] 156 | input_std = [] 157 | data_length = [] 158 | for moda in args.modality: 159 | if moda=='RGB': 160 | input_mean += [0.485, 0.456, 0.406] 161 | input_std += [0.229, 0.224, 0.225] 162 | data_length += [1] 163 | elif moda=='Flow': 164 | input_mean += [0.5]*10 165 | input_std += [0.226]*10 166 | data_length += [5] 167 | elif moda=='RGBDiff': 168 | input_mean += [0.]*18 169 | input_std += [1.]*18 170 | data_length += [6] 171 | 172 | normalize = GroupNormalize(input_mean, input_std) 173 | train_loader = torch.utils.data.DataLoader( 174 | TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments, 175 | modality=args.modality, 176 | new_length=data_length, 177 | image_tmpl=prefix, 178 | transform=torchvision.transforms.Compose([ 179 | train_augmentation, 180 | Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), 181 | ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])), 182 | normalize, 183 | ]), dense_sample=args.dense_sample, random_sample=args.random_sample, 184 | dense_length=args.dense_length), 185 | batch_size=args.batch_size, shuffle=True, 186 | num_workers=args.workers, pin_memory=True, 187 | drop_last=True) # prevent something not % n_GPU 188 | 189 | if args.modality[args.rank]=='RGB': 190 | normalize_val = GroupNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 191 | elif args.modality[args.rank]=='Flow': 192 | normalize_val = GroupNormalize([0.5]*10, [0.226]*10) 193 | elif args.modality[args.rank]=='RGBDiff': 194 | normalize_val = IdentityTransform() 195 | 196 | val_loader = torch.utils.data.DataLoader( 197 | TSNDataSet([args.root_path[args.rank]], [args.val_list[args.rank]], num_segments=args.num_segments, 198 | new_length=[data_length[args.rank]], 199 | modality=[args.modality[args.rank]], 200 | image_tmpl=[prefix[args.rank]], 201 | random_shift=False, 202 | transform=torchvision.transforms.Compose([ 203 | GroupScale(int(scale_size)), 204 | GroupCenterCrop(crop_size), 205 | Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), 206 | ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])), 207 | normalize_val, 208 | ]), dense_sample=args.dense_sample, dense_length=args.dense_length), 209 | batch_size=args.batch_size, shuffle=False, 210 | num_workers=args.workers, pin_memory=True) 211 | 212 | # define loss function (criterion) and optimizer 213 | if args.loss_type == 'nll': 214 | criterion = torch.nn.CrossEntropyLoss().cuda(args.gpus[-1]) 215 | else: 216 | raise ValueError("Unknown loss type") 217 | 218 | if len(args.modality)>1: 219 | kl_loss = torch.nn.KLDivLoss(reduction='batchmean').cuda(args.gpus[-1]) 220 | logsoftmax = torch.nn.LogSoftmax(dim=1).cuda(args.gpus[-1]) 221 | softmax = torch.nn.Softmax(dim=1).cuda(args.gpus[-1]) 222 | else: 223 | kl_loss = None 224 | logsoftmax = None 225 | softmax = None 226 | 227 | for group in policies: 228 | print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format( 229 | group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) 230 | 231 | if args.evaluate: 232 | validate(val_loader, model, criterion, 0) 233 | return 234 | 235 | log_training = open(os.path.join(args.root_log, args.store_name, 'log.csv'), 'w') 236 | with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f: 237 | f.write(str(args)) 238 | tf_writer = SummaryWriter(log_dir=os.path.join(args.root_log, args.store_name)) 239 | for epoch in range(args.start_epoch, args.epochs): 240 | adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps) 241 | 242 | # train for one epoch 243 | train(train_loader, model, criterion, kl_loss, logsoftmax, softmax, optimizer, epoch, log_training, tf_writer) 244 | 245 | # evaluate on validation set 246 | if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: 247 | prec1 = validate(val_loader, model, criterion, epoch, log_training, tf_writer) 248 | 249 | # remember best prec@1 and save checkpoint 250 | is_best = prec1 > best_prec1 251 | best_prec1 = max(prec1, best_prec1) 252 | tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch) 253 | 254 | output_best = 'Best Prec@1: %.3f\n' % (best_prec1) 255 | print(output_best) 256 | log_training.write(output_best + '\n') 257 | log_training.flush() 258 | 259 | save_checkpoint({ 260 | 'epoch': epoch + 1, 261 | 'arch': args.arch, 262 | 'state_dict': model.state_dict(), 263 | 'optimizer': optimizer.state_dict(), 264 | 'best_prec1': best_prec1, 265 | }, is_best) 266 | 267 | 268 | def train(train_loader, model, criterion, kl_loss, logsoftmax, softmax, optimizer, epoch, log, tf_writer): 269 | batch_time = AverageMeter() 270 | data_time = AverageMeter() 271 | losses = AverageMeter() 272 | loss_kl = AverageMeter() 273 | top1 = AverageMeter() 274 | top5 = AverageMeter() 275 | 276 | total = 0 277 | shift = 0 278 | for i,moda in enumerate(args.modality): 279 | tmp = total 280 | if moda=='RGB': 281 | total += 3 282 | elif moda=='Flow': 283 | total += 10 284 | elif moda=='RGBDiff': 285 | total += 18 286 | if i==0: 287 | shift = total 288 | if i==args.rank and i>0: 289 | start_ind = tmp-shift 290 | end_ind = total-shift 291 | elif i==args.rank and i==0: 292 | start_ind = 0 293 | end_ind = total 294 | 295 | if args.rank==0: 296 | inds = [] 297 | for x in range(args.num_segments): 298 | inds.extend(list(range(x*total+start_ind,x*total+end_ind))) 299 | send_inds = [] 300 | for x in range(args.num_segments): 301 | send_inds.extend(list(range(x*total+end_ind,x*total+total))) 302 | else: 303 | inds = [] 304 | for x in range(args.num_segments): 305 | inds.extend(list(range(x*(total-shift)+start_ind,x*(total-shift)+end_ind))) 306 | 307 | if args.no_partialbn: 308 | model.module.partialBN(False) 309 | else: 310 | model.module.partialBN(True) 311 | 312 | # switch to train mode5r 313 | model.train() 314 | 315 | if args.rank==0: 316 | iter_through = train_loader 317 | else: 318 | iter_through = range(int(len([x for x in open(args.train_list[0])])/args.batch_size)) 319 | 320 | end = time.time() 321 | for i, data in enumerate(iter_through): 322 | # measure data loading time 323 | data_time.update(time.time() - end) 324 | 325 | if args.rank==0: 326 | input, target = data 327 | 328 | target = target.cuda(args.gpus[-1]) 329 | input = input.cuda(args.gpus[0]) 330 | 331 | if args.world_size>1: 332 | torch.distributed.broadcast(input[:,send_inds].contiguous(),0) 333 | torch.distributed.broadcast(target,0) 334 | else: 335 | input = torch.zeros((args.batch_size,(total-shift)*args.num_segments,224,224)).cuda(args.gpus[0]) 336 | target = torch.zeros((args.batch_size,),dtype=torch.int64).cuda(args.gpus[-1]) 337 | torch.distributed.broadcast(input,0) 338 | torch.distributed.broadcast(target,0) 339 | 340 | input_var = torch.autograd.Variable(input[:,inds].contiguous()) 341 | target_var = torch.autograd.Variable(target) 342 | 343 | # compute output 344 | output = model(input_var).cuda(args.gpus[-1]) 345 | loss1 = criterion(output, target_var) 346 | 347 | if args.world_size>1: 348 | reduce_output = output.clone().detach() 349 | distr.all_reduce(reduce_output) 350 | reduce_output = (reduce_output-output.detach())/(args.world_size-1) 351 | loss2 = kl_loss(logsoftmax(output), softmax(reduce_output.detach())) 352 | else: 353 | loss2 = torch.tensor(0.) 354 | loss = loss1+loss2 355 | 356 | # measure accuracy and record loss 357 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 358 | losses.update(loss1.item(), input.size(0)) 359 | loss_kl.update(loss2.item(), input.size(0)) 360 | top1.update(prec1.item(), input.size(0)) 361 | top5.update(prec5.item(), input.size(0)) 362 | 363 | # compute gradient and do SGD step 364 | loss.backward() 365 | 366 | if args.clip_gradient is not None: 367 | total_norm = clip_grad_norm_(model.parameters(), args.clip_gradient) 368 | 369 | optimizer.step() 370 | optimizer.zero_grad() 371 | 372 | # measure elapsed time 373 | batch_time.update(time.time() - end) 374 | end = time.time() 375 | 376 | if i % args.print_freq == 0: 377 | output = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t' 378 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 379 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 380 | 'Loss {loss1.val:.4f} ({loss1.avg:.4f})\t' 381 | 'LossKL {loss2.val:.4f} ({loss2.avg:.4f})\t' 382 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 383 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 384 | epoch, i, len(iter_through), batch_time=batch_time, 385 | data_time=data_time, loss1=losses, loss2=loss_kl, top1=top1, top5=top5, lr=optimizer.param_groups[-1]['lr'] * 0.1)) # TODO 386 | print(output) 387 | log.write(output + '\n') 388 | log.flush() 389 | 390 | tf_writer.add_scalar('loss/train', losses.avg, epoch) 391 | tf_writer.add_scalar('loss/mutual', loss_kl.avg, epoch) 392 | tf_writer.add_scalar('acc/train_top1', top1.avg, epoch) 393 | tf_writer.add_scalar('acc/train_top5', top5.avg, epoch) 394 | tf_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch) 395 | 396 | 397 | def validate(val_loader, model, criterion, epoch, log=None, tf_writer=None): 398 | batch_time = AverageMeter() 399 | losses = AverageMeter() 400 | top1 = AverageMeter() 401 | top5 = AverageMeter() 402 | 403 | # switch to evaluate mode 404 | model.eval() 405 | 406 | end = time.time() 407 | with torch.no_grad(): 408 | for i, (input, target) in enumerate(val_loader): 409 | target = target.cuda(args.gpus[-1]) 410 | 411 | # compute output 412 | output = model(input).cuda(args.gpus[-1]) 413 | loss = criterion(output, target) 414 | 415 | # measure accuracy and record loss 416 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 417 | 418 | losses.update(loss.item(), input.size(0)) 419 | top1.update(prec1.item(), input.size(0)) 420 | top5.update(prec5.item(), input.size(0)) 421 | 422 | # measure elapsed time 423 | batch_time.update(time.time() - end) 424 | end = time.time() 425 | 426 | if i % args.print_freq == 0: 427 | output = ('Test: [{0}/{1}]\t' 428 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 429 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 430 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 431 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 432 | i, len(val_loader), batch_time=batch_time, loss=losses, 433 | top1=top1, top5=top5)) 434 | print(output) 435 | if log is not None: 436 | log.write(output + '\n') 437 | log.flush() 438 | 439 | output = ('Testing Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}' 440 | .format(top1=top1, top5=top5, loss=losses)) 441 | print(output) 442 | if log is not None: 443 | log.write(output + '\n') 444 | log.flush() 445 | 446 | if tf_writer is not None: 447 | tf_writer.add_scalar('loss/test', losses.avg, epoch) 448 | tf_writer.add_scalar('acc/test_top1', top1.avg, epoch) 449 | tf_writer.add_scalar('acc/test_top5', top5.avg, epoch) 450 | 451 | return top1.avg 452 | 453 | 454 | def save_checkpoint(state, is_best): 455 | filename = '%s/%s/ckpt.pth.tar' % (args.root_model, args.store_name) 456 | torch.save(state, filename) 457 | if is_best: 458 | shutil.copyfile(filename, filename.replace('pth.tar', 'best.pth.tar')) 459 | 460 | 461 | def adjust_learning_rate(optimizer, epoch, lr_type, lr_steps): 462 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 463 | if lr_type == 'step': 464 | decay = 0.1 ** (sum(epoch >= np.array(lr_steps))) 465 | lr = args.lr * decay 466 | decay = args.weight_decay 467 | elif lr_type == 'cos': 468 | import math 469 | lr = 0.5 * args.lr * (1 + math.cos(math.pi * epoch / args.epochs)) 470 | decay = args.weight_decay 471 | else: 472 | raise NotImplementedError 473 | for param_group in optimizer.param_groups: 474 | param_group['lr'] = lr * param_group['lr_mult'] 475 | param_group['weight_decay'] = decay * param_group['decay_mult'] 476 | 477 | 478 | def check_rootfolders(): 479 | """Create log and model folder""" 480 | folders_util = [args.root_log, args.root_model, 481 | os.path.join(args.root_log, args.store_name), 482 | os.path.join(args.root_model, args.store_name)] 483 | for folder in folders_util: 484 | if not os.path.exists(folder): 485 | print('creating folder ' + folder) 486 | os.mkdir(folder) 487 | 488 | 489 | if __name__ == '__main__': 490 | main() 491 | -------------------------------------------------------------------------------- /online_demo/README.md: -------------------------------------------------------------------------------- 1 | # TSM Online Hand Gesture Recognition Demo 2 | 3 | ``` 4 | @inproceedings{lin2019tsm, 5 | title={TSM: Temporal Shift Module for Efficient Video Understanding}, 6 | author={Lin, Ji and Gan, Chuang and Han, Song}, 7 | booktitle={Proceedings of the IEEE International Conference on Computer Vision}, 8 | year={2019} 9 | } 10 | ``` 11 | 12 | ![tsm-demo](https://file.lzhu.me/projects/tsm/external/tsm-demo2.gif) 13 | 14 | See the [[full video]](https://hanlab.mit.edu/projects/tsm/#live_demo) of our demo on NVIDIA Jetson Nano. 15 | 16 | ## Overview 17 | 18 | We show how to deploy an online hand gesture recognition system on **NVIDIA Jetson Nano**. The model is based on MobileNetV2 backbone with **Temporal Shift Module (TSM)** to model the temporal relationship. It is compiled with **TVM** [1] for acceleration. 19 | 20 | The model can achieve **real-time** recognition. Without considering the data IO time, it can achieve **>70 FPS** on Nano GPU. 21 | 22 | [1] Tianqi Chen *et al.*, *TVM: An automated end-to-end optimizing compiler for deep learning*, in OSDI 2018 23 | 24 | ## Model 25 | 26 | We used an online version of Temporal Shift Module in this demo. The model design is shown below: 27 |

28 | 29 |

30 | 31 | After compiled with TVM, our model can efficient run on low-power devices. 32 | 33 |

34 | 35 |

36 | 37 | ## Step-by-step Tutorial 38 | 39 | We show how to set up the environment on Jetson Nano, compile the PyTorch model with TVM, and perform the online demo from camera streaming. 40 | 41 | 1. Get an [NVIDIA Jeston Nano](https://developer.nvidia.com/embedded/jetson-nano-developer-kit) board (it is only $99!). 42 | 2. Get a micro SD card and burn the **Nano system image** into it following [here](https://developer.nvidia.com/embedded/learn/get-started-jetson-nano-devkit). Insert the card and boot the Nano. **Note**: you may want to get a power adaptor for a stable power supply. 43 | 3. Check if OpenCv 4.X is installed (it is now included in SD card image from r32.3.1) 44 | ``` 45 | $ Python3 46 | >> Import cv2 47 | >> cv2.__version__ 48 | ``` 49 | It should show 4.X. 50 | If not, build **OpenCV** 4.0.0 using [this script](https://github.com/AastaNV/JEP/blob/master/script/install_opencv4.0.0_Nano.sh), so that we can enable camera access (It may take a while due to the weak CPU). You also need add cv2 package to path import search path. 51 | 52 | ``` 53 | export PYTHONPATH=/usr/local/python 54 | ``` 55 | 56 | 4. Follow [here](https://devtalk.nvidia.com/default/topic/1049071/jetson-nano/pytorch-for-jetson-nano/) to install **PyTorch** and **torchvision**. 57 | 5. Build **TVM** with following commands 58 | 59 | ``` 60 | sudo apt install llvm # install llvm which is required by tvm 61 | git clone -b v0.6 https://github.com/apache/incubator-tvm.git 62 | cd incubator-tvm 63 | git submodule update --init 64 | mkdir build 65 | cp cmake/config.cmake build/ 66 | cd build 67 | #[ 68 | #edit config.cmake to change 69 | # 32 line: USE_CUDA OFF -> USE_CUDA ON 70 | #104 line: USE_LLVM OFF -> USE_LLVM ON 71 | #] 72 | cmake .. 73 | make -j4 74 | cd .. 75 | cd python; sudo python3 setup.py install; cd .. 76 | cd topi/python; sudo python3 setup.py install; cd ../.. 77 | ``` 78 | 79 | 6. Install **ONNX** 80 | 81 | ``` 82 | # install onnx 83 | sudo apt-get install protobuf-compiler libprotoc-dev 84 | pip3 install onnx 85 | ``` 86 | 87 | 7. export cuda toolkit binary to path 88 | 89 | ``` 90 | export PATH=$PATH:/usr/local/cuda/bin 91 | ``` 92 | 93 | 8. **Finally, run the demo**. The first run will compile the PyTorch TSM model into TVM binary first and then run it. Later run will directly execute the compiled TVM model. 94 | 95 | ``` 96 | python3 main.py 97 | ``` 98 | 99 | Press `Q` or `Esc` to quit. Press `F` to enter/exit full-screen. 100 | 101 | ## Supported Gestures 102 | 103 | - No gesture 104 | - Stop Sign 105 | - Drumming Fingers 106 | - Thumb Up 107 | - Thumb Down 108 | - Zooming In With Full Hand 109 | - Zooming In With Two Fingers 110 | - Zooming Out With Full Hand 111 | - Zooming Out With Two Fingers 112 | - Swiping Down 113 | - Swiping Left 114 | - Swiping Right 115 | - Swiping Up 116 | - Sliding Two Fingers Down 117 | - Sliding Two Fingers Left 118 | - Sliding Two Fingers Right 119 | - Sliding Two Fingers Up 120 | - Pulling Hand In 121 | - Pulling Two Fingers In 122 | 123 | ## Contact 124 | 125 | For any problems, contact: 126 | 127 | Ji Lin, jilin@mit.edu 128 | 129 | Yaoyao Ding, yyding@mit.edu 130 | -------------------------------------------------------------------------------- /online_demo/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from typing import Tuple 4 | import io 5 | import tvm 6 | import tvm.relay 7 | import time 8 | import cv2 9 | import torch 10 | import torchvision 11 | import torch.onnx 12 | from PIL import Image, ImageOps 13 | import onnx 14 | import tvm.contrib.graph_runtime as graph_runtime 15 | from mobilenet_v2_tsm import MobileNetV2 16 | 17 | SOFTMAX_THRES = 0 18 | HISTORY_LOGIT = True 19 | REFINE_OUTPUT = True 20 | 21 | def torch2tvm_module(torch_module: torch.nn.Module, torch_inputs: Tuple[torch.Tensor, ...], target): 22 | torch_module.eval() 23 | input_names = [] 24 | input_shapes = {} 25 | with torch.no_grad(): 26 | for index, torch_input in enumerate(torch_inputs): 27 | name = "i" + str(index) 28 | input_names.append(name) 29 | input_shapes[name] = torch_input.shape 30 | buffer = io.BytesIO() 31 | torch.onnx.export(torch_module, torch_inputs, buffer, input_names=input_names, output_names=["o" + str(i) for i in range(len(torch_inputs))]) 32 | outs = torch_module(*torch_inputs) 33 | buffer.seek(0, 0) 34 | onnx_model = onnx.load_model(buffer) 35 | relay_module, params = tvm.relay.frontend.from_onnx(onnx_model, shape=input_shapes) 36 | with tvm.relay.build_config(opt_level=3): 37 | graph, tvm_module, params = tvm.relay.build(relay_module, target, params=params) 38 | return graph, tvm_module, params 39 | 40 | 41 | def torch2executor(torch_module: torch.nn.Module, torch_inputs: Tuple[torch.Tensor, ...], target): 42 | prefix = f"mobilenet_tsm_tvm_{target}" 43 | lib_fname = f'{prefix}.tar' 44 | graph_fname = f'{prefix}.json' 45 | params_fname = f'{prefix}.params' 46 | if os.path.exists(lib_fname) and os.path.exists(graph_fname) and os.path.exists(params_fname): 47 | with open(graph_fname, 'rt') as f: 48 | graph = f.read() 49 | tvm_module = tvm.module.load(lib_fname) 50 | params = tvm.relay.load_param_dict(bytearray(open(params_fname, 'rb').read())) 51 | else: 52 | graph, tvm_module, params = torch2tvm_module(torch_module, torch_inputs, target) 53 | tvm_module.export_library(lib_fname) 54 | with open(graph_fname, 'wt') as f: 55 | f.write(graph) 56 | with open(params_fname, 'wb') as f: 57 | f.write(tvm.relay.save_param_dict(params)) 58 | 59 | ctx = tvm.gpu() if target.startswith('cuda') else tvm.cpu() 60 | graph_module = graph_runtime.create(graph, tvm_module, ctx) 61 | for pname, pvalue in params.items(): 62 | graph_module.set_input(pname, pvalue) 63 | 64 | def executor(inputs: Tuple[tvm.nd.NDArray]): 65 | for index, value in enumerate(inputs): 66 | graph_module.set_input(index, value) 67 | graph_module.run() 68 | return tuple(graph_module.get_output(index) for index in range(len(inputs))) 69 | 70 | return executor, ctx 71 | 72 | 73 | def get_executor(use_gpu=True): 74 | torch_module = MobileNetV2(n_class=27) 75 | if not os.path.exists("mobilenetv2_jester_online.pth.tar"): # checkpoint not downloaded 76 | print('Downloading PyTorch checkpoint...') 77 | import urllib.request 78 | url = 'https://file.lzhu.me/projects/tsm/models/mobilenetv2_jester_online.pth.tar' 79 | urllib.request.urlretrieve(url, './mobilenetv2_jester_online.pth.tar') 80 | torch_module.load_state_dict(torch.load("mobilenetv2_jester_online.pth.tar")) 81 | torch_inputs = (torch.rand(1, 3, 224, 224), 82 | torch.zeros([1, 3, 56, 56]), 83 | torch.zeros([1, 4, 28, 28]), 84 | torch.zeros([1, 4, 28, 28]), 85 | torch.zeros([1, 8, 14, 14]), 86 | torch.zeros([1, 8, 14, 14]), 87 | torch.zeros([1, 8, 14, 14]), 88 | torch.zeros([1, 12, 14, 14]), 89 | torch.zeros([1, 12, 14, 14]), 90 | torch.zeros([1, 20, 7, 7]), 91 | torch.zeros([1, 20, 7, 7])) 92 | if use_gpu: 93 | target = 'cuda' 94 | else: 95 | target = 'llvm -mcpu=cortex-a72 -target=armv7l-linux-gnueabihf' 96 | return torch2executor(torch_module, torch_inputs, target) 97 | 98 | 99 | def transform(frame: np.ndarray): 100 | # 480, 640, 3, 0 ~ 255 101 | frame = cv2.resize(frame, (224, 224)) # (224, 224, 3) 0 ~ 255 102 | frame = frame / 255.0 # (224, 224, 3) 0 ~ 1.0 103 | frame = np.transpose(frame, axes=[2, 0, 1]) # (3, 224, 224) 0 ~ 1.0 104 | frame = np.expand_dims(frame, axis=0) # (1, 3, 480, 640) 0 ~ 1.0 105 | return frame 106 | 107 | 108 | class GroupScale(object): 109 | """ Rescales the input PIL.Image to the given 'size'. 110 | 'size' will be the size of the smaller edge. 111 | For example, if height > width, then image will be 112 | rescaled to (size * height / width, size) 113 | size: size of the smaller edge 114 | interpolation: Default: PIL.Image.BILINEAR 115 | """ 116 | 117 | def __init__(self, size, interpolation=Image.BILINEAR): 118 | self.worker = torchvision.transforms.Scale(size, interpolation) 119 | 120 | def __call__(self, img_group): 121 | return [self.worker(img) for img in img_group] 122 | 123 | 124 | class GroupCenterCrop(object): 125 | def __init__(self, size): 126 | self.worker = torchvision.transforms.CenterCrop(size) 127 | 128 | def __call__(self, img_group): 129 | return [self.worker(img) for img in img_group] 130 | 131 | 132 | class Stack(object): 133 | 134 | def __init__(self, roll=False): 135 | self.roll = roll 136 | 137 | def __call__(self, img_group): 138 | if img_group[0].mode == 'L': 139 | return np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2) 140 | elif img_group[0].mode == 'RGB': 141 | if self.roll: 142 | return np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2) 143 | else: 144 | return np.concatenate(img_group, axis=2) 145 | 146 | 147 | class ToTorchFormatTensor(object): 148 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 149 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 150 | 151 | def __init__(self, div=True): 152 | self.div = div 153 | 154 | def __call__(self, pic): 155 | if isinstance(pic, np.ndarray): 156 | # handle numpy array 157 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 158 | else: 159 | # handle PIL Image 160 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 161 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 162 | # put it from HWC to CHW format 163 | # yikes, this transpose takes 80% of the loading time/CPU 164 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 165 | return img.float().div(255) if self.div else img.float() 166 | 167 | 168 | class GroupNormalize(object): 169 | def __init__(self, mean, std): 170 | self.mean = mean 171 | self.std = std 172 | 173 | def __call__(self, tensor): 174 | rep_mean = self.mean * (tensor.size()[0] // len(self.mean)) 175 | rep_std = self.std * (tensor.size()[0] // len(self.std)) 176 | 177 | # TODO: make efficient 178 | for t, m, s in zip(tensor, rep_mean, rep_std): 179 | t.sub_(m).div_(s) 180 | 181 | return tensor 182 | 183 | 184 | def get_transform(): 185 | cropping = torchvision.transforms.Compose([ 186 | GroupScale(256), 187 | GroupCenterCrop(224), 188 | ]) 189 | transform = torchvision.transforms.Compose([ 190 | cropping, 191 | Stack(roll=False), 192 | ToTorchFormatTensor(div=True), 193 | GroupNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 194 | ]) 195 | return transform 196 | 197 | catigories = [ 198 | "Doing other things", # 0 199 | "Drumming Fingers", # 1 200 | "No gesture", # 2 201 | "Pulling Hand In", # 3 202 | "Pulling Two Fingers In", # 4 203 | "Pushing Hand Away", # 5 204 | "Pushing Two Fingers Away", # 6 205 | "Rolling Hand Backward", # 7 206 | "Rolling Hand Forward", # 8 207 | "Shaking Hand", # 9 208 | "Sliding Two Fingers Down", # 10 209 | "Sliding Two Fingers Left", # 11 210 | "Sliding Two Fingers Right", # 12 211 | "Sliding Two Fingers Up", # 13 212 | "Stop Sign", # 14 213 | "Swiping Down", # 15 214 | "Swiping Left", # 16 215 | "Swiping Right", # 17 216 | "Swiping Up", # 18 217 | "Thumb Down", # 19 218 | "Thumb Up", # 20 219 | "Turning Hand Clockwise", # 21 220 | "Turning Hand Counterclockwise", # 22 221 | "Zooming In With Full Hand", # 23 222 | "Zooming In With Two Fingers", # 24 223 | "Zooming Out With Full Hand", # 25 224 | "Zooming Out With Two Fingers" # 26 225 | ] 226 | 227 | 228 | n_still_frame = 0 229 | 230 | def process_output(idx_, history): 231 | # idx_: the output of current frame 232 | # history: a list containing the history of predictions 233 | if not REFINE_OUTPUT: 234 | return idx_, history 235 | 236 | max_hist_len = 20 # max history buffer 237 | 238 | # mask out illegal action 239 | if idx_ in [7, 8, 21, 22, 3]: 240 | idx_ = history[-1] 241 | 242 | # use only single no action class 243 | if idx_ == 0: 244 | idx_ = 2 245 | 246 | # history smoothing 247 | if idx_ != history[-1]: 248 | if not (history[-1] == history[-2]): # and history[-2] == history[-3]): 249 | idx_ = history[-1] 250 | 251 | 252 | history.append(idx_) 253 | history = history[-max_hist_len:] 254 | 255 | return history[-1], history 256 | 257 | 258 | WINDOW_NAME = 'Video Gesture Recognition' 259 | def main(): 260 | print("Open camera...") 261 | cap = cv2.VideoCapture(0) 262 | 263 | print(cap) 264 | 265 | # set a lower resolution for speed up 266 | cap.set(cv2.CAP_PROP_FRAME_WIDTH, 320) 267 | cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 240) 268 | 269 | # env variables 270 | full_screen = False 271 | cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL) 272 | cv2.resizeWindow(WINDOW_NAME, 640, 480) 273 | cv2.moveWindow(WINDOW_NAME, 0, 0) 274 | cv2.setWindowTitle(WINDOW_NAME, WINDOW_NAME) 275 | 276 | 277 | t = None 278 | index = 0 279 | print("Build transformer...") 280 | transform = get_transform() 281 | print("Build Executor...") 282 | executor, ctx = get_executor() 283 | buffer = ( 284 | tvm.nd.empty((1, 3, 56, 56), ctx=ctx), 285 | tvm.nd.empty((1, 4, 28, 28), ctx=ctx), 286 | tvm.nd.empty((1, 4, 28, 28), ctx=ctx), 287 | tvm.nd.empty((1, 8, 14, 14), ctx=ctx), 288 | tvm.nd.empty((1, 8, 14, 14), ctx=ctx), 289 | tvm.nd.empty((1, 8, 14, 14), ctx=ctx), 290 | tvm.nd.empty((1, 12, 14, 14), ctx=ctx), 291 | tvm.nd.empty((1, 12, 14, 14), ctx=ctx), 292 | tvm.nd.empty((1, 20, 7, 7), ctx=ctx), 293 | tvm.nd.empty((1, 20, 7, 7), ctx=ctx) 294 | ) 295 | idx = 0 296 | history = [2] 297 | history_logit = [] 298 | history_timing = [] 299 | 300 | i_frame = -1 301 | 302 | print("Ready!") 303 | while True: 304 | i_frame += 1 305 | _, img = cap.read() # (480, 640, 3) 0 ~ 255 306 | if i_frame % 2 == 0: # skip every other frame to obtain a suitable frame rate 307 | t1 = time.time() 308 | img_tran = transform([Image.fromarray(img).convert('RGB')]) 309 | input_var = torch.autograd.Variable(img_tran.view(1, 3, img_tran.size(1), img_tran.size(2))) 310 | img_nd = tvm.nd.array(input_var.detach().numpy(), ctx=ctx) 311 | inputs: Tuple[tvm.nd.NDArray] = (img_nd,) + buffer 312 | outputs = executor(inputs) 313 | feat, buffer = outputs[0], outputs[1:] 314 | assert isinstance(feat, tvm.nd.NDArray) 315 | 316 | if SOFTMAX_THRES > 0: 317 | feat_np = feat.asnumpy().reshape(-1) 318 | feat_np -= feat_np.max() 319 | softmax = np.exp(feat_np) / np.sum(np.exp(feat_np)) 320 | 321 | print(max(softmax)) 322 | if max(softmax) > SOFTMAX_THRES: 323 | idx_ = np.argmax(feat.asnumpy(), axis=1)[0] 324 | else: 325 | idx_ = idx 326 | else: 327 | idx_ = np.argmax(feat.asnumpy(), axis=1)[0] 328 | 329 | if HISTORY_LOGIT: 330 | history_logit.append(feat.asnumpy()) 331 | history_logit = history_logit[-12:] 332 | avg_logit = sum(history_logit) 333 | idx_ = np.argmax(avg_logit, axis=1)[0] 334 | 335 | idx, history = process_output(idx_, history) 336 | 337 | t2 = time.time() 338 | print(f"{index} {catigories[idx]}") 339 | 340 | 341 | current_time = t2 - t1 342 | 343 | img = cv2.resize(img, (640, 480)) 344 | img = img[:, ::-1] 345 | height, width, _ = img.shape 346 | label = np.zeros([height // 10, width, 3]).astype('uint8') + 255 347 | 348 | cv2.putText(label, 'Prediction: ' + catigories[idx], 349 | (0, int(height / 16)), 350 | cv2.FONT_HERSHEY_SIMPLEX, 351 | 0.7, (0, 0, 0), 2) 352 | cv2.putText(label, '{:.1f} Vid/s'.format(1 / current_time), 353 | (width - 170, int(height / 16)), 354 | cv2.FONT_HERSHEY_SIMPLEX, 355 | 0.7, (0, 0, 0), 2) 356 | 357 | img = np.concatenate((img, label), axis=0) 358 | cv2.imshow(WINDOW_NAME, img) 359 | 360 | key = cv2.waitKey(1) 361 | if key & 0xFF == ord('q') or key == 27: # exit 362 | break 363 | elif key == ord('F') or key == ord('f'): # full screen 364 | print('Changing full screen option!') 365 | full_screen = not full_screen 366 | if full_screen: 367 | print('Setting FS!!!') 368 | cv2.setWindowProperty(WINDOW_NAME, cv2.WND_PROP_FULLSCREEN, 369 | cv2.WINDOW_FULLSCREEN) 370 | else: 371 | cv2.setWindowProperty(WINDOW_NAME, cv2.WND_PROP_FULLSCREEN, 372 | cv2.WINDOW_NORMAL) 373 | 374 | 375 | if t is None: 376 | t = time.time() 377 | else: 378 | nt = time.time() 379 | index += 1 380 | t = nt 381 | 382 | cap.release() 383 | cv2.destroyAllWindows() 384 | 385 | 386 | main() 387 | -------------------------------------------------------------------------------- /online_demo/mobilenet_v2_tsm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | 5 | 6 | def conv_bn(inp, oup, stride): 7 | return nn.Sequential( 8 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 9 | nn.BatchNorm2d(oup), 10 | nn.ReLU6(inplace=True) 11 | ) 12 | 13 | 14 | def conv_1x1_bn(inp, oup): 15 | return nn.Sequential( 16 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 17 | nn.BatchNorm2d(oup), 18 | nn.ReLU6(inplace=True) 19 | ) 20 | 21 | 22 | def make_divisible(x, divisible_by=8): 23 | import numpy as np 24 | return int(np.ceil(x * 1. / divisible_by) * divisible_by) 25 | 26 | 27 | class InvertedResidual(nn.Module): 28 | def __init__(self, inp, oup, stride, expand_ratio): 29 | super(InvertedResidual, self).__init__() 30 | self.stride = stride 31 | assert stride in [1, 2] 32 | 33 | hidden_dim = int(inp * expand_ratio) 34 | self.use_res_connect = self.stride == 1 and inp == oup 35 | 36 | if expand_ratio == 1: 37 | self.conv = nn.Sequential( 38 | # dw 39 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 40 | nn.BatchNorm2d(hidden_dim), 41 | nn.ReLU6(inplace=True), 42 | # pw-linear 43 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 44 | nn.BatchNorm2d(oup), 45 | ) 46 | else: 47 | self.conv = nn.Sequential( 48 | # pw 49 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 50 | nn.BatchNorm2d(hidden_dim), 51 | nn.ReLU6(inplace=True), 52 | # dw 53 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 54 | nn.BatchNorm2d(hidden_dim), 55 | nn.ReLU6(inplace=True), 56 | # pw-linear 57 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 58 | nn.BatchNorm2d(oup), 59 | ) 60 | 61 | def forward(self, x): 62 | if self.use_res_connect: 63 | return x + self.conv(x) 64 | else: 65 | return self.conv(x) 66 | 67 | class InvertedResidualWithShift(nn.Module): 68 | def __init__(self, inp, oup, stride, expand_ratio): 69 | super(InvertedResidualWithShift, self).__init__() 70 | self.stride = stride 71 | assert stride in [1, 2] 72 | 73 | assert expand_ratio > 1 74 | 75 | hidden_dim = int(inp * expand_ratio) 76 | self.use_res_connect = self.stride == 1 and inp == oup 77 | assert self.use_res_connect 78 | 79 | self.conv = nn.Sequential( 80 | # pw 81 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 82 | nn.BatchNorm2d(hidden_dim), 83 | nn.ReLU6(inplace=True), 84 | # dw 85 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 86 | nn.BatchNorm2d(hidden_dim), 87 | nn.ReLU6(inplace=True), 88 | # pw-linear 89 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 90 | nn.BatchNorm2d(oup), 91 | ) 92 | 93 | def forward(self, x, shift_buffer): 94 | c = x.size(1) 95 | x1, x2 = x[:, : c // 8], x[:, c // 8:] 96 | return x + self.conv(torch.cat((shift_buffer, x2), dim=1)), x1 97 | 98 | 99 | class MobileNetV2(nn.Module): 100 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 101 | super(MobileNetV2, self).__init__() 102 | input_channel = 32 103 | last_channel = 1280 104 | interverted_residual_setting = [ 105 | # t, c, n, s 106 | [1, 16, 1, 1], 107 | [6, 24, 2, 2], 108 | [6, 32, 3, 2], 109 | [6, 64, 4, 2], 110 | [6, 96, 3, 1], 111 | [6, 160, 3, 2], 112 | [6, 320, 1, 1], 113 | ] 114 | 115 | # building first layer 116 | assert input_size % 32 == 0 117 | # input_channel = make_divisible(input_channel * width_mult) # first channel is always 32! 118 | self.last_channel = make_divisible(last_channel * width_mult) if width_mult > 1.0 else last_channel 119 | self.features = [conv_bn(3, input_channel, 2)] 120 | # building inverted residual blocks 121 | global_idx = 0 122 | shift_block_idx = [2, 4, 5, 7, 8, 9, 11, 12, 14, 15] 123 | for t, c, n, s in interverted_residual_setting: 124 | output_channel = make_divisible(c * width_mult) if t > 1 else c 125 | for i in range(n): 126 | if i == 0: 127 | block = InvertedResidualWithShift if global_idx in shift_block_idx else InvertedResidual 128 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 129 | global_idx += 1 130 | else: 131 | block = InvertedResidualWithShift if global_idx in shift_block_idx else InvertedResidual 132 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 133 | global_idx += 1 134 | input_channel = output_channel 135 | # building last several layers 136 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 137 | # make it nn.Sequential 138 | self.features = nn.ModuleList(self.features) 139 | 140 | # building classifier 141 | self.classifier = nn.Linear(self.last_channel, n_class) 142 | 143 | self._initialize_weights() 144 | 145 | def forward(self, x, *shift_buffer): 146 | shift_buffer_idx = 0 147 | out_buffer = [] 148 | for f in self.features: 149 | if isinstance(f, InvertedResidualWithShift): 150 | x, s = f(x, shift_buffer[shift_buffer_idx]) 151 | shift_buffer_idx += 1 152 | out_buffer.append(s) 153 | else: 154 | x = f(x) 155 | x = x.mean(3).mean(2) 156 | x = self.classifier(x) 157 | return (x, *out_buffer) 158 | 159 | def _initialize_weights(self): 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 163 | m.weight.data.normal_(0, math.sqrt(2. / n)) 164 | if m.bias is not None: 165 | m.bias.data.zero_() 166 | elif isinstance(m, nn.BatchNorm2d): 167 | m.weight.data.fill_(1) 168 | m.bias.data.zero_() 169 | elif isinstance(m, nn.Linear): 170 | n = m.weight.size(1) 171 | m.weight.data.normal_(0, 0.01) 172 | m.bias.data.zero_() 173 | 174 | 175 | def mobilenet_v2_140(): 176 | return MobileNetV2(width_mult=1.4) 177 | 178 | 179 | if __name__ == '__main__': 180 | net = MobileNetV2() 181 | x = torch.rand(1, 3, 224, 224) 182 | shift_buffer = [torch.zeros([1, 3, 56, 56]), 183 | torch.zeros([1, 4, 28, 28]), 184 | torch.zeros([1, 4, 28, 28]), 185 | torch.zeros([1, 8, 14, 14]), 186 | torch.zeros([1, 8, 14, 14]), 187 | torch.zeros([1, 8, 14, 14]), 188 | torch.zeros([1, 12, 14, 14]), 189 | torch.zeros([1, 12, 14, 14]), 190 | torch.zeros([1, 20, 7, 7]), 191 | torch.zeros([1, 20, 7, 7])] 192 | with torch.no_grad(): 193 | for _ in range(10): 194 | y, shift_buffer = net(x, *shift_buffer) 195 | print([s.shape for s in shift_buffer]) 196 | -------------------------------------------------------------------------------- /ops/__init__.py: -------------------------------------------------------------------------------- 1 | from ops.basic_ops import * -------------------------------------------------------------------------------- /ops/basic_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Identity(torch.nn.Module): 5 | def forward(self, input): 6 | return input 7 | 8 | 9 | class SegmentConsensus(torch.nn.Module): 10 | 11 | def __init__(self, consensus_type, dim=1): 12 | super(SegmentConsensus, self).__init__() 13 | self.consensus_type = consensus_type 14 | self.dim = dim 15 | self.shape = None 16 | 17 | def forward(self, input_tensor): 18 | self.shape = input_tensor.size() 19 | if self.consensus_type == 'avg': 20 | output = input_tensor.mean(dim=self.dim, keepdim=True) 21 | elif self.consensus_type == 'identity': 22 | output = input_tensor 23 | else: 24 | output = None 25 | 26 | return output 27 | 28 | 29 | class ConsensusModule(torch.nn.Module): 30 | 31 | def __init__(self, consensus_type, dim=1): 32 | super(ConsensusModule, self).__init__() 33 | self.consensus_type = consensus_type if consensus_type != 'rnn' else 'identity' 34 | self.dim = dim 35 | 36 | def forward(self, input): 37 | return SegmentConsensus(self.consensus_type, self.dim)(input) 38 | -------------------------------------------------------------------------------- /ops/dataset.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | 6 | import torch.utils.data as data 7 | 8 | from PIL import Image 9 | import os 10 | import numpy as np 11 | from numpy.random import randint 12 | 13 | 14 | class VideoRecord(object): 15 | def __init__(self, row): 16 | self._data = row 17 | 18 | @property 19 | def path(self): 20 | return self._data[0] 21 | 22 | @property 23 | def num_frames(self): 24 | return int(self._data[1]) 25 | 26 | @property 27 | def label(self): 28 | return int(self._data[2]) 29 | 30 | 31 | class TSNDataSet(data.Dataset): 32 | def __init__(self, root_path, list_file, 33 | num_segments=8, new_length=[1], modality=['RGB'], 34 | image_tmpl=['{:06d}.jpg'], transform=None, 35 | random_shift=True, test_mode=False, 36 | remove_missing=False, random_sample=False, twice_sample=False, 37 | dense_length=32, dense_number=1, dense_sample=False): 38 | 39 | if len(modality)==2 and modality[0]==modality[1]: 40 | self.mml = False 41 | self.root_path = root_path[:1] 42 | self.list_file = list_file[:1] 43 | self.new_length = new_length[:1] 44 | self.modality = modality[:1] 45 | self.image_tmpl = image_tmpl[:1] 46 | else: 47 | self.mml = True 48 | self.root_path = root_path 49 | self.list_file = list_file 50 | self.new_length = new_length 51 | self.modality = modality 52 | self.image_tmpl = image_tmpl 53 | self.num_segments = num_segments 54 | self.transform = transform 55 | self.random_shift = random_shift 56 | self.test_mode = test_mode 57 | self.remove_missing = remove_missing 58 | self.dense_sample = dense_sample # using dense sample as I3D 59 | self.twice_sample = twice_sample # twice sample for more validation 60 | self.random_sample = random_sample 61 | self.dense_length = dense_length 62 | self.dense_number = dense_number 63 | if self.dense_sample: 64 | print('=> Using dense sample for the dataset...') 65 | print('=> Number of frames for run:',dense_length) 66 | elif self.random_sample: 67 | print('=> Using random sample for the dataset...') 68 | print('=> Number of frames for run:',dense_length) 69 | if self.twice_sample: 70 | print('=> Using twice sample for the dataset...') 71 | if test_mode and (self.random_sample or self.dense_sample): 72 | print('=> Number of runs:',dense_number) 73 | 74 | self._parse_list() 75 | 76 | def _load_image(self, directory, idx, moda, root, tmpl): 77 | if moda == 'RGB' or moda == 'RGBDiff': 78 | try: 79 | return [Image.open(os.path.join(root, directory, tmpl.format(idx))).convert('RGB')] 80 | except Exception: 81 | print('error loading image:', os.path.join(root, directory, tmpl.format(idx))) 82 | return [Image.open(os.path.join(root, directory, tmpl.format(1))).convert('RGB')] 83 | elif moda == 'Flow': 84 | if tmpl == 'flow_{}_{:05d}.jpg': # ucf 85 | x_img = Image.open(os.path.join(root, directory, tmpl.format('x', idx))).convert( 86 | 'L') 87 | y_img = Image.open(os.path.join(root, directory, tmpl.format('y', idx))).convert( 88 | 'L') 89 | elif tmpl == '{:06d}-{}_{:05d}.jpg': # something v1 flow 90 | x_img = Image.open(os.path.join(root, '{:06d}'.format(int(directory)), tmpl. 91 | format(int(directory), 'x', idx))).convert('L') 92 | y_img = Image.open(os.path.join(root, '{:06d}'.format(int(directory)), tmpl. 93 | format(int(directory), 'y', idx))).convert('L') 94 | else: 95 | try: 96 | # idx_skip = 1 + (idx-1)*5 97 | flow = Image.open(os.path.join(root, directory, tmpl.format(idx))).convert( 98 | 'RGB') 99 | except Exception: 100 | print('error loading flow file:', 101 | os.path.join(root, directory, tmpl.format(idx))) 102 | flow = Image.open(os.path.join(root, directory, tmpl.format(1))).convert('RGB') 103 | # the input flow file is RGB image with (flow_x, flow_y, blank) for each channel 104 | flow_x, flow_y, _ = flow.split() 105 | x_img = flow_x.convert('L') 106 | y_img = flow_y.convert('L') 107 | 108 | return [x_img, y_img] 109 | 110 | def _parse_list(self): 111 | self.video_list = [] 112 | for i,lst in enumerate(self.list_file): 113 | tmp = [x.strip().split(' ') for x in open(lst)] 114 | if not self.test_mode or self.remove_missing: 115 | tmp = [item for item in tmp if int(item[1]) >= 3] 116 | self.video_list.append([VideoRecord(item) for item in tmp]) 117 | 118 | if self.image_tmpl[i] == '{:06d}-{}_{:05d}.jpg': 119 | for v in self.video_list[i]: 120 | v._data[1] = int(v._data[1]) / 2 121 | print('video number:%d' % (len(self.video_list[i]))) 122 | 123 | def _sample_indices(self, record): 124 | """ 125 | 126 | :param record: VideoRecord 127 | :return: list 128 | """ 129 | frame_nums = np.array([x.num_frames for x in record]) 130 | argmin_frames = np.argmin(frame_nums) 131 | differences = frame_nums-frame_nums[argmin_frames] 132 | 133 | if self.dense_sample or (self.random_sample and np.random.randint(2)==0): # i3d dense sample 134 | sample_pos = max(1, 1 + record[argmin_frames].num_frames - self.dense_length) 135 | t_stride = self.dense_length // self.num_segments 136 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) 137 | offsets = [(idx * t_stride + start_idx) % record[argmin_frames].num_frames for idx in range(self.num_segments)] 138 | return [np.array(offsets) + 1 + np.random.randint(x+1) for x in differences] 139 | else: # normal sample 140 | average_duration = (record[argmin_frames].num_frames - self.new_length[argmin_frames] + 1) // self.num_segments 141 | if average_duration > 0: 142 | offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration, 143 | size=self.num_segments) 144 | elif record[argmin_frames].num_frames > self.num_segments: 145 | offsets = np.sort(randint(record[argmin_frames].num_frames - self.new_length[argmin_frames] + 1, size=self.num_segments)) 146 | else: 147 | offsets = np.zeros((self.num_segments,)) 148 | return [np.array(offsets) + 1 + np.random.randint(x+1) for x in differences] 149 | 150 | def _get_val_indices(self, record): 151 | if self.dense_sample: # i3d dense sample 152 | sample_pos = max(1, 1 + record[0].num_frames - self.dense_length) 153 | t_stride = self.dense_length // self.num_segments 154 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) 155 | offsets = [(idx * t_stride + start_idx) % record[0].num_frames for idx in range(self.num_segments)] 156 | return [np.array(offsets) + 1] 157 | else: 158 | if record[0].num_frames > self.num_segments + self.new_length[0] - 1: 159 | tick = (record[0].num_frames - self.new_length[0] + 1) / float(self.num_segments) 160 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) 161 | else: 162 | offsets = np.zeros((self.num_segments,)) 163 | return [offsets + 1] 164 | 165 | def _get_test_indices(self, record): 166 | if self.random_sample: 167 | sample_pos = max(1, 1 + record[0].num_frames - self.dense_length) 168 | t_stride = self.dense_length // self.num_segments 169 | start_list = np.linspace(0, sample_pos - 1, num=self.dense_number, dtype=int) if self.dense_number>1 else np.array([int((sample_pos-1)/2)]) 170 | offsets = [] 171 | for start_idx in start_list.tolist(): 172 | offsets += [(idx * t_stride + start_idx) % record[0].num_frames for idx in range(self.num_segments)] 173 | 174 | if self.twice_sample: 175 | tick = (record[0].num_frames - self.new_length[0] + 1) / float(self.num_segments) 176 | offsets += [int(tick / 2.0 + tick * x) for x in range(self.num_segments)] + [int(tick * x) for x in range(self.num_segments)] 177 | else: 178 | tick = (record[0].num_frames - self.new_length[0] + 1) / float(self.num_segments) 179 | offsets += [int(tick / 2.0 + tick * x) for x in range(self.num_segments)] 180 | 181 | return [np.array(offsets) + 1] 182 | elif self.dense_sample: 183 | sample_pos = max(1, 1 + record[0].num_frames - self.dense_length) 184 | t_stride = self.dense_length // self.num_segments 185 | start_list = np.linspace(0, sample_pos - 1, num=self.dense_number, dtype=int) if self.dense_number>1 else np.array([int((sample_pos-1)/2)]) 186 | offsets = [] 187 | for start_idx in start_list.tolist(): 188 | offsets += [(idx * t_stride + start_idx) % record[0].num_frames for idx in range(self.num_segments)] 189 | return [np.array(offsets) + 1] 190 | elif self.twice_sample: 191 | tick = (record[0].num_frames - self.new_length[0] + 1) / float(self.num_segments) 192 | 193 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)] + 194 | [int(tick * x) for x in range(self.num_segments)]) 195 | 196 | return [offsets + 1] 197 | else: 198 | tick = (record[0].num_frames - self.new_length[0] + 1) / float(self.num_segments) 199 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) 200 | return [offsets + 1] 201 | 202 | def __getitem__(self, index): 203 | record = [x[index] for x in self.video_list] 204 | # check this is a legit video folder 205 | 206 | for i in range(len(self.modality)): 207 | if self.image_tmpl[i] == 'flow_{}_{:05d}.jpg': 208 | file_name = self.image_tmpl[i].format('x', 1) 209 | full_path = os.path.join(self.root_path[i], record[i].path, file_name) 210 | elif self.image_tmpl[i] == '{:06d}-{}_{:05d}.jpg': 211 | file_name = self.image_tmpl[i].format(int(record[i].path), 'x', 1) 212 | full_path = os.path.join(self.root_path[i], '{:06d}'.format(int(record[i].path)), file_name) 213 | else: 214 | file_name = self.image_tmpl[i].format(1) 215 | full_path = os.path.join(self.root_path[i], record[i].path, file_name) 216 | 217 | while not os.path.exists(full_path): 218 | print('################## Not Found:', os.path.join(self.root_path[i], record[i].path, file_name)) 219 | index = np.random.randint(len(self.video_list[i])) 220 | record[i] = self.video_list[i][index] 221 | if self.image_tmpl[i] == 'flow_{}_{:05d}.jpg': 222 | file_name = self.image_tmpl[i].format('x', 1) 223 | full_path = os.path.join(self.root_path[i], record[i].path, file_name) 224 | elif self.image_tmpl[i] == '{:06d}-{}_{:05d}.jpg': 225 | file_name = self.image_tmpl[i].format(int(record[i].path), 'x', 1) 226 | full_path = os.path.join(self.root_path[i], '{:06d}'.format(int(record[i].path)), file_name) 227 | else: 228 | file_name = self.image_tmpl[i].format(1) 229 | full_path = os.path.join(self.root_path[i], record[i].path, file_name) 230 | 231 | if not self.test_mode: 232 | segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record) 233 | else: 234 | segment_indices = self._get_test_indices(record) 235 | return self.get(record, segment_indices) 236 | 237 | def get(self, record, indices): 238 | images = list() 239 | for seg_ind in zip(*indices): 240 | p = [int(x) for x in seg_ind] 241 | for part in range(len(self.modality)): 242 | for i in range(self.new_length[part]): 243 | seg_imgs = self._load_image(record[part].path, p[part], self.modality[part], self.root_path[part], self.image_tmpl[part]) 244 | images.extend(seg_imgs) 245 | if not self.mml: 246 | images.extend(seg_imgs) 247 | if p[part] < record[part].num_frames: 248 | p[part] += 1 249 | 250 | process_data = self.transform(images) 251 | return process_data, record[0].label 252 | 253 | def __len__(self): 254 | return len(self.video_list[0]) 255 | -------------------------------------------------------------------------------- /ops/dataset_config.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | 6 | import os 7 | 8 | ROOT_DATASET = '/ssd/ssd_srv80/video/datasets/' 9 | 10 | def return_somethingv2(modality): 11 | filename_categories = 'Something-Something-v2/lists/category.txt' 12 | filename_imglist_train = [] 13 | filename_imglist_val = [] 14 | root_data = [] 15 | prefix = [] 16 | 17 | for moda in modality: 18 | if moda == 'RGB' or moda == 'RGBDiff': 19 | root_data.append(ROOT_DATASET + 'Something-Something-v2/images') 20 | filename_imglist_train.append('Something-Something-v2/lists/train_videofolder.txt') 21 | filename_imglist_val.append('Something-Something-v2/lists/val_videofolder.txt') 22 | prefix.append('img_{:06d}.jpg') 23 | elif moda == 'Flow': 24 | root_data.append(ROOT_DATASET + 'Something-Something-v2/flow_images') 25 | filename_imglist_train.append('Something-Something-v2/lists/train_videofolder_flow.txt') 26 | filename_imglist_val.append('Something-Something-v2/lists/val_videofolder_flow.txt') 27 | prefix.append('flow_{}_{:05d}.jpg') 28 | else: 29 | raise NotImplementedError('no such modality:'+modality) 30 | 31 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 32 | 33 | def return_dataset(dataset, modality): 34 | modality = modality.split(',') 35 | #dict_single = {'jester': return_jester, 'something': return_something, 'somethingv2': return_somethingv2, 36 | # 'ucf101': return_ucf101, 'hmdb51': return_hmdb51, 37 | # 'kinetics': return_kinetics } 38 | dict_single = {'somethingv2': return_somethingv2} 39 | 40 | if dataset in dict_single: 41 | file_categories, file_imglist_train, file_imglist_val, root_data, prefix = dict_single[dataset](modality) 42 | else: 43 | raise ValueError('Unknown dataset '+dataset) 44 | 45 | file_imglist_train = [os.path.join(ROOT_DATASET, x) for x in file_imglist_train] 46 | file_imglist_val = [os.path.join(ROOT_DATASET, x) for x in file_imglist_val] 47 | 48 | if isinstance(file_categories, str): 49 | file_categories = os.path.join(ROOT_DATASET, file_categories) 50 | with open(file_categories) as f: 51 | lines = f.readlines() 52 | categories = [item.rstrip() for item in lines] 53 | else: # number of categories 54 | categories = [None] * file_categories 55 | n_class = len(categories) 56 | print('{}: {} classes'.format(dataset, n_class)) 57 | return n_class, file_imglist_train, file_imglist_val, root_data, prefix 58 | -------------------------------------------------------------------------------- /ops/models.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | 6 | from torch import nn 7 | 8 | from ops.basic_ops import ConsensusModule 9 | from ops.transforms import * 10 | from torch.nn.init import normal_, constant_ 11 | 12 | 13 | class TSN(nn.Module): 14 | def __init__(self, num_class, num_segments, modality, 15 | base_model='resnet101', new_length=None, 16 | consensus_type='avg', before_softmax=True, 17 | dropout=0.8, img_feature_dim=256, 18 | crop_num=1, partial_bn=True, print_spec=True, pretrain='imagenet', 19 | is_shift=False, shift_div=8, shift_place='blockres', fc_lr5=False, 20 | temporal_pool=False, non_local=False): 21 | super(TSN, self).__init__() 22 | self.modality = modality 23 | self.num_segments = num_segments 24 | self.reshape = True 25 | self.before_softmax = before_softmax 26 | self.dropout = dropout 27 | self.crop_num = crop_num 28 | self.consensus_type = consensus_type 29 | self.img_feature_dim = img_feature_dim # the dimension of the CNN feature to represent each frame 30 | self.pretrain = pretrain 31 | 32 | self.is_shift = is_shift 33 | self.shift_div = shift_div 34 | self.shift_place = shift_place 35 | self.base_model_name = base_model 36 | self.fc_lr5 = fc_lr5 37 | self.temporal_pool = temporal_pool 38 | self.non_local = non_local 39 | 40 | if not before_softmax and consensus_type != 'avg': 41 | raise ValueError("Only avg consensus can be used after Softmax") 42 | 43 | if new_length is None: 44 | self.new_length = 1 if modality == "RGB" else 5 45 | else: 46 | self.new_length = new_length 47 | if print_spec: 48 | print((""" 49 | Initializing TSN with base model: {}. 50 | TSN Configurations: 51 | input_modality: {} 52 | num_segments: {} 53 | new_length: {} 54 | consensus_module: {} 55 | dropout_ratio: {} 56 | img_feature_dim: {} 57 | """.format(base_model, self.modality, self.num_segments, self.new_length, consensus_type, self.dropout, self.img_feature_dim))) 58 | 59 | self._prepare_base_model(base_model) 60 | 61 | feature_dim = self._prepare_tsn(num_class) 62 | 63 | if self.modality == 'Flow': 64 | self.base_model = self._construct_flow_model(self.base_model) 65 | elif self.modality == 'RGBDiff': 66 | self.base_model = self._construct_diff_model(self.base_model) 67 | 68 | self.consensus = ConsensusModule(consensus_type) 69 | 70 | if not self.before_softmax: 71 | self.softmax = nn.Softmax() 72 | 73 | self._enable_pbn = partial_bn 74 | if partial_bn: 75 | self.partialBN(True) 76 | 77 | def _prepare_tsn(self, num_class): 78 | feature_dim = getattr(self.base_model, self.base_model.last_layer_name).in_features 79 | if self.dropout == 0: 80 | setattr(self.base_model, self.base_model.last_layer_name, nn.Linear(feature_dim, num_class)) 81 | self.new_fc = None 82 | else: 83 | setattr(self.base_model, self.base_model.last_layer_name, nn.Dropout(p=self.dropout)) 84 | self.new_fc = nn.Linear(feature_dim, num_class) 85 | 86 | std = 0.001 87 | if self.new_fc is None: 88 | normal_(getattr(self.base_model, self.base_model.last_layer_name).weight, 0, std) 89 | constant_(getattr(self.base_model, self.base_model.last_layer_name).bias, 0) 90 | else: 91 | if hasattr(self.new_fc, 'weight'): 92 | normal_(self.new_fc.weight, 0, std) 93 | constant_(self.new_fc.bias, 0) 94 | return feature_dim 95 | 96 | def _prepare_base_model(self, base_model): 97 | print('=> base model: {}'.format(base_model)) 98 | 99 | if 'resnet' in base_model: 100 | self.base_model = getattr(torchvision.models, base_model)(True if self.pretrain == 'imagenet' else False) 101 | if self.is_shift: 102 | print('Adding temporal shift...') 103 | from ops.temporal_shift import make_temporal_shift 104 | make_temporal_shift(self.base_model, self.num_segments, 105 | n_div=self.shift_div, place=self.shift_place, temporal_pool=self.temporal_pool) 106 | 107 | if self.non_local: 108 | print('Adding non-local module...') 109 | from ops.non_local import make_non_local 110 | make_non_local(self.base_model, self.num_segments) 111 | 112 | self.base_model.last_layer_name = 'fc' 113 | self.input_size = 224 114 | self.input_mean = [0.485, 0.456, 0.406] 115 | self.input_std = [0.229, 0.224, 0.225] 116 | 117 | self.base_model.avgpool = nn.AdaptiveAvgPool2d(1) 118 | 119 | if self.modality == 'Flow': 120 | self.input_mean = [0.5] 121 | self.input_std = [np.mean(self.input_std)] 122 | elif self.modality == 'RGBDiff': 123 | self.input_mean = [0.485, 0.456, 0.406] + [0] * 3 * self.new_length 124 | self.input_std = self.input_std + [np.mean(self.input_std) * 2] * 3 * self.new_length 125 | 126 | elif base_model == 'mobilenetv2': 127 | from archs.mobilenet_v2 import mobilenet_v2, InvertedResidual 128 | self.base_model = mobilenet_v2(True if self.pretrain == 'imagenet' else False) 129 | 130 | self.base_model.last_layer_name = 'classifier' 131 | self.input_size = 224 132 | self.input_mean = [0.485, 0.456, 0.406] 133 | self.input_std = [0.229, 0.224, 0.225] 134 | 135 | self.base_model.avgpool = nn.AdaptiveAvgPool2d(1) 136 | if self.is_shift: 137 | from ops.temporal_shift import TemporalShift 138 | for m in self.base_model.modules(): 139 | if isinstance(m, InvertedResidual) and len(m.conv) == 8 and m.use_res_connect: 140 | if self.print_spec: 141 | print('Adding temporal shift... {}'.format(m.use_res_connect)) 142 | m.conv[0] = TemporalShift(m.conv[0], n_segment=self.num_segments, n_div=self.shift_div) 143 | if self.modality == 'Flow': 144 | self.input_mean = [0.5] 145 | self.input_std = [np.mean(self.input_std)] 146 | elif self.modality == 'RGBDiff': 147 | self.input_mean = [0.485, 0.456, 0.406] + [0] * 3 * self.new_length 148 | self.input_std = self.input_std + [np.mean(self.input_std) * 2] * 3 * self.new_length 149 | 150 | elif base_model == 'BNInception': 151 | from archs.bn_inception import bninception 152 | self.base_model = bninception(pretrained=self.pretrain) 153 | self.input_size = self.base_model.input_size 154 | self.input_mean = self.base_model.mean 155 | self.input_std = self.base_model.std 156 | self.base_model.last_layer_name = 'fc' 157 | if self.modality == 'Flow': 158 | self.input_mean = [128] 159 | elif self.modality == 'RGBDiff': 160 | self.input_mean = self.input_mean * (1 + self.new_length) 161 | if self.is_shift: 162 | print('Adding temporal shift...') 163 | self.base_model.build_temporal_ops( 164 | self.num_segments, is_temporal_shift=self.shift_place, shift_div=self.shift_div) 165 | else: 166 | raise ValueError('Unknown base model: {}'.format(base_model)) 167 | 168 | def train(self, mode=True): 169 | """ 170 | Override the default train() to freeze the BN parameters 171 | :return: 172 | """ 173 | super(TSN, self).train(mode) 174 | count = 0 175 | if self._enable_pbn and mode: 176 | print("Freezing BatchNorm2D except the first one.") 177 | for m in self.base_model.modules(): 178 | if isinstance(m, nn.BatchNorm2d): 179 | count += 1 180 | if count >= (2 if self._enable_pbn else 1): 181 | m.eval() 182 | # shutdown update in frozen mode 183 | m.weight.requires_grad = False 184 | m.bias.requires_grad = False 185 | 186 | def partialBN(self, enable): 187 | self._enable_pbn = enable 188 | 189 | def get_optim_policies(self): 190 | first_conv_weight = [] 191 | first_conv_bias = [] 192 | normal_weight = [] 193 | normal_bias = [] 194 | lr5_weight = [] 195 | lr10_bias = [] 196 | bn = [] 197 | custom_ops = [] 198 | 199 | conv_cnt = 0 200 | bn_cnt = 0 201 | for m in self.modules(): 202 | if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv3d): 203 | ps = list(m.parameters()) 204 | conv_cnt += 1 205 | if conv_cnt == 1: 206 | first_conv_weight.append(ps[0]) 207 | if len(ps) == 2: 208 | first_conv_bias.append(ps[1]) 209 | else: 210 | normal_weight.append(ps[0]) 211 | if len(ps) == 2: 212 | normal_bias.append(ps[1]) 213 | elif isinstance(m, torch.nn.Linear): 214 | ps = list(m.parameters()) 215 | if self.fc_lr5: 216 | lr5_weight.append(ps[0]) 217 | else: 218 | normal_weight.append(ps[0]) 219 | if len(ps) == 2: 220 | if self.fc_lr5: 221 | lr10_bias.append(ps[1]) 222 | else: 223 | normal_bias.append(ps[1]) 224 | 225 | elif isinstance(m, torch.nn.BatchNorm2d): 226 | bn_cnt += 1 227 | # later BN's are frozen 228 | if not self._enable_pbn or bn_cnt == 1: 229 | bn.extend(list(m.parameters())) 230 | elif isinstance(m, torch.nn.BatchNorm3d): 231 | bn_cnt += 1 232 | # later BN's are frozen 233 | if not self._enable_pbn or bn_cnt == 1: 234 | bn.extend(list(m.parameters())) 235 | elif len(m._modules) == 0: 236 | if len(list(m.parameters())) > 0: 237 | raise ValueError("New atomic module type: {}. Need to give it a learning policy".format(type(m))) 238 | 239 | return [ 240 | {'params': first_conv_weight, 'lr_mult': 1, 'decay_mult': 1, 241 | 'name': "first_conv_weight"}, 242 | {'params': first_conv_bias, 'lr_mult': 2, 'decay_mult': 0, 243 | 'name': "first_conv_bias"}, 244 | {'params': normal_weight, 'lr_mult': 1, 'decay_mult': 1, 245 | 'name': "normal_weight"}, 246 | {'params': normal_bias, 'lr_mult': 2, 'decay_mult': 0, 247 | 'name': "normal_bias"}, 248 | {'params': bn, 'lr_mult': 1, 'decay_mult': 0, 249 | 'name': "BN scale/shift"}, 250 | {'params': custom_ops, 'lr_mult': 1, 'decay_mult': 1, 251 | 'name': "custom_ops"}, 252 | # for fc 253 | {'params': lr5_weight, 'lr_mult': 5, 'decay_mult': 1, 254 | 'name': "lr5_weight"}, 255 | {'params': lr10_bias, 'lr_mult': 10, 'decay_mult': 0, 256 | 'name': "lr10_bias"}, 257 | ] 258 | 259 | def forward(self, input, no_reshape=False): 260 | if not no_reshape: 261 | sample_len = (3 if self.modality == "RGB" else 2) * self.new_length 262 | 263 | if self.modality == 'RGBDiff': 264 | sample_len = 3 * self.new_length 265 | input = self._get_diff(input) 266 | 267 | base_out = self.base_model(input.view((-1, sample_len) + input.size()[-2:])) 268 | else: 269 | base_out = self.base_model(input) 270 | 271 | if self.dropout > 0: 272 | base_out = self.new_fc(base_out) 273 | 274 | if not self.before_softmax: 275 | base_out = self.softmax(base_out) 276 | 277 | if self.reshape: 278 | if self.is_shift and self.temporal_pool: 279 | base_out = base_out.view((-1, self.num_segments // 2) + base_out.size()[1:]) 280 | else: 281 | base_out = base_out.view((-1, self.num_segments) + base_out.size()[1:]) 282 | output = self.consensus(base_out) 283 | return output.squeeze(1) 284 | 285 | def _get_diff(self, input, keep_rgb=False): 286 | input_c = 3 if self.modality in ["RGB", "RGBDiff"] else 2 287 | input_view = input.view((-1, self.num_segments, self.new_length + 1, input_c,) + input.size()[-2:]) 288 | if keep_rgb: 289 | new_data = input_view.clone() 290 | else: 291 | new_data = input_view[:, :, 1:, :, :, :].clone() 292 | 293 | for x in reversed(list(range(1, self.new_length + 1))): 294 | if keep_rgb: 295 | new_data[:, :, x, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :] 296 | else: 297 | new_data[:, :, x - 1, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :] 298 | 299 | return new_data 300 | 301 | def _construct_flow_model(self, base_model): 302 | print('... Constructing Flow model') 303 | # modify the convolution layers 304 | # Torch models are usually defined in a hierarchical way. 305 | # nn.modules.children() return all sub modules in a DFS manner 306 | modules = list(self.base_model.modules()) 307 | first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0] 308 | conv_layer = modules[first_conv_idx] 309 | container = modules[first_conv_idx - 1] 310 | 311 | # modify parameters, assume the first blob contains the convolution kernels 312 | params = [x.clone() for x in conv_layer.parameters()] 313 | kernel_size = params[0].size() 314 | new_kernel_size = kernel_size[:1] + (2 * 5, ) + kernel_size[2:] 315 | new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous() 316 | 317 | new_conv = nn.Conv2d(2 * 5, conv_layer.out_channels, 318 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding, 319 | bias=True if len(params) == 2 else False) 320 | new_conv.weight.data = new_kernels 321 | if len(params) == 2: 322 | new_conv.bias.data = params[1].data # add bias if neccessary 323 | layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name 324 | 325 | # replace the first convlution layer 326 | setattr(container, layer_name, new_conv) 327 | return base_model 328 | 329 | def _construct_diff_model(self, base_model, keep_rgb=False): 330 | print('... Constructing Diff model') 331 | # modify the convolution layers 332 | # Torch models are usually defined in a hierarchical way. 333 | # nn.modules.children() return all sub modules in a DFS manner 334 | modules = list(self.base_model.modules()) 335 | first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0] 336 | conv_layer = modules[first_conv_idx] 337 | container = modules[first_conv_idx - 1] 338 | 339 | # modify parameters, assume the first blob contains the convolution kernels 340 | params = [x.clone() for x in conv_layer.parameters()] 341 | kernel_size = params[0].size() 342 | if not keep_rgb: 343 | new_kernel_size = kernel_size[:1] + (3 * 5,) + kernel_size[2:] 344 | new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous() 345 | else: 346 | new_kernel_size = kernel_size[:1] + (3 * 5,) + kernel_size[2:] 347 | new_kernels = torch.cat((params[0].data, params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()), 348 | 1) 349 | new_kernel_size = kernel_size[:1] + (3 + 3 * 5,) + kernel_size[2:] 350 | 351 | new_conv = nn.Conv2d(new_kernel_size[1], conv_layer.out_channels, 352 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding, 353 | bias=True if len(params) == 2 else False) 354 | new_conv.weight.data = new_kernels 355 | if len(params) == 2: 356 | new_conv.bias.data = params[1].data # add bias if neccessary 357 | layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name 358 | 359 | # replace the first convolution layer 360 | setattr(container, layer_name, new_conv) 361 | return base_model 362 | 363 | def _construct_rgb_model(self, base_model): 364 | print('... Constructing RGB model') 365 | # modify the convolution layers 366 | # Torch models are usually defined in a hierarchical way. 367 | # nn.modules.children() return all sub modules in a DFS manner 368 | modules = list(self.base_model.modules()) 369 | first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0] 370 | conv_layer = modules[first_conv_idx] 371 | container = modules[first_conv_idx - 1] 372 | 373 | # modify parameters, assume the first blob contains the convolution kernels 374 | params = [x.clone() for x in conv_layer.parameters()] 375 | kernel_size = params[0].size() 376 | new_kernel_size = kernel_size[:1] + (3 * 1, ) + kernel_size[2:] 377 | new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous() 378 | 379 | new_conv = nn.Conv2d(3 * 1, conv_layer.out_channels, 380 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding, 381 | bias=True if len(params) == 2 else False) 382 | new_conv.weight.data = new_kernels 383 | if len(params) == 2: 384 | new_conv.bias.data = params[1].data # add bias if neccessary 385 | layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name 386 | 387 | # replace the first convlution layer 388 | setattr(container, layer_name, new_conv) 389 | 390 | return base_model 391 | 392 | @property 393 | def crop_size(self): 394 | return self.input_size 395 | 396 | @property 397 | def scale_size(self): 398 | return self.input_size * 256 // 224 399 | 400 | def get_augmentation(self, flip=True): 401 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66])]) 402 | 403 | -------------------------------------------------------------------------------- /ops/non_local.py: -------------------------------------------------------------------------------- 1 | # Non-local block using embedded gaussian 2 | # Code from 3 | # https://github.com/AlexHex7/Non-local_pytorch/blob/master/Non-Local_pytorch_0.3.1/lib/non_local_embedded_gaussian.py 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | 9 | class _NonLocalBlockND(nn.Module): 10 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 11 | super(_NonLocalBlockND, self).__init__() 12 | 13 | assert dimension in [1, 2, 3] 14 | 15 | self.dimension = dimension 16 | self.sub_sample = sub_sample 17 | 18 | self.in_channels = in_channels 19 | self.inter_channels = inter_channels 20 | 21 | if self.inter_channels is None: 22 | self.inter_channels = in_channels // 2 23 | if self.inter_channels == 0: 24 | self.inter_channels = 1 25 | 26 | if dimension == 3: 27 | conv_nd = nn.Conv3d 28 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 29 | bn = nn.BatchNorm3d 30 | elif dimension == 2: 31 | conv_nd = nn.Conv2d 32 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 33 | bn = nn.BatchNorm2d 34 | else: 35 | conv_nd = nn.Conv1d 36 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 37 | bn = nn.BatchNorm1d 38 | 39 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 40 | kernel_size=1, stride=1, padding=0) 41 | 42 | if bn_layer: 43 | self.W = nn.Sequential( 44 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 45 | kernel_size=1, stride=1, padding=0), 46 | bn(self.in_channels) 47 | ) 48 | nn.init.constant_(self.W[1].weight, 0) 49 | nn.init.constant_(self.W[1].bias, 0) 50 | else: 51 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 52 | kernel_size=1, stride=1, padding=0) 53 | nn.init.constant_(self.W.weight, 0) 54 | nn.init.constant_(self.W.bias, 0) 55 | 56 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 57 | kernel_size=1, stride=1, padding=0) 58 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 59 | kernel_size=1, stride=1, padding=0) 60 | 61 | if sub_sample: 62 | self.g = nn.Sequential(self.g, max_pool_layer) 63 | self.phi = nn.Sequential(self.phi, max_pool_layer) 64 | 65 | def forward(self, x): 66 | ''' 67 | :param x: (b, c, t, h, w) 68 | :return: 69 | ''' 70 | 71 | batch_size = x.size(0) 72 | 73 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 74 | g_x = g_x.permute(0, 2, 1) 75 | 76 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 77 | theta_x = theta_x.permute(0, 2, 1) 78 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 79 | f = torch.matmul(theta_x, phi_x) 80 | f_div_C = F.softmax(f, dim=-1) 81 | 82 | y = torch.matmul(f_div_C, g_x) 83 | y = y.permute(0, 2, 1).contiguous() 84 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 85 | W_y = self.W(y) 86 | z = W_y + x 87 | 88 | return z 89 | 90 | 91 | class NONLocalBlock1D(_NonLocalBlockND): 92 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 93 | super(NONLocalBlock1D, self).__init__(in_channels, 94 | inter_channels=inter_channels, 95 | dimension=1, sub_sample=sub_sample, 96 | bn_layer=bn_layer) 97 | 98 | 99 | class NONLocalBlock2D(_NonLocalBlockND): 100 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 101 | super(NONLocalBlock2D, self).__init__(in_channels, 102 | inter_channels=inter_channels, 103 | dimension=2, sub_sample=sub_sample, 104 | bn_layer=bn_layer) 105 | 106 | 107 | class NONLocalBlock3D(_NonLocalBlockND): 108 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 109 | super(NONLocalBlock3D, self).__init__(in_channels, 110 | inter_channels=inter_channels, 111 | dimension=3, sub_sample=sub_sample, 112 | bn_layer=bn_layer) 113 | 114 | 115 | class NL3DWrapper(nn.Module): 116 | def __init__(self, block, n_segment): 117 | super(NL3DWrapper, self).__init__() 118 | self.block = block 119 | self.nl = NONLocalBlock3D(block.bn3.num_features) 120 | self.n_segment = n_segment 121 | 122 | def forward(self, x): 123 | x = self.block(x) 124 | 125 | nt, c, h, w = x.size() 126 | x = x.view(nt // self.n_segment, self.n_segment, c, h, w).transpose(1, 2) # n, c, t, h, w 127 | x = self.nl(x) 128 | x = x.transpose(1, 2).contiguous().view(nt, c, h, w) 129 | return x 130 | 131 | 132 | def make_non_local(net, n_segment): 133 | import torchvision 134 | import archs 135 | if isinstance(net, torchvision.models.ResNet): 136 | net.layer2 = nn.Sequential( 137 | NL3DWrapper(net.layer2[0], n_segment), 138 | net.layer2[1], 139 | NL3DWrapper(net.layer2[2], n_segment), 140 | net.layer2[3], 141 | ) 142 | net.layer3 = nn.Sequential( 143 | NL3DWrapper(net.layer3[0], n_segment), 144 | net.layer3[1], 145 | NL3DWrapper(net.layer3[2], n_segment), 146 | net.layer3[3], 147 | NL3DWrapper(net.layer3[4], n_segment), 148 | net.layer3[5], 149 | ) 150 | else: 151 | raise NotImplementedError 152 | 153 | 154 | if __name__ == '__main__': 155 | from torch.autograd import Variable 156 | import torch 157 | 158 | sub_sample = True 159 | bn_layer = True 160 | 161 | img = Variable(torch.zeros(2, 3, 20)) 162 | net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer) 163 | out = net(img) 164 | print(out.size()) 165 | 166 | img = Variable(torch.zeros(2, 3, 20, 20)) 167 | net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer) 168 | out = net(img) 169 | print(out.size()) 170 | 171 | img = Variable(torch.randn(2, 3, 10, 20, 20)) 172 | net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer) 173 | out = net(img) 174 | print(out.size()) -------------------------------------------------------------------------------- /ops/temporal_shift.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class TemporalShift(nn.Module): 12 | def __init__(self, net, n_segment=3, n_div=8, inplace=False): 13 | super(TemporalShift, self).__init__() 14 | self.net = net 15 | self.n_segment = n_segment 16 | self.fold_div = n_div 17 | self.inplace = inplace 18 | if inplace: 19 | print('=> Using in-place shift...') 20 | print('=> Using fold div: {}'.format(self.fold_div)) 21 | 22 | def forward(self, x): 23 | x = self.shift(x, self.n_segment, fold_div=self.fold_div, inplace=self.inplace) 24 | return self.net(x) 25 | 26 | @staticmethod 27 | def shift(x, n_segment, fold_div=3, inplace=False): 28 | nt, c, h, w = x.size() 29 | n_batch = nt // n_segment 30 | x = x.view(n_batch, n_segment, c, h, w) 31 | 32 | fold = c // fold_div 33 | if inplace: 34 | # Due to some out of order error when performing parallel computing. 35 | # May need to write a CUDA kernel. 36 | raise NotImplementedError 37 | # out = InplaceShift.apply(x, fold) 38 | else: 39 | out = torch.zeros_like(x) 40 | out[:, :-1, :fold] = x[:, 1:, :fold] # shift left 41 | out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right 42 | out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift 43 | 44 | return out.view(nt, c, h, w) 45 | 46 | 47 | class InplaceShift(torch.autograd.Function): 48 | # Special thanks to @raoyongming for the help to this function 49 | @staticmethod 50 | def forward(ctx, input, fold): 51 | # not support higher order gradient 52 | # input = input.detach_() 53 | ctx.fold_ = fold 54 | n, t, c, h, w = input.size() 55 | buffer = input.data.new(n, t, fold, h, w).zero_() 56 | buffer[:, :-1] = input.data[:, 1:, :fold] 57 | input.data[:, :, :fold] = buffer 58 | buffer.zero_() 59 | buffer[:, 1:] = input.data[:, :-1, fold: 2 * fold] 60 | input.data[:, :, fold: 2 * fold] = buffer 61 | return input 62 | 63 | @staticmethod 64 | def backward(ctx, grad_output): 65 | # grad_output = grad_output.detach_() 66 | fold = ctx.fold_ 67 | n, t, c, h, w = grad_output.size() 68 | buffer = grad_output.data.new(n, t, fold, h, w).zero_() 69 | buffer[:, 1:] = grad_output.data[:, :-1, :fold] 70 | grad_output.data[:, :, :fold] = buffer 71 | buffer.zero_() 72 | buffer[:, :-1] = grad_output.data[:, 1:, fold: 2 * fold] 73 | grad_output.data[:, :, fold: 2 * fold] = buffer 74 | return grad_output, None 75 | 76 | 77 | class TemporalPool(nn.Module): 78 | def __init__(self, net, n_segment): 79 | super(TemporalPool, self).__init__() 80 | self.net = net 81 | self.n_segment = n_segment 82 | 83 | def forward(self, x): 84 | x = self.temporal_pool(x, n_segment=self.n_segment) 85 | return self.net(x) 86 | 87 | @staticmethod 88 | def temporal_pool(x, n_segment): 89 | nt, c, h, w = x.size() 90 | n_batch = nt // n_segment 91 | x = x.view(n_batch, n_segment, c, h, w).transpose(1, 2) # n, c, t, h, w 92 | x = F.max_pool3d(x, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0)) 93 | x = x.transpose(1, 2).contiguous().view(nt // 2, c, h, w) 94 | return x 95 | 96 | 97 | def make_temporal_shift(net, n_segment, n_div=8, place='blockres', temporal_pool=False): 98 | if temporal_pool: 99 | n_segment_list = [n_segment, n_segment // 2, n_segment // 2, n_segment // 2] 100 | else: 101 | n_segment_list = [n_segment] * 4 102 | assert n_segment_list[-1] > 0 103 | print('=> n_segment per stage: {}'.format(n_segment_list)) 104 | 105 | import torchvision 106 | if isinstance(net, torchvision.models.ResNet): 107 | if place == 'block': 108 | def make_block_temporal(stage, this_segment): 109 | blocks = list(stage.children()) 110 | print('=> Processing stage with {} blocks'.format(len(blocks))) 111 | for i, b in enumerate(blocks): 112 | blocks[i] = TemporalShift(b, n_segment=this_segment, n_div=n_div) 113 | return nn.Sequential(*(blocks)) 114 | 115 | net.layer1 = make_block_temporal(net.layer1, n_segment_list[0]) 116 | net.layer2 = make_block_temporal(net.layer2, n_segment_list[1]) 117 | net.layer3 = make_block_temporal(net.layer3, n_segment_list[2]) 118 | net.layer4 = make_block_temporal(net.layer4, n_segment_list[3]) 119 | 120 | elif 'blockres' in place: 121 | n_round = 1 122 | if len(list(net.layer3.children())) >= 23: 123 | n_round = 2 124 | print('=> Using n_round {} to insert temporal shift'.format(n_round)) 125 | 126 | def make_block_temporal(stage, this_segment): 127 | blocks = list(stage.children()) 128 | print('=> Processing stage with {} blocks residual'.format(len(blocks))) 129 | for i, b in enumerate(blocks): 130 | if i % n_round == 0: 131 | blocks[i].conv1 = TemporalShift(b.conv1, n_segment=this_segment, n_div=n_div) 132 | return nn.Sequential(*blocks) 133 | 134 | net.layer1 = make_block_temporal(net.layer1, n_segment_list[0]) 135 | net.layer2 = make_block_temporal(net.layer2, n_segment_list[1]) 136 | net.layer3 = make_block_temporal(net.layer3, n_segment_list[2]) 137 | net.layer4 = make_block_temporal(net.layer4, n_segment_list[3]) 138 | else: 139 | raise NotImplementedError(place) 140 | 141 | 142 | def make_temporal_pool(net, n_segment): 143 | import torchvision 144 | if isinstance(net, torchvision.models.ResNet): 145 | print('=> Injecting nonlocal pooling') 146 | net.layer2 = TemporalPool(net.layer2, n_segment) 147 | else: 148 | raise NotImplementedError 149 | 150 | 151 | if __name__ == '__main__': 152 | # test inplace shift v.s. vanilla shift 153 | tsm1 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=False) 154 | tsm2 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=True) 155 | 156 | print('=> Testing CPU...') 157 | # test forward 158 | with torch.no_grad(): 159 | for i in range(10): 160 | x = torch.rand(2 * 8, 3, 224, 224) 161 | y1 = tsm1(x) 162 | y2 = tsm2(x) 163 | assert torch.norm(y1 - y2).item() < 1e-5 164 | 165 | # test backward 166 | with torch.enable_grad(): 167 | for i in range(10): 168 | x1 = torch.rand(2 * 8, 3, 224, 224) 169 | x1.requires_grad_() 170 | x2 = x1.clone() 171 | y1 = tsm1(x1) 172 | y2 = tsm2(x2) 173 | grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0] 174 | grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0] 175 | assert torch.norm(grad1 - grad2).item() < 1e-5 176 | 177 | print('=> Testing GPU...') 178 | tsm1.cuda() 179 | tsm2.cuda() 180 | # test forward 181 | with torch.no_grad(): 182 | for i in range(10): 183 | x = torch.rand(2 * 8, 3, 224, 224).cuda() 184 | y1 = tsm1(x) 185 | y2 = tsm2(x) 186 | assert torch.norm(y1 - y2).item() < 1e-5 187 | 188 | # test backward 189 | with torch.enable_grad(): 190 | for i in range(10): 191 | x1 = torch.rand(2 * 8, 3, 224, 224).cuda() 192 | x1.requires_grad_() 193 | x2 = x1.clone() 194 | y1 = tsm1(x1) 195 | y2 = tsm2(x2) 196 | grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0] 197 | grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0] 198 | assert torch.norm(grad1 - grad2).item() < 1e-5 199 | print('Test passed.') 200 | 201 | 202 | 203 | 204 | -------------------------------------------------------------------------------- /ops/transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import random 3 | from PIL import Image, ImageOps 4 | import numpy as np 5 | import numbers 6 | import math 7 | import torch 8 | 9 | 10 | class GroupRandomCrop(object): 11 | def __init__(self, size): 12 | if isinstance(size, numbers.Number): 13 | self.size = (int(size), int(size)) 14 | else: 15 | self.size = size 16 | 17 | def __call__(self, img_group): 18 | 19 | w, h = img_group[0].size 20 | th, tw = self.size 21 | 22 | out_images = list() 23 | 24 | x1 = random.randint(0, w - tw) 25 | y1 = random.randint(0, h - th) 26 | 27 | for img in img_group: 28 | assert(img.size[0] == w and img.size[1] == h) 29 | if w == tw and h == th: 30 | out_images.append(img) 31 | else: 32 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 33 | 34 | return out_images 35 | 36 | 37 | class GroupCenterCrop(object): 38 | def __init__(self, size): 39 | self.worker = torchvision.transforms.CenterCrop(size) 40 | 41 | def __call__(self, img_group): 42 | return [self.worker(img) for img in img_group] 43 | 44 | 45 | class GroupRandomHorizontalFlip(object): 46 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 47 | """ 48 | def __init__(self, is_flow=False): 49 | self.is_flow = is_flow 50 | 51 | def __call__(self, img_group, is_flow=False): 52 | v = random.random() 53 | if v < 0.5: 54 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] 55 | if self.is_flow: 56 | for i in range(0, len(ret), 2): 57 | ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping 58 | return ret 59 | else: 60 | return img_group 61 | 62 | 63 | class GroupNormalize(object): 64 | def __init__(self, mean, std): 65 | self.mean = mean 66 | self.std = std 67 | 68 | def __call__(self, tensor): 69 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean)) 70 | rep_std = self.std * (tensor.size()[0]//len(self.std)) 71 | 72 | # TODO: make efficient 73 | assert len(rep_mean)==tensor.size()[0] 74 | for t, m, s in zip(tensor, rep_mean, rep_std): 75 | t.sub_(m).div_(s) 76 | 77 | return tensor 78 | 79 | 80 | class GroupScale(object): 81 | """ Rescales the input PIL.Image to the given 'size'. 82 | 'size' will be the size of the smaller edge. 83 | For example, if height > width, then image will be 84 | rescaled to (size * height / width, size) 85 | size: size of the smaller edge 86 | interpolation: Default: PIL.Image.BILINEAR 87 | """ 88 | 89 | def __init__(self, size, interpolation=Image.BILINEAR): 90 | self.worker = torchvision.transforms.Resize(size, interpolation) 91 | 92 | def __call__(self, img_group): 93 | return [self.worker(img) for img in img_group] 94 | 95 | 96 | class GroupOverSample(object): 97 | def __init__(self, crop_size, scale_size=None, flip=True): 98 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) 99 | 100 | if scale_size is not None: 101 | self.scale_worker = GroupScale(scale_size) 102 | else: 103 | self.scale_worker = None 104 | self.flip = flip 105 | 106 | def __call__(self, img_group): 107 | 108 | if self.scale_worker is not None: 109 | img_group = self.scale_worker(img_group) 110 | 111 | image_w, image_h = img_group[0].size 112 | crop_w, crop_h = self.crop_size 113 | 114 | offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h) 115 | oversample_group = list() 116 | for o_w, o_h in offsets: 117 | normal_group = list() 118 | flip_group = list() 119 | for i, img in enumerate(img_group): 120 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 121 | normal_group.append(crop) 122 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 123 | 124 | if img.mode == 'L' and i % 2 == 0: 125 | flip_group.append(ImageOps.invert(flip_crop)) 126 | else: 127 | flip_group.append(flip_crop) 128 | 129 | oversample_group.extend(normal_group) 130 | if self.flip: 131 | oversample_group.extend(flip_group) 132 | return oversample_group 133 | 134 | 135 | class GroupFullResSample(object): 136 | def __init__(self, crop_size, scale_size=None, flip=True): 137 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) 138 | 139 | if scale_size is not None: 140 | self.scale_worker = GroupScale(scale_size) 141 | else: 142 | self.scale_worker = None 143 | self.flip = flip 144 | 145 | def __call__(self, img_group): 146 | 147 | if self.scale_worker is not None: 148 | img_group = self.scale_worker(img_group) 149 | 150 | image_w, image_h = img_group[0].size 151 | crop_w, crop_h = self.crop_size 152 | 153 | w_step = (image_w - crop_w) // 4 154 | h_step = (image_h - crop_h) // 4 155 | 156 | offsets = list() 157 | offsets.append((0 * w_step, 2 * h_step)) # left 158 | offsets.append((4 * w_step, 2 * h_step)) # right 159 | offsets.append((2 * w_step, 2 * h_step)) # center 160 | 161 | oversample_group = list() 162 | for o_w, o_h in offsets: 163 | normal_group = list() 164 | flip_group = list() 165 | for i, img in enumerate(img_group): 166 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 167 | normal_group.append(crop) 168 | if self.flip: 169 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 170 | 171 | if img.mode == 'L' and i % 2 == 0: 172 | flip_group.append(ImageOps.invert(flip_crop)) 173 | else: 174 | flip_group.append(flip_crop) 175 | 176 | oversample_group.extend(normal_group) 177 | oversample_group.extend(flip_group) 178 | return oversample_group 179 | 180 | 181 | class GroupMultiScaleCrop(object): 182 | 183 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 184 | self.scales = scales if scales is not None else [1, .875, .75, .66] 185 | self.max_distort = max_distort 186 | self.fix_crop = fix_crop 187 | self.more_fix_crop = more_fix_crop 188 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 189 | self.interpolation = Image.BILINEAR 190 | 191 | def __call__(self, img_group): 192 | 193 | im_size = img_group[0].size 194 | 195 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 196 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 197 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) 198 | for img in crop_img_group] 199 | return ret_img_group 200 | 201 | def _sample_crop_size(self, im_size): 202 | image_w, image_h = im_size[0], im_size[1] 203 | 204 | # find a crop size 205 | base_size = min(image_w, image_h) 206 | crop_sizes = [int(base_size * x) for x in self.scales] 207 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 208 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 209 | 210 | pairs = [] 211 | for i, h in enumerate(crop_h): 212 | for j, w in enumerate(crop_w): 213 | if abs(i - j) <= self.max_distort: 214 | pairs.append((w, h)) 215 | 216 | crop_pair = random.choice(pairs) 217 | if not self.fix_crop: 218 | w_offset = random.randint(0, image_w - crop_pair[0]) 219 | h_offset = random.randint(0, image_h - crop_pair[1]) 220 | else: 221 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 222 | 223 | return crop_pair[0], crop_pair[1], w_offset, h_offset 224 | 225 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 226 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 227 | return random.choice(offsets) 228 | 229 | @staticmethod 230 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 231 | w_step = (image_w - crop_w) // 4 232 | h_step = (image_h - crop_h) // 4 233 | 234 | ret = list() 235 | ret.append((0, 0)) # upper left 236 | ret.append((4 * w_step, 0)) # upper right 237 | ret.append((0, 4 * h_step)) # lower left 238 | ret.append((4 * w_step, 4 * h_step)) # lower right 239 | ret.append((2 * w_step, 2 * h_step)) # center 240 | 241 | if more_fix_crop: 242 | ret.append((0, 2 * h_step)) # center left 243 | ret.append((4 * w_step, 2 * h_step)) # center right 244 | ret.append((2 * w_step, 4 * h_step)) # lower center 245 | ret.append((2 * w_step, 0 * h_step)) # upper center 246 | 247 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 248 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 249 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 250 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 251 | 252 | return ret 253 | 254 | 255 | class GroupRandomSizedCrop(object): 256 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size 257 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 258 | This is popularly used to train the Inception networks 259 | size: size of the smaller edge 260 | interpolation: Default: PIL.Image.BILINEAR 261 | """ 262 | def __init__(self, size, interpolation=Image.BILINEAR): 263 | self.size = size 264 | self.interpolation = interpolation 265 | 266 | def __call__(self, img_group): 267 | for attempt in range(10): 268 | area = img_group[0].size[0] * img_group[0].size[1] 269 | target_area = random.uniform(0.08, 1.0) * area 270 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 271 | 272 | w = int(round(math.sqrt(target_area * aspect_ratio))) 273 | h = int(round(math.sqrt(target_area / aspect_ratio))) 274 | 275 | if random.random() < 0.5: 276 | w, h = h, w 277 | 278 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]: 279 | x1 = random.randint(0, img_group[0].size[0] - w) 280 | y1 = random.randint(0, img_group[0].size[1] - h) 281 | found = True 282 | break 283 | else: 284 | found = False 285 | x1 = 0 286 | y1 = 0 287 | 288 | if found: 289 | out_group = list() 290 | for img in img_group: 291 | img = img.crop((x1, y1, x1 + w, y1 + h)) 292 | assert(img.size == (w, h)) 293 | out_group.append(img.resize((self.size, self.size), self.interpolation)) 294 | return out_group 295 | else: 296 | # Fallback 297 | scale = GroupScale(self.size, interpolation=self.interpolation) 298 | crop = GroupRandomCrop(self.size) 299 | return crop(scale(img_group)) 300 | 301 | 302 | class Stack(object): 303 | 304 | def __init__(self, roll=False): 305 | self.roll = roll 306 | 307 | def __call__(self, img_group): 308 | for i in range(len(img_group)): 309 | if img_group[i].mode == 'L': 310 | img_group[i] = np.expand_dims(img_group[i], 2) 311 | elif img_group[i].mode == 'RGB': 312 | if self.roll: 313 | img_group[i] = np.array(img_group[i])[:, :, ::-1] 314 | else: 315 | img_group[i] = img_group[i] 316 | return np.concatenate(img_group, axis=2) 317 | 318 | 319 | class ToTorchFormatTensor(object): 320 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 321 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 322 | def __init__(self, div=True): 323 | self.div = div 324 | 325 | def __call__(self, pic): 326 | if isinstance(pic, np.ndarray): 327 | # handle numpy array 328 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 329 | else: 330 | # handle PIL Image 331 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 332 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 333 | # put it from HWC to CHW format 334 | # yikes, this transpose takes 80% of the loading time/CPU 335 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 336 | return img.float().div(255) if self.div else img.float() 337 | 338 | 339 | class IdentityTransform(object): 340 | 341 | def __call__(self, data): 342 | return data 343 | 344 | 345 | if __name__ == "__main__": 346 | trans = torchvision.transforms.Compose([ 347 | GroupScale(256), 348 | GroupRandomCrop(224), 349 | Stack(), 350 | ToTorchFormatTensor(), 351 | GroupNormalize( 352 | mean=[.485, .456, .406], 353 | std=[.229, .224, .225] 354 | )] 355 | ) 356 | 357 | im = Image.open('../tensorflow-model-zoo.torch/lena_299.png') 358 | 359 | color_group = [im] * 3 360 | rst = trans(color_group) 361 | 362 | gray_group = [im.convert('L')] * 9 363 | gray_rst = trans(gray_group) 364 | 365 | trans2 = torchvision.transforms.Compose([ 366 | GroupRandomSizedCrop(256), 367 | Stack(), 368 | ToTorchFormatTensor(), 369 | GroupNormalize( 370 | mean=[.485, .456, .406], 371 | std=[.229, .224, .225]) 372 | ]) 373 | print(trans2(color_group)) 374 | -------------------------------------------------------------------------------- /ops/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def softmax(scores): 5 | es = np.exp(scores - scores.max(axis=-1)[..., None]) 6 | return es / es.sum(axis=-1)[..., None] 7 | 8 | 9 | class AverageMeter(object): 10 | """Computes and stores the average and current value""" 11 | 12 | def __init__(self): 13 | self.reset() 14 | 15 | def reset(self): 16 | self.val = 0 17 | self.avg = 0 18 | self.sum = 0 19 | self.count = 0 20 | 21 | def update(self, val, n=1): 22 | self.val = val 23 | self.sum += val * n 24 | self.count += n 25 | self.avg = self.sum / self.count 26 | 27 | 28 | def accuracy(output, target, topk=(1,)): 29 | """Computes the precision@k for the specified values of k""" 30 | maxk = max(topk) 31 | batch_size = target.size(0) 32 | 33 | _, pred = output.topk(maxk, 1, True, True) 34 | pred = pred.t() 35 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 36 | 37 | res = [] 38 | for k in topk: 39 | correct_k = correct[:k].view(-1).float().sum(0) 40 | res.append(correct_k.mul_(100.0 / batch_size)) 41 | return res -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | 6 | import argparse 7 | parser = argparse.ArgumentParser(description="PyTorch implementation of Temporal Segment Networks") 8 | parser.add_argument('dataset', type=str) 9 | parser.add_argument('modality', type=str) 10 | parser.add_argument('--train_list', type=str, default="") 11 | parser.add_argument('--val_list', type=str, default="") 12 | parser.add_argument('--root_path', type=str, default="") 13 | parser.add_argument('--store_name', type=str, default="") 14 | # ========================= Model Configs ========================== 15 | parser.add_argument('--arch', type=str, default="BNInception") 16 | parser.add_argument('--num_segments', type=int, default=3) 17 | parser.add_argument('--consensus_type', type=str, default='avg') 18 | parser.add_argument('--k', type=int, default=3) 19 | 20 | parser.add_argument('--dropout', '--do', default=0.5, type=float, 21 | metavar='DO', help='dropout ratio (default: 0.5)') 22 | parser.add_argument('--loss_type', type=str, default="nll", 23 | choices=['nll']) 24 | parser.add_argument('--img_feature_dim', default=256, type=int, help="the feature dimension for each frame") 25 | parser.add_argument('--suffix', type=str, default=None) 26 | parser.add_argument('--pretrain', type=str, default='imagenet') 27 | parser.add_argument('--tune_from', type=str, default=None, help='fine-tune from checkpoint') 28 | 29 | # ========================= Learning Configs ========================== 30 | parser.add_argument('--epochs', default=120, type=int, metavar='N', 31 | help='number of total epochs to run') 32 | parser.add_argument('-b', '--batch-size', default=128, type=int, 33 | metavar='N', help='mini-batch size (default: 256)') 34 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, 35 | metavar='LR', help='initial learning rate') 36 | parser.add_argument('--lr_type', default='step', type=str, 37 | metavar='LRtype', help='learning rate type') 38 | parser.add_argument('--lr_steps', default=[50, 100], type=float, nargs="+", 39 | metavar='LRSteps', help='epochs to decay learning rate by 10') 40 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 41 | help='momentum') 42 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 43 | metavar='W', help='weight decay (default: 5e-4)') 44 | parser.add_argument('--clip-gradient', '--gd', default=None, type=float, 45 | metavar='W', help='gradient norm clipping (default: disabled)') 46 | parser.add_argument('--no_partialbn', '--npb', default=False, action="store_true") 47 | 48 | # ========================= Monitor Configs ========================== 49 | parser.add_argument('--print-freq', '-p', default=20, type=int, 50 | metavar='N', help='print frequency (default: 10)') 51 | parser.add_argument('--eval-freq', '-ef', default=5, type=int, 52 | metavar='N', help='evaluation frequency (default: 5)') 53 | 54 | 55 | # ========================= Runtime Configs ========================== 56 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 57 | help='number of data loading workers (default: 8)') 58 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 59 | help='path to latest checkpoint (default: none)') 60 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 61 | help='evaluate model on validation set') 62 | parser.add_argument('--snapshot_pref', type=str, default="") 63 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 64 | help='manual epoch number (useful on restarts)') 65 | parser.add_argument('--flow_prefix', default="", type=str) 66 | parser.add_argument('--root_log',type=str, default='log') 67 | parser.add_argument('--root_model', type=str, default='checkpoint') 68 | 69 | parser.add_argument('--shift', default=False, action="store_true", help='use shift for models') 70 | parser.add_argument('--shift_div', default=8, type=int, help='number of div for shift (default: 8)') 71 | parser.add_argument('--shift_place', default='blockres', type=str, help='place for shift (default: stageres)') 72 | 73 | parser.add_argument('--temporal_pool', default=False, action="store_true", help='add temporal pooling') 74 | parser.add_argument('--non_local', default=False, action="store_true", help='add non local block') 75 | 76 | parser.add_argument('--dense_sample', default=False, action="store_true", help='use dense sample for video dataset') 77 | parser.add_argument('--random_sample', default=False, action="store_true", help='use random sample for video dataset') 78 | parser.add_argument('--dense_length', type=int, default=32, help='Length of frames part while dense sampling') 79 | parser.add_argument('--dense_number', type=int, default=1, help='Number of runs for dense sampling test') 80 | 81 | 82 | # ========================= Mutual Learning Configs ========================== 83 | parser.add_argument('--rank', default=0, type=int) 84 | parser.add_argument('--world_size', default=1, type=int) 85 | parser.add_argument('--init_method', default='tcp://127.0.0.1:52525', type=str) 86 | parser.add_argument('--gpus', default='0,1,2,3', type=str) 87 | -------------------------------------------------------------------------------- /scripts/finetune_tsm_ucf101_rgb_8f.sh: -------------------------------------------------------------------------------- 1 | python main.py ucf101 RGB \ 2 | --arch resnet50 --num_segments 8 \ 3 | --gd 20 --lr 0.001 --lr_steps 10 20 --epochs 25 \ 4 | --batch-size 64 -j 16 --dropout 0.8 --consensus_type=avg --eval-freq=1 \ 5 | --shift --shift_div=8 --shift_place=blockres \ 6 | --tune_from=pretrained/TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth -------------------------------------------------------------------------------- /scripts/test_tsm_kinetics_rgb_8f.sh: -------------------------------------------------------------------------------- 1 | # test the TSN and TSM on Kinetics using 8-frame, you should get top-1 accuracy around: 2 | # TSN: 68.8% 3 | # TSM: 71.2% 4 | 5 | # test TSN 6 | python test_models.py kinetics \ 7 | --weights=pretrained/TSM_kinetics_RGB_resnet50_avg_segment5_e50.pth \ 8 | --test_segments=8 --test_crops=1 \ 9 | --batch_size=64 10 | 11 | # test TSM 12 | python test_models.py kinetics \ 13 | --weights=pretrained/TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth \ 14 | --test_segments=8 --test_crops=1 \ 15 | --batch_size=64 -------------------------------------------------------------------------------- /scripts/train_tsm_kinetics_rgb_16f.sh: -------------------------------------------------------------------------------- 1 | # You should get TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment16_e50 2 | python main.py kinetics RGB \ 3 | --arch resnet50 --num_segments 16 \ 4 | --gd 20 --lr 0.02 --wd 1e-4 --lr_steps 20 40 --epochs 50 \ 5 | --batch-size 128 -j 16 --dropout 0.5 --consensus_type=avg --eval-freq=1 \ 6 | --shift --shift_div=8 --shift_place=blockres --npb -------------------------------------------------------------------------------- /scripts/train_tsm_kinetics_rgb_8f.sh: -------------------------------------------------------------------------------- 1 | # You should get TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth 2 | python main.py kinetics RGB \ 3 | --arch resnet50 --num_segments 8 \ 4 | --gd 20 --lr 0.02 --wd 1e-4 --lr_steps 20 40 --epochs 50 \ 5 | --batch-size 128 -j 16 --dropout 0.5 --consensus_type=avg --eval-freq=1 \ 6 | --shift --shift_div=8 --shift_place=blockres --npb -------------------------------------------------------------------------------- /scripts/train_tsn_kinetics_rgb_5f.sh: -------------------------------------------------------------------------------- 1 | # You should get TSM_kinetics_RGB_resnet50_avg_segment5_e50 2 | # Notice that for TSN 2D baseline, it is recommended to train using 5 segments and test with more segments to avoid overfitting 3 | 4 | python main.py kinetics RGB \ 5 | --arch resnet50 --num_segments 5 \ 6 | --gd 20 --lr 0.02 --wd 1e-4 --lr_steps 20 40 --epochs 50 \ 7 | --batch-size 128 -j 16 --dropout 0.5 --consensus_type=avg --eval-freq=1 \ 8 | --npb -------------------------------------------------------------------------------- /test_models.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | 6 | # Notice that this file has been modified to support ensemble testing 7 | 8 | import argparse 9 | import time 10 | 11 | import torch.nn.parallel 12 | import torch.optim 13 | from sklearn.metrics import confusion_matrix 14 | from ops.dataset import TSNDataSet 15 | from ops.models import TSN 16 | from ops.transforms import * 17 | from ops import dataset_config 18 | from torch.nn import functional as F 19 | 20 | # options 21 | parser = argparse.ArgumentParser(description="TSM testing on the full validation set") 22 | parser.add_argument('dataset', type=str) 23 | 24 | # may contain splits 25 | parser.add_argument('--weights', type=str, default=None) 26 | parser.add_argument('--test_segments', type=str, default=25) 27 | parser.add_argument('--dense_sample', default=False, action="store_true", help='use dense sample as I3D') 28 | parser.add_argument('--dense_length', type=int, default=32, help='Num of frames for the dense sampling') 29 | parser.add_argument('--dense_number', type=int, default=1, help='Num of dense samples') 30 | parser.add_argument('--twice_sample', default=False, action="store_true", help='use twice sample for ensemble') 31 | parser.add_argument('--random_sample', default=False, action="store_true", help='use random sample for ensemble') 32 | parser.add_argument('--full_res', default=False, action="store_true", 33 | help='use full resolution 256x256 for test as in Non-local I3D') 34 | 35 | parser.add_argument('--test_crops', type=int, default=1) 36 | parser.add_argument('--coeff', type=str, default=None) 37 | parser.add_argument('--batch_size', type=int, default=1) 38 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 39 | help='number of data loading workers (default: 8)') 40 | 41 | # for true test 42 | parser.add_argument('--test_list', type=str, default=None) 43 | parser.add_argument('--csv_file', type=str, default=None) 44 | 45 | parser.add_argument('--softmax', default=False, action="store_true", help='use softmax') 46 | 47 | parser.add_argument('--max_num', type=int, default=-1) 48 | parser.add_argument('--input_size', type=int, default=224) 49 | parser.add_argument('--crop_fusion_type', type=str, default='avg') 50 | parser.add_argument('--gpus', nargs='+', type=int, default=None) 51 | parser.add_argument('--img_feature_dim',type=int, default=256) 52 | parser.add_argument('--num_set_segments',type=int, default=1,help='TODO: select multiply set of n-frames from a video') 53 | parser.add_argument('--pretrain', type=str, default='imagenet') 54 | 55 | args = parser.parse_args() 56 | 57 | 58 | class AverageMeter(object): 59 | """Computes and stores the average and current value""" 60 | def __init__(self): 61 | self.reset() 62 | 63 | def reset(self): 64 | self.val = 0 65 | self.avg = 0 66 | self.sum = 0 67 | self.count = 0 68 | 69 | def update(self, val, n=1): 70 | self.val = val 71 | self.sum += val * n 72 | self.count += n 73 | self.avg = self.sum / self.count 74 | 75 | 76 | def accuracy(output, target, topk=(1,)): 77 | """Computes the precision@k for the specified values of k""" 78 | maxk = max(topk) 79 | batch_size = target.size(0) 80 | _, pred = output.topk(maxk, 1, True, True) 81 | pred = pred.t() 82 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 83 | res = [] 84 | for k in topk: 85 | correct_k = correct[:k].view(-1).float().sum(0) 86 | res.append(correct_k.mul_(100.0 / batch_size)) 87 | return res 88 | 89 | 90 | def parse_shift_option_from_log_name(log_name): 91 | if 'shift' in log_name: 92 | strings = log_name.split('_') 93 | for i, s in enumerate(strings): 94 | if 'shift' in s: 95 | break 96 | return True, int(strings[i].replace('shift', '')), strings[i + 1] 97 | else: 98 | return False, None, None 99 | 100 | 101 | weights_list = args.weights.split(',') 102 | test_segments_list = [int(s) for s in args.test_segments.split(',')] 103 | assert len(weights_list) == len(test_segments_list) 104 | if args.coeff is None: 105 | coeff_list = [1] * len(weights_list) 106 | else: 107 | coeff_list = [float(c) for c in args.coeff.split(',')] 108 | 109 | if args.test_list is not None: 110 | test_file_list = args.test_list.split(',') 111 | else: 112 | test_file_list = [None] * len(weights_list) 113 | 114 | 115 | data_iter_list = [] 116 | net_list = [] 117 | modality_list = [] 118 | 119 | total_num = None 120 | for this_weights, this_test_segments, test_file in zip(weights_list, test_segments_list, test_file_list): 121 | is_shift, shift_div, shift_place = parse_shift_option_from_log_name(this_weights) 122 | if 'RGBDiff' in this_weights: 123 | modality = 'RGBDiff' 124 | new_length = 6 125 | elif 'Flow' in this_weights: 126 | modality = 'Flow' 127 | new_length = 5 128 | else: 129 | modality = 'RGB' 130 | new_length = 1 131 | this_arch = this_weights.split('TSM_')[1].split('_')[2] 132 | modality_list.append(modality) 133 | num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(args.dataset, 134 | modality) 135 | print('=> shift: {}, shift_div: {}, shift_place: {}'.format(is_shift, shift_div, shift_place)) 136 | net = TSN(num_class, this_test_segments if is_shift else 1, modality, 137 | base_model=this_arch, 138 | consensus_type=args.crop_fusion_type, 139 | img_feature_dim=args.img_feature_dim, 140 | pretrain=args.pretrain, 141 | is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, 142 | non_local='_nl' in this_weights, 143 | ) 144 | 145 | if 'tpool' in this_weights: 146 | from ops.temporal_shift import make_temporal_pool 147 | make_temporal_pool(net.base_model, this_test_segments) # since DataParallel 148 | 149 | checkpoint = torch.load(this_weights) 150 | checkpoint = checkpoint['state_dict'] 151 | 152 | # base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())} 153 | base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())} 154 | replace_dict = {'base_model.classifier.weight': 'new_fc.weight', 155 | 'base_model.classifier.bias': 'new_fc.bias', 156 | } 157 | for k, v in replace_dict.items(): 158 | if k in base_dict: 159 | base_dict[v] = base_dict.pop(k) 160 | 161 | net.load_state_dict(base_dict) 162 | 163 | input_size = net.scale_size if args.full_res else net.input_size 164 | if args.test_crops == 1: 165 | cropping = torchvision.transforms.Compose([ 166 | GroupScale(net.scale_size), 167 | GroupCenterCrop(input_size), 168 | ]) 169 | elif args.test_crops == 3: # do not flip, so only 5 crops 170 | cropping = torchvision.transforms.Compose([ 171 | GroupFullResSample(input_size, net.scale_size, flip=False) 172 | ]) 173 | elif args.test_crops == 5: # do not flip, so only 5 crops 174 | cropping = torchvision.transforms.Compose([ 175 | GroupOverSample(input_size, net.scale_size, flip=False) 176 | ]) 177 | elif args.test_crops == 10: 178 | cropping = torchvision.transforms.Compose([ 179 | GroupOverSample(input_size, net.scale_size) 180 | ]) 181 | else: 182 | raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(args.test_crops)) 183 | 184 | data_loader = torch.utils.data.DataLoader( 185 | TSNDataSet(root_path, [test_file] if test_file is not None else val_list, num_segments=this_test_segments, 186 | new_length=[new_length], 187 | modality=[modality], 188 | image_tmpl=prefix, 189 | test_mode=True, 190 | remove_missing=len(weights_list) == 1, 191 | transform=torchvision.transforms.Compose([ 192 | cropping, 193 | Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])), 194 | ToTorchFormatTensor(div=(this_arch not in ['BNInception', 'InceptionV3'])), 195 | GroupNormalize(net.input_mean, net.input_std) if modality!='RGBDiff' else IdentityTransform(), 196 | ]), dense_sample=args.dense_sample, dense_length=args.dense_length, dense_number=args.dense_number, 197 | twice_sample=args.twice_sample, random_sample=args.random_sample), 198 | batch_size=args.batch_size, shuffle=False, 199 | num_workers=args.workers, pin_memory=True, 200 | ) 201 | 202 | if args.gpus is not None: 203 | devices = [args.gpus[i] for i in range(args.workers)] 204 | else: 205 | devices = list(range(args.workers)) 206 | 207 | net = torch.nn.DataParallel(net.cuda()) 208 | net.eval() 209 | 210 | data_gen = enumerate(data_loader) 211 | 212 | if total_num is None: 213 | total_num = len(data_loader.dataset) 214 | else: 215 | assert total_num == len(data_loader.dataset) 216 | 217 | data_iter_list.append(data_gen) 218 | net_list.append(net) 219 | 220 | 221 | output = [] 222 | 223 | 224 | def eval_video(video_data, net, this_test_segments, modality): 225 | net.eval() 226 | with torch.no_grad(): 227 | i, data, label = video_data 228 | batch_size = label.numel() 229 | num_crop = args.test_crops 230 | if args.random_sample: 231 | if args.twice_sample: 232 | num_crop *= (args.dense_number+2) 233 | else: 234 | num_crop *= (args.dense_number+1) 235 | elif args.dense_sample: 236 | num_crop *= args.dense_number # 10 clips for testing when using dense sample 237 | elif args.twice_sample: 238 | num_crop *= 2 239 | 240 | if modality == 'RGB': 241 | length = 3 242 | elif modality == 'Flow': 243 | length = 10 244 | elif modality == 'RGBDiff': 245 | length = 18 246 | else: 247 | raise ValueError("Unknown modality "+ modality) 248 | 249 | data_in = data.view(-1, length, data.size(2), data.size(3)) 250 | if is_shift: 251 | data_in = data_in.view(batch_size * num_crop, this_test_segments, length, data_in.size(2), data_in.size(3)) 252 | rst = net(data_in) 253 | rst = rst.reshape(batch_size, num_crop, -1).mean(1) 254 | 255 | if args.softmax: 256 | # take the softmax to normalize the output to probability 257 | rst = F.softmax(rst, dim=1) 258 | 259 | rst = rst.data.cpu().numpy().copy() 260 | 261 | if net.module.is_shift: 262 | rst = rst.reshape(batch_size, num_class) 263 | else: 264 | rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class)) 265 | 266 | return i, rst, label 267 | 268 | 269 | proc_start_time = time.time() 270 | max_num = args.max_num if args.max_num > 0 else total_num 271 | 272 | top1 = AverageMeter() 273 | top5 = AverageMeter() 274 | 275 | for i, data_label_pairs in enumerate(zip(*data_iter_list)): 276 | with torch.no_grad(): 277 | if i >= max_num: 278 | break 279 | this_rst_list = [] 280 | this_label = None 281 | for n_seg, (_, (data, label)), net, modality in zip(test_segments_list, data_label_pairs, net_list, modality_list): 282 | rst = eval_video((i, data, label), net, n_seg, modality) 283 | this_rst_list.append(rst[1]) 284 | this_label = label 285 | assert len(this_rst_list) == len(coeff_list) 286 | for i_coeff in range(len(this_rst_list)): 287 | this_rst_list[i_coeff] *= coeff_list[i_coeff] 288 | ensembled_predict = sum(this_rst_list) / len(this_rst_list) 289 | 290 | for p, g in zip(ensembled_predict, this_label.cpu().numpy()): 291 | output.append([p[None, ...], g]) 292 | cnt_time = time.time() - proc_start_time 293 | prec1, prec5 = accuracy(torch.from_numpy(ensembled_predict), this_label, topk=(1, 5)) 294 | top1.update(prec1.item(), this_label.numel()) 295 | top5.update(prec5.item(), this_label.numel()) 296 | if i % 20 == 0: 297 | print('video {} done, total {}/{}, average {:.3f} sec/video, ' 298 | 'moving Prec@1 {:.3f} Prec@5 {:.3f}'.format(i * args.batch_size, i * args.batch_size, total_num, 299 | float(cnt_time) / (i+1) / args.batch_size, top1.avg, top5.avg)) 300 | 301 | video_pred = [np.argmax(x[0]) for x in output] 302 | video_pred_top5 = [np.argsort(np.mean(x[0], axis=0).reshape(-1))[::-1][:5] for x in output] 303 | 304 | video_labels = [x[1] for x in output] 305 | 306 | 307 | if args.csv_file is not None: 308 | print('=> Writing result to csv file: {}'.format(args.csv_file)) 309 | with open(test_file_list[0].replace('test_videofolder.txt', 'category.txt')) as f: 310 | categories = f.readlines() 311 | categories = [f.strip() for f in categories] 312 | with open(test_file_list[0]) as f: 313 | vid_names = f.readlines() 314 | vid_names = [n.split(' ')[0] for n in vid_names] 315 | assert len(vid_names) == len(video_pred) 316 | if args.dataset != 'somethingv2': # only output top1 317 | with open(args.csv_file, 'w') as f: 318 | for n, pred in zip(vid_names, video_pred): 319 | f.write('{};{}\n'.format(n, categories[pred])) 320 | else: 321 | with open(args.csv_file, 'w') as f: 322 | for n, pred5 in zip(vid_names, video_pred_top5): 323 | fill = [n] 324 | for p in list(pred5): 325 | fill.append(p) 326 | f.write('{};{};{};{};{};{}\n'.format(*fill)) 327 | 328 | 329 | cf = confusion_matrix(video_labels, video_pred).astype(float) 330 | 331 | np.save('cm.npy', cf) 332 | cls_cnt = cf.sum(axis=1) 333 | cls_hit = np.diag(cf) 334 | 335 | cls_acc = cls_hit / cls_cnt 336 | print(cls_acc) 337 | upper = np.mean(np.max(cf, axis=1) / cls_cnt) 338 | print('upper bound: {}'.format(upper)) 339 | 340 | print('-----Evaluation is finished------') 341 | print('Class Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100)) 342 | print('Overall Prec@1 {:.02f}% Prec@5 {:.02f}%'.format(top1.avg, top5.avg)) 343 | 344 | 345 | -------------------------------------------------------------------------------- /tools/gen_label_kinetics.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | # ------------------------------------------------------ 6 | # Code adapted from https://github.com/metalbubble/TRN-pytorch/blob/master/process_dataset.py 7 | 8 | import os 9 | 10 | 11 | dataset_path = '/ssd/video/kinetics/images256/' 12 | label_path = '/ssd/video/kinetics/labels' 13 | 14 | if __name__ == '__main__': 15 | with open('kinetics_label_map.txt') as f: 16 | categories = f.readlines() 17 | categories = [c.strip().replace(' ', '_').replace('"', '').replace('(', '').replace(')', '').replace("'", '') for c in categories] 18 | assert len(set(categories)) == 400 19 | dict_categories = {} 20 | for i, category in enumerate(categories): 21 | dict_categories[category] = i 22 | 23 | print(dict_categories) 24 | 25 | files_input = ['kinetics_val.csv', 'kinetics_train.csv'] 26 | files_output = ['val_videofolder.txt', 'train_videofolder.txt'] 27 | for (filename_input, filename_output) in zip(files_input, files_output): 28 | count_cat = {k: 0 for k in dict_categories.keys()} 29 | with open(os.path.join(label_path, filename_input)) as f: 30 | lines = f.readlines()[1:] 31 | folders = [] 32 | idx_categories = [] 33 | categories_list = [] 34 | for line in lines: 35 | line = line.rstrip() 36 | items = line.split(',') 37 | folders.append(items[1] + '_' + items[2]) 38 | this_catergory = items[0].replace(' ', '_').replace('"', '').replace('(', '').replace(')', '').replace("'", '') 39 | categories_list.append(this_catergory) 40 | idx_categories.append(dict_categories[this_catergory]) 41 | count_cat[this_catergory] += 1 42 | print(max(count_cat.values())) 43 | 44 | assert len(idx_categories) == len(folders) 45 | missing_folders = [] 46 | output = [] 47 | for i in range(len(folders)): 48 | curFolder = folders[i] 49 | curIDX = idx_categories[i] 50 | # counting the number of frames in each video folders 51 | img_dir = os.path.join(dataset_path, categories_list[i], curFolder) 52 | if not os.path.exists(img_dir): 53 | missing_folders.append(img_dir) 54 | # print(missing_folders) 55 | else: 56 | dir_files = os.listdir(img_dir) 57 | output.append('%s %d %d'%(os.path.join(categories_list[i], curFolder), len(dir_files), curIDX)) 58 | print('%d/%d, missing %d'%(i, len(folders), len(missing_folders))) 59 | with open(os.path.join(label_path, filename_output),'w') as f: 60 | f.write('\n'.join(output)) 61 | with open(os.path.join(label_path, 'missing_' + filename_output),'w') as f: 62 | f.write('\n'.join(missing_folders)) 63 | -------------------------------------------------------------------------------- /tools/gen_label_sthv1.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | # ------------------------------------------------------ 6 | # Code adapted from https://github.com/metalbubble/TRN-pytorch/blob/master/process_dataset.py 7 | # processing the raw data of the video Something-Something-V1 8 | 9 | import os 10 | 11 | if __name__ == '__main__': 12 | dataset_name = 'something-something-v1' # 'jester-v1' 13 | with open('%s-labels.csv' % dataset_name) as f: 14 | lines = f.readlines() 15 | categories = [] 16 | for line in lines: 17 | line = line.rstrip() 18 | categories.append(line) 19 | categories = sorted(categories) 20 | with open('category.txt', 'w') as f: 21 | f.write('\n'.join(categories)) 22 | 23 | dict_categories = {} 24 | for i, category in enumerate(categories): 25 | dict_categories[category] = i 26 | 27 | files_input = ['%s-validation.csv' % dataset_name, '%s-train.csv' % dataset_name] 28 | files_output = ['val_videofolder.txt', 'train_videofolder.txt'] 29 | for (filename_input, filename_output) in zip(files_input, files_output): 30 | with open(filename_input) as f: 31 | lines = f.readlines() 32 | folders = [] 33 | idx_categories = [] 34 | for line in lines: 35 | line = line.rstrip() 36 | items = line.split(';') 37 | folders.append(items[0]) 38 | idx_categories.append(dict_categories[items[1]]) 39 | output = [] 40 | for i in range(len(folders)): 41 | curFolder = folders[i] 42 | curIDX = idx_categories[i] 43 | # counting the number of frames in each video folders 44 | dir_files = os.listdir(os.path.join('../img', curFolder)) 45 | output.append('%s %d %d' % ('something/v1/img/' + curFolder, len(dir_files), curIDX)) 46 | print('%d/%d' % (i, len(folders))) 47 | with open(filename_output, 'w') as f: 48 | f.write('\n'.join(output)) 49 | -------------------------------------------------------------------------------- /tools/gen_label_sthv2.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | # ------------------------------------------------------ 6 | # Code adapted from https://github.com/metalbubble/TRN-pytorch/blob/master/process_dataset.py 7 | # processing the raw data of the video Something-Something-V2 8 | 9 | import os 10 | import json 11 | 12 | if __name__ == '__main__': 13 | dataset_name = 'something-something-v2' # 'jester-v1' 14 | with open('%s-labels.json' % dataset_name) as f: 15 | data = json.load(f) 16 | categories = [] 17 | for i, (cat, idx) in enumerate(data.items()): 18 | assert i == int(idx) # make sure the rank is right 19 | categories.append(cat) 20 | 21 | with open('category.txt', 'w') as f: 22 | f.write('\n'.join(categories)) 23 | 24 | dict_categories = {} 25 | for i, category in enumerate(categories): 26 | dict_categories[category] = i 27 | 28 | files_input = ['%s-validation.json' % dataset_name, '%s-train.json' % dataset_name, '%s-test.json' % dataset_name] 29 | files_output = ['val_videofolder.txt', 'train_videofolder.txt', 'test_videofolder.txt'] 30 | for (filename_input, filename_output) in zip(files_input, files_output): 31 | with open(filename_input) as f: 32 | data = json.load(f) 33 | folders = [] 34 | idx_categories = [] 35 | for item in data: 36 | folders.append(item['id']) 37 | if 'test' not in filename_input: 38 | idx_categories.append(dict_categories[item['template'].replace('[', '').replace(']', '')]) 39 | else: 40 | idx_categories.append(0) 41 | output = [] 42 | for i in range(len(folders)): 43 | curFolder = folders[i] 44 | curIDX = idx_categories[i] 45 | # counting the number of frames in each video folders 46 | dir_files = os.listdir(os.path.join('20bn-something-something-v2-frames', curFolder)) 47 | output.append('%s %d %d' % (curFolder, len(dir_files), curIDX)) 48 | print('%d/%d' % (i, len(folders))) 49 | with open(filename_output, 'w') as f: 50 | f.write('\n'.join(output)) 51 | -------------------------------------------------------------------------------- /tools/kinetics_label_map.txt: -------------------------------------------------------------------------------- 1 | abseiling 2 | air drumming 3 | answering questions 4 | applauding 5 | applying cream 6 | archery 7 | arm wrestling 8 | arranging flowers 9 | assembling computer 10 | auctioning 11 | baby waking up 12 | baking cookies 13 | balloon blowing 14 | bandaging 15 | barbequing 16 | bartending 17 | beatboxing 18 | bee keeping 19 | belly dancing 20 | bench pressing 21 | bending back 22 | bending metal 23 | biking through snow 24 | blasting sand 25 | blowing glass 26 | blowing leaves 27 | blowing nose 28 | blowing out candles 29 | bobsledding 30 | bookbinding 31 | bouncing on trampoline 32 | bowling 33 | braiding hair 34 | breading or breadcrumbing 35 | breakdancing 36 | brush painting 37 | brushing hair 38 | brushing teeth 39 | building cabinet 40 | building shed 41 | bungee jumping 42 | busking 43 | canoeing or kayaking 44 | capoeira 45 | carrying baby 46 | cartwheeling 47 | carving pumpkin 48 | catching fish 49 | catching or throwing baseball 50 | catching or throwing frisbee 51 | catching or throwing softball 52 | celebrating 53 | changing oil 54 | changing wheel 55 | checking tires 56 | cheerleading 57 | chopping wood 58 | clapping 59 | clay pottery making 60 | clean and jerk 61 | cleaning floor 62 | cleaning gutters 63 | cleaning pool 64 | cleaning shoes 65 | cleaning toilet 66 | cleaning windows 67 | climbing a rope 68 | climbing ladder 69 | climbing tree 70 | contact juggling 71 | cooking chicken 72 | cooking egg 73 | cooking on campfire 74 | cooking sausages 75 | counting money 76 | country line dancing 77 | cracking neck 78 | crawling baby 79 | crossing river 80 | crying 81 | curling hair 82 | cutting nails 83 | cutting pineapple 84 | cutting watermelon 85 | dancing ballet 86 | dancing charleston 87 | dancing gangnam style 88 | dancing macarena 89 | deadlifting 90 | decorating the christmas tree 91 | digging 92 | dining 93 | disc golfing 94 | diving cliff 95 | dodgeball 96 | doing aerobics 97 | doing laundry 98 | doing nails 99 | drawing 100 | dribbling basketball 101 | drinking 102 | drinking beer 103 | drinking shots 104 | driving car 105 | driving tractor 106 | drop kicking 107 | drumming fingers 108 | dunking basketball 109 | dying hair 110 | eating burger 111 | eating cake 112 | eating carrots 113 | eating chips 114 | eating doughnuts 115 | eating hotdog 116 | eating ice cream 117 | eating spaghetti 118 | eating watermelon 119 | egg hunting 120 | exercising arm 121 | exercising with an exercise ball 122 | extinguishing fire 123 | faceplanting 124 | feeding birds 125 | feeding fish 126 | feeding goats 127 | filling eyebrows 128 | finger snapping 129 | fixing hair 130 | flipping pancake 131 | flying kite 132 | folding clothes 133 | folding napkins 134 | folding paper 135 | front raises 136 | frying vegetables 137 | garbage collecting 138 | gargling 139 | getting a haircut 140 | getting a tattoo 141 | giving or receiving award 142 | golf chipping 143 | golf driving 144 | golf putting 145 | grinding meat 146 | grooming dog 147 | grooming horse 148 | gymnastics tumbling 149 | hammer throw 150 | headbanging 151 | headbutting 152 | high jump 153 | high kick 154 | hitting baseball 155 | hockey stop 156 | holding snake 157 | hopscotch 158 | hoverboarding 159 | hugging 160 | hula hooping 161 | hurdling 162 | hurling (sport) 163 | ice climbing 164 | ice fishing 165 | ice skating 166 | ironing 167 | javelin throw 168 | jetskiing 169 | jogging 170 | juggling balls 171 | juggling fire 172 | juggling soccer ball 173 | jumping into pool 174 | jumpstyle dancing 175 | kicking field goal 176 | kicking soccer ball 177 | kissing 178 | kitesurfing 179 | knitting 180 | krumping 181 | laughing 182 | laying bricks 183 | long jump 184 | lunge 185 | making a cake 186 | making a sandwich 187 | making bed 188 | making jewelry 189 | making pizza 190 | making snowman 191 | making sushi 192 | making tea 193 | marching 194 | massaging back 195 | massaging feet 196 | massaging legs 197 | massaging person's head 198 | milking cow 199 | mopping floor 200 | motorcycling 201 | moving furniture 202 | mowing lawn 203 | news anchoring 204 | opening bottle 205 | opening present 206 | paragliding 207 | parasailing 208 | parkour 209 | passing American football (in game) 210 | passing American football (not in game) 211 | peeling apples 212 | peeling potatoes 213 | petting animal (not cat) 214 | petting cat 215 | picking fruit 216 | planting trees 217 | plastering 218 | playing accordion 219 | playing badminton 220 | playing bagpipes 221 | playing basketball 222 | playing bass guitar 223 | playing cards 224 | playing cello 225 | playing chess 226 | playing clarinet 227 | playing controller 228 | playing cricket 229 | playing cymbals 230 | playing didgeridoo 231 | playing drums 232 | playing flute 233 | playing guitar 234 | playing harmonica 235 | playing harp 236 | playing ice hockey 237 | playing keyboard 238 | playing kickball 239 | playing monopoly 240 | playing organ 241 | playing paintball 242 | playing piano 243 | playing poker 244 | playing recorder 245 | playing saxophone 246 | playing squash or racquetball 247 | playing tennis 248 | playing trombone 249 | playing trumpet 250 | playing ukulele 251 | playing violin 252 | playing volleyball 253 | playing xylophone 254 | pole vault 255 | presenting weather forecast 256 | pull ups 257 | pumping fist 258 | pumping gas 259 | punching bag 260 | punching person (boxing) 261 | push up 262 | pushing car 263 | pushing cart 264 | pushing wheelchair 265 | reading book 266 | reading newspaper 267 | recording music 268 | riding a bike 269 | riding camel 270 | riding elephant 271 | riding mechanical bull 272 | riding mountain bike 273 | riding mule 274 | riding or walking with horse 275 | riding scooter 276 | riding unicycle 277 | ripping paper 278 | robot dancing 279 | rock climbing 280 | rock scissors paper 281 | roller skating 282 | running on treadmill 283 | sailing 284 | salsa dancing 285 | sanding floor 286 | scrambling eggs 287 | scuba diving 288 | setting table 289 | shaking hands 290 | shaking head 291 | sharpening knives 292 | sharpening pencil 293 | shaving head 294 | shaving legs 295 | shearing sheep 296 | shining shoes 297 | shooting basketball 298 | shooting goal (soccer) 299 | shot put 300 | shoveling snow 301 | shredding paper 302 | shuffling cards 303 | side kick 304 | sign language interpreting 305 | singing 306 | situp 307 | skateboarding 308 | ski jumping 309 | skiing (not slalom or crosscountry) 310 | skiing crosscountry 311 | skiing slalom 312 | skipping rope 313 | skydiving 314 | slacklining 315 | slapping 316 | sled dog racing 317 | smoking 318 | smoking hookah 319 | snatch weight lifting 320 | sneezing 321 | sniffing 322 | snorkeling 323 | snowboarding 324 | snowkiting 325 | snowmobiling 326 | somersaulting 327 | spinning poi 328 | spray painting 329 | spraying 330 | springboard diving 331 | squat 332 | sticking tongue out 333 | stomping grapes 334 | stretching arm 335 | stretching leg 336 | strumming guitar 337 | surfing crowd 338 | surfing water 339 | sweeping floor 340 | swimming backstroke 341 | swimming breast stroke 342 | swimming butterfly stroke 343 | swing dancing 344 | swinging legs 345 | swinging on something 346 | sword fighting 347 | tai chi 348 | taking a shower 349 | tango dancing 350 | tap dancing 351 | tapping guitar 352 | tapping pen 353 | tasting beer 354 | tasting food 355 | testifying 356 | texting 357 | throwing axe 358 | throwing ball 359 | throwing discus 360 | tickling 361 | tobogganing 362 | tossing coin 363 | tossing salad 364 | training dog 365 | trapezing 366 | trimming or shaving beard 367 | trimming trees 368 | triple jump 369 | tying bow tie 370 | tying knot (not on a tie) 371 | tying tie 372 | unboxing 373 | unloading truck 374 | using computer 375 | using remote controller (not gaming) 376 | using segway 377 | vault 378 | waiting in line 379 | walking the dog 380 | washing dishes 381 | washing feet 382 | washing hair 383 | washing hands 384 | water skiing 385 | water sliding 386 | watering plants 387 | waxing back 388 | waxing chest 389 | waxing eyebrows 390 | waxing legs 391 | weaving basket 392 | welding 393 | whistling 394 | windsurfing 395 | wrapping present 396 | wrestling 397 | writing 398 | yawning 399 | yoga 400 | zumba -------------------------------------------------------------------------------- /tools/vid2img_kinetics.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | 6 | from __future__ import print_function, division 7 | import os 8 | import sys 9 | import subprocess 10 | from multiprocessing import Pool 11 | from tqdm import tqdm 12 | 13 | n_thread = 100 14 | 15 | 16 | def vid2jpg(file_name, class_path, dst_class_path): 17 | if '.mp4' not in file_name: 18 | return 19 | name, ext = os.path.splitext(file_name) 20 | dst_directory_path = os.path.join(dst_class_path, name) 21 | 22 | video_file_path = os.path.join(class_path, file_name) 23 | try: 24 | if os.path.exists(dst_directory_path): 25 | if not os.path.exists(os.path.join(dst_directory_path, 'img_00001.jpg')): 26 | subprocess.call('rm -r \"{}\"'.format(dst_directory_path), shell=True) 27 | print('remove {}'.format(dst_directory_path)) 28 | os.mkdir(dst_directory_path) 29 | else: 30 | print('*** convert has been done: {}'.format(dst_directory_path)) 31 | return 32 | else: 33 | os.mkdir(dst_directory_path) 34 | except: 35 | print(dst_directory_path) 36 | return 37 | cmd = 'ffmpeg -i \"{}\" -threads 1 -vf scale=-1:331 -q:v 0 \"{}/img_%05d.jpg\"'.format(video_file_path, dst_directory_path) 38 | # print(cmd) 39 | subprocess.call(cmd, shell=True, 40 | stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) 41 | 42 | 43 | def class_process(dir_path, dst_dir_path, class_name): 44 | print('*' * 20, class_name, '*'*20) 45 | class_path = os.path.join(dir_path, class_name) 46 | if not os.path.isdir(class_path): 47 | print('*** is not a dir {}'.format(class_path)) 48 | return 49 | 50 | dst_class_path = os.path.join(dst_dir_path, class_name) 51 | if not os.path.exists(dst_class_path): 52 | os.mkdir(dst_class_path) 53 | 54 | vid_list = os.listdir(class_path) 55 | vid_list.sort() 56 | p = Pool(n_thread) 57 | from functools import partial 58 | worker = partial(vid2jpg, class_path=class_path, dst_class_path=dst_class_path) 59 | for _ in tqdm(p.imap_unordered(worker, vid_list), total=len(vid_list)): 60 | pass 61 | # p.map(worker, vid_list) 62 | p.close() 63 | p.join() 64 | 65 | print('\n') 66 | 67 | 68 | if __name__ == "__main__": 69 | dir_path = sys.argv[1] 70 | dst_dir_path = sys.argv[2] 71 | 72 | class_list = os.listdir(dir_path) 73 | class_list.sort() 74 | for class_name in class_list: 75 | class_process(dir_path, dst_dir_path, class_name) 76 | 77 | class_name = 'test' 78 | class_process(dir_path, dst_dir_path, class_name) 79 | -------------------------------------------------------------------------------- /tools/vid2img_sthv2.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | 6 | import os 7 | import threading 8 | 9 | NUM_THREADS = 100 10 | VIDEO_ROOT = '/ssd/video/something/v2/20bn-something-something-v2' # Downloaded webm videos 11 | FRAME_ROOT = '/ssd/video/something/v2/20bn-something-something-v2-frames' # Directory for extracted frames 12 | 13 | 14 | def split(l, n): 15 | """Yield successive n-sized chunks from l.""" 16 | for i in range(0, len(l), n): 17 | yield l[i:i + n] 18 | 19 | 20 | def extract(video, tmpl='%06d.jpg'): 21 | # os.system(f'ffmpeg -i {VIDEO_ROOT}/{video} -vf -threads 1 -vf scale=-1:256 -q:v 0 ' 22 | # f'{FRAME_ROOT}/{video[:-5]}/{tmpl}') 23 | cmd = 'ffmpeg -i \"{}/{}\" -threads 1 -vf scale=-1:256 -q:v 0 \"{}/{}/%06d.jpg\"'.format(VIDEO_ROOT, video, 24 | FRAME_ROOT, video[:-5]) 25 | os.system(cmd) 26 | 27 | 28 | def target(video_list): 29 | for video in video_list: 30 | os.makedirs(os.path.join(FRAME_ROOT, video[:-5])) 31 | extract(video) 32 | 33 | 34 | if __name__ == '__main__': 35 | if not os.path.exists(VIDEO_ROOT): 36 | raise ValueError('Please download videos and set VIDEO_ROOT variable.') 37 | if not os.path.exists(FRAME_ROOT): 38 | os.makedirs(FRAME_ROOT) 39 | 40 | video_list = os.listdir(VIDEO_ROOT) 41 | splits = list(split(video_list, NUM_THREADS)) 42 | 43 | threads = [] 44 | for i, split in enumerate(splits): 45 | thread = threading.Thread(target=target, args=(split,)) 46 | thread.start() 47 | threads.append(thread) 48 | 49 | for thread in threads: 50 | thread.join() --------------------------------------------------------------------------------