├── .gitignore
├── .gitmodules
├── LICENSE
├── MLPmodule.py
├── README.md
├── dataset.py
├── datasets_video.py
├── images
├── motion_fused_frames.jpg
└── network_arch.jpg
├── main.py
├── models.py
├── ops
├── __init__.py
├── basic_ops.py
└── utils.py
├── opts.py
├── pretrained_models
├── MFF_jester_RGBFlow_BNInception_segment4_3f1c_best.pth.tar
└── MFF_jester_RGBFlow_BNInception_segment8_3f1c_best.pth.tar
├── process_dataset.py
├── requirements.txt
├── test_models.py
└── transforms.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # files types to exculde
2 | *.mp4
3 | *.h5
4 |
5 | jester/
6 | #*.txt
7 |
8 | __pycache__
9 | *.pyc
10 |
11 | model/
12 | #*.pth.tar
13 | #*.pth
14 |
15 | log/
16 | *.csv
17 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "model_zoo"]
2 | path = model_zoo
3 | url = https://github.com/yjxiong/tensorflow-model-zoo.torch.git
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 2-Clause License for Motion Fused Frames
2 |
3 | Copyright (c) 2017, Okan Köpüklü
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 |
27 |
28 | BSD 2-Clause License for TSN-PyTorch
29 |
30 | Copyright (c) 2017, Multimedia Laboratary, The Chinese University of Hong Kong
31 | All rights reserved.
32 |
33 | Redistribution and use in source and binary forms, with or without
34 | modification, are permitted provided that the following conditions are met:
35 |
36 | * Redistributions of source code must retain the above copyright notice, this
37 | list of conditions and the following disclaimer.
38 |
39 | * Redistributions in binary form must reproduce the above copyright notice,
40 | this list of conditions and the following disclaimer in the documentation
41 | and/or other materials provided with the distribution.
42 |
43 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
44 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
45 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
46 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
47 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
48 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
49 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
50 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
51 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
52 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
53 |
--------------------------------------------------------------------------------
/MLPmodule.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class MLPmodule(torch.nn.Module):
5 | """
6 | This is the 2-layer MLP implementation used for linking spatio-temporal
7 | features coming from different segments.
8 | """
9 | def __init__(self, img_feature_dim, num_frames, num_class):
10 | super(MLPmodule, self).__init__()
11 | self.num_frames = num_frames
12 | self.num_class = num_class
13 | self.img_feature_dim = img_feature_dim
14 | self.num_bottleneck = 512
15 | self.classifier = nn.Sequential(
16 | nn.ReLU(),
17 | nn.Linear(self.num_frames * self.img_feature_dim,
18 | self.num_bottleneck),
19 | #nn.Dropout(0.90), # Add an extra DO if necess.
20 | nn.ReLU(),
21 | nn.Linear(self.num_bottleneck,self.num_class),
22 | )
23 | def forward(self, input):
24 | input = input.view(input.size(0), self.num_frames*self.img_feature_dim)
25 | input = self.classifier(input)
26 | return input
27 |
28 |
29 | def return_MLP(relation_type, img_feature_dim, num_frames, num_class):
30 | MLPmodel = MLPmodule(img_feature_dim, num_frames, num_class)
31 |
32 | return MLPmodel
33 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Motion Fused Frames (MFFs)
2 |
3 | Pytorch implementation of the article [Motion fused frames: Data level fusion strategy for hand gesture recognition](http://openaccess.thecvf.com/content_cvpr_2018_workshops/papers/w41/Kopuklu_Motion_Fused_Frames_CVPR_2018_paper.pdf)
4 |
5 | ```diff
6 | - Update: Code is updated for Pytorch 1.5.0 and CUDA 10.2
7 | ```
8 |
9 |

10 |
11 |
12 | ### Installation
13 | * Clone the repo with the following command:
14 | ```bash
15 | git clone https://github.com/okankop/MFF-pytorch.git
16 | ```
17 |
18 | * Setup in virtual environment and install the requirements:
19 | ```bash
20 | conda create -n MFF python=3.7.4
21 | pip install -r requirements.txt
22 | ```
23 |
24 | ### Dataset Preparation
25 | Download the [jester dataset](https://www.twentybn.com/datasets/something-something) or [NVIDIA dynamic hand gestures dataset](http://research.nvidia.com/publication/online-detection-and-classification-dynamic-hand-gestures-recurrent-3d-convolutional) or [ChaLearn LAP IsoGD dataset](http://www.cbsr.ia.ac.cn/users/jwan/database/isogd.html). Decompress them into the same folder and use [process_dataset.py](process_dataset.py) to generate the index files for train, val, and test split. Poperly set up the train, validatin, and category meta files in [datasets_video.py](datasets_video.py). Finally, use directory [flow_computation](https://github.com/okankop/flow_computation) to calculate the optical flow images using Brox method.
26 |
27 | Assume the structure of data directories is the following:
28 |
29 | ```misc
30 | ~/MFF-pytorch/
31 | datasets/
32 | jester/
33 | rgb/
34 | .../ (directories of video samples)
35 | .../ (jpg color frames)
36 | flow/
37 | u/
38 | .../ (directories of video samples)
39 | .../ (jpg optical-flow-u frames)
40 | v/
41 | .../ (directories of video samples)
42 | .../ (jpg optical-flow-v frames)
43 | model/
44 | .../(saved models for the last checkpoint and best model)
45 | ```
46 |
47 |
48 | ### Running the Code
49 | Followings are some examples for training under different scenarios:
50 |
51 | * Train 4-segment network with 3 flow, 1 color frames (4-MFFs-3f1c architecture)
52 | ```bash
53 | python main.py jester RGBFlow --arch BNInception --num_segments 4 \
54 | --consensus_type MLP --num_motion 3 --batch-size 32
55 | ```
56 |
57 | * Train resuming the last checkpoint (4-MFFs-3f1c architecture)
58 | ```bash
59 | python main.py jester RGBFlow --resume= --arch BNInception \
60 | --consensus_type MLP --num_segments 4 --num_motion 3 --batch-size 32
61 | ```
62 |
63 | * The command to test trained models (4-MFFs-3f1c architecture). Pretrained models are under [pretrained_models](pretrained_models).
64 |
65 | ```bash
66 | python test_models.py jester RGBFlow pretrained_models/MFF_jester_RGBFlow_BNInception_segment4_3f1c_best.pth.tar --arch BNInception --consensus_type MLP --test_crops 1 --num_motion 3 --test_segments 4
67 | ```
68 |
69 | All GPUs are used for the training. If you want a part of GPUs, use CUDA_VISIBLE_DEVICES=...
70 |
71 | ### Citation
72 | If you use this code or pre-trained models, please cite the following:
73 |
74 | ```bibtex
75 | @InProceedings{Kopuklu_2018_CVPR_Workshops,
76 | author = {Kopuklu, Okan and Kose, Neslihan and Rigoll, Gerhard},
77 | title = {Motion Fused Frames: Data Level Fusion Strategy for Hand Gesture Recognition},
78 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
79 | month = {June},
80 | year = {2018}
81 | }
82 | ```
83 |
84 | ### Acknowledgement
85 | This project is built on top of the codebase [TSN-pytorch](https://github.com/yjxiong/temporal-segment-networks). We thank Yuanjun Xiong for releasing [TSN-Pytorch codebase](https://github.com/yjxiong/temporal-segment-networks), which we build our work on top. We also thank Bolei Zhou for the insprational work [Temporal Segment Networks](https://arxiv.org/pdf/1711.08496.pdf), from which we imported [process_dataset.py](https://github.com/metalbubble/TRN-pytorch/blob/master/process_dataset.py) to our project.
86 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | import random
4 | from PIL import Image
5 | import os
6 | import os.path
7 | import numpy as np
8 | from numpy.random import randint
9 |
10 | class VideoRecord(object):
11 | def __init__(self, row):
12 | self._data = row
13 |
14 | @property
15 | def path(self):
16 | return self._data[0]
17 |
18 | @property
19 | def num_frames(self):
20 | return int(self._data[1])
21 |
22 | @property
23 | def label(self):
24 | return int(self._data[2])
25 |
26 |
27 | class TSNDataSet(data.Dataset):
28 | def __init__(self, root_path, list_file,
29 | num_segments=3, new_length=1, modality='RGB',
30 | image_tmpl='img_{:05d}.jpg', transform=None,
31 | force_grayscale=False, random_shift=True,
32 | test_mode=False, dataset='jester'):
33 |
34 | self.root_path = root_path
35 | self.list_file = list_file
36 | self.num_segments = num_segments
37 | self.new_length = new_length
38 | self.modality = modality
39 | self.image_tmpl = image_tmpl
40 | self.transform = transform
41 | self.random_shift = random_shift
42 | self.test_mode = test_mode
43 | self.dataset = dataset
44 |
45 | if self.modality == 'RGBDiff' or self.modality == 'RGBFlow':
46 | self.new_length += 1# Diff needs one more image to calculate diff
47 |
48 | self._parse_list()
49 |
50 | def _load_image(self, directory, idx, isLast=False):
51 | if self.modality == 'RGB' or self.modality == 'RGBDiff':
52 | try:
53 | return [Image.open(os.path.join(self.root_path, "rgb", directory, self.image_tmpl.format(idx))).convert('RGB')]
54 | except Exception:
55 | print('error loading image:', os.path.join(self.root_path, "rgb", directory, self.image_tmpl.format(idx)))
56 | return [Image.open(os.path.join(self.root_path, "rgb", directory, self.image_tmpl.format(1))).convert('RGB')]
57 |
58 | elif self.modality == 'Flow':
59 | try:
60 | x_img = Image.open(os.path.join(self.root_path, "flow/u", directory, self.image_tmpl.format(idx))).convert('L')
61 | y_img = Image.open(os.path.join(self.root_path, "flow/v", directory, self.image_tmpl.format(idx))).convert('L')
62 | except Exception:
63 | print('error loading flow file:', os.path.join(self.root_path, "flow/v", directory, self.image_tmpl.format(idx)))
64 | x_img = Image.open(os.path.join(self.root_path, "flow/u", directory, self.image_tmpl.format(1))).convert('L')
65 | y_img = Image.open(os.path.join(self.root_path, "flow/v", directory, self.image_tmpl.format(1))).convert('L')
66 | return [x_img, y_img]
67 |
68 | elif self.modality == 'RGBFlow':
69 | if isLast:
70 | return [Image.open(os.path.join(self.root_path, "rgb", directory, self.image_tmpl.format(idx))).convert('RGB')]
71 | else:
72 | x_img = Image.open(os.path.join(self.root_path, "flow/u", directory, self.image_tmpl.format(idx))).convert('L')
73 | y_img = Image.open(os.path.join(self.root_path, "flow/v", directory, self.image_tmpl.format(idx))).convert('L')
74 | return [x_img, y_img]
75 |
76 |
77 | def _parse_list(self):
78 | # check the frame number is large >3:
79 | # usualy it is [video_id, num_frames, class_idx]
80 | tmp = [x.strip().split(' ') for x in open(self.list_file)]
81 | tmp = [item for item in tmp if int(item[1])>=3]
82 | self.video_list = [VideoRecord(item) for item in tmp]
83 | print('video number:%d'%(len(self.video_list)))
84 |
85 | def _sample_indices(self, record):
86 | """
87 |
88 | :param record: VideoRecord
89 | :return: list
90 | """
91 | average_duration = (record.num_frames - self.new_length + 1) // self.num_segments
92 |
93 | if average_duration > 0:
94 | offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration, size=self.num_segments)
95 | elif record.num_frames > self.num_segments:
96 | offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments))
97 | else:
98 | offsets = np.zeros((self.num_segments,))
99 | return offsets + 1
100 |
101 | def _get_val_indices(self, record):
102 | if record.num_frames > self.num_segments + self.new_length - 1:
103 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments)
104 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)])
105 | else:
106 | offsets = np.zeros((self.num_segments,))
107 | return offsets + 1
108 |
109 | def _get_test_indices(self, record):
110 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments)
111 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)])
112 | return offsets + 1
113 |
114 | def __getitem__(self, index):
115 | record = self.video_list[index]
116 | # check this is a legit video folder
117 | if self.modality == 'RGBFlow':
118 | while not os.path.exists(os.path.join(self.root_path, "rgb", record.path, self.image_tmpl.format(1))):
119 | index = np.random.randint(len(self.video_list))
120 | record = self.video_list[index]
121 | else:
122 | while not os.path.exists(os.path.join(self.root_path, "rgb", record.path, self.image_tmpl.format(1))):
123 | index = np.random.randint(len(self.video_list))
124 | record = self.video_list[index]
125 |
126 | if not self.test_mode:
127 | segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record)
128 | else:
129 | segment_indices = self._get_test_indices(record)
130 |
131 | return self.get(record, segment_indices)
132 |
133 | def get(self, record, indices):
134 | images = list()
135 | for seg_ind in indices:
136 | p = int(seg_ind)
137 | for i in range(self.new_length):
138 | if self.modality == 'RGBFlow':
139 | if i == self.new_length - 1:
140 | seg_imgs = self._load_image(record.path, p, True)
141 | else:
142 | if p == record.num_frames:
143 | seg_imgs = self._load_image(record.path, p-1)
144 | else:
145 | seg_imgs = self._load_image(record.path, p)
146 | else:
147 | seg_imgs = self._load_image(record.path, p)
148 |
149 | images.extend(seg_imgs)
150 | if p < record.num_frames:
151 | p += 1
152 |
153 | process_data = self.transform(images)
154 | return process_data, record.label
155 |
156 | def __len__(self):
157 | return len(self.video_list)
158 |
--------------------------------------------------------------------------------
/datasets_video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision
4 | import torchvision.datasets as datasets
5 |
6 |
7 | ROOT_DATASET= '/usr/home/kop/MFF-pytorch'
8 |
9 | def return_jester(modality):
10 | filename_categories = 'jester/category.txt'
11 | filename_imglist_train = 'jester/train_videofolder.txt'
12 | filename_imglist_val = 'jester/val_videofolder.txt'
13 | if modality == 'RGB':
14 | prefix = '{:05d}.jpg'
15 | root_data = '/usr/home/kop/MFF-pytorch/datasets/jester'
16 | elif modality == 'RGBFlow':
17 | prefix = '{:05d}.jpg'
18 | root_data = '/usr/home/kop/MFF-pytorch/datasets/jester'
19 | else:
20 | print('no such modality:'+modality)
21 | os.exit()
22 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
23 |
24 | def return_nvgesture(modality):
25 | filename_categories = 'nvgesture/category.txt'
26 | filename_imglist_train = 'nvgesture/train_videofolder.txt'
27 | filename_imglist_val = 'nvgesture/val_videofolder.txt'
28 | if modality == 'RGB':
29 | prefix = '{:05d}.jpg'
30 | root_data = '/data2/nvGesture'
31 | elif modality == 'RGBFlow':
32 | prefix = '{:05d}.jpg'
33 | root_data = '/data2/nvGesture'
34 | else:
35 | print('no such modality:'+modality)
36 | os.exit()
37 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
38 |
39 | def return_chalearn(modality):
40 | filename_categories = 'chalearn/category.txt'
41 | filename_imglist_train = 'chalearn/train_videofolder.txt'
42 | filename_imglist_val = 'chalearn/val_videofolder.txt'
43 | #filename_imglist_val = 'chalearn/test_videofolder.txt'
44 | if modality == 'RGB':
45 | prefix = '{:05d}.jpg'
46 | root_data = '/data2/ChaLearn'
47 | elif modality == 'RGBFlow':
48 | prefix = '{:05d}.jpg'
49 | root_data = '/data2/ChaLearn'
50 | else:
51 | print('no such modality:'+modality)
52 | os.exit()
53 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
54 |
55 | def return_dataset(dataset, modality):
56 | dict_single = {'jester':return_jester, 'nvgesture': return_nvgesture, 'chalearn': return_chalearn}
57 | if dataset in dict_single:
58 | file_categories, file_imglist_train, file_imglist_val, root_data, prefix = dict_single[dataset](modality)
59 | else:
60 | raise ValueError('Unknown dataset '+dataset)
61 |
62 | file_imglist_train = os.path.join(ROOT_DATASET, file_imglist_train)
63 | file_imglist_val = os.path.join(ROOT_DATASET, file_imglist_val)
64 | file_categories = os.path.join(ROOT_DATASET, file_categories)
65 | with open(file_categories) as f:
66 | lines = f.readlines()
67 | categories = [item.rstrip() for item in lines]
68 | return categories, file_imglist_train, file_imglist_val, root_data, prefix
69 |
70 |
--------------------------------------------------------------------------------
/images/motion_fused_frames.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/okankop/MFF-pytorch/77bb9b14d294cad83f07c6394aeee14780950edb/images/motion_fused_frames.jpg
--------------------------------------------------------------------------------
/images/network_arch.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/okankop/MFF-pytorch/77bb9b14d294cad83f07c6394aeee14780950edb/images/network_arch.jpg
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | import shutil
5 | import torch
6 | import torchvision
7 | import torch.nn.parallel
8 | import torch.backends.cudnn as cudnn
9 | import torch.optim
10 | from torch.autograd import Variable
11 | from torch.nn.utils import clip_grad_norm
12 | from torch.utils.data.sampler import SequentialSampler
13 |
14 | from dataset import TSNDataSet
15 | from models import TSN
16 | from transforms import *
17 | from opts import parser
18 | import datasets_video
19 |
20 |
21 | best_prec1 = 0
22 |
23 | def main():
24 | global args, best_prec1
25 | args = parser.parse_args()
26 | check_rootfolders()
27 |
28 | categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(args.dataset, args.modality)
29 | num_class = len(categories)
30 |
31 |
32 | args.store_name = '_'.join(['MFF', args.dataset, args.modality, args.arch,
33 | 'segment%d'% args.num_segments, '%df1c'% args.num_motion])
34 | print('storing name: ' + args.store_name)
35 |
36 | model = TSN(num_class, args.num_segments, args.modality,
37 | base_model=args.arch,
38 | consensus_type=args.consensus_type,
39 | dropout=args.dropout, num_motion=args.num_motion,
40 | img_feature_dim=args.img_feature_dim,
41 | partial_bn=not args.no_partialbn,
42 | dataset=args.dataset)
43 |
44 | crop_size = model.crop_size
45 | scale_size = model.scale_size
46 | input_mean = model.input_mean
47 | input_std = model.input_std
48 | train_augmentation = model.get_augmentation()
49 |
50 | policies = model.get_optim_policies()
51 | model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()
52 |
53 | if args.resume:
54 | if os.path.isfile(args.resume):
55 | print(("=> loading checkpoint '{}'".format(args.resume)))
56 | checkpoint = torch.load(args.resume)
57 | args.start_epoch = checkpoint['epoch']
58 | best_prec1 = checkpoint['best_prec1']
59 | model.load_state_dict(checkpoint['state_dict'])
60 | print(("=> loaded checkpoint '{}' (epoch {})"
61 | .format(args.evaluate, checkpoint['epoch'])))
62 | else:
63 | print(("=> no checkpoint found at '{}'".format(args.resume)))
64 |
65 | print(model)
66 | cudnn.benchmark = True
67 |
68 | # Data loading code
69 | if ((args.modality != 'RGBDiff') | (args.modality != 'RGBFlow')):
70 | normalize = GroupNormalize(input_mean, input_std)
71 | else:
72 | normalize = IdentityTransform()
73 |
74 | if args.modality == 'RGB':
75 | data_length = 1
76 | elif args.modality in ['Flow', 'RGBDiff']:
77 | data_length = 5
78 | elif args.modality == 'RGBFlow':
79 | data_length = args.num_motion
80 |
81 | train_loader = torch.utils.data.DataLoader(
82 | TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments,
83 | new_length=data_length,
84 | modality=args.modality,
85 | image_tmpl=prefix,
86 | dataset=args.dataset,
87 | transform=torchvision.transforms.Compose([
88 | train_augmentation,
89 | Stack(roll=(args.arch in ['BNInception','InceptionV3']), isRGBFlow = (args.modality == 'RGBFlow')),
90 | ToTorchFormatTensor(div=(args.arch not in ['BNInception','InceptionV3'])),
91 | normalize,
92 | ])),
93 | batch_size=args.batch_size, shuffle=True,
94 | num_workers=args.workers, pin_memory=False)
95 |
96 | val_loader = torch.utils.data.DataLoader(
97 | TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
98 | new_length=data_length,
99 | modality=args.modality,
100 | image_tmpl=prefix,
101 | dataset=args.dataset,
102 | random_shift=False,
103 | transform=torchvision.transforms.Compose([
104 | GroupScale(int(scale_size)),
105 | GroupCenterCrop(crop_size),
106 | Stack(roll=(args.arch in ['BNInception','InceptionV3']), isRGBFlow = (args.modality == 'RGBFlow')),
107 | ToTorchFormatTensor(div=(args.arch not in ['BNInception','InceptionV3'])),
108 | normalize,
109 | ])),
110 | batch_size=args.batch_size, shuffle=False,
111 | num_workers=args.workers, pin_memory=False)
112 |
113 | # define loss function (criterion) and optimizer
114 | if args.loss_type == 'nll':
115 | criterion = torch.nn.CrossEntropyLoss().cuda()
116 | else:
117 | raise ValueError("Unknown loss type")
118 |
119 | for group in policies:
120 | print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
121 | group['name'], len(group['params']), group['lr_mult'], group['decay_mult'])))
122 |
123 | optimizer = torch.optim.SGD(policies,
124 | args.lr,
125 | momentum=args.momentum,
126 | weight_decay=args.weight_decay)
127 |
128 | if args.evaluate:
129 | validate(val_loader, model, criterion, 0)
130 | return
131 |
132 | log_training = open(os.path.join(args.root_log, '%s.csv' % args.store_name), 'w')
133 | for epoch in range(args.start_epoch, args.epochs):
134 | adjust_learning_rate(optimizer, epoch, args.lr_steps)
135 |
136 | # train for one epoch
137 | train(train_loader, model, criterion, optimizer, epoch, log_training)
138 |
139 | # evaluate on validation set
140 | if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
141 | prec1 = validate(val_loader, model, criterion, (epoch + 1) * len(train_loader), log_training)
142 |
143 | # remember best prec@1 and save checkpoint
144 | is_best = prec1 > best_prec1
145 | best_prec1 = max(prec1, best_prec1)
146 | save_checkpoint({
147 | 'epoch': epoch + 1,
148 | 'arch': args.arch,
149 | 'state_dict': model.state_dict(),
150 | 'best_prec1': best_prec1,
151 | }, is_best)
152 |
153 |
154 | def train(train_loader, model, criterion, optimizer, epoch, log):
155 | batch_time = AverageMeter()
156 | data_time = AverageMeter()
157 | losses = AverageMeter()
158 | top1 = AverageMeter()
159 | top5 = AverageMeter()
160 |
161 | if args.no_partialbn:
162 | model.module.partialBN(False)
163 | else:
164 | model.module.partialBN(True)
165 |
166 | # switch to train mode
167 | model.train()
168 |
169 | end = time.time()
170 | for i, (input, target) in enumerate(train_loader):
171 | # measure data loading time
172 | data_time.update(time.time() - end)
173 |
174 | target = target.cuda()
175 | input_var = Variable(input)
176 | target_var = Variable(target)
177 |
178 | # compute output
179 | output = model(input_var)
180 | loss = criterion(output, target_var)
181 |
182 | # measure accuracy and record loss
183 | prec1, prec5 = accuracy(output.data, target, topk=(1,5))
184 | losses.update(loss.data[0], input.size(0))
185 | top1.update(prec1[0], input.size(0))
186 | top5.update(prec5[0], input.size(0))
187 |
188 |
189 | # compute gradient and do SGD step
190 | optimizer.zero_grad()
191 |
192 | loss.backward()
193 |
194 | if args.clip_gradient is not None:
195 | total_norm = clip_grad_norm(model.parameters(), args.clip_gradient)
196 | # if total_norm > args.clip_gradient:
197 | # print("clipping gradient: {} with coef {}".format(total_norm, args.clip_gradient / total_norm))
198 |
199 | optimizer.step()
200 |
201 | # measure elapsed time
202 | batch_time.update(time.time() - end)
203 | end = time.time()
204 |
205 | if i % args.print_freq == 0:
206 | output = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
207 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
208 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
209 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
210 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
211 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
212 | epoch, i, len(train_loader), batch_time=batch_time,
213 | data_time=data_time, loss=losses, top1=top1, top5=top5, lr=optimizer.param_groups[-1]['lr']))
214 | print(output)
215 | log.write(output + '\n')
216 | log.flush()
217 |
218 |
219 |
220 | def validate(val_loader, model, criterion, iter, log):
221 | batch_time = AverageMeter()
222 | losses = AverageMeter()
223 | top1 = AverageMeter()
224 | top5 = AverageMeter()
225 |
226 | # switch to evaluate mode
227 | model.eval()
228 |
229 | end = time.time()
230 | for i, (input, target) in enumerate(val_loader):
231 | target = target.cuda()
232 | with torch.no_grad():
233 | input_var = Variable(input)
234 | target_var = Variable(target)
235 |
236 | # compute output
237 | output = model(input_var)
238 | loss = criterion(output, target_var)
239 |
240 | # measure accuracy and record loss
241 | prec1, prec5 = accuracy(output.data, target, topk=(1,5))
242 |
243 | losses.update(loss.data[0], input.size(0))
244 | top1.update(prec1[0], input.size(0))
245 | top5.update(prec5[0], input.size(0))
246 |
247 | # measure elapsed time
248 | batch_time.update(time.time() - end)
249 | end = time.time()
250 |
251 | if i % args.print_freq == 0:
252 | output = ('Test: [{0}/{1}]\t'
253 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
254 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
255 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
256 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
257 | i, len(val_loader), batch_time=batch_time, loss=losses,
258 | top1=top1, top5=top5))
259 | print(output)
260 | log.write(output + '\n')
261 | log.flush()
262 |
263 | output = ('Testing Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}'
264 | .format(top1=top1, top5=top5, loss=losses))
265 | print(output)
266 | output_best = '\nBest Prec@1: %.3f'%(best_prec1)
267 | print(output_best)
268 | log.write(output + ' ' + output_best + '\n')
269 | log.flush()
270 |
271 | return top1.avg
272 |
273 |
274 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
275 | torch.save(state, '%s/%s_checkpoint.pth.tar' % (args.root_model, args.store_name))
276 | if is_best:
277 | shutil.copyfile('%s/%s_checkpoint.pth.tar' % (args.root_model, args.store_name),'%s/%s_best.pth.tar' % (args.root_model, args.store_name))
278 |
279 | class AverageMeter(object):
280 | """Computes and stores the average and current value"""
281 | def __init__(self):
282 | self.reset()
283 |
284 | def reset(self):
285 | self.val = 0
286 | self.avg = 0
287 | self.sum = 0
288 | self.count = 0
289 |
290 | def update(self, val, n=1):
291 | self.val = val
292 | self.sum += val * n
293 | self.count += n
294 | self.avg = self.sum / self.count
295 |
296 |
297 | def adjust_learning_rate(optimizer, epoch, lr_steps):
298 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
299 | decay = 0.5 ** (sum(epoch >= np.array(lr_steps)))
300 | lr = args.lr * decay
301 | decay = args.weight_decay
302 | for param_group in optimizer.param_groups:
303 | param_group['lr'] = lr * param_group['lr_mult']
304 | param_group['weight_decay'] = decay * param_group['decay_mult']
305 |
306 |
307 | def accuracy(output, target, topk=(1,)):
308 | """Computes the precision@k for the specified values of k"""
309 | maxk = max(topk)
310 | batch_size = target.size(0)
311 |
312 | _, pred = output.topk(maxk, 1, True, True)
313 | pred = pred.t()
314 | correct = pred.eq(target.view(1, -1).expand_as(pred))
315 |
316 | res = []
317 | for k in topk:
318 | correct_k = correct[:k].view(-1).float().sum(0)
319 | res.append(correct_k.mul_(100.0 / batch_size))
320 | return res
321 |
322 |
323 | def count_parameters(model):
324 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
325 |
326 |
327 | def check_rootfolders():
328 | """Create log and model folder"""
329 | folders_util = [args.root_log, args.root_model, args.root_output]
330 | for folder in folders_util:
331 | if not os.path.exists(folder):
332 | print('creating folder ' + folder)
333 | os.mkdir(folder)
334 |
335 |
336 | if __name__ == '__main__':
337 | main()
338 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | from ops.basic_ops import ConsensusModule, Identity
4 | from transforms import *
5 | from torch.nn.init import normal, constant
6 |
7 | import pretrainedmodels
8 | import MLPmodule
9 |
10 | class TSN(nn.Module):
11 | def __init__(self, num_class, num_segments, modality,
12 | base_model='resnet101', new_length=None,
13 | consensus_type='avg', before_softmax=True, num_motion=3,
14 | dropout=0.8,img_feature_dim=256, dataset='jester',
15 | crop_num=1, partial_bn=True, print_spec=True):
16 | super(TSN, self).__init__()
17 | self.modality = modality
18 | self.num_segments = num_segments
19 | self.num_motion = num_motion
20 | self.reshape = True
21 | self.before_softmax = before_softmax
22 | self.dropout = dropout
23 | self.dataset = dataset
24 | self.crop_num = crop_num
25 | self.consensus_type = consensus_type
26 | self.img_feature_dim = img_feature_dim # the dimension of the CNN feature to represent each frame
27 | if not before_softmax and consensus_type != 'avg':
28 | raise ValueError("Only avg consensus can be used after Softmax")
29 |
30 | if new_length is None:
31 | if modality == "RGB":
32 | self.new_length = 1
33 | elif modality == "Flow":
34 | self.new_length = 5
35 | elif modality == "RGBFlow":
36 | #self.new_length = 1
37 | self.new_length = self.num_motion
38 | else:
39 | self.new_length = new_length
40 | if print_spec == True:
41 | print(("""
42 | Initializing TSN with base model: {}.
43 | TSN Configurations:
44 | input_modality: {}
45 | num_segments: {}
46 | new_length: {}
47 | consensus_module: {}
48 | dropout_ratio: {}
49 | img_feature_dim: {}
50 | """.format(base_model, self.modality, self.num_segments, self.new_length, consensus_type, self.dropout, self.img_feature_dim)))
51 |
52 | self._prepare_base_model(base_model)
53 |
54 | feature_dim = self._prepare_tsn(num_class)
55 |
56 | if self.modality == 'Flow':
57 | print("Converting the ImageNet model to a flow init model")
58 | self.base_model = self._construct_flow_model(self.base_model)
59 | print("Done. Flow model ready...")
60 | elif self.modality == 'RGBDiff':
61 | print("Converting the ImageNet model to RGB+Diff init model")
62 | self.base_model = self._construct_diff_model(self.base_model)
63 | print("Done. RGBDiff model ready.")
64 | elif self.modality == 'RGBFlow':
65 | print("Converting the ImageNet model to RGB+Flow init model")
66 | self.base_model = self._construct_rgbflow_model(self.base_model)
67 | print("Done. RGBFlow model ready.")
68 | if consensus_type == 'MLP':
69 | self.consensus = MLPmodule.return_MLP(consensus_type, self.img_feature_dim, self.num_segments, num_class)
70 | else:
71 | self.consensus = ConsensusModule(consensus_type)
72 |
73 | if not self.before_softmax:
74 | self.softmax = nn.Softmax()
75 |
76 | self._enable_pbn = partial_bn
77 | if partial_bn:
78 | self.partialBN(True)
79 |
80 |
81 | def _prepare_tsn(self, num_class):
82 | feature_dim = getattr(self.base_model, self.base_model.last_layer_name).in_features
83 | if self.dropout == 0:
84 | setattr(self.base_model, self.base_model.last_layer_name, nn.Linear(feature_dim, num_class))
85 | self.new_fc = None
86 | else:
87 | setattr(self.base_model, self.base_model.last_layer_name, nn.Dropout(p=self.dropout))
88 | if self.consensus_type == 'MLP':
89 | # set the MFFs feature dimension
90 | self.new_fc = nn.Linear(feature_dim, self.img_feature_dim)
91 | else:
92 | # the default consensus types in TSN
93 | self.new_fc = nn.Linear(feature_dim, num_class)
94 |
95 | std = 0.001
96 | if self.new_fc is None:
97 | normal(getattr(self.base_model, self.base_model.last_layer_name).weight, 0, std)
98 | constant(getattr(self.base_model, self.base_model.last_layer_name).bias, 0)
99 | else:
100 | normal(self.new_fc.weight, 0, std)
101 | constant(self.new_fc.bias, 0)
102 | return feature_dim
103 |
104 | def _prepare_base_model(self, base_model):
105 |
106 | if 'resnet' in base_model or 'vgg' in base_model or 'squeezenet1_1' in base_model:
107 | self.base_model = pretrainedmodels.__dict__[base_model](num_classes=1000, pretrained='imagenet')
108 | if base_model == 'squeezenet1_1':
109 | self.base_model = self.base_model.features
110 | self.base_model.last_layer_name = '12'
111 | else:
112 | self.base_model.last_layer_name = 'fc'
113 | self.input_size = 224
114 | self.input_mean = [0.485, 0.456, 0.406]
115 | self.input_std = [0.229, 0.224, 0.225]
116 |
117 | if self.modality == 'Flow':
118 | self.input_mean = [0.5]
119 | self.input_std = [np.mean(self.input_std)]
120 | elif self.modality == 'RGBDiff':
121 | self.input_mean = [0.485, 0.456, 0.406] + [0] * 3 * self.new_length
122 | self.input_std = self.input_std + [np.mean(self.input_std) * 2] * 3 * self.new_length
123 | elif base_model == 'BNInception':
124 | self.base_model = pretrainedmodels.__dict__['bninception'](num_classes=1000, pretrained='imagenet')
125 | self.base_model.last_layer_name = 'last_linear'
126 | self.input_size = 224
127 | self.input_mean = [104, 117, 128]
128 | self.input_std = [1]
129 | if self.modality == 'Flow':
130 | self.input_mean = [128]
131 | elif self.modality == 'RGBDiff':
132 | self.input_mean = self.input_mean * (1 + self.new_length)
133 | elif 'resnext101' in base_model:
134 | self.base_model = pretrainedmodels.__dict__[base_model](num_classes=1000, pretrained='imagenet')
135 | print(self.base_model)
136 | self.base_model.last_layer_name = 'last_linear'
137 | self.input_size = 224
138 | self.input_mean = [0.485, 0.456, 0.406]
139 | self.input_std = [0.229, 0.224, 0.225]
140 | if self.modality == 'Flow':
141 | self.input_mean = [128]
142 | elif self.modality == 'RGBDiff':
143 | self.input_mean = self.input_mean * (1 + self.new_length)
144 | else:
145 | raise ValueError('Unknown base model: {}'.format(base_model))
146 |
147 | def train(self, mode=True):
148 | """
149 | Override the default train() to freeze the BN parameters
150 | :return:
151 | """
152 | super(TSN, self).train(mode)
153 | count = 0
154 | if self._enable_pbn:
155 | print("Freezing BatchNorm2D except the first one.")
156 | for m in self.base_model.modules():
157 | if isinstance(m, nn.BatchNorm2d):
158 | count += 1
159 | if count >= (2 if self._enable_pbn else 1):
160 | m.eval()
161 |
162 | # shutdown update in frozen mode
163 | m.weight.requires_grad = False
164 | m.bias.requires_grad = False
165 |
166 |
167 | def partialBN(self, enable):
168 | self._enable_pbn = enable
169 |
170 | def get_optim_policies(self):
171 | first_conv_weight = []
172 | first_conv_bias = []
173 | normal_weight = []
174 | normal_bias = []
175 | bn = []
176 |
177 | conv_cnt = 0
178 | bn_cnt = 0
179 | for m in self.modules():
180 | if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv1d):
181 | ps = list(m.parameters())
182 | conv_cnt += 1
183 | if conv_cnt == 1:
184 | first_conv_weight.append(ps[0])
185 | if len(ps) == 2:
186 | first_conv_bias.append(ps[1])
187 | else:
188 | normal_weight.append(ps[0])
189 | if len(ps) == 2:
190 | normal_bias.append(ps[1])
191 | elif isinstance(m, torch.nn.Linear):
192 | ps = list(m.parameters())
193 | normal_weight.append(ps[0])
194 | if len(ps) == 2:
195 | normal_bias.append(ps[1])
196 |
197 | elif isinstance(m, torch.nn.BatchNorm1d):
198 | bn.extend(list(m.parameters()))
199 | elif isinstance(m, torch.nn.BatchNorm2d):
200 | bn_cnt += 1
201 | # later BN's are frozen
202 | if not self._enable_pbn or bn_cnt == 1:
203 | bn.extend(list(m.parameters()))
204 | elif len(m._modules) == 0:
205 | if len(list(m.parameters())) > 0:
206 | raise ValueError("New atomic module type: {}. Need to give it a learning policy".format(type(m)))
207 |
208 | return [
209 | {'params': first_conv_weight, 'lr_mult': 5 if self.modality == 'Flow' else 1, 'decay_mult': 1,
210 | 'name': "first_conv_weight"},
211 | {'params': first_conv_bias, 'lr_mult': 10 if self.modality == 'Flow' else 2, 'decay_mult': 0,
212 | 'name': "first_conv_bias"},
213 | {'params': normal_weight, 'lr_mult': 1, 'decay_mult': 1,
214 | 'name': "normal_weight"},
215 | {'params': normal_bias, 'lr_mult': 2, 'decay_mult': 0,
216 | 'name': "normal_bias"},
217 | {'params': bn, 'lr_mult': 1, 'decay_mult': 0,
218 | 'name': "BN scale/shift"},
219 | ]
220 |
221 | def forward(self, input):
222 | sample_len = (3 if self.modality == "RGB" else 2) * self.new_length
223 |
224 | if self.modality == 'RGBDiff':
225 | sample_len = 3 * self.new_length
226 | input = self._get_diff(input)
227 |
228 | if self.modality == 'RGBFlow':
229 | sample_len = 3 + 2 * self.new_length
230 |
231 | base_out = self.base_model(input.view((-1, sample_len) + input.size()[-2:]))
232 |
233 | if self.dropout > 0:
234 | base_out = self.new_fc(base_out)
235 |
236 | if not self.before_softmax:
237 | base_out = self.softmax(base_out)
238 | if self.reshape:
239 | base_out = base_out.view((-1, self.num_segments) + base_out.size()[1:])
240 |
241 | output = self.consensus(base_out)
242 | return output.squeeze(1)
243 |
244 | def _get_diff(self, input, keep_rgb=False):
245 | input_c = 3 if self.modality in ["RGB", "RGBDiff"] else 2
246 | input_view = input.view((-1, self.num_segments, self.new_length + 1, input_c,) + input.size()[2:])
247 | if keep_rgb:
248 | new_data = input_view.clone()
249 | else:
250 | new_data = input_view[:, :, 1:, :, :, :].clone()
251 |
252 | for x in reversed(list(range(1, self.new_length + 1))):
253 | if keep_rgb:
254 | new_data[:, :, x, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :]
255 | else:
256 | new_data[:, :, x - 1, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :]
257 |
258 | return new_data
259 |
260 | """ # There is no need now!!
261 | def _get_rgbflow(self, input):
262 | input_c = 3 + 2 * self.new_length # 3 is rgb channels, and 2 is coming for x & y channels of opt.flow
263 | input_view = input.view((-1, self.num_segments, self.new_length + 1, input_c,) + input.size()[2:])
264 | new_data = input_view.clone()
265 | return new_data
266 | """
267 |
268 | def _construct_rgbflow_model(self, base_model):
269 | # modify the convolution layers
270 | # Torch models are usually defined in a hierarchical way.
271 | # nn.modules.children() return all sub modules in a DFS manner
272 | modules = list(self.base_model.modules())
273 | filter_conv2d = filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules))))
274 | first_conv_idx = next(filter_conv2d)
275 | conv_layer = modules[first_conv_idx]
276 | container = modules[first_conv_idx - 1]
277 |
278 | # modify parameters, assume the first blob contains the convolution kernels
279 | params = [x.clone() for x in conv_layer.parameters()]
280 | kernel_size = params[0].size()
281 | new_kernel_size = kernel_size[:1] + (2 * self.new_length,) + kernel_size[2:]
282 | new_kernels = torch.cat((params[0].data.mean(dim=1,keepdim=True).expand(new_kernel_size).contiguous(), params[0].data), 1) # NOTE: Concatanating might be other way around. Check it!
283 | new_kernel_size = kernel_size[:1] + (3 + 2 * self.new_length,) + kernel_size[2:]
284 |
285 | new_conv = nn.Conv2d(new_kernel_size[1], conv_layer.out_channels,
286 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding,
287 | bias=True if len(params) == 2 else False)
288 | new_conv.weight.data = new_kernels
289 | if len(params) == 2:
290 | new_conv.bias.data = params[1].data # add bias if neccessary
291 | layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name
292 |
293 | # replace the first convolution layer
294 | setattr(container, layer_name, new_conv)
295 | return base_model
296 |
297 | def _construct_flow_model(self, base_model):
298 | # modify the convolution layers
299 | # Torch models are usually defined in a hierarchical way.
300 | # nn.modules.children() return all sub modules in a DFS manner
301 | modules = list(self.base_model.modules())
302 | first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0]
303 | conv_layer = modules[first_conv_idx]
304 | container = modules[first_conv_idx - 1]
305 |
306 | # modify parameters, assume the first blob contains the convolution kernels
307 | params = [x.clone() for x in conv_layer.parameters()]
308 | kernel_size = params[0].size()
309 | print(kernel_size)
310 | new_kernel_size = kernel_size[:1] + (2 * self.new_length, ) + kernel_size[2:]
311 | print(new_kernel_size)
312 | new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()
313 |
314 | new_conv = nn.Conv2d(2 * self.new_length, conv_layer.out_channels,
315 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding,
316 | bias=True if len(params) == 2 else False)
317 | new_conv.weight.data = new_kernels
318 | if len(params) == 2:
319 | new_conv.bias.data = params[1].data # add bias if neccessary
320 | layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name
321 |
322 | # replace the first convlution layer
323 | setattr(container, layer_name, new_conv)
324 | return base_model
325 |
326 | def _construct_diff_model(self, base_model, keep_rgb=False):
327 | # modify the convolution layers
328 | # Torch models are usually defined in a hierarchical way.
329 | # nn.modules.children() return all sub modules in a DFS manner
330 | modules = list(self.base_model.modules())
331 | first_conv_idx = filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules))))[0]
332 | conv_layer = modules[first_conv_idx]
333 | container = modules[first_conv_idx - 1]
334 |
335 | # modify parameters, assume the first blob contains the convolution kernels
336 | params = [x.clone() for x in conv_layer.parameters()]
337 | kernel_size = params[0].size()
338 | if not keep_rgb:
339 | new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:]
340 | new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()
341 | else:
342 | new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:]
343 | new_kernels = torch.cat((params[0].data, params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()),
344 | 1)
345 | new_kernel_size = kernel_size[:1] + (3 + 3 * self.new_length,) + kernel_size[2:]
346 |
347 | new_conv = nn.Conv2d(new_kernel_size[1], conv_layer.out_channels,
348 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding,
349 | bias=True if len(params) == 2 else False)
350 | new_conv.weight.data = new_kernels
351 | if len(params) == 2:
352 | new_conv.bias.data = params[1].data # add bias if neccessary
353 | layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name
354 |
355 | # replace the first convolution layer
356 | setattr(container, layer_name, new_conv)
357 | return base_model
358 |
359 | @property
360 | def crop_size(self):
361 | return self.input_size
362 |
363 | @property
364 | def scale_size(self):
365 | return self.input_size * 256 // 224
366 |
367 | def get_augmentation(self):
368 | if self.modality == 'RGB':
369 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66]),
370 | GroupRandomHorizontalFlip(is_flow=False)])
371 | elif self.modality == 'Flow':
372 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]),
373 | GroupRandomHorizontalFlip(is_flow=True)])
374 | elif self.modality == 'RGBDiff':
375 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]),
376 | GroupRandomHorizontalFlip(is_flow=False)])
377 | elif self.modality == 'RGBFlow':
378 | return torchvision.transforms.Compose([GroupMultiScaleResize(0.2),
379 | GroupMultiScaleRotate(20),
380 | #GroupSpatialElasticDisplacement(),
381 | GroupMultiScaleCrop(self.input_size,
382 | [1, .875,
383 | .75,
384 | .66]),
385 | #GroupRandomHorizontalFlip(is_flow=False)
386 | ])
387 |
388 |
--------------------------------------------------------------------------------
/ops/__init__.py:
--------------------------------------------------------------------------------
1 | from ops.basic_ops import *
--------------------------------------------------------------------------------
/ops/basic_ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 |
4 |
5 | class Identity(torch.nn.Module):
6 | def forward(self, input):
7 | return input
8 |
9 |
10 | class SegmentConsensus(torch.autograd.Function):
11 |
12 | def __init__(self, consensus_type, dim=1):
13 | self.consensus_type = consensus_type
14 | self.dim = dim
15 | self.shape = None
16 |
17 | def forward(self, input_tensor):
18 | self.shape = input_tensor.size()
19 | if self.consensus_type == 'avg':
20 | output = input_tensor.mean(dim=self.dim, keepdim=True)
21 | elif self.consensus_type == 'identity':
22 | output = input_tensor
23 | else:
24 | output = None
25 |
26 | return output
27 |
28 | def backward(self, grad_output):
29 | if self.consensus_type == 'avg':
30 | grad_in = grad_output.expand(self.shape) / float(self.shape[self.dim])
31 | elif self.consensus_type == 'identity':
32 | grad_in = grad_output
33 | else:
34 | grad_in = None
35 |
36 | return grad_in
37 |
38 |
39 | class ConsensusModule(torch.nn.Module):
40 |
41 | def __init__(self, consensus_type, dim=1):
42 | super(ConsensusModule, self).__init__()
43 | self.consensus_type = consensus_type if consensus_type != 'rnn' else 'identity'
44 | self.dim = dim
45 |
46 | def forward(self, input):
47 | return SegmentConsensus(self.consensus_type, self.dim)(input)
48 |
--------------------------------------------------------------------------------
/ops/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from sklearn.metrics import confusion_matrix
4 |
5 | def get_grad_hook(name):
6 | def hook(m, grad_in, grad_out):
7 | print((name, grad_out[0].data.abs().mean(), grad_in[0].data.abs().mean()))
8 | print((grad_out[0].size()))
9 | print((grad_in[0].size()))
10 |
11 | print((grad_out[0]))
12 | print((grad_in[0]))
13 |
14 | return hook
15 |
16 |
17 | def softmax(scores):
18 | es = np.exp(scores - scores.max(axis=-1)[..., None])
19 | return es / es.sum(axis=-1)[..., None]
20 |
21 |
22 | def log_add(log_a, log_b):
23 | return log_a + np.log(1 + np.exp(log_b - log_a))
24 |
25 |
26 | def class_accuracy(prediction, label):
27 | cf = confusion_matrix(prediction, label)
28 | cls_cnt = cf.sum(axis=1)
29 | cls_hit = np.diag(cf)
30 |
31 | cls_acc = cls_hit / cls_cnt.astype(float)
32 |
33 | mean_cls_acc = cls_acc.mean()
34 |
35 | return cls_acc, mean_cls_acc
--------------------------------------------------------------------------------
/opts.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | parser = argparse.ArgumentParser(description="PyTorch implementation of Temporal Segment Networks")
3 | parser.add_argument('dataset', type=str, choices=['jester', 'nvgesture', 'chalearn'])
4 | parser.add_argument('modality', type=str, choices=['RGB', 'Flow', 'RGBDiff', 'RGBFlow'])
5 | parser.add_argument('--train_list', type=str,default="")
6 | parser.add_argument('--val_list', type=str, default="")
7 | parser.add_argument('--root_path', type=str, default="")
8 | parser.add_argument('--store_name', type=str, default="")
9 | # ========================= Model Configs ==========================
10 | parser.add_argument('--arch', type=str, default="BNInception")
11 | parser.add_argument('--num_segments', type=int, default=4)
12 | parser.add_argument('--num_motion', type=int, default=3)
13 | parser.add_argument('--consensus_type', type=str, default='avg')
14 | parser.add_argument('--k', type=int, default=3)
15 |
16 | parser.add_argument('--dropout', '--do', default=0.8, type=float,
17 | metavar='DO', help='dropout ratio (default: 0.5)')
18 | parser.add_argument('--loss_type', type=str, default="nll",
19 | choices=['nll'])
20 | parser.add_argument('--img_feature_dim', default=256, type=int, help="the feature dimension for each frame")
21 |
22 | # ========================= Learning Configs ==========================
23 | parser.add_argument('--epochs', default=100, type=int, metavar='N',
24 | help='number of total epochs to run')
25 | parser.add_argument('-b', '--batch-size', default=128, type=int,
26 | metavar='N', help='mini-batch size (default: 256)')
27 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
28 | metavar='LR', help='initial learning rate')
29 | parser.add_argument('--lr_steps', default=[25, 40], type=float, nargs="+",
30 | metavar='LRSteps', help='epochs to decay learning rate by 10')
31 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
32 | help='momentum')
33 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
34 | metavar='W', help='weight decay (default: 5e-4)')
35 | parser.add_argument('--clip-gradient', '--gd', default=20, type=float,
36 | metavar='W', help='gradient norm clipping (default: disabled)')
37 | parser.add_argument('--no_partialbn', '--npb', default=False, action="store_true")
38 |
39 |
40 | # ========================= Monitor Configs ==========================
41 | parser.add_argument('--print-freq', '-p', default=10, type=int,
42 | metavar='N', help='print frequency (default: 10)')
43 | parser.add_argument('--eval-freq', '-ef', default=1, type=int,
44 | metavar='N', help='evaluation frequency (default: 5)')
45 |
46 |
47 | # ========================= Runtime Configs ==========================
48 | parser.add_argument('-j', '--workers', default=30, type=int, metavar='N',
49 | help='number of data loading workers (default: 4)')
50 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
51 | help='path to latest checkpoint (default: none)')
52 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
53 | help='evaluate model on validation set')
54 | parser.add_argument('--snapshot_pref', type=str, default="")
55 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
56 | help='manual epoch number (useful on restarts)')
57 | parser.add_argument('--gpus', nargs='+', type=int, default=None)
58 | parser.add_argument('--flow_prefix', default="", type=str)
59 | parser.add_argument('--root_log',type=str, default='log')
60 | parser.add_argument('--root_model', type=str, default='model')
61 | parser.add_argument('--root_output',type=str, default='output')
62 |
63 |
64 |
65 |
--------------------------------------------------------------------------------
/pretrained_models/MFF_jester_RGBFlow_BNInception_segment4_3f1c_best.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/okankop/MFF-pytorch/77bb9b14d294cad83f07c6394aeee14780950edb/pretrained_models/MFF_jester_RGBFlow_BNInception_segment4_3f1c_best.pth.tar
--------------------------------------------------------------------------------
/pretrained_models/MFF_jester_RGBFlow_BNInception_segment8_3f1c_best.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/okankop/MFF-pytorch/77bb9b14d294cad83f07c6394aeee14780950edb/pretrained_models/MFF_jester_RGBFlow_BNInception_segment8_3f1c_best.pth.tar
--------------------------------------------------------------------------------
/process_dataset.py:
--------------------------------------------------------------------------------
1 | # This code hase been acquired from TRN-pytorch repository
2 | # 'https://github.com/metalbubble/TRN-pytorch/blob/master/process_dataset.py'
3 | # which is prepared by Bolei Zhou
4 | #
5 | # Processing the raw dataset of Jester
6 | #
7 | # generate the meta files:
8 | # category.txt: the list of categories.
9 | # train_videofolder.txt: each row contains [videoname num_frames classIDX]
10 | # val_videofolder.txt: same as above
11 | #
12 | # Created by Bolei Zhou, Dec.2 2017
13 |
14 | import os
15 | import pdb
16 | dataset_name = 'jester-v1'
17 | with open('%s-labels.csv'% dataset_name) as f:
18 | lines = f.readlines()
19 | categories = []
20 | for line in lines:
21 | line = line.rstrip()
22 | categories.append(line)
23 | categories = sorted(categories)
24 | with open('category.txt','w') as f:
25 | f.write('\n'.join(categories))
26 |
27 | dict_categories = {}
28 | for i, category in enumerate(categories):
29 | dict_categories[category] = i
30 |
31 | files_input = ['%s-validation.csv'%dataset_name,'%s-train.csv'%dataset_name]
32 | files_output = ['val_videofolder.txt','train_videofolder.txt']
33 | for (filename_input, filename_output) in zip(files_input, files_output):
34 | with open(filename_input) as f:
35 | lines = f.readlines()
36 | folders = []
37 | idx_categories = []
38 | for line in lines:
39 | line = line.rstrip()
40 | items = line.split(';')
41 | folders.append(items[0])
42 | idx_categories.append(os.path.join(dict_categories[items[1]]))
43 | output = []
44 | for i in range(len(folders)):
45 | curFolder = folders[i]
46 | curIDX = idx_categories[i]
47 | # counting the number of frames in each video folders
48 | dir_files = os.listdir(os.path.join('20bn-%s'%dataset_name, curFolder))
49 | output.append('%s %d %d'%(curFolder, len(dir_files), curIDX))
50 | print('%d/%d'%(i, len(folders)))
51 | with open(filename_output,'w') as f:
52 | f.write('\n'.join(output))
53 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.17.2
2 | opencv-python==4.2.0.32
3 | Pillow>=6.2.2
4 | pretrainedmodels==0.7.4
5 | python-dateutil==2.8.0
6 | pytz==2019.2
7 | six==1.12.0
8 | torch==1.5.0
9 | torchvision==0.6.0
10 |
--------------------------------------------------------------------------------
/test_models.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 |
4 | import numpy as np
5 | import torch.nn.parallel
6 | import torch.optim
7 | from sklearn.metrics import confusion_matrix
8 | from dataset import TSNDataSet
9 | from models import TSN
10 | from transforms import *
11 | from ops import ConsensusModule
12 | import datasets_video
13 | import pdb
14 | from torch.nn import functional as F
15 |
16 |
17 | # options
18 | parser = argparse.ArgumentParser(
19 | description="MFF testing on the full validation set")
20 | parser.add_argument('dataset', type=str, choices=['jester', 'nvgesture', 'chalearn'])
21 | parser.add_argument('modality', type=str, choices=['RGB', 'Flow', 'RGBDiff', 'RGBFlow'])
22 | parser.add_argument('weights', type=str)
23 | parser.add_argument('--arch', type=str, default="resnet101")
24 | parser.add_argument('--save_scores', type=str, default=None)
25 | parser.add_argument('--test_segments', type=int, default=25)
26 | parser.add_argument('--max_num', type=int, default=-1)
27 | parser.add_argument('--test_crops', type=int, default=10)
28 | parser.add_argument('--input_size', type=int, default=224)
29 | parser.add_argument('--num_motion', type=int, default=3)
30 | parser.add_argument('--consensus_type', type=str, default='MLP', choices=['avg', 'MLP'])
31 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
32 | help='number of data loading workers (default: 4)')
33 | parser.add_argument('--gpus', nargs='+', type=int, default=None)
34 | parser.add_argument('--img_feature_dim',type=int, default=256)
35 | parser.add_argument('--num_set_segments',type=int, default=1)
36 | parser.add_argument('--softmax', type=int, default=0)
37 |
38 | args = parser.parse_args()
39 |
40 | class AverageMeter(object):
41 | """Computes and stores the average and current value"""
42 | def __init__(self):
43 | self.reset()
44 |
45 | def reset(self):
46 | self.val = 0
47 | self.avg = 0
48 | self.sum = 0
49 | self.count = 0
50 |
51 | def update(self, val, n=1):
52 | self.val = val
53 | self.sum += val * n
54 | self.count += n
55 | self.avg = self.sum / self.count
56 |
57 | def accuracy(output, target, topk=(1,)):
58 | """Computes the precision@k for the specified values of k"""
59 | maxk = max(topk)
60 | batch_size = target.size(0)
61 | _, pred = output.topk(maxk, 1, True, True)
62 | pred = pred.t()
63 | correct = pred.eq(target.view(1, -1).expand_as(pred))
64 | res = []
65 | for k in topk:
66 | correct_k = correct[:k].view(-1).float().sum(0)
67 | res.append(correct_k.mul_(100.0 / batch_size))
68 | return res
69 |
70 |
71 |
72 | categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(args.dataset, args.modality)
73 | num_class = len(categories)
74 |
75 | net = TSN(num_class, args.test_segments if args.consensus_type in ['MLP'] else 1, args.modality,
76 | base_model=args.arch,
77 | consensus_type=args.consensus_type,
78 | img_feature_dim=args.img_feature_dim,
79 | )
80 |
81 | checkpoint = torch.load(args.weights)
82 | print("model epoch {} best prec@1: {}".format(checkpoint['epoch'], checkpoint['best_prec1']))
83 |
84 | base_dict = {'.'.join(k.split('.')[1:]): v for k,v in list(checkpoint['state_dict'].items())}
85 | net.load_state_dict(base_dict)
86 |
87 | if args.test_crops == 1:
88 | cropping = torchvision.transforms.Compose([
89 | GroupScale(net.scale_size),
90 | GroupCenterCrop(net.input_size),
91 | ])
92 | elif args.test_crops == 10:
93 | cropping = torchvision.transforms.Compose([
94 | #GroupOverSample(net.input_size, net.input_size)
95 | GroupOverSample(net.input_size, net.scale_size)
96 | ])
97 | else:
98 | raise ValueError("Only 1 and 10 crops are supported while we got {}".format(args.test_crops))
99 |
100 | ############ Data Loading Part #####
101 | if args.modality == 'RGB':
102 | data_length = 1
103 | elif args.modality in ['Flow', 'RGBDiff']:
104 | data_length = 5
105 | elif args.modality == 'RGBFlow':
106 | data_length = args.num_motion
107 |
108 | data_loader = torch.utils.data.DataLoader(
109 | TSNDataSet(args.root_path, args.val_list, num_segments=args.test_segments,
110 | new_length=data_length,
111 | modality=args.modality,
112 | image_tmpl=prefix,
113 | dataset=args.dataset,
114 | test_mode=True,
115 | transform=torchvision.transforms.Compose([
116 | cropping,
117 | Stack(roll=(args.arch in
118 | ['BNInception','InceptionV3']), isRGBFlow=(args.modality == 'RGBFlow')),
119 | ToTorchFormatTensor(div=(args.arch not in ['BNInception','InceptionV3'])),
120 | GroupNormalize(net.input_mean, net.input_std),
121 | ])),
122 | batch_size=1, shuffle=False,
123 | num_workers=args.workers*2, pin_memory=False)
124 |
125 | if args.gpus is not None:
126 | devices = [args.gpus[i] for i in range(args.workers)]
127 | else:
128 | devices = list(range(args.workers))
129 |
130 |
131 | #net = torch.nn.DataParallel(net.cuda(devices[0]), device_ids=devices)
132 | net = torch.nn.DataParallel(net.cuda())
133 | net.eval()
134 |
135 | data_gen = enumerate(data_loader)
136 |
137 | total_num = len(data_loader.dataset)
138 | output = []
139 |
140 |
141 | def eval_video(video_data):
142 | i, data, label = video_data
143 | num_crop = args.test_crops
144 |
145 | if args.modality == 'RGB':
146 | length = 3
147 | elif args.modality == 'Flow':
148 | length = 10
149 | elif args.modality == 'RGBDiff':
150 | length = 18
151 | elif args.modality == 'RGBFlow':
152 | length = 3 + 2 * args.num_motion # 3 rgb channels and 3*2=6 flow channels
153 | else:
154 | raise ValueError("Unknown modality "+args.modality)
155 |
156 | input_var = torch.autograd.Variable(data.view(-1, length, data.size(2), data.size(3)),
157 | volatile=True)
158 | rst = net(input_var)
159 | if args.softmax==1:
160 | # take the softmax to normalize the output to probability
161 | rst = F.softmax(rst)
162 |
163 | rst = rst.data.cpu().numpy().copy()
164 |
165 | if args.consensus_type in ['MLP']:
166 | rst = rst.reshape(-1, 1, num_class)
167 | else:
168 | rst = rst.reshape((num_crop, args.test_segments, num_class)).mean(axis=0).reshape((args.test_segments, 1, num_class))
169 |
170 | return i, rst, label[0]
171 |
172 |
173 | proc_start_time = time.time()
174 | max_num = args.max_num if args.max_num > 0 else len(data_loader.dataset)
175 |
176 | top1 = AverageMeter()
177 | top5 = AverageMeter()
178 |
179 | for i, (data, label) in data_gen:
180 | if i >= max_num:
181 | break
182 | rst = eval_video((i, data, label))
183 | output.append(rst[1:])
184 | cnt_time = time.time() - proc_start_time
185 | prec1, prec5 = accuracy(torch.from_numpy(np.mean(rst[1], axis=0)), label, topk=(1, 5))
186 | top1.update(prec1[0], 1)
187 | top5.update(prec5[0], 1)
188 | print('video {} done, total {}/{}, average {:.3f} sec/video, moving Prec@1 {:.3f} Prec@5 {:.3f}'.format(i, i+1,
189 | total_num,
190 | float(cnt_time) / (i+1), top1.avg, top5.avg))
191 |
192 | video_pred = [np.argmax(np.mean(x[0], axis=0)) for x in output]
193 |
194 | video_labels = [x[1] for x in output]
195 |
196 |
197 | cf = confusion_matrix(video_labels, video_pred).astype(float)
198 |
199 | cls_cnt = cf.sum(axis=1)
200 | cls_hit = np.diag(cf)
201 |
202 | cls_acc = cls_hit / cls_cnt
203 |
204 | print('-----Evaluation is finished------')
205 | print('Class Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100))
206 | print('Overall Prec@1 {:.02f}% Prec@5 {:.02f}%'.format(top1.avg, top5.avg))
207 |
208 | if args.save_scores is not None:
209 |
210 | # reorder before saving
211 | name_list = [x.strip().split()[0] for x in open(args.val_list)]
212 | order_dict = {e:i for i, e in enumerate(sorted(name_list))}
213 | reorder_output = [None] * len(output)
214 | reorder_label = [None] * len(output)
215 | reorder_pred = [None] * len(output)
216 | output_csv = []
217 | for i in range(len(output)):
218 | idx = order_dict[name_list[i]]
219 | reorder_output[idx] = output[i]
220 | reorder_label[idx] = video_labels[i]
221 | reorder_pred[idx] = video_pred[i]
222 | output_csv.append('%s;%s'%(name_list[i], categories[video_pred[i]]))
223 |
224 | np.savez(args.save_scores, scores=reorder_output, labels=reorder_label, predictions=reorder_pred, cf=cf)
225 |
226 | with open(args.save_scores.replace('npz','csv'),'w') as f:
227 | f.write('\n'.join(output_csv))
228 |
229 |
230 |
--------------------------------------------------------------------------------
/transforms.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | import random
3 | from PIL import Image, ImageOps
4 | import numpy as np
5 | import cv2
6 | import numbers
7 | import math
8 | import torch
9 |
10 |
11 | class GroupRandomCrop(object):
12 | def __init__(self, size):
13 | if isinstance(size, numbers.Number):
14 | self.size = (int(size), int(size))
15 | else:
16 | self.size = size
17 |
18 | def __call__(self, img_group):
19 |
20 | w, h = img_group[0].size
21 | th, tw = self.size
22 |
23 | out_images = list()
24 |
25 | x1 = random.randint(0, w - tw)
26 | y1 = random.randint(0, h - th)
27 |
28 | for img in img_group:
29 | assert(img.size[0] == w and img.size[1] == h)
30 | if w == tw and h == th:
31 | out_images.append(img)
32 | else:
33 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
34 |
35 | return out_images
36 |
37 |
38 | class GroupCenterCrop(object):
39 | def __init__(self, size):
40 | self.worker = torchvision.transforms.CenterCrop(size)
41 |
42 | def __call__(self, img_group):
43 | return [self.worker(img) for img in img_group]
44 |
45 |
46 | class GroupRandomHorizontalFlip(object):
47 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5
48 | """
49 | def __init__(self, is_flow=False):
50 | self.is_flow = is_flow
51 |
52 | def __call__(self, img_group, is_flow=False):
53 | v = random.random()
54 | if v < 0.5:
55 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
56 | if self.is_flow:
57 | for i in range(0, len(ret), 2):
58 | ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping
59 | return ret
60 | else:
61 | return img_group
62 |
63 |
64 | class GroupNormalize(object):
65 | def __init__(self, mean, std):
66 | self.mean = mean
67 | self.std = std
68 |
69 | def __call__(self, tensor):
70 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean))
71 | rep_std = self.std * (tensor.size()[0]//len(self.std))
72 |
73 | # TODO: make efficient
74 | for t, m, s in zip(tensor, rep_mean, rep_std):
75 | t.sub_(m).div_(s)
76 |
77 | return tensor
78 |
79 |
80 | class GroupScale(object):
81 | """ Rescales the input PIL.Image to the given 'size'.
82 | 'size' will be the size of the smaller edge.
83 | For example, if height > width, then image will be
84 | rescaled to (size * height / width, size)
85 | size: size of the smaller edge
86 | interpolation: Default: PIL.Image.BILINEAR
87 | """
88 |
89 | def __init__(self, size, interpolation=Image.BILINEAR):
90 | self.worker = torchvision.transforms.Scale(size, interpolation)
91 |
92 | def __call__(self, img_group):
93 | return [self.worker(img) for img in img_group]
94 |
95 |
96 | class GroupOverSample(object):
97 | def __init__(self, crop_size, scale_size=None):
98 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size)
99 |
100 | if scale_size is not None:
101 | self.scale_worker = GroupScale(scale_size)
102 | else:
103 | self.scale_worker = None
104 |
105 | def __call__(self, img_group):
106 |
107 | if self.scale_worker is not None:
108 | img_group = self.scale_worker(img_group)
109 |
110 | image_w, image_h = img_group[0].size
111 | crop_w, crop_h = self.crop_size
112 |
113 | offsets = GroupMultiScaleCrop.fill_fix_offset(True, image_w, image_h, crop_w, crop_h)
114 | oversample_group = list()
115 | for o_w, o_h in offsets:
116 | normal_group = list()
117 | flip_group = list()
118 | for i, img in enumerate(img_group):
119 | #print(img.size)
120 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
121 | #print([o_w, o_h, o_w + crop_w, o_h + crop_h])
122 | normal_group.append(crop)
123 | #flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
124 | #flip_group.append(flip_crop)
125 |
126 | #if img.mode == 'L' and i % 2 == 0:
127 | #flip_group.append(ImageOps.invert(flip_crop))
128 | #else:
129 | #flip_group.append(flip_crop)
130 |
131 | oversample_group.extend(normal_group)
132 | #oversample_group.extend(flip_group)
133 | return oversample_group
134 |
135 | class GroupSpatialElasticDisplacement(object):
136 |
137 | def __init__(self):
138 | self.displacement = 20
139 | self.displacement_kernel = 25
140 | self.displacement_magnification = 0.60
141 |
142 |
143 | def __call__(self, img_group):
144 | v = random.random()
145 | if v < 0.5:
146 | im_size = img_group[0].size
147 | image_w, image_h = im_size[0], im_size[1]
148 | displacement_map = np.random.rand(image_h, image_w, 2) * 2 * self.displacement - self.displacement
149 | displacement_map = cv2.GaussianBlur(displacement_map, None, self.displacement_kernel)
150 | displacement_map *= self.displacement_magnification * self.displacement_kernel
151 | displacement_map = np.floor(displacement_map).astype('int32')
152 |
153 | displacement_map_rows = displacement_map[..., 0] + np.tile(np.arange(image_h), (image_w, 1)).T.astype('int32')
154 | displacement_map_rows = np.clip(displacement_map_rows, 0, image_h - 1)
155 |
156 | displacement_map_cols = displacement_map[..., 1] + np.tile(np.arange(image_w), (image_h, 1)).astype('int32')
157 | displacement_map_cols = np.clip(displacement_map_cols, 0, image_w - 1)
158 | ret_img_group = [Image.fromarray(np.asarray(img)[(displacement_map_rows.flatten(), displacement_map_cols.flatten())].reshape(np.asarray(img).shape)) for img in img_group]
159 | return ret_img_group
160 |
161 | else:
162 | return img_group
163 |
164 |
165 |
166 | class GroupMultiScaleResize(object):
167 |
168 | def __init__(self, scale):
169 | self.scale = scale
170 |
171 | def __call__(self, img_group):
172 | im_size = img_group[0].size
173 | self.resize_const = random.uniform(1.0 - self.scale, 1.0 + self.scale) # Aplly random resize constant
174 | resize_img_group = [img.resize((int(im_size[0]*self.resize_const), int(im_size[1]*self.resize_const))) for img in img_group]
175 |
176 | return resize_img_group
177 |
178 |
179 |
180 | class GroupMultiScaleRotate(object):
181 |
182 | def __init__(self, degree):
183 | self.degree = degree
184 | self.interpolation = Image.BILINEAR
185 |
186 | def __call__(self, img_group):
187 | im_size = img_group[0].size
188 | self.rotate_angle = random.randint(-self.degree, self.degree) # Aplly random rotation angle
189 | ret_img_group = [img.rotate(self.rotate_angle, resample=self.interpolation) for img in img_group]
190 |
191 | return ret_img_group
192 |
193 |
194 |
195 | class GroupMultiScaleCrop(object):
196 |
197 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=False):
198 | self.scales = scales if scales is not None else [1, 875, .75, .66]
199 | self.max_distort = max_distort
200 | self.fix_crop = fix_crop
201 | self.more_fix_crop = more_fix_crop
202 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
203 | self.interpolation = Image.BILINEAR
204 |
205 | def __call__(self, img_group):
206 |
207 | im_size = img_group[0].size
208 | #self.scales = [1, random.uniform(0.85, 1.0)]
209 |
210 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
211 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
212 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)
213 | for img in crop_img_group]
214 |
215 | return ret_img_group
216 |
217 | def _sample_crop_size(self, im_size):
218 | image_w, image_h = im_size[0], im_size[1]
219 |
220 | # find a crop size
221 | base_size = min(image_w, image_h)
222 | crop_sizes = [int(base_size * x) for x in self.scales]
223 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]
224 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]
225 |
226 | pairs = []
227 | for i, h in enumerate(crop_h):
228 | for j, w in enumerate(crop_w):
229 | if abs(i - j) <= self.max_distort:
230 | pairs.append((w, h))
231 |
232 | crop_pair = random.choice(pairs)
233 | if not self.fix_crop:
234 | w_offset = random.randint(0, image_w - crop_pair[0])
235 | h_offset = random.randint(0, image_h - crop_pair[1])
236 | else:
237 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
238 |
239 | return crop_pair[0], crop_pair[1], w_offset, h_offset
240 |
241 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
242 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)
243 | return random.choice(offsets)
244 |
245 | @staticmethod
246 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
247 | w_step = (image_w - crop_w) // 4
248 | h_step = (image_h - crop_h) // 4
249 |
250 | ret = list()
251 | ret.append((0, 0)) # upper left
252 | ret.append((4 * w_step, 0)) # upper right
253 | ret.append((0, 4 * h_step)) # lower left
254 | ret.append((4 * w_step, 4 * h_step)) # lower right
255 | ret.append((2 * w_step, 2 * h_step)) # center
256 |
257 | if more_fix_crop:
258 | ret.append((0 * w_step, 2 * h_step)) # center left
259 | ret.append((4 * w_step, 2 * h_step)) # center right
260 | ret.append((2 * w_step, 4 * h_step)) # lower center
261 | ret.append((2 * w_step, 0 * h_step)) # upper center
262 |
263 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter
264 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter
265 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter
266 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
267 |
268 | return ret
269 |
270 |
271 | class GroupRandomSizedCrop(object):
272 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
273 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
274 | This is popularly used to train the Inception networks
275 | size: size of the smaller edge
276 | interpolation: Default: PIL.Image.BILINEAR
277 | """
278 | def __init__(self, size, interpolation=Image.BILINEAR):
279 | self.size = size
280 | self.interpolation = interpolation
281 |
282 | def __call__(self, img_group):
283 | for attempt in range(10):
284 | area = img_group[0].size[0] * img_group[0].size[1]
285 | target_area = random.uniform(0.08, 1.0) * area
286 | aspect_ratio = random.uniform(3. / 4, 4. / 3)
287 |
288 | w = int(round(math.sqrt(target_area * aspect_ratio)))
289 | h = int(round(math.sqrt(target_area / aspect_ratio)))
290 |
291 | if random.random() < 0.5:
292 | w, h = h, w
293 |
294 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]:
295 | x1 = random.randint(0, img_group[0].size[0] - w)
296 | y1 = random.randint(0, img_group[0].size[1] - h)
297 | found = True
298 | break
299 | else:
300 | found = False
301 | x1 = 0
302 | y1 = 0
303 |
304 | if found:
305 | out_group = list()
306 | for img in img_group:
307 | img = img.crop((x1, y1, x1 + w, y1 + h))
308 | assert(img.size == (w, h))
309 | out_group.append(img.resize((self.size, self.size), self.interpolation))
310 | return out_group
311 | else:
312 | # Fallback
313 | scale = GroupScale(self.size, interpolation=self.interpolation)
314 | crop = GroupRandomCrop(self.size)
315 | return crop(scale(img_group))
316 |
317 |
318 | class Stack(object):
319 |
320 | def __init__(self, roll=False, isRGBFlow=False):
321 | self.roll = roll
322 | self.isRGBFlow = isRGBFlow
323 |
324 | def __call__(self, img_group):
325 | if self.isRGBFlow:
326 | stacked_array = np.array([])
327 | for x in img_group:
328 | if x.mode == 'L':
329 | if stacked_array.size ==0:
330 | stacked_array = np.expand_dims(x, 2)
331 | else:
332 | stacked_array = np.concatenate([stacked_array, np.expand_dims(x, 2)], axis=2)
333 | elif x.mode == 'RGB':
334 | if self.roll:
335 | stacked_array = np.concatenate([stacked_array, np.array(x)[:, :, ::-1]], axis=2)
336 | else:
337 | stacked_array = np.concatenate([stacked_array, np.array(x)], axis=2)
338 | return stacked_array
339 |
340 | else:
341 | if img_group[0].mode == 'L':
342 | return np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2)
343 | elif img_group[0].mode == 'RGB':
344 | if self.roll:
345 | asd = np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2)
346 | return np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2)
347 | else:
348 | return np.concatenate(img_group, axis=2)
349 |
350 |
351 | class ToTorchFormatTensor(object):
352 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
353 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
354 | def __init__(self, div=True):
355 | self.div = div
356 |
357 | def __call__(self, pic):
358 | if isinstance(pic, np.ndarray):
359 | # handle numpy array
360 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
361 | else:
362 | # handle PIL Image
363 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
364 | img = img.view(pic.size[1], pic.size[0], len(pic.mode))
365 | # put it from HWC to CHW format
366 | # yikes, this transpose takes 80% of the loading time/CPU
367 | img = img.transpose(0, 1).transpose(0, 2).contiguous()
368 | return img.float().div(255) if self.div else img.float()
369 |
370 |
371 | class IdentityTransform(object):
372 |
373 | def __call__(self, data):
374 | return data
375 |
376 |
377 | if __name__ == "__main__":
378 | trans = torchvision.transforms.Compose([
379 | GroupScale(256),
380 | GroupRandomCrop(224),
381 | Stack(),
382 | ToTorchFormatTensor(),
383 | GroupNormalize(
384 | mean=[.485, .456, .406],
385 | std=[.229, .224, .225]
386 | )]
387 | )
388 |
389 | im = Image.open('../tensorflow-model-zoo.torch/lena_299.png')
390 |
391 | color_group = [im] * 3
392 | rst = trans(color_group)
393 |
394 | gray_group = [im.convert('L')] * 9
395 | gray_rst = trans(gray_group)
396 |
397 | trans2 = torchvision.transforms.Compose([
398 | GroupRandomSizedCrop(256),
399 | Stack(),
400 | ToTorchFormatTensor(),
401 | GroupNormalize(
402 | mean=[.485, .456, .406],
403 | std=[.229, .224, .225])
404 | ])
405 | print(trans2(color_group))
406 |
--------------------------------------------------------------------------------