├── ops ├── __init__.py ├── utils.py └── basic_ops.py ├── .gitmodules ├── LICENSE ├── .gitignore ├── README.md ├── opts.py ├── dataset.py ├── test_models.py ├── main.py ├── transforms.py └── models.py /ops/__init__.py: -------------------------------------------------------------------------------- 1 | from ops.basic_ops import * -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "tf_model_zoo"] 2 | path = tf_model_zoo 3 | url = https://github.com/yjxiong/tensorflow-model-zoo.torch 4 | branch = tsn-pytorch 5 | -------------------------------------------------------------------------------- /ops/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.metrics import confusion_matrix 4 | 5 | def get_grad_hook(name): 6 | def hook(m, grad_in, grad_out): 7 | print((name, grad_out[0].data.abs().mean(), grad_in[0].data.abs().mean())) 8 | print((grad_out[0].size())) 9 | print((grad_in[0].size())) 10 | 11 | print((grad_out[0])) 12 | print((grad_in[0])) 13 | 14 | return hook 15 | 16 | 17 | def softmax(scores): 18 | es = np.exp(scores - scores.max(axis=-1)[..., None]) 19 | return es / es.sum(axis=-1)[..., None] 20 | 21 | 22 | def log_add(log_a, log_b): 23 | return log_a + np.log(1 + np.exp(log_b - log_a)) 24 | 25 | 26 | def class_accuracy(prediction, label): 27 | cf = confusion_matrix(prediction, label) 28 | cls_cnt = cf.sum(axis=1) 29 | cls_hit = np.diag(cf) 30 | 31 | cls_acc = cls_hit / cls_cnt.astype(float) 32 | 33 | mean_cls_acc = cls_acc.mean() 34 | 35 | return cls_acc, mean_cls_acc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2017, Multimedia Laboratary, The Chinese University of Hong Kong 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /ops/basic_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | class Identity(torch.nn.Module): 6 | def forward(self, input): 7 | return input 8 | 9 | 10 | class SegmentConsensus(torch.autograd.Function): 11 | 12 | def __init__(self, consensus_type, dim=1): 13 | self.consensus_type = consensus_type 14 | self.dim = dim 15 | self.shape = None 16 | 17 | def forward(self, input_tensor): 18 | self.shape = input_tensor.size() 19 | if self.consensus_type == 'avg': 20 | output = input_tensor.mean(dim=self.dim, keepdim=True) 21 | elif self.consensus_type == 'identity': 22 | output = input_tensor 23 | else: 24 | output = None 25 | 26 | return output 27 | 28 | def backward(self, grad_output): 29 | if self.consensus_type == 'avg': 30 | grad_in = grad_output.expand(self.shape) / float(self.shape[self.dim]) 31 | elif self.consensus_type == 'identity': 32 | grad_in = grad_output 33 | else: 34 | grad_in = None 35 | 36 | return grad_in 37 | 38 | 39 | class ConsensusModule(torch.nn.Module): 40 | 41 | def __init__(self, consensus_type, dim=1): 42 | super(ConsensusModule, self).__init__() 43 | self.consensus_type = consensus_type if consensus_type != 'rnn' else 'identity' 44 | self.dim = dim 45 | 46 | def forward(self, input): 47 | return SegmentConsensus(self.consensus_type, self.dim)(input) 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | .idea -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TSN-Pytorch 2 | 3 | **We have released [MMAction](https://github.com/open-mmlab/mmaction), a full-fledged action understanding toolbox based on PyTorch. It includes implementation for TSN as well as other STOA frameworks for various tasks. The lessons we learned in this repo are incorporated into MMAction to make it bettter. We highly recommend you switch to it. This repo will remain here for historical references.** 4 | 5 | **Note**: always use `git clone --recursive https://github.com/yjxiong/tsn-pytorch` to clone this project. 6 | Otherwise you will not be able to use the inception series CNN archs. 7 | 8 | This is a reimplementation of temporal segment networks (TSN) in PyTorch. All settings are kept identical to the original caffe implementation. 9 | 10 | For optical flow extraction and video list generation, you still need to use the original [TSN codebase](https://github.com/yjxiong/temporal-segment-networks). 11 | 12 | ## Training 13 | 14 | To train a new model, use the `main.py` script. 15 | 16 | The command to reproduce the original TSN experiments of RGB modality on UCF101 can be 17 | 18 | ```bash 19 | python main.py ucf101 RGB \ 20 | --arch BNInception --num_segments 3 \ 21 | --gd 20 --lr 0.001 --lr_steps 30 60 --epochs 80 \ 22 | -b 128 -j 8 --dropout 0.8 \ 23 | --snapshot_pref ucf101_bninception_ 24 | ``` 25 | 26 | For flow models: 27 | 28 | ```bash 29 | python main.py ucf101 Flow \ 30 | --arch BNInception --num_segments 3 \ 31 | --gd 20 --lr 0.001 --lr_steps 190 300 --epochs 340 \ 32 | -b 128 -j 8 --dropout 0.7 \ 33 | --snapshot_pref ucf101_bninception_ --flow_pref flow_ 34 | ``` 35 | 36 | For RGB-diff models: 37 | 38 | ```bash 39 | python main.py ucf101 RGBDiff \ 40 | --arch BNInception --num_segments 7 \ 41 | --gd 40 --lr 0.001 --lr_steps 80 160 --epochs 180 \ 42 | -b 128 -j 8 --dropout 0.8 \ 43 | --snapshot_pref ucf101_bninception_ 44 | ``` 45 | 46 | ## Testing 47 | 48 | After training, there will checkpoints saved by pytorch, for example `ucf101_bninception_rgb_checkpoint.pth`. 49 | 50 | Use the following command to test its performance in the standard TSN testing protocol: 51 | 52 | ```bash 53 | python test_models.py ucf101 RGB ucf101_bninception_rgb_checkpoint.pth \ 54 | --arch BNInception --save_scores 55 | 56 | ``` 57 | 58 | Or for flow models: 59 | 60 | ```bash 61 | python test_models.py ucf101 Flow ucf101_bninception_flow_checkpoint.pth \ 62 | --arch BNInception --save_scores --flow_pref flow_ 63 | 64 | ``` 65 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | parser = argparse.ArgumentParser(description="PyTorch implementation of Temporal Segment Networks") 3 | parser.add_argument('dataset', type=str, choices=['ucf101', 'hmdb51', 'kinetics']) 4 | parser.add_argument('modality', type=str, choices=['RGB', 'Flow', 'RGBDiff']) 5 | parser.add_argument('train_list', type=str) 6 | parser.add_argument('val_list', type=str) 7 | 8 | # ========================= Model Configs ========================== 9 | parser.add_argument('--arch', type=str, default="resnet101") 10 | parser.add_argument('--num_segments', type=int, default=3) 11 | parser.add_argument('--consensus_type', type=str, default='avg', 12 | choices=['avg', 'max', 'topk', 'identity', 'rnn', 'cnn']) 13 | parser.add_argument('--k', type=int, default=3) 14 | 15 | parser.add_argument('--dropout', '--do', default=0.5, type=float, 16 | metavar='DO', help='dropout ratio (default: 0.5)') 17 | parser.add_argument('--loss_type', type=str, default="nll", 18 | choices=['nll']) 19 | 20 | # ========================= Learning Configs ========================== 21 | parser.add_argument('--epochs', default=45, type=int, metavar='N', 22 | help='number of total epochs to run') 23 | parser.add_argument('-b', '--batch-size', default=256, type=int, 24 | metavar='N', help='mini-batch size (default: 256)') 25 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, 26 | metavar='LR', help='initial learning rate') 27 | parser.add_argument('--lr_steps', default=[20, 40], type=float, nargs="+", 28 | metavar='LRSteps', help='epochs to decay learning rate by 10') 29 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 30 | help='momentum') 31 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 32 | metavar='W', help='weight decay (default: 5e-4)') 33 | parser.add_argument('--clip-gradient', '--gd', default=None, type=float, 34 | metavar='W', help='gradient norm clipping (default: disabled)') 35 | parser.add_argument('--no_partialbn', '--npb', default=False, action="store_true") 36 | 37 | # ========================= Monitor Configs ========================== 38 | parser.add_argument('--print-freq', '-p', default=20, type=int, 39 | metavar='N', help='print frequency (default: 10)') 40 | parser.add_argument('--eval-freq', '-ef', default=5, type=int, 41 | metavar='N', help='evaluation frequency (default: 5)') 42 | 43 | 44 | # ========================= Runtime Configs ========================== 45 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 46 | help='number of data loading workers (default: 4)') 47 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 48 | help='path to latest checkpoint (default: none)') 49 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 50 | help='evaluate model on validation set') 51 | parser.add_argument('--snapshot_pref', type=str, default="") 52 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 53 | help='manual epoch number (useful on restarts)') 54 | parser.add_argument('--gpus', nargs='+', type=int, default=None) 55 | parser.add_argument('--flow_prefix', default="", type=str) 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | from numpy.random import randint 8 | 9 | class VideoRecord(object): 10 | def __init__(self, row): 11 | self._data = row 12 | 13 | @property 14 | def path(self): 15 | return self._data[0] 16 | 17 | @property 18 | def num_frames(self): 19 | return int(self._data[1]) 20 | 21 | @property 22 | def label(self): 23 | return int(self._data[2]) 24 | 25 | 26 | class TSNDataSet(data.Dataset): 27 | def __init__(self, root_path, list_file, 28 | num_segments=3, new_length=1, modality='RGB', 29 | image_tmpl='img_{:05d}.jpg', transform=None, 30 | force_grayscale=False, random_shift=True, test_mode=False): 31 | 32 | self.root_path = root_path 33 | self.list_file = list_file 34 | self.num_segments = num_segments 35 | self.new_length = new_length 36 | self.modality = modality 37 | self.image_tmpl = image_tmpl 38 | self.transform = transform 39 | self.random_shift = random_shift 40 | self.test_mode = test_mode 41 | 42 | if self.modality == 'RGBDiff': 43 | self.new_length += 1# Diff needs one more image to calculate diff 44 | 45 | self._parse_list() 46 | 47 | def _load_image(self, directory, idx): 48 | if self.modality == 'RGB' or self.modality == 'RGBDiff': 49 | return [Image.open(os.path.join(directory, self.image_tmpl.format(idx))).convert('RGB')] 50 | elif self.modality == 'Flow': 51 | x_img = Image.open(os.path.join(directory, self.image_tmpl.format('x', idx))).convert('L') 52 | y_img = Image.open(os.path.join(directory, self.image_tmpl.format('y', idx))).convert('L') 53 | 54 | return [x_img, y_img] 55 | 56 | def _parse_list(self): 57 | self.video_list = [VideoRecord(x.strip().split(' ')) for x in open(self.list_file)] 58 | 59 | def _sample_indices(self, record): 60 | """ 61 | 62 | :param record: VideoRecord 63 | :return: list 64 | """ 65 | 66 | average_duration = (record.num_frames - self.new_length + 1) // self.num_segments 67 | if average_duration > 0: 68 | offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration, size=self.num_segments) 69 | elif record.num_frames > self.num_segments: 70 | offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments)) 71 | else: 72 | offsets = np.zeros((self.num_segments,)) 73 | return offsets + 1 74 | 75 | def _get_val_indices(self, record): 76 | if record.num_frames > self.num_segments + self.new_length - 1: 77 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) 78 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) 79 | else: 80 | offsets = np.zeros((self.num_segments,)) 81 | return offsets + 1 82 | 83 | def _get_test_indices(self, record): 84 | 85 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) 86 | 87 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) 88 | 89 | return offsets + 1 90 | 91 | def __getitem__(self, index): 92 | record = self.video_list[index] 93 | 94 | if not self.test_mode: 95 | segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record) 96 | else: 97 | segment_indices = self._get_test_indices(record) 98 | 99 | return self.get(record, segment_indices) 100 | 101 | def get(self, record, indices): 102 | 103 | images = list() 104 | for seg_ind in indices: 105 | p = int(seg_ind) 106 | for i in range(self.new_length): 107 | seg_imgs = self._load_image(record.path, p) 108 | images.extend(seg_imgs) 109 | if p < record.num_frames: 110 | p += 1 111 | 112 | process_data = self.transform(images) 113 | return process_data, record.label 114 | 115 | def __len__(self): 116 | return len(self.video_list) 117 | -------------------------------------------------------------------------------- /test_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import numpy as np 5 | import torch.nn.parallel 6 | import torch.optim 7 | from sklearn.metrics import confusion_matrix 8 | 9 | from dataset import TSNDataSet 10 | from models import TSN 11 | from transforms import * 12 | from ops import ConsensusModule 13 | 14 | # options 15 | parser = argparse.ArgumentParser( 16 | description="Standard video-level testing") 17 | parser.add_argument('dataset', type=str, choices=['ucf101', 'hmdb51', 'kinetics']) 18 | parser.add_argument('modality', type=str, choices=['RGB', 'Flow', 'RGBDiff']) 19 | parser.add_argument('test_list', type=str) 20 | parser.add_argument('weights', type=str) 21 | parser.add_argument('--arch', type=str, default="resnet101") 22 | parser.add_argument('--save_scores', type=str, default=None) 23 | parser.add_argument('--test_segments', type=int, default=25) 24 | parser.add_argument('--max_num', type=int, default=-1) 25 | parser.add_argument('--test_crops', type=int, default=10) 26 | parser.add_argument('--input_size', type=int, default=224) 27 | parser.add_argument('--crop_fusion_type', type=str, default='avg', 28 | choices=['avg', 'max', 'topk']) 29 | parser.add_argument('--k', type=int, default=3) 30 | parser.add_argument('--dropout', type=float, default=0.7) 31 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 32 | help='number of data loading workers (default: 4)') 33 | parser.add_argument('--gpus', nargs='+', type=int, default=None) 34 | parser.add_argument('--flow_prefix', type=str, default='') 35 | 36 | args = parser.parse_args() 37 | 38 | 39 | if args.dataset == 'ucf101': 40 | num_class = 101 41 | elif args.dataset == 'hmdb51': 42 | num_class = 51 43 | elif args.dataset == 'kinetics': 44 | num_class = 400 45 | else: 46 | raise ValueError('Unknown dataset '+args.dataset) 47 | 48 | net = TSN(num_class, 1, args.modality, 49 | base_model=args.arch, 50 | consensus_type=args.crop_fusion_type, 51 | dropout=args.dropout) 52 | 53 | checkpoint = torch.load(args.weights) 54 | print("model epoch {} best prec@1: {}".format(checkpoint['epoch'], checkpoint['best_prec1'])) 55 | 56 | base_dict = {'.'.join(k.split('.')[1:]): v for k,v in list(checkpoint['state_dict'].items())} 57 | net.load_state_dict(base_dict) 58 | 59 | if args.test_crops == 1: 60 | cropping = torchvision.transforms.Compose([ 61 | GroupScale(net.scale_size), 62 | GroupCenterCrop(net.input_size), 63 | ]) 64 | elif args.test_crops == 10: 65 | cropping = torchvision.transforms.Compose([ 66 | GroupOverSample(net.input_size, net.scale_size) 67 | ]) 68 | else: 69 | raise ValueError("Only 1 and 10 crops are supported while we got {}".format(args.test_crops)) 70 | 71 | data_loader = torch.utils.data.DataLoader( 72 | TSNDataSet("", args.test_list, num_segments=args.test_segments, 73 | new_length=1 if args.modality == "RGB" else 5, 74 | modality=args.modality, 75 | image_tmpl="img_{:05d}.jpg" if args.modality in ['RGB', 'RGBDiff'] else args.flow_prefix+"{}_{:05d}.jpg", 76 | test_mode=True, 77 | transform=torchvision.transforms.Compose([ 78 | cropping, 79 | Stack(roll=args.arch == 'BNInception'), 80 | ToTorchFormatTensor(div=args.arch != 'BNInception'), 81 | GroupNormalize(net.input_mean, net.input_std), 82 | ])), 83 | batch_size=1, shuffle=False, 84 | num_workers=args.workers * 2, pin_memory=True) 85 | 86 | if args.gpus is not None: 87 | devices = [args.gpus[i] for i in range(args.workers)] 88 | else: 89 | devices = list(range(args.workers)) 90 | 91 | 92 | net = torch.nn.DataParallel(net.cuda(devices[0]), device_ids=devices) 93 | net.eval() 94 | 95 | data_gen = enumerate(data_loader) 96 | 97 | total_num = len(data_loader.dataset) 98 | output = [] 99 | 100 | 101 | def eval_video(video_data): 102 | i, data, label = video_data 103 | num_crop = args.test_crops 104 | 105 | if args.modality == 'RGB': 106 | length = 3 107 | elif args.modality == 'Flow': 108 | length = 10 109 | elif args.modality == 'RGBDiff': 110 | length = 18 111 | else: 112 | raise ValueError("Unknown modality "+args.modality) 113 | 114 | input_var = torch.autograd.Variable(data.view(-1, length, data.size(2), data.size(3)), 115 | volatile=True) 116 | rst = net(input_var).data.cpu().numpy().copy() 117 | return i, rst.reshape((num_crop, args.test_segments, num_class)).mean(axis=0).reshape( 118 | (args.test_segments, 1, num_class) 119 | ), label[0] 120 | 121 | 122 | proc_start_time = time.time() 123 | max_num = args.max_num if args.max_num > 0 else len(data_loader.dataset) 124 | 125 | for i, (data, label) in data_gen: 126 | if i >= max_num: 127 | break 128 | rst = eval_video((i, data, label)) 129 | output.append(rst[1:]) 130 | cnt_time = time.time() - proc_start_time 131 | print('video {} done, total {}/{}, average {} sec/video'.format(i, i+1, 132 | total_num, 133 | float(cnt_time) / (i+1))) 134 | 135 | video_pred = [np.argmax(np.mean(x[0], axis=0)) for x in output] 136 | 137 | video_labels = [x[1] for x in output] 138 | 139 | 140 | cf = confusion_matrix(video_labels, video_pred).astype(float) 141 | 142 | cls_cnt = cf.sum(axis=1) 143 | cls_hit = np.diag(cf) 144 | 145 | cls_acc = cls_hit / cls_cnt 146 | 147 | print(cls_acc) 148 | 149 | print('Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100)) 150 | 151 | if args.save_scores is not None: 152 | 153 | # reorder before saving 154 | name_list = [x.strip().split()[0] for x in open(args.test_list)] 155 | 156 | order_dict = {e:i for i, e in enumerate(sorted(name_list))} 157 | 158 | reorder_output = [None] * len(output) 159 | reorder_label = [None] * len(output) 160 | 161 | for i in range(len(output)): 162 | idx = order_dict[name_list[i]] 163 | reorder_output[idx] = output[i] 164 | reorder_label[idx] = video_labels[i] 165 | 166 | np.savez(args.save_scores, scores=reorder_output, labels=reorder_label) 167 | 168 | 169 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import shutil 5 | import torch 6 | import torchvision 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | from torch.nn.utils import clip_grad_norm 11 | 12 | from dataset import TSNDataSet 13 | from models import TSN 14 | from transforms import * 15 | from opts import parser 16 | 17 | best_prec1 = 0 18 | 19 | 20 | def main(): 21 | global args, best_prec1 22 | args = parser.parse_args() 23 | 24 | if args.dataset == 'ucf101': 25 | num_class = 101 26 | elif args.dataset == 'hmdb51': 27 | num_class = 51 28 | elif args.dataset == 'kinetics': 29 | num_class = 400 30 | else: 31 | raise ValueError('Unknown dataset '+args.dataset) 32 | 33 | model = TSN(num_class, args.num_segments, args.modality, 34 | base_model=args.arch, 35 | consensus_type=args.consensus_type, dropout=args.dropout, partial_bn=not args.no_partialbn) 36 | 37 | crop_size = model.crop_size 38 | scale_size = model.scale_size 39 | input_mean = model.input_mean 40 | input_std = model.input_std 41 | policies = model.get_optim_policies() 42 | train_augmentation = model.get_augmentation() 43 | 44 | model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() 45 | 46 | if args.resume: 47 | if os.path.isfile(args.resume): 48 | print(("=> loading checkpoint '{}'".format(args.resume))) 49 | checkpoint = torch.load(args.resume) 50 | args.start_epoch = checkpoint['epoch'] 51 | best_prec1 = checkpoint['best_prec1'] 52 | model.load_state_dict(checkpoint['state_dict']) 53 | print(("=> loaded checkpoint '{}' (epoch {})" 54 | .format(args.evaluate, checkpoint['epoch']))) 55 | else: 56 | print(("=> no checkpoint found at '{}'".format(args.resume))) 57 | 58 | cudnn.benchmark = True 59 | 60 | # Data loading code 61 | if args.modality != 'RGBDiff': 62 | normalize = GroupNormalize(input_mean, input_std) 63 | else: 64 | normalize = IdentityTransform() 65 | 66 | if args.modality == 'RGB': 67 | data_length = 1 68 | elif args.modality in ['Flow', 'RGBDiff']: 69 | data_length = 5 70 | 71 | train_loader = torch.utils.data.DataLoader( 72 | TSNDataSet("", args.train_list, num_segments=args.num_segments, 73 | new_length=data_length, 74 | modality=args.modality, 75 | image_tmpl="img_{:05d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+"{}_{:05d}.jpg", 76 | transform=torchvision.transforms.Compose([ 77 | train_augmentation, 78 | Stack(roll=args.arch == 'BNInception'), 79 | ToTorchFormatTensor(div=args.arch != 'BNInception'), 80 | normalize, 81 | ])), 82 | batch_size=args.batch_size, shuffle=True, 83 | num_workers=args.workers, pin_memory=True) 84 | 85 | val_loader = torch.utils.data.DataLoader( 86 | TSNDataSet("", args.val_list, num_segments=args.num_segments, 87 | new_length=data_length, 88 | modality=args.modality, 89 | image_tmpl="img_{:05d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+"{}_{:05d}.jpg", 90 | random_shift=False, 91 | transform=torchvision.transforms.Compose([ 92 | GroupScale(int(scale_size)), 93 | GroupCenterCrop(crop_size), 94 | Stack(roll=args.arch == 'BNInception'), 95 | ToTorchFormatTensor(div=args.arch != 'BNInception'), 96 | normalize, 97 | ])), 98 | batch_size=args.batch_size, shuffle=False, 99 | num_workers=args.workers, pin_memory=True) 100 | 101 | # define loss function (criterion) and optimizer 102 | if args.loss_type == 'nll': 103 | criterion = torch.nn.CrossEntropyLoss().cuda() 104 | else: 105 | raise ValueError("Unknown loss type") 106 | 107 | for group in policies: 108 | print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format( 109 | group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) 110 | 111 | optimizer = torch.optim.SGD(policies, 112 | args.lr, 113 | momentum=args.momentum, 114 | weight_decay=args.weight_decay) 115 | 116 | if args.evaluate: 117 | validate(val_loader, model, criterion, 0) 118 | return 119 | 120 | for epoch in range(args.start_epoch, args.epochs): 121 | adjust_learning_rate(optimizer, epoch, args.lr_steps) 122 | 123 | # train for one epoch 124 | train(train_loader, model, criterion, optimizer, epoch) 125 | 126 | # evaluate on validation set 127 | if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: 128 | prec1 = validate(val_loader, model, criterion, (epoch + 1) * len(train_loader)) 129 | 130 | # remember best prec@1 and save checkpoint 131 | is_best = prec1 > best_prec1 132 | best_prec1 = max(prec1, best_prec1) 133 | save_checkpoint({ 134 | 'epoch': epoch + 1, 135 | 'arch': args.arch, 136 | 'state_dict': model.state_dict(), 137 | 'best_prec1': best_prec1, 138 | }, is_best) 139 | 140 | 141 | def train(train_loader, model, criterion, optimizer, epoch): 142 | batch_time = AverageMeter() 143 | data_time = AverageMeter() 144 | losses = AverageMeter() 145 | top1 = AverageMeter() 146 | top5 = AverageMeter() 147 | 148 | if args.no_partialbn: 149 | model.module.partialBN(False) 150 | else: 151 | model.module.partialBN(True) 152 | 153 | # switch to train mode 154 | model.train() 155 | 156 | end = time.time() 157 | for i, (input, target) in enumerate(train_loader): 158 | # measure data loading time 159 | data_time.update(time.time() - end) 160 | 161 | target = target.cuda(async=True) 162 | input_var = torch.autograd.Variable(input) 163 | target_var = torch.autograd.Variable(target) 164 | 165 | # compute output 166 | output = model(input_var) 167 | loss = criterion(output, target_var) 168 | 169 | # measure accuracy and record loss 170 | prec1, prec5 = accuracy(output.data, target, topk=(1,5)) 171 | losses.update(loss.data[0], input.size(0)) 172 | top1.update(prec1[0], input.size(0)) 173 | top5.update(prec5[0], input.size(0)) 174 | 175 | 176 | # compute gradient and do SGD step 177 | optimizer.zero_grad() 178 | 179 | loss.backward() 180 | 181 | if args.clip_gradient is not None: 182 | total_norm = clip_grad_norm(model.parameters(), args.clip_gradient) 183 | if total_norm > args.clip_gradient: 184 | print("clipping gradient: {} with coef {}".format(total_norm, args.clip_gradient / total_norm)) 185 | 186 | optimizer.step() 187 | 188 | # measure elapsed time 189 | batch_time.update(time.time() - end) 190 | end = time.time() 191 | 192 | if i % args.print_freq == 0: 193 | print(('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t' 194 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 195 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 196 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 197 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 198 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 199 | epoch, i, len(train_loader), batch_time=batch_time, 200 | data_time=data_time, loss=losses, top1=top1, top5=top5, lr=optimizer.param_groups[-1]['lr']))) 201 | 202 | 203 | def validate(val_loader, model, criterion, iter, logger=None): 204 | batch_time = AverageMeter() 205 | losses = AverageMeter() 206 | top1 = AverageMeter() 207 | top5 = AverageMeter() 208 | 209 | # switch to evaluate mode 210 | model.eval() 211 | 212 | end = time.time() 213 | for i, (input, target) in enumerate(val_loader): 214 | target = target.cuda(async=True) 215 | input_var = torch.autograd.Variable(input, volatile=True) 216 | target_var = torch.autograd.Variable(target, volatile=True) 217 | 218 | # compute output 219 | output = model(input_var) 220 | loss = criterion(output, target_var) 221 | 222 | # measure accuracy and record loss 223 | prec1, prec5 = accuracy(output.data, target, topk=(1,5)) 224 | 225 | losses.update(loss.data[0], input.size(0)) 226 | top1.update(prec1[0], input.size(0)) 227 | top5.update(prec5[0], input.size(0)) 228 | 229 | # measure elapsed time 230 | batch_time.update(time.time() - end) 231 | end = time.time() 232 | 233 | if i % args.print_freq == 0: 234 | print(('Test: [{0}/{1}]\t' 235 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 236 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 237 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 238 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 239 | i, len(val_loader), batch_time=batch_time, loss=losses, 240 | top1=top1, top5=top5))) 241 | 242 | print(('Testing Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}' 243 | .format(top1=top1, top5=top5, loss=losses))) 244 | 245 | return top1.avg 246 | 247 | 248 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 249 | filename = '_'.join((args.snapshot_pref, args.modality.lower(), filename)) 250 | torch.save(state, filename) 251 | if is_best: 252 | best_name = '_'.join((args.snapshot_pref, args.modality.lower(), 'model_best.pth.tar')) 253 | shutil.copyfile(filename, best_name) 254 | 255 | 256 | class AverageMeter(object): 257 | """Computes and stores the average and current value""" 258 | def __init__(self): 259 | self.reset() 260 | 261 | def reset(self): 262 | self.val = 0 263 | self.avg = 0 264 | self.sum = 0 265 | self.count = 0 266 | 267 | def update(self, val, n=1): 268 | self.val = val 269 | self.sum += val * n 270 | self.count += n 271 | self.avg = self.sum / self.count 272 | 273 | 274 | def adjust_learning_rate(optimizer, epoch, lr_steps): 275 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 276 | decay = 0.1 ** (sum(epoch >= np.array(lr_steps))) 277 | lr = args.lr * decay 278 | decay = args.weight_decay 279 | for param_group in optimizer.param_groups: 280 | param_group['lr'] = lr * param_group['lr_mult'] 281 | param_group['weight_decay'] = decay * param_group['decay_mult'] 282 | 283 | 284 | def accuracy(output, target, topk=(1,)): 285 | """Computes the precision@k for the specified values of k""" 286 | maxk = max(topk) 287 | batch_size = target.size(0) 288 | 289 | _, pred = output.topk(maxk, 1, True, True) 290 | pred = pred.t() 291 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 292 | 293 | res = [] 294 | for k in topk: 295 | correct_k = correct[:k].view(-1).float().sum(0) 296 | res.append(correct_k.mul_(100.0 / batch_size)) 297 | return res 298 | 299 | 300 | if __name__ == '__main__': 301 | main() 302 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import random 3 | from PIL import Image, ImageOps 4 | import numpy as np 5 | import numbers 6 | import math 7 | import torch 8 | 9 | 10 | class GroupRandomCrop(object): 11 | def __init__(self, size): 12 | if isinstance(size, numbers.Number): 13 | self.size = (int(size), int(size)) 14 | else: 15 | self.size = size 16 | 17 | def __call__(self, img_group): 18 | 19 | w, h = img_group[0].size 20 | th, tw = self.size 21 | 22 | out_images = list() 23 | 24 | x1 = random.randint(0, w - tw) 25 | y1 = random.randint(0, h - th) 26 | 27 | for img in img_group: 28 | assert(img.size[0] == w and img.size[1] == h) 29 | if w == tw and h == th: 30 | out_images.append(img) 31 | else: 32 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 33 | 34 | return out_images 35 | 36 | 37 | class GroupCenterCrop(object): 38 | def __init__(self, size): 39 | self.worker = torchvision.transforms.CenterCrop(size) 40 | 41 | def __call__(self, img_group): 42 | return [self.worker(img) for img in img_group] 43 | 44 | 45 | class GroupRandomHorizontalFlip(object): 46 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 47 | """ 48 | def __init__(self, is_flow=False): 49 | self.is_flow = is_flow 50 | 51 | def __call__(self, img_group, is_flow=False): 52 | v = random.random() 53 | if v < 0.5: 54 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] 55 | if self.is_flow: 56 | for i in range(0, len(ret), 2): 57 | ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping 58 | return ret 59 | else: 60 | return img_group 61 | 62 | 63 | class GroupNormalize(object): 64 | def __init__(self, mean, std): 65 | self.mean = mean 66 | self.std = std 67 | 68 | def __call__(self, tensor): 69 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean)) 70 | rep_std = self.std * (tensor.size()[0]//len(self.std)) 71 | 72 | # TODO: make efficient 73 | for t, m, s in zip(tensor, rep_mean, rep_std): 74 | t.sub_(m).div_(s) 75 | 76 | return tensor 77 | 78 | 79 | class GroupScale(object): 80 | """ Rescales the input PIL.Image to the given 'size'. 81 | 'size' will be the size of the smaller edge. 82 | For example, if height > width, then image will be 83 | rescaled to (size * height / width, size) 84 | size: size of the smaller edge 85 | interpolation: Default: PIL.Image.BILINEAR 86 | """ 87 | 88 | def __init__(self, size, interpolation=Image.BILINEAR): 89 | self.worker = torchvision.transforms.Scale(size, interpolation) 90 | 91 | def __call__(self, img_group): 92 | return [self.worker(img) for img in img_group] 93 | 94 | 95 | class GroupOverSample(object): 96 | def __init__(self, crop_size, scale_size=None): 97 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) 98 | 99 | if scale_size is not None: 100 | self.scale_worker = GroupScale(scale_size) 101 | else: 102 | self.scale_worker = None 103 | 104 | def __call__(self, img_group): 105 | 106 | if self.scale_worker is not None: 107 | img_group = self.scale_worker(img_group) 108 | 109 | image_w, image_h = img_group[0].size 110 | crop_w, crop_h = self.crop_size 111 | 112 | offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h) 113 | oversample_group = list() 114 | for o_w, o_h in offsets: 115 | normal_group = list() 116 | flip_group = list() 117 | for i, img in enumerate(img_group): 118 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 119 | normal_group.append(crop) 120 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 121 | 122 | if img.mode == 'L' and i % 2 == 0: 123 | flip_group.append(ImageOps.invert(flip_crop)) 124 | else: 125 | flip_group.append(flip_crop) 126 | 127 | oversample_group.extend(normal_group) 128 | oversample_group.extend(flip_group) 129 | return oversample_group 130 | 131 | 132 | class GroupMultiScaleCrop(object): 133 | 134 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 135 | self.scales = scales if scales is not None else [1, .875, .75, .66] 136 | self.max_distort = max_distort 137 | self.fix_crop = fix_crop 138 | self.more_fix_crop = more_fix_crop 139 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 140 | self.interpolation = Image.BILINEAR 141 | 142 | def __call__(self, img_group): 143 | 144 | im_size = img_group[0].size 145 | 146 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 147 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 148 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) 149 | for img in crop_img_group] 150 | return ret_img_group 151 | 152 | def _sample_crop_size(self, im_size): 153 | image_w, image_h = im_size[0], im_size[1] 154 | 155 | # find a crop size 156 | base_size = min(image_w, image_h) 157 | crop_sizes = [int(base_size * x) for x in self.scales] 158 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 159 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 160 | 161 | pairs = [] 162 | for i, h in enumerate(crop_h): 163 | for j, w in enumerate(crop_w): 164 | if abs(i - j) <= self.max_distort: 165 | pairs.append((w, h)) 166 | 167 | crop_pair = random.choice(pairs) 168 | if not self.fix_crop: 169 | w_offset = random.randint(0, image_w - crop_pair[0]) 170 | h_offset = random.randint(0, image_h - crop_pair[1]) 171 | else: 172 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 173 | 174 | return crop_pair[0], crop_pair[1], w_offset, h_offset 175 | 176 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 177 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 178 | return random.choice(offsets) 179 | 180 | @staticmethod 181 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 182 | w_step = (image_w - crop_w) // 4 183 | h_step = (image_h - crop_h) // 4 184 | 185 | ret = list() 186 | ret.append((0, 0)) # upper left 187 | ret.append((4 * w_step, 0)) # upper right 188 | ret.append((0, 4 * h_step)) # lower left 189 | ret.append((4 * w_step, 4 * h_step)) # lower right 190 | ret.append((2 * w_step, 2 * h_step)) # center 191 | 192 | if more_fix_crop: 193 | ret.append((0, 2 * h_step)) # center left 194 | ret.append((4 * w_step, 2 * h_step)) # center right 195 | ret.append((2 * w_step, 4 * h_step)) # lower center 196 | ret.append((2 * w_step, 0 * h_step)) # upper center 197 | 198 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 199 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 200 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 201 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 202 | 203 | return ret 204 | 205 | 206 | class GroupRandomSizedCrop(object): 207 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size 208 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 209 | This is popularly used to train the Inception networks 210 | size: size of the smaller edge 211 | interpolation: Default: PIL.Image.BILINEAR 212 | """ 213 | def __init__(self, size, interpolation=Image.BILINEAR): 214 | self.size = size 215 | self.interpolation = interpolation 216 | 217 | def __call__(self, img_group): 218 | for attempt in range(10): 219 | area = img_group[0].size[0] * img_group[0].size[1] 220 | target_area = random.uniform(0.08, 1.0) * area 221 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 222 | 223 | w = int(round(math.sqrt(target_area * aspect_ratio))) 224 | h = int(round(math.sqrt(target_area / aspect_ratio))) 225 | 226 | if random.random() < 0.5: 227 | w, h = h, w 228 | 229 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]: 230 | x1 = random.randint(0, img_group[0].size[0] - w) 231 | y1 = random.randint(0, img_group[0].size[1] - h) 232 | found = True 233 | break 234 | else: 235 | found = False 236 | x1 = 0 237 | y1 = 0 238 | 239 | if found: 240 | out_group = list() 241 | for img in img_group: 242 | img = img.crop((x1, y1, x1 + w, y1 + h)) 243 | assert(img.size == (w, h)) 244 | out_group.append(img.resize((self.size, self.size), self.interpolation)) 245 | return out_group 246 | else: 247 | # Fallback 248 | scale = GroupScale(self.size, interpolation=self.interpolation) 249 | crop = GroupRandomCrop(self.size) 250 | return crop(scale(img_group)) 251 | 252 | 253 | class Stack(object): 254 | 255 | def __init__(self, roll=False): 256 | self.roll = roll 257 | 258 | def __call__(self, img_group): 259 | if img_group[0].mode == 'L': 260 | return np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2) 261 | elif img_group[0].mode == 'RGB': 262 | if self.roll: 263 | return np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2) 264 | else: 265 | return np.concatenate(img_group, axis=2) 266 | 267 | 268 | class ToTorchFormatTensor(object): 269 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 270 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 271 | def __init__(self, div=True): 272 | self.div = div 273 | 274 | def __call__(self, pic): 275 | if isinstance(pic, np.ndarray): 276 | # handle numpy array 277 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 278 | else: 279 | # handle PIL Image 280 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 281 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 282 | # put it from HWC to CHW format 283 | # yikes, this transpose takes 80% of the loading time/CPU 284 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 285 | return img.float().div(255) if self.div else img.float() 286 | 287 | 288 | class IdentityTransform(object): 289 | 290 | def __call__(self, data): 291 | return data 292 | 293 | 294 | if __name__ == "__main__": 295 | trans = torchvision.transforms.Compose([ 296 | GroupScale(256), 297 | GroupRandomCrop(224), 298 | Stack(), 299 | ToTorchFormatTensor(), 300 | GroupNormalize( 301 | mean=[.485, .456, .406], 302 | std=[.229, .224, .225] 303 | )] 304 | ) 305 | 306 | im = Image.open('../tensorflow-model-zoo.torch/lena_299.png') 307 | 308 | color_group = [im] * 3 309 | rst = trans(color_group) 310 | 311 | gray_group = [im.convert('L')] * 9 312 | gray_rst = trans(gray_group) 313 | 314 | trans2 = torchvision.transforms.Compose([ 315 | GroupRandomSizedCrop(256), 316 | Stack(), 317 | ToTorchFormatTensor(), 318 | GroupNormalize( 319 | mean=[.485, .456, .406], 320 | std=[.229, .224, .225]) 321 | ]) 322 | print(trans2(color_group)) 323 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from ops.basic_ops import ConsensusModule, Identity 4 | from transforms import * 5 | from torch.nn.init import normal, constant 6 | 7 | class TSN(nn.Module): 8 | def __init__(self, num_class, num_segments, modality, 9 | base_model='resnet101', new_length=None, 10 | consensus_type='avg', before_softmax=True, 11 | dropout=0.8, 12 | crop_num=1, partial_bn=True): 13 | super(TSN, self).__init__() 14 | self.modality = modality 15 | self.num_segments = num_segments 16 | self.reshape = True 17 | self.before_softmax = before_softmax 18 | self.dropout = dropout 19 | self.crop_num = crop_num 20 | self.consensus_type = consensus_type 21 | if not before_softmax and consensus_type != 'avg': 22 | raise ValueError("Only avg consensus can be used after Softmax") 23 | 24 | if new_length is None: 25 | self.new_length = 1 if modality == "RGB" else 5 26 | else: 27 | self.new_length = new_length 28 | 29 | print((""" 30 | Initializing TSN with base model: {}. 31 | TSN Configurations: 32 | input_modality: {} 33 | num_segments: {} 34 | new_length: {} 35 | consensus_module: {} 36 | dropout_ratio: {} 37 | """.format(base_model, self.modality, self.num_segments, self.new_length, consensus_type, self.dropout))) 38 | 39 | self._prepare_base_model(base_model) 40 | 41 | feature_dim = self._prepare_tsn(num_class) 42 | 43 | if self.modality == 'Flow': 44 | print("Converting the ImageNet model to a flow init model") 45 | self.base_model = self._construct_flow_model(self.base_model) 46 | print("Done. Flow model ready...") 47 | elif self.modality == 'RGBDiff': 48 | print("Converting the ImageNet model to RGB+Diff init model") 49 | self.base_model = self._construct_diff_model(self.base_model) 50 | print("Done. RGBDiff model ready.") 51 | 52 | self.consensus = ConsensusModule(consensus_type) 53 | 54 | if not self.before_softmax: 55 | self.softmax = nn.Softmax() 56 | 57 | self._enable_pbn = partial_bn 58 | if partial_bn: 59 | self.partialBN(True) 60 | 61 | def _prepare_tsn(self, num_class): 62 | feature_dim = getattr(self.base_model, self.base_model.last_layer_name).in_features 63 | if self.dropout == 0: 64 | setattr(self.base_model, self.base_model.last_layer_name, nn.Linear(feature_dim, num_class)) 65 | self.new_fc = None 66 | else: 67 | setattr(self.base_model, self.base_model.last_layer_name, nn.Dropout(p=self.dropout)) 68 | self.new_fc = nn.Linear(feature_dim, num_class) 69 | 70 | std = 0.001 71 | if self.new_fc is None: 72 | normal(getattr(self.base_model, self.base_model.last_layer_name).weight, 0, std) 73 | constant(getattr(self.base_model, self.base_model.last_layer_name).bias, 0) 74 | else: 75 | normal(self.new_fc.weight, 0, std) 76 | constant(self.new_fc.bias, 0) 77 | return feature_dim 78 | 79 | def _prepare_base_model(self, base_model): 80 | 81 | if 'resnet' in base_model or 'vgg' in base_model: 82 | self.base_model = getattr(torchvision.models, base_model)(True) 83 | self.base_model.last_layer_name = 'fc' 84 | self.input_size = 224 85 | self.input_mean = [0.485, 0.456, 0.406] 86 | self.input_std = [0.229, 0.224, 0.225] 87 | 88 | if self.modality == 'Flow': 89 | self.input_mean = [0.5] 90 | self.input_std = [np.mean(self.input_std)] 91 | elif self.modality == 'RGBDiff': 92 | self.input_mean = [0.485, 0.456, 0.406] + [0] * 3 * self.new_length 93 | self.input_std = self.input_std + [np.mean(self.input_std) * 2] * 3 * self.new_length 94 | elif base_model == 'BNInception': 95 | import tf_model_zoo 96 | self.base_model = getattr(tf_model_zoo, base_model)() 97 | self.base_model.last_layer_name = 'fc' 98 | self.input_size = 224 99 | self.input_mean = [104, 117, 128] 100 | self.input_std = [1] 101 | 102 | if self.modality == 'Flow': 103 | self.input_mean = [128] 104 | elif self.modality == 'RGBDiff': 105 | self.input_mean = self.input_mean * (1 + self.new_length) 106 | 107 | elif 'inception' in base_model: 108 | import tf_model_zoo 109 | self.base_model = getattr(tf_model_zoo, base_model)() 110 | self.base_model.last_layer_name = 'classif' 111 | self.input_size = 299 112 | self.input_mean = [0.5] 113 | self.input_std = [0.5] 114 | else: 115 | raise ValueError('Unknown base model: {}'.format(base_model)) 116 | 117 | def train(self, mode=True): 118 | """ 119 | Override the default train() to freeze the BN parameters 120 | :return: 121 | """ 122 | super(TSN, self).train(mode) 123 | count = 0 124 | if self._enable_pbn: 125 | print("Freezing BatchNorm2D except the first one.") 126 | for m in self.base_model.modules(): 127 | if isinstance(m, nn.BatchNorm2d): 128 | count += 1 129 | if count >= (2 if self._enable_pbn else 1): 130 | m.eval() 131 | 132 | # shutdown update in frozen mode 133 | m.weight.requires_grad = False 134 | m.bias.requires_grad = False 135 | 136 | def partialBN(self, enable): 137 | self._enable_pbn = enable 138 | 139 | def get_optim_policies(self): 140 | first_conv_weight = [] 141 | first_conv_bias = [] 142 | normal_weight = [] 143 | normal_bias = [] 144 | bn = [] 145 | 146 | conv_cnt = 0 147 | bn_cnt = 0 148 | for m in self.modules(): 149 | if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv1d): 150 | ps = list(m.parameters()) 151 | conv_cnt += 1 152 | if conv_cnt == 1: 153 | first_conv_weight.append(ps[0]) 154 | if len(ps) == 2: 155 | first_conv_bias.append(ps[1]) 156 | else: 157 | normal_weight.append(ps[0]) 158 | if len(ps) == 2: 159 | normal_bias.append(ps[1]) 160 | elif isinstance(m, torch.nn.Linear): 161 | ps = list(m.parameters()) 162 | normal_weight.append(ps[0]) 163 | if len(ps) == 2: 164 | normal_bias.append(ps[1]) 165 | 166 | elif isinstance(m, torch.nn.BatchNorm1d): 167 | bn.extend(list(m.parameters())) 168 | elif isinstance(m, torch.nn.BatchNorm2d): 169 | bn_cnt += 1 170 | # later BN's are frozen 171 | if not self._enable_pbn or bn_cnt == 1: 172 | bn.extend(list(m.parameters())) 173 | elif len(m._modules) == 0: 174 | if len(list(m.parameters())) > 0: 175 | raise ValueError("New atomic module type: {}. Need to give it a learning policy".format(type(m))) 176 | 177 | return [ 178 | {'params': first_conv_weight, 'lr_mult': 5 if self.modality == 'Flow' else 1, 'decay_mult': 1, 179 | 'name': "first_conv_weight"}, 180 | {'params': first_conv_bias, 'lr_mult': 10 if self.modality == 'Flow' else 2, 'decay_mult': 0, 181 | 'name': "first_conv_bias"}, 182 | {'params': normal_weight, 'lr_mult': 1, 'decay_mult': 1, 183 | 'name': "normal_weight"}, 184 | {'params': normal_bias, 'lr_mult': 2, 'decay_mult': 0, 185 | 'name': "normal_bias"}, 186 | {'params': bn, 'lr_mult': 1, 'decay_mult': 0, 187 | 'name': "BN scale/shift"}, 188 | ] 189 | 190 | def forward(self, input): 191 | sample_len = (3 if self.modality == "RGB" else 2) * self.new_length 192 | 193 | if self.modality == 'RGBDiff': 194 | sample_len = 3 * self.new_length 195 | input = self._get_diff(input) 196 | 197 | base_out = self.base_model(input.view((-1, sample_len) + input.size()[-2:])) 198 | 199 | if self.dropout > 0: 200 | base_out = self.new_fc(base_out) 201 | 202 | if not self.before_softmax: 203 | base_out = self.softmax(base_out) 204 | if self.reshape: 205 | base_out = base_out.view((-1, self.num_segments) + base_out.size()[1:]) 206 | 207 | output = self.consensus(base_out) 208 | return output.squeeze(1) 209 | 210 | def _get_diff(self, input, keep_rgb=False): 211 | input_c = 3 if self.modality in ["RGB", "RGBDiff"] else 2 212 | input_view = input.view((-1, self.num_segments, self.new_length + 1, input_c,) + input.size()[2:]) 213 | if keep_rgb: 214 | new_data = input_view.clone() 215 | else: 216 | new_data = input_view[:, :, 1:, :, :, :].clone() 217 | 218 | for x in reversed(list(range(1, self.new_length + 1))): 219 | if keep_rgb: 220 | new_data[:, :, x, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :] 221 | else: 222 | new_data[:, :, x - 1, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :] 223 | 224 | return new_data 225 | 226 | 227 | def _construct_flow_model(self, base_model): 228 | # modify the convolution layers 229 | # Torch models are usually defined in a hierarchical way. 230 | # nn.modules.children() return all sub modules in a DFS manner 231 | modules = list(self.base_model.modules()) 232 | first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0] 233 | conv_layer = modules[first_conv_idx] 234 | container = modules[first_conv_idx - 1] 235 | 236 | # modify parameters, assume the first blob contains the convolution kernels 237 | params = [x.clone() for x in conv_layer.parameters()] 238 | kernel_size = params[0].size() 239 | new_kernel_size = kernel_size[:1] + (2 * self.new_length, ) + kernel_size[2:] 240 | new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous() 241 | 242 | new_conv = nn.Conv2d(2 * self.new_length, conv_layer.out_channels, 243 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding, 244 | bias=True if len(params) == 2 else False) 245 | new_conv.weight.data = new_kernels 246 | if len(params) == 2: 247 | new_conv.bias.data = params[1].data # add bias if neccessary 248 | layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name 249 | 250 | # replace the first convlution layer 251 | setattr(container, layer_name, new_conv) 252 | return base_model 253 | 254 | def _construct_diff_model(self, base_model, keep_rgb=False): 255 | # modify the convolution layers 256 | # Torch models are usually defined in a hierarchical way. 257 | # nn.modules.children() return all sub modules in a DFS manner 258 | modules = list(self.base_model.modules()) 259 | first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0] 260 | conv_layer = modules[first_conv_idx] 261 | container = modules[first_conv_idx - 1] 262 | 263 | # modify parameters, assume the first blob contains the convolution kernels 264 | params = [x.clone() for x in conv_layer.parameters()] 265 | kernel_size = params[0].size() 266 | if not keep_rgb: 267 | new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:] 268 | new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous() 269 | else: 270 | new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:] 271 | new_kernels = torch.cat((params[0].data, params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()), 272 | 1) 273 | new_kernel_size = kernel_size[:1] + (3 + 3 * self.new_length,) + kernel_size[2:] 274 | 275 | new_conv = nn.Conv2d(new_kernel_size[1], conv_layer.out_channels, 276 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding, 277 | bias=True if len(params) == 2 else False) 278 | new_conv.weight.data = new_kernels 279 | if len(params) == 2: 280 | new_conv.bias.data = params[1].data # add bias if neccessary 281 | layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name 282 | 283 | # replace the first convolution layer 284 | setattr(container, layer_name, new_conv) 285 | return base_model 286 | 287 | @property 288 | def crop_size(self): 289 | return self.input_size 290 | 291 | @property 292 | def scale_size(self): 293 | return self.input_size * 256 // 224 294 | 295 | def get_augmentation(self): 296 | if self.modality == 'RGB': 297 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66]), 298 | GroupRandomHorizontalFlip(is_flow=False)]) 299 | elif self.modality == 'Flow': 300 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]), 301 | GroupRandomHorizontalFlip(is_flow=True)]) 302 | elif self.modality == 'RGBDiff': 303 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]), 304 | GroupRandomHorizontalFlip(is_flow=False)]) 305 | --------------------------------------------------------------------------------