├── README.md ├── archs ├── __init__.py ├── mobilenet_v2.pth.tar └── mobilenet_v2.py ├── datas ├── dataset.py ├── dataset_config.py ├── generate_label.py └── jester │ ├── jester-v1-labels.csv │ ├── jester-v1-test.csv │ ├── jester-v1-train.csv │ └── jester-v1-validation.csv ├── docker ├── Dockerfile ├── README.md └── sources.list ├── main.py ├── online_demo ├── README.md ├── main.py ├── mobilenet_v2_tsm.py └── mobilenetv2_jester.pth.tar ├── ops ├── __init__.py ├── basic_ops.py ├── models.py ├── temporal_shift.py ├── transforms.py └── utils.py ├── options.py ├── requirements.txt ├── test.sh └── train.sh /README.md: -------------------------------------------------------------------------------- 1 | ## Temporal Shift Module for Jester Gesture Recognition 2 | 3 | According to [mit official code](https://github.com/mit-han-lab/temporal-shift-module), 4 | we reduce and modify some codes for jester dataset. 5 | 6 | ### Prerequisites 7 | 8 | * Python 3.6 9 | * PyTorch 1.2 10 | * Opencv 3.4 11 | * Other packages can be found in ```requirements.txt``` 12 | 13 | ### Data Preparation 14 | 15 | Firstly, we need to download the [Jester](https://20bn.com/datasets/jester/v1) dataset. 16 | Then, we process the data and generate corresponding labels. 17 | Finally, we get category.txt, train_videofolder.txt, val_videofolder.txt and test_videofolder.txt documents. 18 | 19 | `python3 datas/generate_label.py` 20 | 21 | ### Train and Validate 22 | 23 | `bash train.sh` 24 | 25 | After total training epochs, you can get result.csv, 26 | that is the test result document, including video number and corresponding label. 27 | 28 | ### Reference 29 | 30 | [paper links](https://arxiv.org/abs/1811.08383) 31 | -------------------------------------------------------------------------------- /archs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiamingNo1/Temporal-Shift-Module/f6bad75d22c4038304cf5610462fd2fc951f0f82/archs/__init__.py -------------------------------------------------------------------------------- /archs/mobilenet_v2.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiamingNo1/Temporal-Shift-Module/f6bad75d22c4038304cf5610462fd2fc951f0f82/archs/mobilenet_v2.pth.tar -------------------------------------------------------------------------------- /archs/mobilenet_v2.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def conv_bn(inp, oup, stride): 7 | return nn.Sequential( 8 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 9 | nn.BatchNorm2d(oup), 10 | nn.ReLU6(inplace=True) 11 | ) 12 | 13 | 14 | def conv_1x1_bn(inp, oup): 15 | return nn.Sequential( 16 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 17 | nn.BatchNorm2d(oup), 18 | nn.ReLU6(inplace=True) 19 | ) 20 | 21 | 22 | def make_divisible(x, divisible_by=8): 23 | import numpy as np 24 | return int(np.ceil(x * 1. / divisible_by) * divisible_by) 25 | 26 | 27 | class InvertedResidual(nn.Module): 28 | def __init__(self, inp, oup, stride, expand_ratio): 29 | super(InvertedResidual, self).__init__() 30 | self.stride = stride 31 | assert stride in [1, 2] 32 | 33 | hidden_dim = int(inp * expand_ratio) 34 | self.use_res_connect = self.stride == 1 and inp == oup 35 | 36 | if expand_ratio == 1: 37 | self.conv = nn.Sequential( 38 | # dw 39 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 40 | nn.BatchNorm2d(hidden_dim), 41 | nn.ReLU6(inplace=True), 42 | # pw-linear 43 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 44 | nn.BatchNorm2d(oup), 45 | ) 46 | else: 47 | self.conv = nn.Sequential( 48 | # pw 49 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 50 | nn.BatchNorm2d(hidden_dim), 51 | nn.ReLU6(inplace=True), 52 | # dw 53 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 54 | nn.BatchNorm2d(hidden_dim), 55 | nn.ReLU6(inplace=True), 56 | # pw-linear 57 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 58 | nn.BatchNorm2d(oup), 59 | ) 60 | 61 | def forward(self, x): 62 | if self.use_res_connect: 63 | return x + self.conv(x) 64 | else: 65 | return self.conv(x) 66 | 67 | 68 | class MobileNetV2(nn.Module): 69 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 70 | super(MobileNetV2, self).__init__() 71 | block = InvertedResidual 72 | input_channel = 32 73 | last_channel = 1280 74 | interverted_residual_setting = [ 75 | # t, c, n, s 76 | [1, 16, 1, 1], 77 | [6, 24, 2, 2], 78 | [6, 32, 3, 2], 79 | [6, 64, 4, 2], 80 | [6, 96, 3, 1], 81 | [6, 160, 3, 2], 82 | [6, 320, 1, 1], 83 | ] 84 | 85 | # building first layer 86 | assert input_size % 32 == 0 87 | self.last_channel = make_divisible(last_channel * width_mult) if width_mult > 1.0 else last_channel 88 | self.features = [conv_bn(3, input_channel, 2)] 89 | # building inverted residual blocks 90 | for t, c, n, s in interverted_residual_setting: 91 | output_channel = make_divisible(c * width_mult) if t > 1 else c 92 | for i in range(n): 93 | if i == 0: 94 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 95 | else: 96 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 97 | input_channel = output_channel 98 | # building last several layers 99 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 100 | # make it nn.Sequential 101 | self.features = nn.Sequential(*self.features) 102 | 103 | # building classifier modified 104 | self.classifier = nn.Linear(self.last_channel, n_class) 105 | 106 | self._initialize_weights() 107 | 108 | def forward(self, x): 109 | x = self.features(x) 110 | x = x.mean(3).mean(2) 111 | x = self.classifier(x) 112 | return x 113 | 114 | def _initialize_weights(self): 115 | for m in self.modules(): 116 | if isinstance(m, nn.Conv2d): 117 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 118 | m.weight.data.normal_(0, math.sqrt(2. / n)) 119 | if m.bias is not None: 120 | m.bias.data.zero_() 121 | elif isinstance(m, nn.BatchNorm2d): 122 | m.weight.data.fill_(1) 123 | m.bias.data.zero_() 124 | elif isinstance(m, nn.Linear): 125 | m.weight.data.normal_(0, 0.01) 126 | m.bias.data.zero_() 127 | 128 | 129 | def mobilenet_v2(pretrained=True): 130 | model = MobileNetV2(width_mult=1) 131 | model_dict = model.state_dict() 132 | if pretrained: 133 | mobilenetv2 = torch.load('archs/mobilenet_v2.pth.tar') 134 | mobilenetv2 = {k: v for k, v in mobilenetv2.items() if k in model_dict} 135 | model_dict.update(mobilenetv2) 136 | model.load_state_dict(model_dict) 137 | 138 | return model 139 | 140 | 141 | if __name__ == '__main__': 142 | model = mobilenet_v2(True) 143 | model.last_layer_name = 'classifier' 144 | feature_dim = getattr(model, model.last_layer_name).in_features 145 | setattr(model, model.last_layer_name, nn.Dropout(p=0.8)) 146 | print(model.state_dict().keys()) 147 | -------------------------------------------------------------------------------- /datas/dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import numpy as np 4 | from numpy.random import randint 5 | import torch.utils.data as data 6 | 7 | 8 | class VideoRecord(object): 9 | def __init__(self, row): 10 | self._data = row 11 | 12 | @property 13 | def path(self): 14 | return self._data[0] 15 | 16 | @property 17 | def num_frames(self): 18 | return int(self._data[1]) 19 | 20 | @property 21 | def label(self): 22 | return int(self._data[2]) 23 | 24 | 25 | class TSNDataSet(data.Dataset): 26 | def __init__(self, list_file, num_segments=3, 27 | image_tmpl='{:05d}.jpg', transform=None, 28 | random_shift=True, test_mode=False, remove_missing=False): 29 | 30 | self.list_file = list_file 31 | self.num_segments = num_segments 32 | self.image_tmpl = image_tmpl 33 | self.transform = transform 34 | self.random_shift = random_shift 35 | self.test_mode = test_mode 36 | self.remove_missing = remove_missing 37 | 38 | self._parse_list() 39 | 40 | def _load_image(self, directory, idx): 41 | return [Image.open(os.path.join(directory, self.image_tmpl.format(idx))).convert('RGB')] 42 | 43 | def _parse_list(self): 44 | # check the frame number is large >3: 45 | tmp = [x.strip().split(' ') for x in open(self.list_file)] 46 | if not self.test_mode or self.remove_missing: 47 | tmp = [item for item in tmp if int(item[1]) >= 3] 48 | self.video_list = [VideoRecord(item) for item in tmp] 49 | print('video number:%d' % (len(self.video_list))) 50 | 51 | def _sample_indices(self, record): 52 | average_duration = record.num_frames // self.num_segments 53 | if average_duration > 0: 54 | offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration, 55 | size=self.num_segments) 56 | elif record.num_frames > self.num_segments: 57 | offsets = np.sort(randint(record.num_frames, size=self.num_segments)) 58 | else: 59 | offsets = np.zeros((self.num_segments,)) 60 | return offsets + 1 61 | 62 | def _get_val_indices(self, record): 63 | if record.num_frames > self.num_segments: 64 | tick = record.num_frames / float(self.num_segments) 65 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) 66 | else: 67 | offsets = np.zeros((self.num_segments,)) 68 | return offsets + 1 69 | 70 | def _get_test_indices(self, record): 71 | tick = record.num_frames / float(self.num_segments) 72 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) 73 | return offsets + 1 74 | 75 | def __getitem__(self, index): 76 | record = self.video_list[index] 77 | if not self.test_mode: 78 | segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record) 79 | else: 80 | segment_indices = self._get_test_indices(record) 81 | return self.get(record, segment_indices) 82 | 83 | def __len__(self): 84 | return len(self.video_list) 85 | 86 | def get(self, record, indices): 87 | images = list() 88 | for seg_ind in indices: 89 | p = int(seg_ind) 90 | seg_imgs = self._load_image(record.path, p) 91 | images.extend(seg_imgs) 92 | process_data = self.transform(images) 93 | if self.test_mode: 94 | return process_data 95 | return process_data, record.label 96 | -------------------------------------------------------------------------------- /datas/dataset_config.py: -------------------------------------------------------------------------------- 1 | def dataset(): 2 | prefix = '{:05d}.jpg' 3 | file_categories = 'datas/jester/category.txt' 4 | file_imglist_train = 'datas/jester/train_videofolder.txt' 5 | file_imglist_val = 'datas/jester/val_videofolder.txt' 6 | file_imglist_test = 'datas/jester/test_videofolder.txt' 7 | 8 | with open(file_categories) as f: 9 | lines = f.readlines() 10 | categories = [item.rstrip() for item in lines] 11 | n_class = len(categories) 12 | print('jester: {} classes'.format(n_class)) 13 | 14 | return n_class, file_imglist_train, file_imglist_val, file_imglist_test, prefix 15 | -------------------------------------------------------------------------------- /datas/generate_label.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | if __name__ == '__main__': 4 | dataset = 'datas/jester/jester-v1' 5 | with open('%s-labels.csv' % dataset) as f: 6 | lines = f.readlines() 7 | categories = [] 8 | for line in lines: 9 | line = line.rstrip() 10 | categories.append(line) 11 | categories = sorted(categories) 12 | with open('datas/jester/category.txt', 'w') as f: 13 | f.write('\n'.join(categories)) 14 | 15 | dict_categories = {} 16 | for i, category in enumerate(categories): 17 | dict_categories[category] = i 18 | 19 | # train and validate dataset 20 | files_input = ['%s-validation.csv' % dataset, '%s-train.csv' % dataset] 21 | files_output = ['datas/jester/val_videofolder.txt', 'datas/jester/train_videofolder.txt'] 22 | for (filename_input, filename_output) in zip(files_input, files_output): 23 | with open(filename_input) as f: 24 | lines = f.readlines() 25 | folders = [] 26 | idx_categories = [] 27 | for line in lines: 28 | line = line.rstrip() 29 | items = line.split(';') 30 | folders.append(items[0]) 31 | idx_categories.append(dict_categories[items[1]]) 32 | output = [] 33 | for i in range(len(folders)): 34 | curFolder = folders[i] 35 | curIDX = idx_categories[i] 36 | dir_files = os.listdir(os.path.join('/workspace/datas/jester/20bn-jester-v1/', curFolder)) 37 | output.append('%s %d %d' % ('/workspace/datas/jester/20bn-jester-v1/' + curFolder, len(dir_files), curIDX)) 38 | print('%d/%d' % (i, len(folders))) 39 | with open(filename_output, 'w') as f: 40 | f.write('\n'.join(output)) 41 | 42 | # test dataset 43 | with open('datas/jester/jester-v1-test.csv') as f: 44 | lines = f.readlines() 45 | folders = [] 46 | for line in lines: 47 | folders.append(line.strip()) 48 | output = [] 49 | for i in range(len(folders)): 50 | curFolder = folders[i] 51 | dir_files = os.listdir(os.path.join('/workspace/datas/jester/20bn-jester-v1/', curFolder)) 52 | output.append('%s %d' % ('/workspace/datas/jester/20bn-jester-v1/' + curFolder, len(dir_files))) 53 | print('%d/%d' % (i, len(folders))) 54 | with open('datas/jester/test_videofolder.txt', 'w') as f: 55 | f.write('\n'.join(output)) 56 | -------------------------------------------------------------------------------- /datas/jester/jester-v1-labels.csv: -------------------------------------------------------------------------------- 1 | Swiping Left 2 | Swiping Right 3 | Swiping Down 4 | Swiping Up 5 | Pushing Hand Away 6 | Pulling Hand In 7 | Sliding Two Fingers Left 8 | Sliding Two Fingers Right 9 | Sliding Two Fingers Down 10 | Sliding Two Fingers Up 11 | Pushing Two Fingers Away 12 | Pulling Two Fingers In 13 | Rolling Hand Forward 14 | Rolling Hand Backward 15 | Turning Hand Clockwise 16 | Turning Hand Counterclockwise 17 | Zooming In With Full Hand 18 | Zooming Out With Full Hand 19 | Zooming In With Two Fingers 20 | Zooming Out With Two Fingers 21 | Thumb Up 22 | Thumb Down 23 | Shaking Hand 24 | Stop Sign 25 | Drumming Fingers 26 | No gesture 27 | Doing other things 28 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.1-devel-ubuntu18.04 2 | MAINTAINER jiamingNo1 jiaming19.huang@foxmail.com 3 | 4 | COPY sources.list . 5 | 6 | RUN mv /etc/apt/sources.list /etc/apt/sources.list.save && \ 7 | mv sources.list /etc/apt/sources.list && \ 8 | apt-get update && apt-get install -y \ 9 | apt-utils \ 10 | build-essential \ 11 | cmake \ 12 | git \ 13 | curl \ 14 | python3-pip \ 15 | libsm6 libxext6 libxrender-dev libglib2.0-0 && \ 16 | rm -rf /var/lib/apt/lists/* 17 | 18 | WORKDIR /workspace 19 | EXPOSE 6006 20 | 21 | RUN python3 -m pip install --upgrade pip && \ 22 | pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \ 23 | pip3 install setuptools 24 | 25 | RUN git clone https://github.com/jiamingNo1/Temporal-Shift-Module.git . 26 | 27 | RUN pip3 install -r requirements.txt 28 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | ### Running temporal-shift module using the docker image 2 | 3 | #### Build docker 4 | 5 | `docker build -t temporal_shift_module/pytorch1.2:jiaming_huang -f Dockerfile .` 6 | 7 | #### Run docker 8 | 9 | We only consider the gpu case, and you have to install nvidia-docker. You can reference to [github](https://github.com/NVIDIA/nvidia-docker). 10 | 11 | `docker run --gpus all --name tsm -p 6006:6006 --shm-size 8G 12 | -v xxx/20bn-jester-v1/:/workspace/datas/jester/20bn-jester-v1 13 | -it temporal_shift_module/pytorch1.2:jiaming_huang` 14 | -------------------------------------------------------------------------------- /docker/sources.list: -------------------------------------------------------------------------------- 1 | deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic main restricted universe multiverse 2 | # deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic main restricted universe multiverse 3 | deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse 4 | # deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse 5 | deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse 6 | # deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse 7 | deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-security main restricted universe multiverse 8 | # deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-security main restricted universe multiverse 9 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch.optim 4 | import torch.nn as nn 5 | import torch.backends.cudnn as cudnn 6 | from datetime import datetime 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | from datas.dataset import TSNDataSet 10 | from ops.models import TSN 11 | from ops.transforms import * 12 | from options import parser 13 | from datas import dataset_config 14 | from ops.utils import AverageMeter, accuracy, save_checkpoint, adjust_learning_rate, check_rootfolders 15 | 16 | best_prec1 = 0 17 | 18 | 19 | def main(): 20 | # settings 21 | global args, best_prec1 22 | args = parser.parse_args() 23 | n_class, args.train_list, args.val_list, args.test_list, prefix = dataset_config.dataset() 24 | full_arch_name = args.arch 25 | if args.shift: 26 | full_arch_name += '_shift{}'.format(args.shift_div) 27 | args.store_name = '_'.join( 28 | ['tsm', full_arch_name, 'segment%d' % args.num_segments]) 29 | print('storing name: ' + args.store_name) 30 | check_rootfolders(args.root_log, args.root_model, args.store_name) 31 | 32 | # tsn model added temporal shift module 33 | model = TSN(n_class, args.num_segments, 34 | base_model=args.arch, 35 | dropout=args.dropout, 36 | partial_bn=not args.no_partialbn, 37 | is_shift=args.shift, shift_div=args.shift_div) 38 | 39 | # preprocessing for input 40 | crop_size = model.crop_size 41 | scale_size = model.scale_size 42 | input_mean = model.input_mean 43 | input_std = model.input_std 44 | policies = model.get_optim_policies() 45 | train_augmentation = model.get_augmentation(flip=False) 46 | 47 | # optimizer 48 | optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 49 | 50 | # cuda and cudnn 51 | try: 52 | model = nn.DataParallel(model).cuda() 53 | except: 54 | model = model.cuda() 55 | cudnn.benchmark = True 56 | 57 | # data loader 58 | normalize = GroupNormalize(input_mean, input_std) 59 | train_loader = torch.utils.data.DataLoader( 60 | TSNDataSet(args.train_list, 61 | num_segments=args.num_segments, 62 | image_tmpl=prefix, 63 | transform=torchvision.transforms.Compose([ 64 | train_augmentation, 65 | Stack(roll=False), 66 | ToTorchFormatTensor(div=True), 67 | normalize])), 68 | batch_size=args.batch_size, shuffle=True, 69 | num_workers=args.workers, pin_memory=False, drop_last=True) 70 | 71 | val_loader = torch.utils.data.DataLoader( 72 | TSNDataSet(args.val_list, 73 | num_segments=args.num_segments, 74 | image_tmpl=prefix, 75 | random_shift=False, 76 | transform=torchvision.transforms.Compose([ 77 | GroupScale(int(scale_size)), 78 | GroupCenterCrop(crop_size), 79 | Stack(roll=False), 80 | ToTorchFormatTensor(div=True), 81 | normalize])), 82 | batch_size=args.batch_size, shuffle=False, 83 | num_workers=args.workers, pin_memory=False) 84 | 85 | test_loader = torch.utils.data.DataLoader( 86 | TSNDataSet(args.test_list, 87 | num_segments=args.num_segments, 88 | image_tmpl=prefix, 89 | random_shift=False, 90 | test_mode=True, 91 | transform=torchvision.transforms.Compose([ 92 | GroupScale(int(scale_size)), 93 | GroupCenterCrop(crop_size), 94 | Stack(roll=False), 95 | ToTorchFormatTensor(div=True), 96 | normalize])), 97 | batch_size=args.batch_size, shuffle=False, 98 | num_workers=args.workers, pin_memory=False) 99 | 100 | # loss function 101 | criterion = nn.CrossEntropyLoss().cuda() 102 | for group in policies: 103 | print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format( 104 | group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) 105 | 106 | # tensorboard 107 | time_stamp = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now()) 108 | 109 | # train 110 | if args.mode == 'train': 111 | log_training = open(os.path.join(args.root_log, args.store_name, time_stamp, 'log.csv'), 'w') 112 | tf_writer = SummaryWriter('{}/{}/'.format(args.root_log, args.store_name) + time_stamp) 113 | for epoch in range(args.start_epoch, args.epochs): 114 | adjust_learning_rate(optimizer, epoch, args.lr_steps, args.lr, args.weight_decay) 115 | train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer) 116 | 117 | # evaluate on validation set 118 | if (epoch + 1) % args.eval_freq == 0: 119 | prec1 = validate(val_loader, model, criterion, epoch, log_training, tf_writer) 120 | 121 | # remember best precision and save checkpoint 122 | is_best = prec1 >= best_prec1 123 | best_prec1 = max(prec1, best_prec1) 124 | output_best = 'Best Prec@1: %.2f\n' % (best_prec1) 125 | print(output_best) 126 | log_training.write(output_best + '\n') 127 | log_training.flush() 128 | 129 | save_checkpoint({ 130 | 'epoch': epoch + 1, 131 | 'arch': args.arch, 132 | 'state_dict': model.state_dict(), 133 | 'optimizer': optimizer.state_dict(), 134 | 'best_prec1': best_prec1, 135 | }, is_best, args.root_model, args.store_name) 136 | tf_writer.close() 137 | 138 | # test 139 | checkpoint = '%s/%s/ckpt.best.pth.tar' % (args.root_model, args.store_name) 140 | test(test_loader, model, checkpoint, time_stamp) 141 | 142 | 143 | def train(train_loader, model, criterion, optimizer, epoch, log, tf_writer): 144 | batch_time = AverageMeter() 145 | data_time = AverageMeter() 146 | losses = AverageMeter() 147 | top1 = AverageMeter() 148 | if args.no_partialbn: 149 | try: 150 | model.module.partialBN(False) 151 | except: 152 | model.partialBN(False) 153 | else: 154 | try: 155 | model.module.partialBN(True) 156 | except: 157 | model.partialBN(True) 158 | model.train() 159 | 160 | end = time.time() 161 | for idx, (input, target) in enumerate(train_loader): 162 | data_time.update(time.time() - end) 163 | input, target = input.cuda(), target.cuda() 164 | output = model(input) 165 | loss = criterion(output, target) 166 | 167 | # accuracy and loss 168 | prec1, = accuracy(output.data, target, topk=(1,)) 169 | losses.update(loss.item(), input.size(0)) 170 | top1.update(prec1.item(), input.size(0)) 171 | 172 | # gradient and optimizer 173 | loss.backward() 174 | if (idx + 1) % args.update_weight == 0: 175 | optimizer.step() 176 | optimizer.zero_grad() 177 | 178 | # time 179 | batch_time.update(time.time() - end) 180 | end = time.time() 181 | if (idx + 1) % args.print_freq == 0: 182 | output = ('Train: epoch-{0} ({1}/{2})\t' 183 | 'batch_time {batch_time.avg:.2f}\t\t' 184 | 'data_time {data_time.avg:.2f}\t\t' 185 | 'loss {loss.avg:.3f}\t' 186 | 'prec@1 {top1.avg:.2f}\t'.format( 187 | epoch, idx + 1, len(train_loader), batch_time=batch_time, 188 | data_time=data_time, loss=losses, top1=top1)) 189 | batch_time.reset() 190 | data_time.reset() 191 | losses.reset() 192 | top1.reset() 193 | print(output) 194 | log.write(output + '\n') 195 | log.flush() 196 | 197 | tf_writer.add_scalar('loss/train', losses.avg, epoch) 198 | tf_writer.add_scalar('acc/train_top1', top1.avg, epoch) 199 | tf_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch) 200 | 201 | 202 | def validate(val_loader, model, criterion, epoch, log, tf_writer): 203 | losses = AverageMeter() 204 | top1 = AverageMeter() 205 | model.eval() 206 | with torch.no_grad(): 207 | for input, target in val_loader: 208 | input, target = input.cuda(), target.cuda() 209 | output = model(input) 210 | loss = criterion(output, target) 211 | 212 | # accuracy and loss 213 | prec1, = accuracy(output.data, target, topk=(1,)) 214 | losses.update(loss.item(), input.size(0)) 215 | top1.update(prec1.item(), input.size(0)) 216 | 217 | output = ('Validate: Prec@1 {top1.avg:.2f} Loss {loss.avg:.3f}'.format(top1=top1, loss=losses)) 218 | print(output) 219 | log.write(output + '\n') 220 | log.flush() 221 | tf_writer.add_scalar('loss/val', losses.avg, epoch) 222 | tf_writer.add_scalar('acc/val_top1', top1.avg, epoch) 223 | 224 | return top1.avg 225 | 226 | 227 | def test(test_loader, model, checkpoint, time_stamp): 228 | model.load_state_dict(torch.load(checkpoint)['state_dict']) 229 | model.eval() 230 | labels = [] 231 | with torch.no_grad(): 232 | for input in test_loader: 233 | input = input.cuda() 234 | output = model(input) 235 | pred = output.argmax(dim=1).cpu().numpy().tolist() 236 | labels.extend(pred) 237 | 238 | with open('datas/jester/jester-v1-test.csv') as f: 239 | videos = f.readlines() 240 | with open('datas/jester/category.txt') as f: 241 | categories = f.readlines() 242 | assert len(videos) == len(labels) 243 | result = [] 244 | for idx in range(len(labels)): 245 | result.append(videos[idx].strip() + ';' + categories[labels[idx]].rstrip()) 246 | with open(os.path.join(args.root_log, args.store_name, time_stamp, 'result.csv'), 'w') as f: 247 | f.write('\n'.join(result)) 248 | 249 | 250 | if __name__ == '__main__': 251 | main() 252 | -------------------------------------------------------------------------------- /online_demo/README.md: -------------------------------------------------------------------------------- 1 | ## Hand Gesture Recognition Online Demo with TSM 2 | 3 | ### Prerequisites 4 | * tvm 5 | ``` 6 | sudo apt install llvm 7 | git clone https://github.com/dmlc/tvm.git 8 | cd tvm 9 | git submodule update --init 10 | mkdir build 11 | cp cmake/config.cmake build/ 12 | cd build 13 | #[ 14 | #edit config.cmake to change 15 | # 32 line: USE_CUDA OFF -> USE_CUDA ON 16 | #104 line: USE_LLVM OFF -> USE_LLVM ON 17 | #] 18 | cmake .. 19 | make -j8 20 | cd .. 21 | cd python; sudo python3 setup.py install; cd .. 22 | cd nnvm/python; sudo python3 setup.py install; cd ../.. 23 | cd topi/python; sudo python3 setup.py install; cd ../.. 24 | ``` 25 | * onnx 26 | ``` 27 | sudo apt-get install protobuf-compiler libprotoc-dev 28 | pip3 install onnx 29 | ``` 30 | * add cuda path 31 | 32 | `export PATH=$PATH:/usr/local/cuda/bin` 33 | 34 | ### Run The Demo 35 | Firstly, export the pytorch model. 36 | 37 | `cp xxx/xxx.pt.tar ./mobilenetv2_jester.pth.tar` 38 | 39 | Then, run the demo. The first run will compile pytorch model into onnx model, 40 | and then compile the onnx model into tvm binary, finally run it. 41 | Later run will directly execute the compiled tvm model. 42 | 43 | `python3 main.py` 44 | 45 | Press `Q` or `Esc` to quit. Press `F` to enter or exit full screen. -------------------------------------------------------------------------------- /online_demo/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import tvm 4 | import time 5 | import cv2 6 | import onnx 7 | import torch 8 | import torch.onnx 9 | import torchvision 10 | import numpy as np 11 | import tvm.relay as relay 12 | from PIL import Image 13 | import tvm.contrib.graph_runtime as graph_runtime 14 | from mobilenet_v2_tsm import MobileNetV2 15 | 16 | HISTORY_LOGIT = True 17 | 18 | 19 | def torch2tvm_module(model, inputs, target): 20 | model.eval() 21 | input_names = [] 22 | input_shapes = {} 23 | with torch.no_grad(): 24 | for idx, input in enumerate(inputs): 25 | name = "i" + str(idx) 26 | input_names.append(name) 27 | input_shapes[name] = input.shape 28 | buffer = io.BytesIO() 29 | torch.onnx.export(model, inputs, buffer, input_names=input_names, 30 | output_names=["o" + str(i) for i in range(len(inputs))]) # torch to onnx model 31 | buffer.seek(0, 0) 32 | onnx_model = onnx.load_model(buffer) # load onnx model 33 | relay_module, params = relay.frontend.from_onnx(onnx_model, shape=input_shapes) # params(weights) 34 | with relay.build_config(opt_level=3): 35 | graph, tvm_module, params = relay.build(relay_module, target, params=params) 36 | 37 | return graph, tvm_module, params 38 | 39 | 40 | def torch2executor(model, inputs, target): 41 | prefix = f"mobilenetv2_tsm_tvm_{target}" 42 | lib_fname = f'{prefix}.tar' 43 | graph_fname = f'{prefix}.json' 44 | params_fname = f'{prefix}.params' 45 | if os.path.exists(lib_fname) and os.path.exists(graph_fname) and os.path.exists(params_fname): 46 | with open(graph_fname, 'rt') as f: 47 | graph = f.read() 48 | tvm_module = tvm.module.load(lib_fname) 49 | params = relay.load_param_dict(bytearray(open(params_fname, 'rb').read())) 50 | else: 51 | graph, tvm_module, params = torch2tvm_module(model, inputs, target) 52 | tvm_module.export_library(lib_fname) 53 | with open(graph_fname, 'wt') as f: 54 | f.write(graph) 55 | with open(params_fname, 'wb') as f: 56 | f.write(relay.save_param_dict(params)) 57 | 58 | ctx = tvm.gpu() if target.startswith('cuda') else tvm.cpu() 59 | graph_module = graph_runtime.create(graph, tvm_module, ctx) # graph json, tvm module and tvm context 60 | for pname, pvalue in params.items(): 61 | graph_module.set_input(pname, pvalue) 62 | 63 | def executor(inputs): 64 | for idx, value in enumerate(inputs): 65 | graph_module.set_input(idx, value) 66 | graph_module.run() 67 | return tuple(graph_module.get_output(idx) for idx in range(len(inputs))) 68 | 69 | return executor, ctx 70 | 71 | 72 | def get_executor(): 73 | model = MobileNetV2(n_class=27) 74 | mobilenetv2_jester = torch.load('mobilenetv2_jester.pth.tar')['state_dict'] 75 | from collections import OrderedDict 76 | new_state_dict = OrderedDict() 77 | for k, v in mobilenetv2_jester.items(): 78 | name = k[7:] 79 | if 'new_fc' in name: 80 | name = name.replace('new_fc', 'classifier') 81 | else: 82 | if 'net' in name: 83 | name = name.replace('net.', '') 84 | name = name[11:] 85 | new_state_dict[name] = v 86 | model.load_state_dict(new_state_dict) 87 | inputs = (torch.rand(1, 3, 224, 224), 88 | torch.zeros([1, 3, 56, 56]), 89 | torch.zeros([1, 4, 28, 28]), 90 | torch.zeros([1, 4, 28, 28]), 91 | torch.zeros([1, 8, 14, 14]), 92 | torch.zeros([1, 8, 14, 14]), 93 | torch.zeros([1, 8, 14, 14]), 94 | torch.zeros([1, 12, 14, 14]), 95 | torch.zeros([1, 12, 14, 14]), 96 | torch.zeros([1, 20, 7, 7]), 97 | torch.zeros([1, 20, 7, 7])) 98 | return torch2executor(model, inputs, target='cuda') 99 | 100 | 101 | class GroupScale(object): 102 | def __init__(self, size, interpolation=Image.BILINEAR): 103 | self.worker = torchvision.transforms.Resize(size, interpolation) 104 | 105 | def __call__(self, img_group): 106 | return [self.worker(img) for img in img_group] 107 | 108 | 109 | class GroupCenterCrop(object): 110 | def __init__(self, size): 111 | self.worker = torchvision.transforms.CenterCrop(size) 112 | 113 | def __call__(self, img_group): 114 | return [self.worker(img) for img in img_group] 115 | 116 | 117 | class Stack(object): 118 | def __init__(self, roll=False): 119 | self.roll = roll 120 | 121 | def __call__(self, img_group): 122 | if self.roll: 123 | return np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2) 124 | else: 125 | return np.concatenate(img_group, axis=2) 126 | 127 | 128 | class ToTorchFormatTensor(object): 129 | def __init__(self, div=True): 130 | self.div = div 131 | 132 | def __call__(self, pic): 133 | if isinstance(pic, np.ndarray): 134 | # handle numpy array 135 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 136 | else: 137 | # handle PIL Image 138 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 139 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 140 | # from HWC to CHW format 141 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 142 | return img.float().div(255) if self.div else img.float() 143 | 144 | 145 | class GroupNormalize(object): 146 | def __init__(self, mean, std): 147 | self.mean = mean 148 | self.std = std 149 | 150 | def __call__(self, tensor): 151 | rep_mean = self.mean * (tensor.size()[0] // len(self.mean)) 152 | rep_std = self.std * (tensor.size()[0] // len(self.std)) 153 | for t, m, s in zip(tensor, rep_mean, rep_std): 154 | t.sub_(m).div_(s) 155 | return tensor 156 | 157 | 158 | def transform(frame): 159 | # from H*W*C to 1*C*H*W 160 | frame = cv2.resize(frame, (224, 224)) 161 | frame = frame / 255.0 162 | frame = np.transpose(frame, axes=[2, 0, 1]) 163 | frame = np.expand_dims(frame, axis=0) 164 | return frame 165 | 166 | 167 | def get_transform(): 168 | cropping = torchvision.transforms.Compose([ 169 | GroupScale(256), 170 | GroupCenterCrop(224), 171 | ]) 172 | transform = torchvision.transforms.Compose([ 173 | cropping, 174 | Stack(roll=False), 175 | ToTorchFormatTensor(div=True), 176 | GroupNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 177 | ]) 178 | return transform 179 | 180 | 181 | catigories = [ 182 | "Doing other things", # 0 183 | "Drumming Fingers", # 1 184 | "No gesture", # 2 185 | "Pulling Hand In", # 3 186 | "Pulling Two Fingers In", # 4 187 | "Pushing Hand Away", # 5 188 | "Pushing Two Fingers Away", # 6 189 | "Rolling Hand Backward", # 7 190 | "Rolling Hand Forward", # 8 191 | "Shaking Hand", # 9 192 | "Sliding Two Fingers Down", # 10 193 | "Sliding Two Fingers Left", # 11 194 | "Sliding Two Fingers Right", # 12 195 | "Sliding Two Fingers Up", # 13 196 | "Stop Sign", # 14 197 | "Swiping Down", # 15 198 | "Swiping Left", # 16 199 | "Swiping Right", # 17 200 | "Swiping Up", # 18 201 | "Thumb Down", # 19 202 | "Thumb Up", # 20 203 | "Turning Hand Clockwise", # 21 204 | "Turning Hand Counterclockwise", # 22 205 | "Zooming In With Full Hand", # 23 206 | "Zooming In With Two Fingers", # 24 207 | "Zooming Out With Full Hand", # 25 208 | "Zooming Out With Two Fingers" # 26 209 | ] 210 | 211 | 212 | def process_output(idx_, history): 213 | max_hist_len = 20 214 | 215 | # mask out illegal action 216 | if idx_ in [7, 8, 21, 22, 3]: 217 | idx_ = history[-1] 218 | 219 | # use only single no action class 220 | if idx_ == 0: 221 | idx_ = 2 222 | 223 | # history smoothing 224 | if idx_ != history[-1]: 225 | if not (history[-1] == history[-2]): 226 | idx_ = history[-1] 227 | 228 | history.append(idx_) 229 | history = history[-max_hist_len:] 230 | 231 | return history[-1], history 232 | 233 | 234 | WINDOW_NAME = 'Video Gesture Recognition' 235 | 236 | 237 | def main(): 238 | print("Open Camera...") 239 | cap = cv2.VideoCapture(0) 240 | cap.set(cv2.CAP_PROP_FRAME_WIDTH, 320) 241 | cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 240) 242 | 243 | # env variables 244 | full_screen = False 245 | cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL) 246 | cv2.resizeWindow(WINDOW_NAME, 640, 480) # 640->width, 480->height 247 | cv2.moveWindow(WINDOW_NAME, 0, 0) 248 | cv2.setWindowTitle(WINDOW_NAME, WINDOW_NAME) 249 | 250 | print("Build Transformer...") 251 | transform = get_transform() 252 | print("Build Executor...") 253 | executor, ctx = get_executor() 254 | buffer = ( 255 | tvm.nd.empty((1, 3, 56, 56), ctx=ctx), 256 | tvm.nd.empty((1, 4, 28, 28), ctx=ctx), 257 | tvm.nd.empty((1, 4, 28, 28), ctx=ctx), 258 | tvm.nd.empty((1, 8, 14, 14), ctx=ctx), 259 | tvm.nd.empty((1, 8, 14, 14), ctx=ctx), 260 | tvm.nd.empty((1, 8, 14, 14), ctx=ctx), 261 | tvm.nd.empty((1, 12, 14, 14), ctx=ctx), 262 | tvm.nd.empty((1, 12, 14, 14), ctx=ctx), 263 | tvm.nd.empty((1, 20, 7, 7), ctx=ctx), 264 | tvm.nd.empty((1, 20, 7, 7), ctx=ctx) 265 | ) 266 | idx = 0 267 | history = [2] 268 | history_logit = [] 269 | idx_frame = -1 270 | 271 | print("Ready!") 272 | while True: 273 | idx_frame += 1 274 | _, img = cap.read() # 240*320*3 275 | if idx_frame % 2 == 0: # skip every other frame to obtain a suitable frame rate 276 | end = time.time() 277 | img_tran = transform([Image.fromarray(img).convert('RGB')]) 278 | torch_input = img_tran.view(1, 3, img_tran.size(1), img_tran.size(2)) 279 | img_nd = tvm.nd.array(torch_input.detach().numpy(), ctx=ctx) 280 | inputs = (img_nd,) + buffer 281 | outputs = executor(inputs) 282 | feat, buffer = outputs[0], outputs[1:] 283 | idx_ = np.argmax(feat.asnumpy(), axis=1)[0] 284 | 285 | if HISTORY_LOGIT: 286 | history_logit.append(feat.asnumpy()) 287 | history_logit = history_logit[-12:] 288 | avg_logit = sum(history_logit) 289 | idx_ = np.argmax(avg_logit, axis=1)[0] 290 | 291 | idx, history = process_output(idx_, history) 292 | print(f"{idx_frame} {catigories[idx]}") 293 | current_time = time.time() - end 294 | 295 | img = cv2.resize(img, (640, 480)) 296 | img = img[:, ::-1] 297 | height, width, _ = img.shape 298 | label = np.zeros([height // 10, width, 3]).astype('uint8') + 255 299 | 300 | cv2.putText(label, 'Prediction: ' + catigories[idx], 301 | (0, int(height / 16)), 302 | cv2.FONT_HERSHEY_SIMPLEX, 303 | 0.7, (0, 0, 0), 2) 304 | cv2.putText(label, '{:.1f} Vid/s'.format(1 / current_time), 305 | (width - 170, int(height / 16)), 306 | cv2.FONT_HERSHEY_SIMPLEX, 307 | 0.7, (0, 0, 0), 2) 308 | 309 | img = np.concatenate((img, label), axis=0) 310 | cv2.imshow(WINDOW_NAME, img) 311 | 312 | key = cv2.waitKey(1) 313 | if key & 0xFF == ord('q') or key == 27: 314 | break 315 | elif key == ord('F') or key == ord('f'): 316 | print('Changing full screen option!') 317 | full_screen = not full_screen 318 | if full_screen: 319 | print('Setting FS!!!') 320 | cv2.setWindowProperty(WINDOW_NAME, cv2.WND_PROP_FULLSCREEN, 321 | cv2.WINDOW_FULLSCREEN) 322 | else: 323 | cv2.setWindowProperty(WINDOW_NAME, cv2.WND_PROP_FULLSCREEN, 324 | cv2.WINDOW_NORMAL) 325 | 326 | cap.release() 327 | cv2.destroyAllWindows() 328 | 329 | 330 | if __name__ == '__main__': 331 | main() 332 | -------------------------------------------------------------------------------- /online_demo/mobilenet_v2_tsm.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def conv_bn(inp, oup, stride): 7 | return nn.Sequential( 8 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 9 | nn.BatchNorm2d(oup), 10 | nn.ReLU6(inplace=True) 11 | ) 12 | 13 | 14 | def conv_1x1_bn(inp, oup): 15 | return nn.Sequential( 16 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 17 | nn.BatchNorm2d(oup), 18 | nn.ReLU6(inplace=True) 19 | ) 20 | 21 | 22 | def make_divisible(x, divisible_by=8): 23 | import numpy as np 24 | return int(np.ceil(x * 1. / divisible_by) * divisible_by) 25 | 26 | 27 | class InvertedResidual(nn.Module): 28 | def __init__(self, inp, oup, stride, expand_ratio): 29 | super(InvertedResidual, self).__init__() 30 | self.stride = stride 31 | assert stride in [1, 2] 32 | 33 | hidden_dim = int(inp * expand_ratio) 34 | self.use_res_connect = self.stride == 1 and inp == oup 35 | 36 | if expand_ratio == 1: 37 | self.conv = nn.Sequential( 38 | # dw 39 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 40 | nn.BatchNorm2d(hidden_dim), 41 | nn.ReLU6(inplace=True), 42 | # pw-linear 43 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 44 | nn.BatchNorm2d(oup), 45 | ) 46 | else: 47 | self.conv = nn.Sequential( 48 | # pw 49 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 50 | nn.BatchNorm2d(hidden_dim), 51 | nn.ReLU6(inplace=True), 52 | # dw 53 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 54 | nn.BatchNorm2d(hidden_dim), 55 | nn.ReLU6(inplace=True), 56 | # pw-linear 57 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 58 | nn.BatchNorm2d(oup), 59 | ) 60 | 61 | def forward(self, x): 62 | if self.use_res_connect: 63 | return x + self.conv(x) 64 | else: 65 | return self.conv(x) 66 | 67 | 68 | class InvertedResidualWithShift(nn.Module): 69 | def __init__(self, inp, oup, stride, expand_ratio): 70 | super(InvertedResidualWithShift, self).__init__() 71 | self.stride = stride 72 | assert stride in [1, 2] 73 | 74 | assert expand_ratio > 1 75 | 76 | hidden_dim = int(inp * expand_ratio) 77 | self.use_res_connect = self.stride == 1 and inp == oup 78 | assert self.use_res_connect 79 | 80 | self.conv = nn.Sequential( 81 | # pw 82 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 83 | nn.BatchNorm2d(hidden_dim), 84 | nn.ReLU6(inplace=True), 85 | # dw 86 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 87 | nn.BatchNorm2d(hidden_dim), 88 | nn.ReLU6(inplace=True), 89 | # pw-linear 90 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 91 | nn.BatchNorm2d(oup), 92 | ) 93 | 94 | def forward(self, x, shift_buffer): 95 | c = x.size(1) 96 | x1, x2 = x[:, : c // 8], x[:, c // 8:] 97 | return x + self.conv(torch.cat((shift_buffer, x2), dim=1)), x1 98 | 99 | 100 | class MobileNetV2(nn.Module): 101 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 102 | super(MobileNetV2, self).__init__() 103 | input_channel = 32 104 | last_channel = 1280 105 | interverted_residual_setting = [ 106 | # t, c, n, s 107 | [1, 16, 1, 1], 108 | [6, 24, 2, 2], 109 | [6, 32, 3, 2], 110 | [6, 64, 4, 2], 111 | [6, 96, 3, 1], 112 | [6, 160, 3, 2], 113 | [6, 320, 1, 1], 114 | ] 115 | 116 | # building first layer 117 | assert input_size % 32 == 0 118 | self.last_channel = make_divisible(last_channel * width_mult) if width_mult > 1.0 else last_channel 119 | self.features = [conv_bn(3, input_channel, 2)] 120 | # building inverted residual blocks 121 | global_idx = 0 122 | shift_block_idx = [2, 4, 5, 7, 8, 9, 11, 12, 14, 15] 123 | for t, c, n, s in interverted_residual_setting: 124 | output_channel = make_divisible(c * width_mult) if t > 1 else c 125 | for i in range(n): 126 | if i == 0: 127 | block = InvertedResidualWithShift if global_idx in shift_block_idx else InvertedResidual 128 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 129 | global_idx += 1 130 | else: 131 | block = InvertedResidualWithShift if global_idx in shift_block_idx else InvertedResidual 132 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 133 | global_idx += 1 134 | input_channel = output_channel 135 | # building last several layers 136 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 137 | # make it nn.MoudleList 138 | self.features = nn.ModuleList(self.features) 139 | 140 | # building classifier 141 | self.classifier = nn.Linear(self.last_channel, n_class) 142 | 143 | self._initialize_weights() 144 | 145 | def forward(self, x, *shift_buffer): 146 | idx = 0 147 | out_buffer = [] 148 | for f in self.features: 149 | if isinstance(f, InvertedResidualWithShift): 150 | x, s = f(x, shift_buffer[idx]) 151 | idx += 1 152 | out_buffer.append(s) 153 | else: 154 | x = f(x) 155 | x = x.mean(3).mean(2) 156 | x = self.classifier(x) 157 | 158 | return (x, *out_buffer) 159 | 160 | def _initialize_weights(self): 161 | for m in self.modules(): 162 | if isinstance(m, nn.Conv2d): 163 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 164 | m.weight.data.normal_(0, math.sqrt(2. / n)) 165 | if m.bias is not None: 166 | m.bias.data.zero_() 167 | elif isinstance(m, nn.BatchNorm2d): 168 | m.weight.data.fill_(1) 169 | m.bias.data.zero_() 170 | elif isinstance(m, nn.Linear): 171 | m.weight.data.normal_(0, 0.01) 172 | m.bias.data.zero_() 173 | 174 | 175 | if __name__ == '__main__': 176 | model = MobileNetV2(n_class=27) 177 | mobilenetv2_jester = torch.load('mobilenetv2_jester.pth.tar')['state_dict'] 178 | from collections import OrderedDict 179 | 180 | new_state_dict = OrderedDict() 181 | for k, v in mobilenetv2_jester.items(): 182 | name = k[7:] 183 | if 'new_fc' in name: 184 | name = name.replace('new_fc', 'classifier') 185 | else: 186 | if 'net' in name: 187 | name = name.replace('net.', '') 188 | name = name[11:] 189 | new_state_dict[name] = v 190 | model.load_state_dict(new_state_dict) 191 | print(model.state_dict()) -------------------------------------------------------------------------------- /online_demo/mobilenetv2_jester.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiamingNo1/Temporal-Shift-Module/f6bad75d22c4038304cf5610462fd2fc951f0f82/online_demo/mobilenetv2_jester.pth.tar -------------------------------------------------------------------------------- /ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiamingNo1/Temporal-Shift-Module/f6bad75d22c4038304cf5610462fd2fc951f0f82/ops/__init__.py -------------------------------------------------------------------------------- /ops/basic_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class SegmentConsensus(torch.autograd.Function): 5 | 6 | @staticmethod 7 | def forward(ctx, x): 8 | ctx.tensor = x 9 | output = x.mean(dim=1, keepdim=True) 10 | 11 | return output 12 | 13 | @staticmethod 14 | def backward(ctx, grad_output): 15 | shape = ctx.tensor.size() 16 | grad_in = grad_output.expand(shape) / float(shape[1]) 17 | 18 | return grad_in 19 | 20 | 21 | class ConsensusModule(torch.nn.Module): 22 | def forward(self, input): 23 | return SegmentConsensus.apply(input) 24 | -------------------------------------------------------------------------------- /ops/models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.init import normal_, constant_ 3 | 4 | from ops.basic_ops import ConsensusModule 5 | from ops.transforms import * 6 | 7 | 8 | class TSN(nn.Module): 9 | def __init__(self, num_class, num_segments, base_model='mobilenetv2', 10 | dropout=0.5, partial_bn=True, is_shift=False, shift_div=8): 11 | super(TSN, self).__init__() 12 | self.num_segments = num_segments 13 | self.base_model_name = base_model 14 | self.dropout = dropout 15 | self.is_shift = is_shift 16 | self.shift_div = shift_div 17 | 18 | print((""" 19 | TSN Configurations: 20 | base model: {} 21 | num_segments: {} 22 | dropout_ratio: {} 23 | shift_div: {} 24 | """.format(base_model, self.num_segments, self.dropout, self.shift_div))) 25 | 26 | self._prepare_base_model(base_model) 27 | self._prepare_tsn(num_class) 28 | 29 | self.consensus = ConsensusModule() 30 | 31 | self._enable_pbn = partial_bn 32 | if partial_bn: 33 | self.partialBN(True) 34 | 35 | def _prepare_tsn(self, num_class): 36 | feature_dim = getattr(self.base_model, self.base_model.last_layer_name).in_features 37 | 38 | setattr(self.base_model, self.base_model.last_layer_name, nn.Dropout(p=self.dropout)) 39 | self.new_fc = nn.Linear(feature_dim, num_class) 40 | if hasattr(self.new_fc, 'weight'): 41 | normal_(self.new_fc.weight, 0, 0.001) 42 | constant_(self.new_fc.bias, 0) 43 | 44 | def _prepare_base_model(self, base_model): 45 | print('=> base model: {}'.format(base_model)) 46 | if base_model == 'mobilenetv2': 47 | from archs.mobilenet_v2 import mobilenet_v2, InvertedResidual 48 | self.base_model = mobilenet_v2(True) 49 | 50 | self.base_model.last_layer_name = 'classifier' 51 | self.input_size = 224 52 | self.input_mean = [0.485, 0.456, 0.406] 53 | self.input_std = [0.229, 0.224, 0.225] 54 | 55 | if self.is_shift: 56 | from ops.temporal_shift import TemporalShift 57 | for m in self.base_model.modules(): 58 | if isinstance(m, InvertedResidual) and len(m.conv) == 8 and m.use_res_connect: 59 | print('Adding temporal shift... {}'.format(m.use_res_connect)) 60 | m.conv[0] = TemporalShift(m.conv[0], n_segment=self.num_segments, n_div=self.shift_div) 61 | else: 62 | raise ValueError('Unknown base model: {}'.format(base_model)) 63 | 64 | def train(self, mode=True): 65 | """ 66 | Override the default train() to freeze the BN parameters 67 | :return: 68 | """ 69 | super(TSN, self).train(mode) 70 | count = 0 71 | if self._enable_pbn and mode: 72 | print("Freezing BatchNorm2D except the first one.") 73 | for m in self.base_model.modules(): 74 | if isinstance(m, nn.BatchNorm2d): 75 | count += 1 76 | if count >= (2 if self._enable_pbn else 1): 77 | m.eval() 78 | # shutdown update in frozen mode 79 | m.weight.requires_grad = False 80 | m.bias.requires_grad = False 81 | 82 | def partialBN(self, enable): 83 | self._enable_pbn = enable 84 | 85 | def get_optim_policies(self): 86 | first_conv_weight = [] 87 | first_conv_bias = [] 88 | normal_weight = [] 89 | normal_bias = [] 90 | bn = [] 91 | 92 | conv_cnt = 0 93 | bn_cnt = 0 94 | for m in self.modules(): 95 | if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv3d): 96 | ps = list(m.parameters()) 97 | conv_cnt += 1 98 | if conv_cnt == 1: 99 | first_conv_weight.append(ps[0]) 100 | if len(ps) == 2: 101 | first_conv_bias.append(ps[1]) 102 | else: 103 | normal_weight.append(ps[0]) 104 | if len(ps) == 2: 105 | normal_bias.append(ps[1]) 106 | elif isinstance(m, torch.nn.Linear): 107 | ps = list(m.parameters()) 108 | normal_weight.append(ps[0]) 109 | if len(ps) == 2: 110 | normal_bias.append(ps[1]) 111 | elif isinstance(m, torch.nn.BatchNorm2d): 112 | bn_cnt += 1 113 | # later BN's are frozen 114 | if not self._enable_pbn or bn_cnt == 1: 115 | bn.extend(list(m.parameters())) 116 | elif isinstance(m, torch.nn.BatchNorm3d): 117 | bn_cnt += 1 118 | # later BN's are frozen 119 | if not self._enable_pbn or bn_cnt == 1: 120 | bn.extend(list(m.parameters())) 121 | 122 | return [ 123 | {'params': first_conv_weight, 'lr_mult': 1, 'decay_mult': 1, 124 | 'name': "first_conv_weight"}, 125 | {'params': first_conv_bias, 'lr_mult': 2, 'decay_mult': 0, 126 | 'name': "first_conv_bias"}, 127 | {'params': normal_weight, 'lr_mult': 1, 'decay_mult': 1, 128 | 'name': "normal_weight"}, 129 | {'params': normal_bias, 'lr_mult': 2, 'decay_mult': 0, 130 | 'name': "normal_bias"}, 131 | {'params': bn, 'lr_mult': 1, 'decay_mult': 0, 132 | 'name': "BN scale/shift"} 133 | ] 134 | 135 | def forward(self, input): 136 | base_out = self.base_model(input.view((-1, 3) + input.size()[-2:])) 137 | base_out = self.new_fc(base_out) 138 | base_out = base_out.view((-1, self.num_segments) + base_out.size()[1:]) 139 | output = self.consensus(base_out) 140 | 141 | return output.squeeze(1) 142 | 143 | @property 144 | def crop_size(self): 145 | return self.input_size 146 | 147 | @property 148 | def scale_size(self): 149 | return self.input_size * 256 // 224 150 | 151 | def get_augmentation(self, flip=True): 152 | if flip: 153 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66]), 154 | GroupRandomHorizontalFlip(is_flow=False)]) 155 | else: 156 | print('#' * 10, 'NO FLIP!!!', '#' * 10) 157 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66])]) 158 | -------------------------------------------------------------------------------- /ops/temporal_shift.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class TemporalShift(nn.Module): 6 | def __init__(self, net, n_segment=3, n_div=8): 7 | super(TemporalShift, self).__init__() 8 | self.net = net 9 | self.n_segment = n_segment 10 | self.fold_div = n_div 11 | print('=> Using fold div: {}'.format(self.fold_div)) 12 | 13 | def forward(self, x): 14 | x = self.shift(x, self.n_segment, fold_div=self.fold_div) 15 | return self.net(x) 16 | 17 | @staticmethod 18 | def shift(x, n_segment, fold_div=3): 19 | nt, c, h, w = x.size() 20 | n_batch = nt // n_segment 21 | x = x.view(n_batch, n_segment, c, h, w) 22 | 23 | fold = c // fold_div 24 | 25 | out = torch.zeros_like(x) 26 | out[:, :-1, :fold] = x[:, 1:, :fold] # shift future frames 27 | out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift past frames 28 | out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift 29 | 30 | return out.view(nt, c, h, w) 31 | 32 | 33 | if __name__ == '__main__': 34 | tsm = TemporalShift(nn.Sequential(), n_segment=3, n_div=8) 35 | 36 | print('=> Testing GPU...') 37 | tsm.cuda() 38 | # test forward 39 | with torch.no_grad(): 40 | x = torch.rand(3, 8, 1, 1).cuda() 41 | y = tsm(x) 42 | print(x) 43 | print(y) 44 | 45 | print('Test passed.') 46 | -------------------------------------------------------------------------------- /ops/transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import random 3 | from PIL import Image, ImageOps 4 | import numpy as np 5 | import numbers 6 | import math 7 | import torch 8 | 9 | 10 | class GroupRandomCrop(object): 11 | def __init__(self, size): 12 | if isinstance(size, numbers.Number): 13 | self.size = (int(size), int(size)) 14 | else: 15 | self.size = size 16 | 17 | def __call__(self, img_group): 18 | 19 | w, h = img_group[0].size 20 | th, tw = self.size 21 | 22 | out_images = list() 23 | 24 | x1 = random.randint(0, w - tw) 25 | y1 = random.randint(0, h - th) 26 | 27 | for img in img_group: 28 | assert (img.size[0] == w and img.size[1] == h) 29 | if w == tw and h == th: 30 | out_images.append(img) 31 | else: 32 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 33 | 34 | return out_images 35 | 36 | 37 | class GroupCenterCrop(object): 38 | def __init__(self, size): 39 | self.worker = torchvision.transforms.CenterCrop(size) 40 | 41 | def __call__(self, img_group): 42 | return [self.worker(img) for img in img_group] 43 | 44 | 45 | class GroupRandomHorizontalFlip(object): 46 | def __init__(self, is_flow=False): 47 | self.is_flow = is_flow 48 | 49 | def __call__(self, img_group, is_flow=False): 50 | v = random.random() 51 | if v < 0.5: 52 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] 53 | if self.is_flow: 54 | for i in range(0, len(ret), 2): 55 | ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping 56 | return ret 57 | else: 58 | return img_group 59 | 60 | 61 | class GroupNormalize(object): 62 | def __init__(self, mean, std): 63 | self.mean = mean 64 | self.std = std 65 | 66 | def __call__(self, tensor): 67 | rep_mean = self.mean * (tensor.size()[0] // len(self.mean)) 68 | rep_std = self.std * (tensor.size()[0] // len(self.std)) 69 | for t, m, s in zip(tensor, rep_mean, rep_std): 70 | t.sub_(m).div_(s) 71 | 72 | return tensor 73 | 74 | 75 | class GroupScale(object): 76 | def __init__(self, size, interpolation=Image.BILINEAR): 77 | self.worker = torchvision.transforms.Resize(size, interpolation) 78 | 79 | def __call__(self, img_group): 80 | return [self.worker(img) for img in img_group] 81 | 82 | 83 | class GroupOverSample(object): 84 | def __init__(self, crop_size, scale_size=None, flip=True): 85 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) 86 | 87 | if scale_size is not None: 88 | self.scale_worker = GroupScale(scale_size) 89 | else: 90 | self.scale_worker = None 91 | self.flip = flip 92 | 93 | def __call__(self, img_group): 94 | 95 | if self.scale_worker is not None: 96 | img_group = self.scale_worker(img_group) 97 | 98 | image_w, image_h = img_group[0].size 99 | crop_w, crop_h = self.crop_size 100 | 101 | offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h) 102 | oversample_group = list() 103 | for o_w, o_h in offsets: 104 | normal_group = list() 105 | flip_group = list() 106 | for i, img in enumerate(img_group): 107 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 108 | normal_group.append(crop) 109 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 110 | 111 | if img.mode == 'L' and i % 2 == 0: 112 | flip_group.append(ImageOps.invert(flip_crop)) 113 | else: 114 | flip_group.append(flip_crop) 115 | 116 | oversample_group.extend(normal_group) 117 | if self.flip: 118 | oversample_group.extend(flip_group) 119 | return oversample_group 120 | 121 | 122 | class GroupFullResSample(object): 123 | def __init__(self, crop_size, scale_size=None, flip=True): 124 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) 125 | 126 | if scale_size is not None: 127 | self.scale_worker = GroupScale(scale_size) 128 | else: 129 | self.scale_worker = None 130 | self.flip = flip 131 | 132 | def __call__(self, img_group): 133 | 134 | if self.scale_worker is not None: 135 | img_group = self.scale_worker(img_group) 136 | 137 | image_w, image_h = img_group[0].size 138 | crop_w, crop_h = self.crop_size 139 | 140 | w_step = (image_w - crop_w) // 4 141 | h_step = (image_h - crop_h) // 4 142 | 143 | offsets = list() 144 | offsets.append((0 * w_step, 2 * h_step)) # left 145 | offsets.append((4 * w_step, 2 * h_step)) # right 146 | offsets.append((2 * w_step, 2 * h_step)) # center 147 | 148 | oversample_group = list() 149 | for o_w, o_h in offsets: 150 | normal_group = list() 151 | flip_group = list() 152 | for i, img in enumerate(img_group): 153 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 154 | normal_group.append(crop) 155 | if self.flip: 156 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 157 | flip_group.append(flip_crop) 158 | 159 | oversample_group.extend(normal_group) 160 | oversample_group.extend(flip_group) 161 | return oversample_group 162 | 163 | 164 | class GroupMultiScaleCrop(object): 165 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 166 | self.scales = scales if scales is not None else [1, .875, .75, .66] 167 | self.max_distort = max_distort 168 | self.fix_crop = fix_crop 169 | self.more_fix_crop = more_fix_crop 170 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 171 | self.interpolation = Image.BILINEAR 172 | 173 | def __call__(self, img_group): 174 | 175 | im_size = img_group[0].size 176 | 177 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 178 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 179 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) 180 | for img in crop_img_group] 181 | return ret_img_group 182 | 183 | def _sample_crop_size(self, im_size): 184 | image_w, image_h = im_size[0], im_size[1] 185 | 186 | # find a crop size 187 | base_size = min(image_w, image_h) 188 | crop_sizes = [int(base_size * x) for x in self.scales] 189 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 190 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 191 | 192 | pairs = [] 193 | for i, h in enumerate(crop_h): 194 | for j, w in enumerate(crop_w): 195 | if abs(i - j) <= self.max_distort: 196 | pairs.append((w, h)) 197 | 198 | crop_pair = random.choice(pairs) 199 | if not self.fix_crop: 200 | w_offset = random.randint(0, image_w - crop_pair[0]) 201 | h_offset = random.randint(0, image_h - crop_pair[1]) 202 | else: 203 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 204 | 205 | return crop_pair[0], crop_pair[1], w_offset, h_offset 206 | 207 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 208 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 209 | return random.choice(offsets) 210 | 211 | @staticmethod 212 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 213 | w_step = (image_w - crop_w) // 4 214 | h_step = (image_h - crop_h) // 4 215 | 216 | ret = list() 217 | ret.append((0, 0)) # upper left 218 | ret.append((4 * w_step, 0)) # upper right 219 | ret.append((0, 4 * h_step)) # lower left 220 | ret.append((4 * w_step, 4 * h_step)) # lower right 221 | ret.append((2 * w_step, 2 * h_step)) # center 222 | 223 | if more_fix_crop: 224 | ret.append((0, 2 * h_step)) # center left 225 | ret.append((4 * w_step, 2 * h_step)) # center right 226 | ret.append((2 * w_step, 4 * h_step)) # lower center 227 | ret.append((2 * w_step, 0 * h_step)) # upper center 228 | 229 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 230 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 231 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 232 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 233 | 234 | return ret 235 | 236 | 237 | class GroupRandomSizedCrop(object): 238 | def __init__(self, size, interpolation=Image.BILINEAR): 239 | self.size = size 240 | self.interpolation = interpolation 241 | 242 | def __call__(self, img_group): 243 | for attempt in range(10): 244 | area = img_group[0].size[0] * img_group[0].size[1] 245 | target_area = random.uniform(0.08, 1.0) * area 246 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 247 | 248 | w = int(round(math.sqrt(target_area * aspect_ratio))) 249 | h = int(round(math.sqrt(target_area / aspect_ratio))) 250 | 251 | if random.random() < 0.5: 252 | w, h = h, w 253 | 254 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]: 255 | x1 = random.randint(0, img_group[0].size[0] - w) 256 | y1 = random.randint(0, img_group[0].size[1] - h) 257 | found = True 258 | break 259 | else: 260 | found = False 261 | x1 = 0 262 | y1 = 0 263 | 264 | if found: 265 | out_group = list() 266 | for img in img_group: 267 | img = img.crop((x1, y1, x1 + w, y1 + h)) 268 | assert (img.size == (w, h)) 269 | out_group.append(img.resize((self.size, self.size), self.interpolation)) 270 | return out_group 271 | else: 272 | # Fallback 273 | scale = GroupScale(self.size, interpolation=self.interpolation) 274 | crop = GroupRandomCrop(self.size) 275 | return crop(scale(img_group)) 276 | 277 | 278 | class Stack(object): 279 | def __init__(self, roll=False): 280 | self.roll = roll 281 | 282 | def __call__(self, img_group): 283 | if self.roll: 284 | return np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2) 285 | else: 286 | return np.concatenate(img_group, axis=2) 287 | 288 | 289 | class ToTorchFormatTensor(object): 290 | def __init__(self, div=True): 291 | self.div = div 292 | 293 | def __call__(self, pic): 294 | if isinstance(pic, np.ndarray): 295 | # handle numpy array 296 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 297 | else: 298 | # handle PIL Image 299 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 300 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 301 | # from HWC to CHW format 302 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 303 | return img.float().div(255) if self.div else img.float() 304 | 305 | 306 | if __name__ == "__main__": 307 | trans = torchvision.transforms.Compose([ 308 | GroupScale(256), 309 | GroupRandomCrop(224), 310 | Stack(), 311 | ToTorchFormatTensor(), 312 | GroupNormalize( 313 | mean=[.485, .456, .406], 314 | std=[.229, .224, .225] 315 | )] 316 | ) 317 | -------------------------------------------------------------------------------- /ops/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import shutil 4 | import numpy as np 5 | 6 | 7 | class AverageMeter(object): 8 | def __init__(self): 9 | self.reset() 10 | 11 | def reset(self): 12 | self.val = 0 13 | self.avg = 0 14 | self.sum = 0 15 | self.count = 0 16 | 17 | def update(self, val, n=1): 18 | self.val = val 19 | self.sum += val * n 20 | self.count += n 21 | self.avg = self.sum / self.count 22 | 23 | 24 | def accuracy(output, target, topk=(1,)): 25 | maxk = max(topk) 26 | batch_size = target.size(0) 27 | 28 | _, pred = output.topk(maxk, 1, True, True) 29 | pred = pred.t() 30 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 31 | 32 | res = [] 33 | for k in topk: 34 | correct_k = correct[:k].view(-1).float().sum(0) 35 | res.append(correct_k.mul_(100.0 / batch_size)) 36 | return res 37 | 38 | 39 | def save_checkpoint(state, is_best, root_model, store_name): 40 | filename = '%s/%s/ckpt.pth.tar' % (root_model, store_name) 41 | torch.save(state, filename) 42 | if is_best: 43 | shutil.copyfile(filename, filename.replace('pth.tar', 'best.pth.tar')) 44 | 45 | 46 | def adjust_learning_rate(optimizer, epoch, lr_steps, lr, decay): 47 | lr_factor = 0.1 ** (sum(epoch >= np.array(lr_steps))) 48 | lr *= lr_factor 49 | for param_group in optimizer.param_groups: 50 | param_group['lr'] = lr * param_group['lr_mult'] 51 | param_group['weight_decay'] = decay * param_group['decay_mult'] 52 | 53 | 54 | def check_rootfolders(root_log, root_model, store_name): 55 | folders_util = [root_log, root_model, 56 | os.path.join(root_log, store_name), 57 | os.path.join(root_model, store_name)] 58 | for folder in folders_util: 59 | if not os.path.exists(folder): 60 | print('creating folder ' + folder) 61 | os.mkdir(folder) 62 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser(description="PyTorch Implementation of Temporal Shift Module for Jester") 4 | parser.add_argument('--train_list', type=str, default="") 5 | parser.add_argument('--val_list', type=str, default="") 6 | parser.add_argument('--test_list', type=str, default="") 7 | parser.add_argument('--store_name', type=str, default="") 8 | parser.add_argument('--mode', type=str, default="train") 9 | 10 | # ========================= Model Configs ========================== 11 | parser.add_argument('--arch', type=str, default="mobilenetv2") 12 | parser.add_argument('--num_segments', type=int, default=8) 13 | parser.add_argument('--dropout', default=0.5, type=float) 14 | 15 | # ========================= Learning Configs ========================== 16 | parser.add_argument('--epochs', default=30, type=int, help='number of total epochs') 17 | parser.add_argument('--batch_size', default=16, type=int, help='number of images per iteration') 18 | parser.add_argument('--update_weight', default=4, type=int, help='the actual batch size for updating weights') 19 | parser.add_argument('--lr', default=0.01, type=float) 20 | parser.add_argument('--lr_type', default='step', type=str) 21 | parser.add_argument('--lr_steps', default=[18, 24], type=float, nargs="+") 22 | parser.add_argument('--momentum', default=0.9, type=float) 23 | parser.add_argument('--weight_decay', default=1e-4, type=float) 24 | parser.add_argument('--no_partialbn', default=False, action="store_true") 25 | 26 | # ========================= Monitor Configs ========================== 27 | parser.add_argument('--print_freq', default=200, type=int) 28 | parser.add_argument('--eval-freq', default=1, type=int, help='evaluation frequency') 29 | 30 | # ========================= Runtime Configs ========================== 31 | parser.add_argument('--workers', default=8, type=int, help='number of data loading workers') 32 | parser.add_argument('--start_epoch', default=0, type=int, help='manual epoch number') 33 | parser.add_argument('--root_log', type=str, default='log') 34 | parser.add_argument('--root_model', type=str, default='checkpoint') 35 | parser.add_argument('--shift', default=False, action="store_true", help='use shift for models') 36 | parser.add_argument('--shift_div', default=8, type=int, help='number for shift') 37 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | future==0.17.1 2 | numpy==1.17.0 3 | pillow==6.2.0 4 | opencv-python==3.4.5.20 5 | torch==1.2.0 6 | torchvision==0.4.0 7 | tensorboard==1.14.0 8 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | python3 main.py \ 2 | --mode test \ 3 | --arch mobilenetv2 \ 4 | --num_segments 8 \ 5 | --update_weight 4 \ 6 | --no_partialbn \ 7 | --shift --shift_div=8 8 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python3 main.py \ 2 | --mode train \ 3 | --arch mobilenetv2 \ 4 | --num_segments 8 \ 5 | --update_weight 4 \ 6 | --no_partialbn \ 7 | --shift --shift_div=8 8 | --------------------------------------------------------------------------------