├── .gitignore ├── .gitmodules ├── LICENSE ├── MLPmodule.py ├── README.md ├── dataset.py ├── datasets_video.py ├── images ├── motion_fused_frames.jpg └── network_arch.jpg ├── main.py ├── models.py ├── ops ├── __init__.py ├── basic_ops.py └── utils.py ├── opts.py ├── pretrained_models ├── MFF_jester_RGBFlow_BNInception_segment4_3f1c_best.pth.tar └── MFF_jester_RGBFlow_BNInception_segment8_3f1c_best.pth.tar ├── process_dataset.py ├── requirements.txt ├── test_models.py └── transforms.py /.gitignore: -------------------------------------------------------------------------------- 1 | # files types to exculde 2 | *.mp4 3 | *.h5 4 | 5 | jester/ 6 | #*.txt 7 | 8 | __pycache__ 9 | *.pyc 10 | 11 | model/ 12 | #*.pth.tar 13 | #*.pth 14 | 15 | log/ 16 | *.csv 17 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "model_zoo"] 2 | path = model_zoo 3 | url = https://github.com/yjxiong/tensorflow-model-zoo.torch.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License for Motion Fused Frames 2 | 3 | Copyright (c) 2017, Okan Köpüklü 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | 28 | BSD 2-Clause License for TSN-PyTorch 29 | 30 | Copyright (c) 2017, Multimedia Laboratary, The Chinese University of Hong Kong 31 | All rights reserved. 32 | 33 | Redistribution and use in source and binary forms, with or without 34 | modification, are permitted provided that the following conditions are met: 35 | 36 | * Redistributions of source code must retain the above copyright notice, this 37 | list of conditions and the following disclaimer. 38 | 39 | * Redistributions in binary form must reproduce the above copyright notice, 40 | this list of conditions and the following disclaimer in the documentation 41 | and/or other materials provided with the distribution. 42 | 43 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 44 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 45 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 46 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 47 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 48 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 49 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 50 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 51 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 52 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 53 | -------------------------------------------------------------------------------- /MLPmodule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class MLPmodule(torch.nn.Module): 5 | """ 6 | This is the 2-layer MLP implementation used for linking spatio-temporal 7 | features coming from different segments. 8 | """ 9 | def __init__(self, img_feature_dim, num_frames, num_class): 10 | super(MLPmodule, self).__init__() 11 | self.num_frames = num_frames 12 | self.num_class = num_class 13 | self.img_feature_dim = img_feature_dim 14 | self.num_bottleneck = 512 15 | self.classifier = nn.Sequential( 16 | nn.ReLU(), 17 | nn.Linear(self.num_frames * self.img_feature_dim, 18 | self.num_bottleneck), 19 | #nn.Dropout(0.90), # Add an extra DO if necess. 20 | nn.ReLU(), 21 | nn.Linear(self.num_bottleneck,self.num_class), 22 | ) 23 | def forward(self, input): 24 | input = input.view(input.size(0), self.num_frames*self.img_feature_dim) 25 | input = self.classifier(input) 26 | return input 27 | 28 | 29 | def return_MLP(relation_type, img_feature_dim, num_frames, num_class): 30 | MLPmodel = MLPmodule(img_feature_dim, num_frames, num_class) 31 | 32 | return MLPmodel 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Motion Fused Frames (MFFs) 2 | 3 | Pytorch implementation of the article [Motion fused frames: Data level fusion strategy for hand gesture recognition](http://openaccess.thecvf.com/content_cvpr_2018_workshops/papers/w41/Kopuklu_Motion_Fused_Frames_CVPR_2018_paper.pdf) 4 | 5 | ```diff 6 | - Update: Code is updated for Pytorch 1.5.0 and CUDA 10.2 7 | ``` 8 | 9 |

10 | 11 | 12 | ### Installation 13 | * Clone the repo with the following command: 14 | ```bash 15 | git clone https://github.com/okankop/MFF-pytorch.git 16 | ``` 17 | 18 | * Setup in virtual environment and install the requirements: 19 | ```bash 20 | conda create -n MFF python=3.7.4 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | ### Dataset Preparation 25 | Download the [jester dataset](https://www.twentybn.com/datasets/something-something) or [NVIDIA dynamic hand gestures dataset](http://research.nvidia.com/publication/online-detection-and-classification-dynamic-hand-gestures-recurrent-3d-convolutional) or [ChaLearn LAP IsoGD dataset](http://www.cbsr.ia.ac.cn/users/jwan/database/isogd.html). Decompress them into the same folder and use [process_dataset.py](process_dataset.py) to generate the index files for train, val, and test split. Poperly set up the train, validatin, and category meta files in [datasets_video.py](datasets_video.py). Finally, use directory [flow_computation](https://github.com/okankop/flow_computation) to calculate the optical flow images using Brox method. 26 | 27 | Assume the structure of data directories is the following: 28 | 29 | ```misc 30 | ~/MFF-pytorch/ 31 | datasets/ 32 | jester/ 33 | rgb/ 34 | .../ (directories of video samples) 35 | .../ (jpg color frames) 36 | flow/ 37 | u/ 38 | .../ (directories of video samples) 39 | .../ (jpg optical-flow-u frames) 40 | v/ 41 | .../ (directories of video samples) 42 | .../ (jpg optical-flow-v frames) 43 | model/ 44 | .../(saved models for the last checkpoint and best model) 45 | ``` 46 | 47 | 48 | ### Running the Code 49 | Followings are some examples for training under different scenarios: 50 | 51 | * Train 4-segment network with 3 flow, 1 color frames (4-MFFs-3f1c architecture) 52 | ```bash 53 | python main.py jester RGBFlow --arch BNInception --num_segments 4 \ 54 | --consensus_type MLP --num_motion 3 --batch-size 32 55 | ``` 56 | 57 | * Train resuming the last checkpoint (4-MFFs-3f1c architecture) 58 | ```bash 59 | python main.py jester RGBFlow --resume= --arch BNInception \ 60 | --consensus_type MLP --num_segments 4 --num_motion 3 --batch-size 32 61 | ``` 62 | 63 | * The command to test trained models (4-MFFs-3f1c architecture). Pretrained models are under [pretrained_models](pretrained_models). 64 | 65 | ```bash 66 | python test_models.py jester RGBFlow pretrained_models/MFF_jester_RGBFlow_BNInception_segment4_3f1c_best.pth.tar --arch BNInception --consensus_type MLP --test_crops 1 --num_motion 3 --test_segments 4 67 | ``` 68 | 69 | All GPUs are used for the training. If you want a part of GPUs, use CUDA_VISIBLE_DEVICES=... 70 | 71 | ### Citation 72 | If you use this code or pre-trained models, please cite the following: 73 | 74 | ```bibtex 75 | @InProceedings{Kopuklu_2018_CVPR_Workshops, 76 | author = {Kopuklu, Okan and Kose, Neslihan and Rigoll, Gerhard}, 77 | title = {Motion Fused Frames: Data Level Fusion Strategy for Hand Gesture Recognition}, 78 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 79 | month = {June}, 80 | year = {2018} 81 | } 82 | ``` 83 | 84 | ### Acknowledgement 85 | This project is built on top of the codebase [TSN-pytorch](https://github.com/yjxiong/temporal-segment-networks). We thank Yuanjun Xiong for releasing [TSN-Pytorch codebase](https://github.com/yjxiong/temporal-segment-networks), which we build our work on top. We also thank Bolei Zhou for the insprational work [Temporal Segment Networks](https://arxiv.org/pdf/1711.08496.pdf), from which we imported [process_dataset.py](https://github.com/metalbubble/TRN-pytorch/blob/master/process_dataset.py) to our project. 86 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | import random 4 | from PIL import Image 5 | import os 6 | import os.path 7 | import numpy as np 8 | from numpy.random import randint 9 | 10 | class VideoRecord(object): 11 | def __init__(self, row): 12 | self._data = row 13 | 14 | @property 15 | def path(self): 16 | return self._data[0] 17 | 18 | @property 19 | def num_frames(self): 20 | return int(self._data[1]) 21 | 22 | @property 23 | def label(self): 24 | return int(self._data[2]) 25 | 26 | 27 | class TSNDataSet(data.Dataset): 28 | def __init__(self, root_path, list_file, 29 | num_segments=3, new_length=1, modality='RGB', 30 | image_tmpl='img_{:05d}.jpg', transform=None, 31 | force_grayscale=False, random_shift=True, 32 | test_mode=False, dataset='jester'): 33 | 34 | self.root_path = root_path 35 | self.list_file = list_file 36 | self.num_segments = num_segments 37 | self.new_length = new_length 38 | self.modality = modality 39 | self.image_tmpl = image_tmpl 40 | self.transform = transform 41 | self.random_shift = random_shift 42 | self.test_mode = test_mode 43 | self.dataset = dataset 44 | 45 | if self.modality == 'RGBDiff' or self.modality == 'RGBFlow': 46 | self.new_length += 1# Diff needs one more image to calculate diff 47 | 48 | self._parse_list() 49 | 50 | def _load_image(self, directory, idx, isLast=False): 51 | if self.modality == 'RGB' or self.modality == 'RGBDiff': 52 | try: 53 | return [Image.open(os.path.join(self.root_path, "rgb", directory, self.image_tmpl.format(idx))).convert('RGB')] 54 | except Exception: 55 | print('error loading image:', os.path.join(self.root_path, "rgb", directory, self.image_tmpl.format(idx))) 56 | return [Image.open(os.path.join(self.root_path, "rgb", directory, self.image_tmpl.format(1))).convert('RGB')] 57 | 58 | elif self.modality == 'Flow': 59 | try: 60 | x_img = Image.open(os.path.join(self.root_path, "flow/u", directory, self.image_tmpl.format(idx))).convert('L') 61 | y_img = Image.open(os.path.join(self.root_path, "flow/v", directory, self.image_tmpl.format(idx))).convert('L') 62 | except Exception: 63 | print('error loading flow file:', os.path.join(self.root_path, "flow/v", directory, self.image_tmpl.format(idx))) 64 | x_img = Image.open(os.path.join(self.root_path, "flow/u", directory, self.image_tmpl.format(1))).convert('L') 65 | y_img = Image.open(os.path.join(self.root_path, "flow/v", directory, self.image_tmpl.format(1))).convert('L') 66 | return [x_img, y_img] 67 | 68 | elif self.modality == 'RGBFlow': 69 | if isLast: 70 | return [Image.open(os.path.join(self.root_path, "rgb", directory, self.image_tmpl.format(idx))).convert('RGB')] 71 | else: 72 | x_img = Image.open(os.path.join(self.root_path, "flow/u", directory, self.image_tmpl.format(idx))).convert('L') 73 | y_img = Image.open(os.path.join(self.root_path, "flow/v", directory, self.image_tmpl.format(idx))).convert('L') 74 | return [x_img, y_img] 75 | 76 | 77 | def _parse_list(self): 78 | # check the frame number is large >3: 79 | # usualy it is [video_id, num_frames, class_idx] 80 | tmp = [x.strip().split(' ') for x in open(self.list_file)] 81 | tmp = [item for item in tmp if int(item[1])>=3] 82 | self.video_list = [VideoRecord(item) for item in tmp] 83 | print('video number:%d'%(len(self.video_list))) 84 | 85 | def _sample_indices(self, record): 86 | """ 87 | 88 | :param record: VideoRecord 89 | :return: list 90 | """ 91 | average_duration = (record.num_frames - self.new_length + 1) // self.num_segments 92 | 93 | if average_duration > 0: 94 | offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration, size=self.num_segments) 95 | elif record.num_frames > self.num_segments: 96 | offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments)) 97 | else: 98 | offsets = np.zeros((self.num_segments,)) 99 | return offsets + 1 100 | 101 | def _get_val_indices(self, record): 102 | if record.num_frames > self.num_segments + self.new_length - 1: 103 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) 104 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) 105 | else: 106 | offsets = np.zeros((self.num_segments,)) 107 | return offsets + 1 108 | 109 | def _get_test_indices(self, record): 110 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) 111 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) 112 | return offsets + 1 113 | 114 | def __getitem__(self, index): 115 | record = self.video_list[index] 116 | # check this is a legit video folder 117 | if self.modality == 'RGBFlow': 118 | while not os.path.exists(os.path.join(self.root_path, "rgb", record.path, self.image_tmpl.format(1))): 119 | index = np.random.randint(len(self.video_list)) 120 | record = self.video_list[index] 121 | else: 122 | while not os.path.exists(os.path.join(self.root_path, "rgb", record.path, self.image_tmpl.format(1))): 123 | index = np.random.randint(len(self.video_list)) 124 | record = self.video_list[index] 125 | 126 | if not self.test_mode: 127 | segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record) 128 | else: 129 | segment_indices = self._get_test_indices(record) 130 | 131 | return self.get(record, segment_indices) 132 | 133 | def get(self, record, indices): 134 | images = list() 135 | for seg_ind in indices: 136 | p = int(seg_ind) 137 | for i in range(self.new_length): 138 | if self.modality == 'RGBFlow': 139 | if i == self.new_length - 1: 140 | seg_imgs = self._load_image(record.path, p, True) 141 | else: 142 | if p == record.num_frames: 143 | seg_imgs = self._load_image(record.path, p-1) 144 | else: 145 | seg_imgs = self._load_image(record.path, p) 146 | else: 147 | seg_imgs = self._load_image(record.path, p) 148 | 149 | images.extend(seg_imgs) 150 | if p < record.num_frames: 151 | p += 1 152 | 153 | process_data = self.transform(images) 154 | return process_data, record.label 155 | 156 | def __len__(self): 157 | return len(self.video_list) 158 | -------------------------------------------------------------------------------- /datasets_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | import torchvision.datasets as datasets 5 | 6 | 7 | ROOT_DATASET= '/usr/home/kop/MFF-pytorch' 8 | 9 | def return_jester(modality): 10 | filename_categories = 'jester/category.txt' 11 | filename_imglist_train = 'jester/train_videofolder.txt' 12 | filename_imglist_val = 'jester/val_videofolder.txt' 13 | if modality == 'RGB': 14 | prefix = '{:05d}.jpg' 15 | root_data = '/usr/home/kop/MFF-pytorch/datasets/jester' 16 | elif modality == 'RGBFlow': 17 | prefix = '{:05d}.jpg' 18 | root_data = '/usr/home/kop/MFF-pytorch/datasets/jester' 19 | else: 20 | print('no such modality:'+modality) 21 | os.exit() 22 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 23 | 24 | def return_nvgesture(modality): 25 | filename_categories = 'nvgesture/category.txt' 26 | filename_imglist_train = 'nvgesture/train_videofolder.txt' 27 | filename_imglist_val = 'nvgesture/val_videofolder.txt' 28 | if modality == 'RGB': 29 | prefix = '{:05d}.jpg' 30 | root_data = '/data2/nvGesture' 31 | elif modality == 'RGBFlow': 32 | prefix = '{:05d}.jpg' 33 | root_data = '/data2/nvGesture' 34 | else: 35 | print('no such modality:'+modality) 36 | os.exit() 37 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 38 | 39 | def return_chalearn(modality): 40 | filename_categories = 'chalearn/category.txt' 41 | filename_imglist_train = 'chalearn/train_videofolder.txt' 42 | filename_imglist_val = 'chalearn/val_videofolder.txt' 43 | #filename_imglist_val = 'chalearn/test_videofolder.txt' 44 | if modality == 'RGB': 45 | prefix = '{:05d}.jpg' 46 | root_data = '/data2/ChaLearn' 47 | elif modality == 'RGBFlow': 48 | prefix = '{:05d}.jpg' 49 | root_data = '/data2/ChaLearn' 50 | else: 51 | print('no such modality:'+modality) 52 | os.exit() 53 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 54 | 55 | def return_dataset(dataset, modality): 56 | dict_single = {'jester':return_jester, 'nvgesture': return_nvgesture, 'chalearn': return_chalearn} 57 | if dataset in dict_single: 58 | file_categories, file_imglist_train, file_imglist_val, root_data, prefix = dict_single[dataset](modality) 59 | else: 60 | raise ValueError('Unknown dataset '+dataset) 61 | 62 | file_imglist_train = os.path.join(ROOT_DATASET, file_imglist_train) 63 | file_imglist_val = os.path.join(ROOT_DATASET, file_imglist_val) 64 | file_categories = os.path.join(ROOT_DATASET, file_categories) 65 | with open(file_categories) as f: 66 | lines = f.readlines() 67 | categories = [item.rstrip() for item in lines] 68 | return categories, file_imglist_train, file_imglist_val, root_data, prefix 69 | 70 | -------------------------------------------------------------------------------- /images/motion_fused_frames.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/okankop/MFF-pytorch/77bb9b14d294cad83f07c6394aeee14780950edb/images/motion_fused_frames.jpg -------------------------------------------------------------------------------- /images/network_arch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/okankop/MFF-pytorch/77bb9b14d294cad83f07c6394aeee14780950edb/images/network_arch.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import shutil 5 | import torch 6 | import torchvision 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | from torch.autograd import Variable 11 | from torch.nn.utils import clip_grad_norm 12 | from torch.utils.data.sampler import SequentialSampler 13 | 14 | from dataset import TSNDataSet 15 | from models import TSN 16 | from transforms import * 17 | from opts import parser 18 | import datasets_video 19 | 20 | 21 | best_prec1 = 0 22 | 23 | def main(): 24 | global args, best_prec1 25 | args = parser.parse_args() 26 | check_rootfolders() 27 | 28 | categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(args.dataset, args.modality) 29 | num_class = len(categories) 30 | 31 | 32 | args.store_name = '_'.join(['MFF', args.dataset, args.modality, args.arch, 33 | 'segment%d'% args.num_segments, '%df1c'% args.num_motion]) 34 | print('storing name: ' + args.store_name) 35 | 36 | model = TSN(num_class, args.num_segments, args.modality, 37 | base_model=args.arch, 38 | consensus_type=args.consensus_type, 39 | dropout=args.dropout, num_motion=args.num_motion, 40 | img_feature_dim=args.img_feature_dim, 41 | partial_bn=not args.no_partialbn, 42 | dataset=args.dataset) 43 | 44 | crop_size = model.crop_size 45 | scale_size = model.scale_size 46 | input_mean = model.input_mean 47 | input_std = model.input_std 48 | train_augmentation = model.get_augmentation() 49 | 50 | policies = model.get_optim_policies() 51 | model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() 52 | 53 | if args.resume: 54 | if os.path.isfile(args.resume): 55 | print(("=> loading checkpoint '{}'".format(args.resume))) 56 | checkpoint = torch.load(args.resume) 57 | args.start_epoch = checkpoint['epoch'] 58 | best_prec1 = checkpoint['best_prec1'] 59 | model.load_state_dict(checkpoint['state_dict']) 60 | print(("=> loaded checkpoint '{}' (epoch {})" 61 | .format(args.evaluate, checkpoint['epoch']))) 62 | else: 63 | print(("=> no checkpoint found at '{}'".format(args.resume))) 64 | 65 | print(model) 66 | cudnn.benchmark = True 67 | 68 | # Data loading code 69 | if ((args.modality != 'RGBDiff') | (args.modality != 'RGBFlow')): 70 | normalize = GroupNormalize(input_mean, input_std) 71 | else: 72 | normalize = IdentityTransform() 73 | 74 | if args.modality == 'RGB': 75 | data_length = 1 76 | elif args.modality in ['Flow', 'RGBDiff']: 77 | data_length = 5 78 | elif args.modality == 'RGBFlow': 79 | data_length = args.num_motion 80 | 81 | train_loader = torch.utils.data.DataLoader( 82 | TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments, 83 | new_length=data_length, 84 | modality=args.modality, 85 | image_tmpl=prefix, 86 | dataset=args.dataset, 87 | transform=torchvision.transforms.Compose([ 88 | train_augmentation, 89 | Stack(roll=(args.arch in ['BNInception','InceptionV3']), isRGBFlow = (args.modality == 'RGBFlow')), 90 | ToTorchFormatTensor(div=(args.arch not in ['BNInception','InceptionV3'])), 91 | normalize, 92 | ])), 93 | batch_size=args.batch_size, shuffle=True, 94 | num_workers=args.workers, pin_memory=False) 95 | 96 | val_loader = torch.utils.data.DataLoader( 97 | TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments, 98 | new_length=data_length, 99 | modality=args.modality, 100 | image_tmpl=prefix, 101 | dataset=args.dataset, 102 | random_shift=False, 103 | transform=torchvision.transforms.Compose([ 104 | GroupScale(int(scale_size)), 105 | GroupCenterCrop(crop_size), 106 | Stack(roll=(args.arch in ['BNInception','InceptionV3']), isRGBFlow = (args.modality == 'RGBFlow')), 107 | ToTorchFormatTensor(div=(args.arch not in ['BNInception','InceptionV3'])), 108 | normalize, 109 | ])), 110 | batch_size=args.batch_size, shuffle=False, 111 | num_workers=args.workers, pin_memory=False) 112 | 113 | # define loss function (criterion) and optimizer 114 | if args.loss_type == 'nll': 115 | criterion = torch.nn.CrossEntropyLoss().cuda() 116 | else: 117 | raise ValueError("Unknown loss type") 118 | 119 | for group in policies: 120 | print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format( 121 | group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) 122 | 123 | optimizer = torch.optim.SGD(policies, 124 | args.lr, 125 | momentum=args.momentum, 126 | weight_decay=args.weight_decay) 127 | 128 | if args.evaluate: 129 | validate(val_loader, model, criterion, 0) 130 | return 131 | 132 | log_training = open(os.path.join(args.root_log, '%s.csv' % args.store_name), 'w') 133 | for epoch in range(args.start_epoch, args.epochs): 134 | adjust_learning_rate(optimizer, epoch, args.lr_steps) 135 | 136 | # train for one epoch 137 | train(train_loader, model, criterion, optimizer, epoch, log_training) 138 | 139 | # evaluate on validation set 140 | if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: 141 | prec1 = validate(val_loader, model, criterion, (epoch + 1) * len(train_loader), log_training) 142 | 143 | # remember best prec@1 and save checkpoint 144 | is_best = prec1 > best_prec1 145 | best_prec1 = max(prec1, best_prec1) 146 | save_checkpoint({ 147 | 'epoch': epoch + 1, 148 | 'arch': args.arch, 149 | 'state_dict': model.state_dict(), 150 | 'best_prec1': best_prec1, 151 | }, is_best) 152 | 153 | 154 | def train(train_loader, model, criterion, optimizer, epoch, log): 155 | batch_time = AverageMeter() 156 | data_time = AverageMeter() 157 | losses = AverageMeter() 158 | top1 = AverageMeter() 159 | top5 = AverageMeter() 160 | 161 | if args.no_partialbn: 162 | model.module.partialBN(False) 163 | else: 164 | model.module.partialBN(True) 165 | 166 | # switch to train mode 167 | model.train() 168 | 169 | end = time.time() 170 | for i, (input, target) in enumerate(train_loader): 171 | # measure data loading time 172 | data_time.update(time.time() - end) 173 | 174 | target = target.cuda() 175 | input_var = Variable(input) 176 | target_var = Variable(target) 177 | 178 | # compute output 179 | output = model(input_var) 180 | loss = criterion(output, target_var) 181 | 182 | # measure accuracy and record loss 183 | prec1, prec5 = accuracy(output.data, target, topk=(1,5)) 184 | losses.update(loss.data[0], input.size(0)) 185 | top1.update(prec1[0], input.size(0)) 186 | top5.update(prec5[0], input.size(0)) 187 | 188 | 189 | # compute gradient and do SGD step 190 | optimizer.zero_grad() 191 | 192 | loss.backward() 193 | 194 | if args.clip_gradient is not None: 195 | total_norm = clip_grad_norm(model.parameters(), args.clip_gradient) 196 | # if total_norm > args.clip_gradient: 197 | # print("clipping gradient: {} with coef {}".format(total_norm, args.clip_gradient / total_norm)) 198 | 199 | optimizer.step() 200 | 201 | # measure elapsed time 202 | batch_time.update(time.time() - end) 203 | end = time.time() 204 | 205 | if i % args.print_freq == 0: 206 | output = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t' 207 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 208 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 209 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 210 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 211 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 212 | epoch, i, len(train_loader), batch_time=batch_time, 213 | data_time=data_time, loss=losses, top1=top1, top5=top5, lr=optimizer.param_groups[-1]['lr'])) 214 | print(output) 215 | log.write(output + '\n') 216 | log.flush() 217 | 218 | 219 | 220 | def validate(val_loader, model, criterion, iter, log): 221 | batch_time = AverageMeter() 222 | losses = AverageMeter() 223 | top1 = AverageMeter() 224 | top5 = AverageMeter() 225 | 226 | # switch to evaluate mode 227 | model.eval() 228 | 229 | end = time.time() 230 | for i, (input, target) in enumerate(val_loader): 231 | target = target.cuda() 232 | with torch.no_grad(): 233 | input_var = Variable(input) 234 | target_var = Variable(target) 235 | 236 | # compute output 237 | output = model(input_var) 238 | loss = criterion(output, target_var) 239 | 240 | # measure accuracy and record loss 241 | prec1, prec5 = accuracy(output.data, target, topk=(1,5)) 242 | 243 | losses.update(loss.data[0], input.size(0)) 244 | top1.update(prec1[0], input.size(0)) 245 | top5.update(prec5[0], input.size(0)) 246 | 247 | # measure elapsed time 248 | batch_time.update(time.time() - end) 249 | end = time.time() 250 | 251 | if i % args.print_freq == 0: 252 | output = ('Test: [{0}/{1}]\t' 253 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 254 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 255 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 256 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 257 | i, len(val_loader), batch_time=batch_time, loss=losses, 258 | top1=top1, top5=top5)) 259 | print(output) 260 | log.write(output + '\n') 261 | log.flush() 262 | 263 | output = ('Testing Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}' 264 | .format(top1=top1, top5=top5, loss=losses)) 265 | print(output) 266 | output_best = '\nBest Prec@1: %.3f'%(best_prec1) 267 | print(output_best) 268 | log.write(output + ' ' + output_best + '\n') 269 | log.flush() 270 | 271 | return top1.avg 272 | 273 | 274 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 275 | torch.save(state, '%s/%s_checkpoint.pth.tar' % (args.root_model, args.store_name)) 276 | if is_best: 277 | shutil.copyfile('%s/%s_checkpoint.pth.tar' % (args.root_model, args.store_name),'%s/%s_best.pth.tar' % (args.root_model, args.store_name)) 278 | 279 | class AverageMeter(object): 280 | """Computes and stores the average and current value""" 281 | def __init__(self): 282 | self.reset() 283 | 284 | def reset(self): 285 | self.val = 0 286 | self.avg = 0 287 | self.sum = 0 288 | self.count = 0 289 | 290 | def update(self, val, n=1): 291 | self.val = val 292 | self.sum += val * n 293 | self.count += n 294 | self.avg = self.sum / self.count 295 | 296 | 297 | def adjust_learning_rate(optimizer, epoch, lr_steps): 298 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 299 | decay = 0.5 ** (sum(epoch >= np.array(lr_steps))) 300 | lr = args.lr * decay 301 | decay = args.weight_decay 302 | for param_group in optimizer.param_groups: 303 | param_group['lr'] = lr * param_group['lr_mult'] 304 | param_group['weight_decay'] = decay * param_group['decay_mult'] 305 | 306 | 307 | def accuracy(output, target, topk=(1,)): 308 | """Computes the precision@k for the specified values of k""" 309 | maxk = max(topk) 310 | batch_size = target.size(0) 311 | 312 | _, pred = output.topk(maxk, 1, True, True) 313 | pred = pred.t() 314 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 315 | 316 | res = [] 317 | for k in topk: 318 | correct_k = correct[:k].view(-1).float().sum(0) 319 | res.append(correct_k.mul_(100.0 / batch_size)) 320 | return res 321 | 322 | 323 | def count_parameters(model): 324 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 325 | 326 | 327 | def check_rootfolders(): 328 | """Create log and model folder""" 329 | folders_util = [args.root_log, args.root_model, args.root_output] 330 | for folder in folders_util: 331 | if not os.path.exists(folder): 332 | print('creating folder ' + folder) 333 | os.mkdir(folder) 334 | 335 | 336 | if __name__ == '__main__': 337 | main() 338 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from ops.basic_ops import ConsensusModule, Identity 4 | from transforms import * 5 | from torch.nn.init import normal, constant 6 | 7 | import pretrainedmodels 8 | import MLPmodule 9 | 10 | class TSN(nn.Module): 11 | def __init__(self, num_class, num_segments, modality, 12 | base_model='resnet101', new_length=None, 13 | consensus_type='avg', before_softmax=True, num_motion=3, 14 | dropout=0.8,img_feature_dim=256, dataset='jester', 15 | crop_num=1, partial_bn=True, print_spec=True): 16 | super(TSN, self).__init__() 17 | self.modality = modality 18 | self.num_segments = num_segments 19 | self.num_motion = num_motion 20 | self.reshape = True 21 | self.before_softmax = before_softmax 22 | self.dropout = dropout 23 | self.dataset = dataset 24 | self.crop_num = crop_num 25 | self.consensus_type = consensus_type 26 | self.img_feature_dim = img_feature_dim # the dimension of the CNN feature to represent each frame 27 | if not before_softmax and consensus_type != 'avg': 28 | raise ValueError("Only avg consensus can be used after Softmax") 29 | 30 | if new_length is None: 31 | if modality == "RGB": 32 | self.new_length = 1 33 | elif modality == "Flow": 34 | self.new_length = 5 35 | elif modality == "RGBFlow": 36 | #self.new_length = 1 37 | self.new_length = self.num_motion 38 | else: 39 | self.new_length = new_length 40 | if print_spec == True: 41 | print((""" 42 | Initializing TSN with base model: {}. 43 | TSN Configurations: 44 | input_modality: {} 45 | num_segments: {} 46 | new_length: {} 47 | consensus_module: {} 48 | dropout_ratio: {} 49 | img_feature_dim: {} 50 | """.format(base_model, self.modality, self.num_segments, self.new_length, consensus_type, self.dropout, self.img_feature_dim))) 51 | 52 | self._prepare_base_model(base_model) 53 | 54 | feature_dim = self._prepare_tsn(num_class) 55 | 56 | if self.modality == 'Flow': 57 | print("Converting the ImageNet model to a flow init model") 58 | self.base_model = self._construct_flow_model(self.base_model) 59 | print("Done. Flow model ready...") 60 | elif self.modality == 'RGBDiff': 61 | print("Converting the ImageNet model to RGB+Diff init model") 62 | self.base_model = self._construct_diff_model(self.base_model) 63 | print("Done. RGBDiff model ready.") 64 | elif self.modality == 'RGBFlow': 65 | print("Converting the ImageNet model to RGB+Flow init model") 66 | self.base_model = self._construct_rgbflow_model(self.base_model) 67 | print("Done. RGBFlow model ready.") 68 | if consensus_type == 'MLP': 69 | self.consensus = MLPmodule.return_MLP(consensus_type, self.img_feature_dim, self.num_segments, num_class) 70 | else: 71 | self.consensus = ConsensusModule(consensus_type) 72 | 73 | if not self.before_softmax: 74 | self.softmax = nn.Softmax() 75 | 76 | self._enable_pbn = partial_bn 77 | if partial_bn: 78 | self.partialBN(True) 79 | 80 | 81 | def _prepare_tsn(self, num_class): 82 | feature_dim = getattr(self.base_model, self.base_model.last_layer_name).in_features 83 | if self.dropout == 0: 84 | setattr(self.base_model, self.base_model.last_layer_name, nn.Linear(feature_dim, num_class)) 85 | self.new_fc = None 86 | else: 87 | setattr(self.base_model, self.base_model.last_layer_name, nn.Dropout(p=self.dropout)) 88 | if self.consensus_type == 'MLP': 89 | # set the MFFs feature dimension 90 | self.new_fc = nn.Linear(feature_dim, self.img_feature_dim) 91 | else: 92 | # the default consensus types in TSN 93 | self.new_fc = nn.Linear(feature_dim, num_class) 94 | 95 | std = 0.001 96 | if self.new_fc is None: 97 | normal(getattr(self.base_model, self.base_model.last_layer_name).weight, 0, std) 98 | constant(getattr(self.base_model, self.base_model.last_layer_name).bias, 0) 99 | else: 100 | normal(self.new_fc.weight, 0, std) 101 | constant(self.new_fc.bias, 0) 102 | return feature_dim 103 | 104 | def _prepare_base_model(self, base_model): 105 | 106 | if 'resnet' in base_model or 'vgg' in base_model or 'squeezenet1_1' in base_model: 107 | self.base_model = pretrainedmodels.__dict__[base_model](num_classes=1000, pretrained='imagenet') 108 | if base_model == 'squeezenet1_1': 109 | self.base_model = self.base_model.features 110 | self.base_model.last_layer_name = '12' 111 | else: 112 | self.base_model.last_layer_name = 'fc' 113 | self.input_size = 224 114 | self.input_mean = [0.485, 0.456, 0.406] 115 | self.input_std = [0.229, 0.224, 0.225] 116 | 117 | if self.modality == 'Flow': 118 | self.input_mean = [0.5] 119 | self.input_std = [np.mean(self.input_std)] 120 | elif self.modality == 'RGBDiff': 121 | self.input_mean = [0.485, 0.456, 0.406] + [0] * 3 * self.new_length 122 | self.input_std = self.input_std + [np.mean(self.input_std) * 2] * 3 * self.new_length 123 | elif base_model == 'BNInception': 124 | self.base_model = pretrainedmodels.__dict__['bninception'](num_classes=1000, pretrained='imagenet') 125 | self.base_model.last_layer_name = 'last_linear' 126 | self.input_size = 224 127 | self.input_mean = [104, 117, 128] 128 | self.input_std = [1] 129 | if self.modality == 'Flow': 130 | self.input_mean = [128] 131 | elif self.modality == 'RGBDiff': 132 | self.input_mean = self.input_mean * (1 + self.new_length) 133 | elif 'resnext101' in base_model: 134 | self.base_model = pretrainedmodels.__dict__[base_model](num_classes=1000, pretrained='imagenet') 135 | print(self.base_model) 136 | self.base_model.last_layer_name = 'last_linear' 137 | self.input_size = 224 138 | self.input_mean = [0.485, 0.456, 0.406] 139 | self.input_std = [0.229, 0.224, 0.225] 140 | if self.modality == 'Flow': 141 | self.input_mean = [128] 142 | elif self.modality == 'RGBDiff': 143 | self.input_mean = self.input_mean * (1 + self.new_length) 144 | else: 145 | raise ValueError('Unknown base model: {}'.format(base_model)) 146 | 147 | def train(self, mode=True): 148 | """ 149 | Override the default train() to freeze the BN parameters 150 | :return: 151 | """ 152 | super(TSN, self).train(mode) 153 | count = 0 154 | if self._enable_pbn: 155 | print("Freezing BatchNorm2D except the first one.") 156 | for m in self.base_model.modules(): 157 | if isinstance(m, nn.BatchNorm2d): 158 | count += 1 159 | if count >= (2 if self._enable_pbn else 1): 160 | m.eval() 161 | 162 | # shutdown update in frozen mode 163 | m.weight.requires_grad = False 164 | m.bias.requires_grad = False 165 | 166 | 167 | def partialBN(self, enable): 168 | self._enable_pbn = enable 169 | 170 | def get_optim_policies(self): 171 | first_conv_weight = [] 172 | first_conv_bias = [] 173 | normal_weight = [] 174 | normal_bias = [] 175 | bn = [] 176 | 177 | conv_cnt = 0 178 | bn_cnt = 0 179 | for m in self.modules(): 180 | if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv1d): 181 | ps = list(m.parameters()) 182 | conv_cnt += 1 183 | if conv_cnt == 1: 184 | first_conv_weight.append(ps[0]) 185 | if len(ps) == 2: 186 | first_conv_bias.append(ps[1]) 187 | else: 188 | normal_weight.append(ps[0]) 189 | if len(ps) == 2: 190 | normal_bias.append(ps[1]) 191 | elif isinstance(m, torch.nn.Linear): 192 | ps = list(m.parameters()) 193 | normal_weight.append(ps[0]) 194 | if len(ps) == 2: 195 | normal_bias.append(ps[1]) 196 | 197 | elif isinstance(m, torch.nn.BatchNorm1d): 198 | bn.extend(list(m.parameters())) 199 | elif isinstance(m, torch.nn.BatchNorm2d): 200 | bn_cnt += 1 201 | # later BN's are frozen 202 | if not self._enable_pbn or bn_cnt == 1: 203 | bn.extend(list(m.parameters())) 204 | elif len(m._modules) == 0: 205 | if len(list(m.parameters())) > 0: 206 | raise ValueError("New atomic module type: {}. Need to give it a learning policy".format(type(m))) 207 | 208 | return [ 209 | {'params': first_conv_weight, 'lr_mult': 5 if self.modality == 'Flow' else 1, 'decay_mult': 1, 210 | 'name': "first_conv_weight"}, 211 | {'params': first_conv_bias, 'lr_mult': 10 if self.modality == 'Flow' else 2, 'decay_mult': 0, 212 | 'name': "first_conv_bias"}, 213 | {'params': normal_weight, 'lr_mult': 1, 'decay_mult': 1, 214 | 'name': "normal_weight"}, 215 | {'params': normal_bias, 'lr_mult': 2, 'decay_mult': 0, 216 | 'name': "normal_bias"}, 217 | {'params': bn, 'lr_mult': 1, 'decay_mult': 0, 218 | 'name': "BN scale/shift"}, 219 | ] 220 | 221 | def forward(self, input): 222 | sample_len = (3 if self.modality == "RGB" else 2) * self.new_length 223 | 224 | if self.modality == 'RGBDiff': 225 | sample_len = 3 * self.new_length 226 | input = self._get_diff(input) 227 | 228 | if self.modality == 'RGBFlow': 229 | sample_len = 3 + 2 * self.new_length 230 | 231 | base_out = self.base_model(input.view((-1, sample_len) + input.size()[-2:])) 232 | 233 | if self.dropout > 0: 234 | base_out = self.new_fc(base_out) 235 | 236 | if not self.before_softmax: 237 | base_out = self.softmax(base_out) 238 | if self.reshape: 239 | base_out = base_out.view((-1, self.num_segments) + base_out.size()[1:]) 240 | 241 | output = self.consensus(base_out) 242 | return output.squeeze(1) 243 | 244 | def _get_diff(self, input, keep_rgb=False): 245 | input_c = 3 if self.modality in ["RGB", "RGBDiff"] else 2 246 | input_view = input.view((-1, self.num_segments, self.new_length + 1, input_c,) + input.size()[2:]) 247 | if keep_rgb: 248 | new_data = input_view.clone() 249 | else: 250 | new_data = input_view[:, :, 1:, :, :, :].clone() 251 | 252 | for x in reversed(list(range(1, self.new_length + 1))): 253 | if keep_rgb: 254 | new_data[:, :, x, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :] 255 | else: 256 | new_data[:, :, x - 1, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :] 257 | 258 | return new_data 259 | 260 | """ # There is no need now!! 261 | def _get_rgbflow(self, input): 262 | input_c = 3 + 2 * self.new_length # 3 is rgb channels, and 2 is coming for x & y channels of opt.flow 263 | input_view = input.view((-1, self.num_segments, self.new_length + 1, input_c,) + input.size()[2:]) 264 | new_data = input_view.clone() 265 | return new_data 266 | """ 267 | 268 | def _construct_rgbflow_model(self, base_model): 269 | # modify the convolution layers 270 | # Torch models are usually defined in a hierarchical way. 271 | # nn.modules.children() return all sub modules in a DFS manner 272 | modules = list(self.base_model.modules()) 273 | filter_conv2d = filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))) 274 | first_conv_idx = next(filter_conv2d) 275 | conv_layer = modules[first_conv_idx] 276 | container = modules[first_conv_idx - 1] 277 | 278 | # modify parameters, assume the first blob contains the convolution kernels 279 | params = [x.clone() for x in conv_layer.parameters()] 280 | kernel_size = params[0].size() 281 | new_kernel_size = kernel_size[:1] + (2 * self.new_length,) + kernel_size[2:] 282 | new_kernels = torch.cat((params[0].data.mean(dim=1,keepdim=True).expand(new_kernel_size).contiguous(), params[0].data), 1) # NOTE: Concatanating might be other way around. Check it! 283 | new_kernel_size = kernel_size[:1] + (3 + 2 * self.new_length,) + kernel_size[2:] 284 | 285 | new_conv = nn.Conv2d(new_kernel_size[1], conv_layer.out_channels, 286 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding, 287 | bias=True if len(params) == 2 else False) 288 | new_conv.weight.data = new_kernels 289 | if len(params) == 2: 290 | new_conv.bias.data = params[1].data # add bias if neccessary 291 | layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name 292 | 293 | # replace the first convolution layer 294 | setattr(container, layer_name, new_conv) 295 | return base_model 296 | 297 | def _construct_flow_model(self, base_model): 298 | # modify the convolution layers 299 | # Torch models are usually defined in a hierarchical way. 300 | # nn.modules.children() return all sub modules in a DFS manner 301 | modules = list(self.base_model.modules()) 302 | first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0] 303 | conv_layer = modules[first_conv_idx] 304 | container = modules[first_conv_idx - 1] 305 | 306 | # modify parameters, assume the first blob contains the convolution kernels 307 | params = [x.clone() for x in conv_layer.parameters()] 308 | kernel_size = params[0].size() 309 | print(kernel_size) 310 | new_kernel_size = kernel_size[:1] + (2 * self.new_length, ) + kernel_size[2:] 311 | print(new_kernel_size) 312 | new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous() 313 | 314 | new_conv = nn.Conv2d(2 * self.new_length, conv_layer.out_channels, 315 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding, 316 | bias=True if len(params) == 2 else False) 317 | new_conv.weight.data = new_kernels 318 | if len(params) == 2: 319 | new_conv.bias.data = params[1].data # add bias if neccessary 320 | layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name 321 | 322 | # replace the first convlution layer 323 | setattr(container, layer_name, new_conv) 324 | return base_model 325 | 326 | def _construct_diff_model(self, base_model, keep_rgb=False): 327 | # modify the convolution layers 328 | # Torch models are usually defined in a hierarchical way. 329 | # nn.modules.children() return all sub modules in a DFS manner 330 | modules = list(self.base_model.modules()) 331 | first_conv_idx = filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules))))[0] 332 | conv_layer = modules[first_conv_idx] 333 | container = modules[first_conv_idx - 1] 334 | 335 | # modify parameters, assume the first blob contains the convolution kernels 336 | params = [x.clone() for x in conv_layer.parameters()] 337 | kernel_size = params[0].size() 338 | if not keep_rgb: 339 | new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:] 340 | new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous() 341 | else: 342 | new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:] 343 | new_kernels = torch.cat((params[0].data, params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()), 344 | 1) 345 | new_kernel_size = kernel_size[:1] + (3 + 3 * self.new_length,) + kernel_size[2:] 346 | 347 | new_conv = nn.Conv2d(new_kernel_size[1], conv_layer.out_channels, 348 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding, 349 | bias=True if len(params) == 2 else False) 350 | new_conv.weight.data = new_kernels 351 | if len(params) == 2: 352 | new_conv.bias.data = params[1].data # add bias if neccessary 353 | layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name 354 | 355 | # replace the first convolution layer 356 | setattr(container, layer_name, new_conv) 357 | return base_model 358 | 359 | @property 360 | def crop_size(self): 361 | return self.input_size 362 | 363 | @property 364 | def scale_size(self): 365 | return self.input_size * 256 // 224 366 | 367 | def get_augmentation(self): 368 | if self.modality == 'RGB': 369 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66]), 370 | GroupRandomHorizontalFlip(is_flow=False)]) 371 | elif self.modality == 'Flow': 372 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]), 373 | GroupRandomHorizontalFlip(is_flow=True)]) 374 | elif self.modality == 'RGBDiff': 375 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]), 376 | GroupRandomHorizontalFlip(is_flow=False)]) 377 | elif self.modality == 'RGBFlow': 378 | return torchvision.transforms.Compose([GroupMultiScaleResize(0.2), 379 | GroupMultiScaleRotate(20), 380 | #GroupSpatialElasticDisplacement(), 381 | GroupMultiScaleCrop(self.input_size, 382 | [1, .875, 383 | .75, 384 | .66]), 385 | #GroupRandomHorizontalFlip(is_flow=False) 386 | ]) 387 | 388 | -------------------------------------------------------------------------------- /ops/__init__.py: -------------------------------------------------------------------------------- 1 | from ops.basic_ops import * -------------------------------------------------------------------------------- /ops/basic_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | class Identity(torch.nn.Module): 6 | def forward(self, input): 7 | return input 8 | 9 | 10 | class SegmentConsensus(torch.autograd.Function): 11 | 12 | def __init__(self, consensus_type, dim=1): 13 | self.consensus_type = consensus_type 14 | self.dim = dim 15 | self.shape = None 16 | 17 | def forward(self, input_tensor): 18 | self.shape = input_tensor.size() 19 | if self.consensus_type == 'avg': 20 | output = input_tensor.mean(dim=self.dim, keepdim=True) 21 | elif self.consensus_type == 'identity': 22 | output = input_tensor 23 | else: 24 | output = None 25 | 26 | return output 27 | 28 | def backward(self, grad_output): 29 | if self.consensus_type == 'avg': 30 | grad_in = grad_output.expand(self.shape) / float(self.shape[self.dim]) 31 | elif self.consensus_type == 'identity': 32 | grad_in = grad_output 33 | else: 34 | grad_in = None 35 | 36 | return grad_in 37 | 38 | 39 | class ConsensusModule(torch.nn.Module): 40 | 41 | def __init__(self, consensus_type, dim=1): 42 | super(ConsensusModule, self).__init__() 43 | self.consensus_type = consensus_type if consensus_type != 'rnn' else 'identity' 44 | self.dim = dim 45 | 46 | def forward(self, input): 47 | return SegmentConsensus(self.consensus_type, self.dim)(input) 48 | -------------------------------------------------------------------------------- /ops/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.metrics import confusion_matrix 4 | 5 | def get_grad_hook(name): 6 | def hook(m, grad_in, grad_out): 7 | print((name, grad_out[0].data.abs().mean(), grad_in[0].data.abs().mean())) 8 | print((grad_out[0].size())) 9 | print((grad_in[0].size())) 10 | 11 | print((grad_out[0])) 12 | print((grad_in[0])) 13 | 14 | return hook 15 | 16 | 17 | def softmax(scores): 18 | es = np.exp(scores - scores.max(axis=-1)[..., None]) 19 | return es / es.sum(axis=-1)[..., None] 20 | 21 | 22 | def log_add(log_a, log_b): 23 | return log_a + np.log(1 + np.exp(log_b - log_a)) 24 | 25 | 26 | def class_accuracy(prediction, label): 27 | cf = confusion_matrix(prediction, label) 28 | cls_cnt = cf.sum(axis=1) 29 | cls_hit = np.diag(cf) 30 | 31 | cls_acc = cls_hit / cls_cnt.astype(float) 32 | 33 | mean_cls_acc = cls_acc.mean() 34 | 35 | return cls_acc, mean_cls_acc -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | parser = argparse.ArgumentParser(description="PyTorch implementation of Temporal Segment Networks") 3 | parser.add_argument('dataset', type=str, choices=['jester', 'nvgesture', 'chalearn']) 4 | parser.add_argument('modality', type=str, choices=['RGB', 'Flow', 'RGBDiff', 'RGBFlow']) 5 | parser.add_argument('--train_list', type=str,default="") 6 | parser.add_argument('--val_list', type=str, default="") 7 | parser.add_argument('--root_path', type=str, default="") 8 | parser.add_argument('--store_name', type=str, default="") 9 | # ========================= Model Configs ========================== 10 | parser.add_argument('--arch', type=str, default="BNInception") 11 | parser.add_argument('--num_segments', type=int, default=4) 12 | parser.add_argument('--num_motion', type=int, default=3) 13 | parser.add_argument('--consensus_type', type=str, default='avg') 14 | parser.add_argument('--k', type=int, default=3) 15 | 16 | parser.add_argument('--dropout', '--do', default=0.8, type=float, 17 | metavar='DO', help='dropout ratio (default: 0.5)') 18 | parser.add_argument('--loss_type', type=str, default="nll", 19 | choices=['nll']) 20 | parser.add_argument('--img_feature_dim', default=256, type=int, help="the feature dimension for each frame") 21 | 22 | # ========================= Learning Configs ========================== 23 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 24 | help='number of total epochs to run') 25 | parser.add_argument('-b', '--batch-size', default=128, type=int, 26 | metavar='N', help='mini-batch size (default: 256)') 27 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, 28 | metavar='LR', help='initial learning rate') 29 | parser.add_argument('--lr_steps', default=[25, 40], type=float, nargs="+", 30 | metavar='LRSteps', help='epochs to decay learning rate by 10') 31 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 32 | help='momentum') 33 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 34 | metavar='W', help='weight decay (default: 5e-4)') 35 | parser.add_argument('--clip-gradient', '--gd', default=20, type=float, 36 | metavar='W', help='gradient norm clipping (default: disabled)') 37 | parser.add_argument('--no_partialbn', '--npb', default=False, action="store_true") 38 | 39 | 40 | # ========================= Monitor Configs ========================== 41 | parser.add_argument('--print-freq', '-p', default=10, type=int, 42 | metavar='N', help='print frequency (default: 10)') 43 | parser.add_argument('--eval-freq', '-ef', default=1, type=int, 44 | metavar='N', help='evaluation frequency (default: 5)') 45 | 46 | 47 | # ========================= Runtime Configs ========================== 48 | parser.add_argument('-j', '--workers', default=30, type=int, metavar='N', 49 | help='number of data loading workers (default: 4)') 50 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 51 | help='path to latest checkpoint (default: none)') 52 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 53 | help='evaluate model on validation set') 54 | parser.add_argument('--snapshot_pref', type=str, default="") 55 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 56 | help='manual epoch number (useful on restarts)') 57 | parser.add_argument('--gpus', nargs='+', type=int, default=None) 58 | parser.add_argument('--flow_prefix', default="", type=str) 59 | parser.add_argument('--root_log',type=str, default='log') 60 | parser.add_argument('--root_model', type=str, default='model') 61 | parser.add_argument('--root_output',type=str, default='output') 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /pretrained_models/MFF_jester_RGBFlow_BNInception_segment4_3f1c_best.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/okankop/MFF-pytorch/77bb9b14d294cad83f07c6394aeee14780950edb/pretrained_models/MFF_jester_RGBFlow_BNInception_segment4_3f1c_best.pth.tar -------------------------------------------------------------------------------- /pretrained_models/MFF_jester_RGBFlow_BNInception_segment8_3f1c_best.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/okankop/MFF-pytorch/77bb9b14d294cad83f07c6394aeee14780950edb/pretrained_models/MFF_jester_RGBFlow_BNInception_segment8_3f1c_best.pth.tar -------------------------------------------------------------------------------- /process_dataset.py: -------------------------------------------------------------------------------- 1 | # This code hase been acquired from TRN-pytorch repository 2 | # 'https://github.com/metalbubble/TRN-pytorch/blob/master/process_dataset.py' 3 | # which is prepared by Bolei Zhou 4 | # 5 | # Processing the raw dataset of Jester 6 | # 7 | # generate the meta files: 8 | # category.txt: the list of categories. 9 | # train_videofolder.txt: each row contains [videoname num_frames classIDX] 10 | # val_videofolder.txt: same as above 11 | # 12 | # Created by Bolei Zhou, Dec.2 2017 13 | 14 | import os 15 | import pdb 16 | dataset_name = 'jester-v1' 17 | with open('%s-labels.csv'% dataset_name) as f: 18 | lines = f.readlines() 19 | categories = [] 20 | for line in lines: 21 | line = line.rstrip() 22 | categories.append(line) 23 | categories = sorted(categories) 24 | with open('category.txt','w') as f: 25 | f.write('\n'.join(categories)) 26 | 27 | dict_categories = {} 28 | for i, category in enumerate(categories): 29 | dict_categories[category] = i 30 | 31 | files_input = ['%s-validation.csv'%dataset_name,'%s-train.csv'%dataset_name] 32 | files_output = ['val_videofolder.txt','train_videofolder.txt'] 33 | for (filename_input, filename_output) in zip(files_input, files_output): 34 | with open(filename_input) as f: 35 | lines = f.readlines() 36 | folders = [] 37 | idx_categories = [] 38 | for line in lines: 39 | line = line.rstrip() 40 | items = line.split(';') 41 | folders.append(items[0]) 42 | idx_categories.append(os.path.join(dict_categories[items[1]])) 43 | output = [] 44 | for i in range(len(folders)): 45 | curFolder = folders[i] 46 | curIDX = idx_categories[i] 47 | # counting the number of frames in each video folders 48 | dir_files = os.listdir(os.path.join('20bn-%s'%dataset_name, curFolder)) 49 | output.append('%s %d %d'%(curFolder, len(dir_files), curIDX)) 50 | print('%d/%d'%(i, len(folders))) 51 | with open(filename_output,'w') as f: 52 | f.write('\n'.join(output)) 53 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.17.2 2 | opencv-python==4.2.0.32 3 | Pillow>=6.2.2 4 | pretrainedmodels==0.7.4 5 | python-dateutil==2.8.0 6 | pytz==2019.2 7 | six==1.12.0 8 | torch==1.5.0 9 | torchvision==0.6.0 10 | -------------------------------------------------------------------------------- /test_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import numpy as np 5 | import torch.nn.parallel 6 | import torch.optim 7 | from sklearn.metrics import confusion_matrix 8 | from dataset import TSNDataSet 9 | from models import TSN 10 | from transforms import * 11 | from ops import ConsensusModule 12 | import datasets_video 13 | import pdb 14 | from torch.nn import functional as F 15 | 16 | 17 | # options 18 | parser = argparse.ArgumentParser( 19 | description="MFF testing on the full validation set") 20 | parser.add_argument('dataset', type=str, choices=['jester', 'nvgesture', 'chalearn']) 21 | parser.add_argument('modality', type=str, choices=['RGB', 'Flow', 'RGBDiff', 'RGBFlow']) 22 | parser.add_argument('weights', type=str) 23 | parser.add_argument('--arch', type=str, default="resnet101") 24 | parser.add_argument('--save_scores', type=str, default=None) 25 | parser.add_argument('--test_segments', type=int, default=25) 26 | parser.add_argument('--max_num', type=int, default=-1) 27 | parser.add_argument('--test_crops', type=int, default=10) 28 | parser.add_argument('--input_size', type=int, default=224) 29 | parser.add_argument('--num_motion', type=int, default=3) 30 | parser.add_argument('--consensus_type', type=str, default='MLP', choices=['avg', 'MLP']) 31 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 32 | help='number of data loading workers (default: 4)') 33 | parser.add_argument('--gpus', nargs='+', type=int, default=None) 34 | parser.add_argument('--img_feature_dim',type=int, default=256) 35 | parser.add_argument('--num_set_segments',type=int, default=1) 36 | parser.add_argument('--softmax', type=int, default=0) 37 | 38 | args = parser.parse_args() 39 | 40 | class AverageMeter(object): 41 | """Computes and stores the average and current value""" 42 | def __init__(self): 43 | self.reset() 44 | 45 | def reset(self): 46 | self.val = 0 47 | self.avg = 0 48 | self.sum = 0 49 | self.count = 0 50 | 51 | def update(self, val, n=1): 52 | self.val = val 53 | self.sum += val * n 54 | self.count += n 55 | self.avg = self.sum / self.count 56 | 57 | def accuracy(output, target, topk=(1,)): 58 | """Computes the precision@k for the specified values of k""" 59 | maxk = max(topk) 60 | batch_size = target.size(0) 61 | _, pred = output.topk(maxk, 1, True, True) 62 | pred = pred.t() 63 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 64 | res = [] 65 | for k in topk: 66 | correct_k = correct[:k].view(-1).float().sum(0) 67 | res.append(correct_k.mul_(100.0 / batch_size)) 68 | return res 69 | 70 | 71 | 72 | categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(args.dataset, args.modality) 73 | num_class = len(categories) 74 | 75 | net = TSN(num_class, args.test_segments if args.consensus_type in ['MLP'] else 1, args.modality, 76 | base_model=args.arch, 77 | consensus_type=args.consensus_type, 78 | img_feature_dim=args.img_feature_dim, 79 | ) 80 | 81 | checkpoint = torch.load(args.weights) 82 | print("model epoch {} best prec@1: {}".format(checkpoint['epoch'], checkpoint['best_prec1'])) 83 | 84 | base_dict = {'.'.join(k.split('.')[1:]): v for k,v in list(checkpoint['state_dict'].items())} 85 | net.load_state_dict(base_dict) 86 | 87 | if args.test_crops == 1: 88 | cropping = torchvision.transforms.Compose([ 89 | GroupScale(net.scale_size), 90 | GroupCenterCrop(net.input_size), 91 | ]) 92 | elif args.test_crops == 10: 93 | cropping = torchvision.transforms.Compose([ 94 | #GroupOverSample(net.input_size, net.input_size) 95 | GroupOverSample(net.input_size, net.scale_size) 96 | ]) 97 | else: 98 | raise ValueError("Only 1 and 10 crops are supported while we got {}".format(args.test_crops)) 99 | 100 | ############ Data Loading Part ##### 101 | if args.modality == 'RGB': 102 | data_length = 1 103 | elif args.modality in ['Flow', 'RGBDiff']: 104 | data_length = 5 105 | elif args.modality == 'RGBFlow': 106 | data_length = args.num_motion 107 | 108 | data_loader = torch.utils.data.DataLoader( 109 | TSNDataSet(args.root_path, args.val_list, num_segments=args.test_segments, 110 | new_length=data_length, 111 | modality=args.modality, 112 | image_tmpl=prefix, 113 | dataset=args.dataset, 114 | test_mode=True, 115 | transform=torchvision.transforms.Compose([ 116 | cropping, 117 | Stack(roll=(args.arch in 118 | ['BNInception','InceptionV3']), isRGBFlow=(args.modality == 'RGBFlow')), 119 | ToTorchFormatTensor(div=(args.arch not in ['BNInception','InceptionV3'])), 120 | GroupNormalize(net.input_mean, net.input_std), 121 | ])), 122 | batch_size=1, shuffle=False, 123 | num_workers=args.workers*2, pin_memory=False) 124 | 125 | if args.gpus is not None: 126 | devices = [args.gpus[i] for i in range(args.workers)] 127 | else: 128 | devices = list(range(args.workers)) 129 | 130 | 131 | #net = torch.nn.DataParallel(net.cuda(devices[0]), device_ids=devices) 132 | net = torch.nn.DataParallel(net.cuda()) 133 | net.eval() 134 | 135 | data_gen = enumerate(data_loader) 136 | 137 | total_num = len(data_loader.dataset) 138 | output = [] 139 | 140 | 141 | def eval_video(video_data): 142 | i, data, label = video_data 143 | num_crop = args.test_crops 144 | 145 | if args.modality == 'RGB': 146 | length = 3 147 | elif args.modality == 'Flow': 148 | length = 10 149 | elif args.modality == 'RGBDiff': 150 | length = 18 151 | elif args.modality == 'RGBFlow': 152 | length = 3 + 2 * args.num_motion # 3 rgb channels and 3*2=6 flow channels 153 | else: 154 | raise ValueError("Unknown modality "+args.modality) 155 | 156 | input_var = torch.autograd.Variable(data.view(-1, length, data.size(2), data.size(3)), 157 | volatile=True) 158 | rst = net(input_var) 159 | if args.softmax==1: 160 | # take the softmax to normalize the output to probability 161 | rst = F.softmax(rst) 162 | 163 | rst = rst.data.cpu().numpy().copy() 164 | 165 | if args.consensus_type in ['MLP']: 166 | rst = rst.reshape(-1, 1, num_class) 167 | else: 168 | rst = rst.reshape((num_crop, args.test_segments, num_class)).mean(axis=0).reshape((args.test_segments, 1, num_class)) 169 | 170 | return i, rst, label[0] 171 | 172 | 173 | proc_start_time = time.time() 174 | max_num = args.max_num if args.max_num > 0 else len(data_loader.dataset) 175 | 176 | top1 = AverageMeter() 177 | top5 = AverageMeter() 178 | 179 | for i, (data, label) in data_gen: 180 | if i >= max_num: 181 | break 182 | rst = eval_video((i, data, label)) 183 | output.append(rst[1:]) 184 | cnt_time = time.time() - proc_start_time 185 | prec1, prec5 = accuracy(torch.from_numpy(np.mean(rst[1], axis=0)), label, topk=(1, 5)) 186 | top1.update(prec1[0], 1) 187 | top5.update(prec5[0], 1) 188 | print('video {} done, total {}/{}, average {:.3f} sec/video, moving Prec@1 {:.3f} Prec@5 {:.3f}'.format(i, i+1, 189 | total_num, 190 | float(cnt_time) / (i+1), top1.avg, top5.avg)) 191 | 192 | video_pred = [np.argmax(np.mean(x[0], axis=0)) for x in output] 193 | 194 | video_labels = [x[1] for x in output] 195 | 196 | 197 | cf = confusion_matrix(video_labels, video_pred).astype(float) 198 | 199 | cls_cnt = cf.sum(axis=1) 200 | cls_hit = np.diag(cf) 201 | 202 | cls_acc = cls_hit / cls_cnt 203 | 204 | print('-----Evaluation is finished------') 205 | print('Class Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100)) 206 | print('Overall Prec@1 {:.02f}% Prec@5 {:.02f}%'.format(top1.avg, top5.avg)) 207 | 208 | if args.save_scores is not None: 209 | 210 | # reorder before saving 211 | name_list = [x.strip().split()[0] for x in open(args.val_list)] 212 | order_dict = {e:i for i, e in enumerate(sorted(name_list))} 213 | reorder_output = [None] * len(output) 214 | reorder_label = [None] * len(output) 215 | reorder_pred = [None] * len(output) 216 | output_csv = [] 217 | for i in range(len(output)): 218 | idx = order_dict[name_list[i]] 219 | reorder_output[idx] = output[i] 220 | reorder_label[idx] = video_labels[i] 221 | reorder_pred[idx] = video_pred[i] 222 | output_csv.append('%s;%s'%(name_list[i], categories[video_pred[i]])) 223 | 224 | np.savez(args.save_scores, scores=reorder_output, labels=reorder_label, predictions=reorder_pred, cf=cf) 225 | 226 | with open(args.save_scores.replace('npz','csv'),'w') as f: 227 | f.write('\n'.join(output_csv)) 228 | 229 | 230 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import random 3 | from PIL import Image, ImageOps 4 | import numpy as np 5 | import cv2 6 | import numbers 7 | import math 8 | import torch 9 | 10 | 11 | class GroupRandomCrop(object): 12 | def __init__(self, size): 13 | if isinstance(size, numbers.Number): 14 | self.size = (int(size), int(size)) 15 | else: 16 | self.size = size 17 | 18 | def __call__(self, img_group): 19 | 20 | w, h = img_group[0].size 21 | th, tw = self.size 22 | 23 | out_images = list() 24 | 25 | x1 = random.randint(0, w - tw) 26 | y1 = random.randint(0, h - th) 27 | 28 | for img in img_group: 29 | assert(img.size[0] == w and img.size[1] == h) 30 | if w == tw and h == th: 31 | out_images.append(img) 32 | else: 33 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 34 | 35 | return out_images 36 | 37 | 38 | class GroupCenterCrop(object): 39 | def __init__(self, size): 40 | self.worker = torchvision.transforms.CenterCrop(size) 41 | 42 | def __call__(self, img_group): 43 | return [self.worker(img) for img in img_group] 44 | 45 | 46 | class GroupRandomHorizontalFlip(object): 47 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 48 | """ 49 | def __init__(self, is_flow=False): 50 | self.is_flow = is_flow 51 | 52 | def __call__(self, img_group, is_flow=False): 53 | v = random.random() 54 | if v < 0.5: 55 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] 56 | if self.is_flow: 57 | for i in range(0, len(ret), 2): 58 | ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping 59 | return ret 60 | else: 61 | return img_group 62 | 63 | 64 | class GroupNormalize(object): 65 | def __init__(self, mean, std): 66 | self.mean = mean 67 | self.std = std 68 | 69 | def __call__(self, tensor): 70 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean)) 71 | rep_std = self.std * (tensor.size()[0]//len(self.std)) 72 | 73 | # TODO: make efficient 74 | for t, m, s in zip(tensor, rep_mean, rep_std): 75 | t.sub_(m).div_(s) 76 | 77 | return tensor 78 | 79 | 80 | class GroupScale(object): 81 | """ Rescales the input PIL.Image to the given 'size'. 82 | 'size' will be the size of the smaller edge. 83 | For example, if height > width, then image will be 84 | rescaled to (size * height / width, size) 85 | size: size of the smaller edge 86 | interpolation: Default: PIL.Image.BILINEAR 87 | """ 88 | 89 | def __init__(self, size, interpolation=Image.BILINEAR): 90 | self.worker = torchvision.transforms.Scale(size, interpolation) 91 | 92 | def __call__(self, img_group): 93 | return [self.worker(img) for img in img_group] 94 | 95 | 96 | class GroupOverSample(object): 97 | def __init__(self, crop_size, scale_size=None): 98 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) 99 | 100 | if scale_size is not None: 101 | self.scale_worker = GroupScale(scale_size) 102 | else: 103 | self.scale_worker = None 104 | 105 | def __call__(self, img_group): 106 | 107 | if self.scale_worker is not None: 108 | img_group = self.scale_worker(img_group) 109 | 110 | image_w, image_h = img_group[0].size 111 | crop_w, crop_h = self.crop_size 112 | 113 | offsets = GroupMultiScaleCrop.fill_fix_offset(True, image_w, image_h, crop_w, crop_h) 114 | oversample_group = list() 115 | for o_w, o_h in offsets: 116 | normal_group = list() 117 | flip_group = list() 118 | for i, img in enumerate(img_group): 119 | #print(img.size) 120 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 121 | #print([o_w, o_h, o_w + crop_w, o_h + crop_h]) 122 | normal_group.append(crop) 123 | #flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 124 | #flip_group.append(flip_crop) 125 | 126 | #if img.mode == 'L' and i % 2 == 0: 127 | #flip_group.append(ImageOps.invert(flip_crop)) 128 | #else: 129 | #flip_group.append(flip_crop) 130 | 131 | oversample_group.extend(normal_group) 132 | #oversample_group.extend(flip_group) 133 | return oversample_group 134 | 135 | class GroupSpatialElasticDisplacement(object): 136 | 137 | def __init__(self): 138 | self.displacement = 20 139 | self.displacement_kernel = 25 140 | self.displacement_magnification = 0.60 141 | 142 | 143 | def __call__(self, img_group): 144 | v = random.random() 145 | if v < 0.5: 146 | im_size = img_group[0].size 147 | image_w, image_h = im_size[0], im_size[1] 148 | displacement_map = np.random.rand(image_h, image_w, 2) * 2 * self.displacement - self.displacement 149 | displacement_map = cv2.GaussianBlur(displacement_map, None, self.displacement_kernel) 150 | displacement_map *= self.displacement_magnification * self.displacement_kernel 151 | displacement_map = np.floor(displacement_map).astype('int32') 152 | 153 | displacement_map_rows = displacement_map[..., 0] + np.tile(np.arange(image_h), (image_w, 1)).T.astype('int32') 154 | displacement_map_rows = np.clip(displacement_map_rows, 0, image_h - 1) 155 | 156 | displacement_map_cols = displacement_map[..., 1] + np.tile(np.arange(image_w), (image_h, 1)).astype('int32') 157 | displacement_map_cols = np.clip(displacement_map_cols, 0, image_w - 1) 158 | ret_img_group = [Image.fromarray(np.asarray(img)[(displacement_map_rows.flatten(), displacement_map_cols.flatten())].reshape(np.asarray(img).shape)) for img in img_group] 159 | return ret_img_group 160 | 161 | else: 162 | return img_group 163 | 164 | 165 | 166 | class GroupMultiScaleResize(object): 167 | 168 | def __init__(self, scale): 169 | self.scale = scale 170 | 171 | def __call__(self, img_group): 172 | im_size = img_group[0].size 173 | self.resize_const = random.uniform(1.0 - self.scale, 1.0 + self.scale) # Aplly random resize constant 174 | resize_img_group = [img.resize((int(im_size[0]*self.resize_const), int(im_size[1]*self.resize_const))) for img in img_group] 175 | 176 | return resize_img_group 177 | 178 | 179 | 180 | class GroupMultiScaleRotate(object): 181 | 182 | def __init__(self, degree): 183 | self.degree = degree 184 | self.interpolation = Image.BILINEAR 185 | 186 | def __call__(self, img_group): 187 | im_size = img_group[0].size 188 | self.rotate_angle = random.randint(-self.degree, self.degree) # Aplly random rotation angle 189 | ret_img_group = [img.rotate(self.rotate_angle, resample=self.interpolation) for img in img_group] 190 | 191 | return ret_img_group 192 | 193 | 194 | 195 | class GroupMultiScaleCrop(object): 196 | 197 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=False): 198 | self.scales = scales if scales is not None else [1, 875, .75, .66] 199 | self.max_distort = max_distort 200 | self.fix_crop = fix_crop 201 | self.more_fix_crop = more_fix_crop 202 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 203 | self.interpolation = Image.BILINEAR 204 | 205 | def __call__(self, img_group): 206 | 207 | im_size = img_group[0].size 208 | #self.scales = [1, random.uniform(0.85, 1.0)] 209 | 210 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 211 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 212 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) 213 | for img in crop_img_group] 214 | 215 | return ret_img_group 216 | 217 | def _sample_crop_size(self, im_size): 218 | image_w, image_h = im_size[0], im_size[1] 219 | 220 | # find a crop size 221 | base_size = min(image_w, image_h) 222 | crop_sizes = [int(base_size * x) for x in self.scales] 223 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 224 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 225 | 226 | pairs = [] 227 | for i, h in enumerate(crop_h): 228 | for j, w in enumerate(crop_w): 229 | if abs(i - j) <= self.max_distort: 230 | pairs.append((w, h)) 231 | 232 | crop_pair = random.choice(pairs) 233 | if not self.fix_crop: 234 | w_offset = random.randint(0, image_w - crop_pair[0]) 235 | h_offset = random.randint(0, image_h - crop_pair[1]) 236 | else: 237 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 238 | 239 | return crop_pair[0], crop_pair[1], w_offset, h_offset 240 | 241 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 242 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 243 | return random.choice(offsets) 244 | 245 | @staticmethod 246 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 247 | w_step = (image_w - crop_w) // 4 248 | h_step = (image_h - crop_h) // 4 249 | 250 | ret = list() 251 | ret.append((0, 0)) # upper left 252 | ret.append((4 * w_step, 0)) # upper right 253 | ret.append((0, 4 * h_step)) # lower left 254 | ret.append((4 * w_step, 4 * h_step)) # lower right 255 | ret.append((2 * w_step, 2 * h_step)) # center 256 | 257 | if more_fix_crop: 258 | ret.append((0 * w_step, 2 * h_step)) # center left 259 | ret.append((4 * w_step, 2 * h_step)) # center right 260 | ret.append((2 * w_step, 4 * h_step)) # lower center 261 | ret.append((2 * w_step, 0 * h_step)) # upper center 262 | 263 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 264 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 265 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 266 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 267 | 268 | return ret 269 | 270 | 271 | class GroupRandomSizedCrop(object): 272 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size 273 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 274 | This is popularly used to train the Inception networks 275 | size: size of the smaller edge 276 | interpolation: Default: PIL.Image.BILINEAR 277 | """ 278 | def __init__(self, size, interpolation=Image.BILINEAR): 279 | self.size = size 280 | self.interpolation = interpolation 281 | 282 | def __call__(self, img_group): 283 | for attempt in range(10): 284 | area = img_group[0].size[0] * img_group[0].size[1] 285 | target_area = random.uniform(0.08, 1.0) * area 286 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 287 | 288 | w = int(round(math.sqrt(target_area * aspect_ratio))) 289 | h = int(round(math.sqrt(target_area / aspect_ratio))) 290 | 291 | if random.random() < 0.5: 292 | w, h = h, w 293 | 294 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]: 295 | x1 = random.randint(0, img_group[0].size[0] - w) 296 | y1 = random.randint(0, img_group[0].size[1] - h) 297 | found = True 298 | break 299 | else: 300 | found = False 301 | x1 = 0 302 | y1 = 0 303 | 304 | if found: 305 | out_group = list() 306 | for img in img_group: 307 | img = img.crop((x1, y1, x1 + w, y1 + h)) 308 | assert(img.size == (w, h)) 309 | out_group.append(img.resize((self.size, self.size), self.interpolation)) 310 | return out_group 311 | else: 312 | # Fallback 313 | scale = GroupScale(self.size, interpolation=self.interpolation) 314 | crop = GroupRandomCrop(self.size) 315 | return crop(scale(img_group)) 316 | 317 | 318 | class Stack(object): 319 | 320 | def __init__(self, roll=False, isRGBFlow=False): 321 | self.roll = roll 322 | self.isRGBFlow = isRGBFlow 323 | 324 | def __call__(self, img_group): 325 | if self.isRGBFlow: 326 | stacked_array = np.array([]) 327 | for x in img_group: 328 | if x.mode == 'L': 329 | if stacked_array.size ==0: 330 | stacked_array = np.expand_dims(x, 2) 331 | else: 332 | stacked_array = np.concatenate([stacked_array, np.expand_dims(x, 2)], axis=2) 333 | elif x.mode == 'RGB': 334 | if self.roll: 335 | stacked_array = np.concatenate([stacked_array, np.array(x)[:, :, ::-1]], axis=2) 336 | else: 337 | stacked_array = np.concatenate([stacked_array, np.array(x)], axis=2) 338 | return stacked_array 339 | 340 | else: 341 | if img_group[0].mode == 'L': 342 | return np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2) 343 | elif img_group[0].mode == 'RGB': 344 | if self.roll: 345 | asd = np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2) 346 | return np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2) 347 | else: 348 | return np.concatenate(img_group, axis=2) 349 | 350 | 351 | class ToTorchFormatTensor(object): 352 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 353 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 354 | def __init__(self, div=True): 355 | self.div = div 356 | 357 | def __call__(self, pic): 358 | if isinstance(pic, np.ndarray): 359 | # handle numpy array 360 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 361 | else: 362 | # handle PIL Image 363 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 364 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 365 | # put it from HWC to CHW format 366 | # yikes, this transpose takes 80% of the loading time/CPU 367 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 368 | return img.float().div(255) if self.div else img.float() 369 | 370 | 371 | class IdentityTransform(object): 372 | 373 | def __call__(self, data): 374 | return data 375 | 376 | 377 | if __name__ == "__main__": 378 | trans = torchvision.transforms.Compose([ 379 | GroupScale(256), 380 | GroupRandomCrop(224), 381 | Stack(), 382 | ToTorchFormatTensor(), 383 | GroupNormalize( 384 | mean=[.485, .456, .406], 385 | std=[.229, .224, .225] 386 | )] 387 | ) 388 | 389 | im = Image.open('../tensorflow-model-zoo.torch/lena_299.png') 390 | 391 | color_group = [im] * 3 392 | rst = trans(color_group) 393 | 394 | gray_group = [im.convert('L')] * 9 395 | gray_rst = trans(gray_group) 396 | 397 | trans2 = torchvision.transforms.Compose([ 398 | GroupRandomSizedCrop(256), 399 | Stack(), 400 | ToTorchFormatTensor(), 401 | GroupNormalize( 402 | mean=[.485, .456, .406], 403 | std=[.229, .224, .225]) 404 | ]) 405 | print(trans2(color_group)) 406 | --------------------------------------------------------------------------------