├── .gitignore
├── QuoVadis.png
├── README.md
├── dataset.py
├── main.py
├── models
├── __init__.py
├── backbones
│ ├── __init__.py
│ └── resnet3d.py
└── i3d.py
├── opts.py
├── test_models.py
└── transforms.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /.idea/
2 |
--------------------------------------------------------------------------------
/QuoVadis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PPPrior/i3d-pytorch/e3a022e0892be5553e82e0ee152dd3e73f9f6bac/QuoVadis.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # I3D-PyTorch
2 | This is a simple and crude implementation of Inflated 3D ConvNet Models (I3D) in PyTorch. Different from models reported in "[Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset](https://arxiv.org/abs/1705.07750)" by Joao Carreira and Andrew Zisserman, this implementation uses [ResNet](https://arxiv.org/pdf/1512.03385.pdf) as backbone.
3 |
4 |
5 |

6 |
7 |
8 | This implementation is based on OpenMMLab's [MMAction2](https://github.com/open-mmlab/mmaction2).
9 |
10 | ## Data Preparation
11 |
12 | For optical flow extraction and video list generation, please refer to [TSN](https://github.com/yjxiong/temporal-segment-networks#code--data-preparation) for details.
13 |
14 | ## Training
15 |
16 | To train a new model, use the `main.py` script.
17 |
18 | For example, command to train models with RGB modality on UCF101 can be
19 |
20 | ```bash
21 | python main.py ucf101 RGB \
22 | \
23 | --arch i3d_resnet50 --clip_length 64 \
24 | --lr 0.001 --lr_steps 30 60 --epochs 80 \
25 | -b 32 -j 8 --dropout 0.8 \
26 | --snapshot_pref ucf101_i3d_resnet50
27 | ```
28 |
29 | For flow models:
30 |
31 | ```bash
32 | python main.py ucf101 Flow \
33 | \
34 | --arch i3d_resnet50 --clip_length 64 \
35 | --lr 0.001 --lr_steps 15 30 --epochs 40 \
36 | -b 64 -j 8 --dropout 0.8 \
37 | --snapshot_pref ucf101_i3d_resnet50
38 | ```
39 |
40 | Please refer to [main.py](main.py) for more details.
41 |
42 | ## Testing
43 |
44 | After training, there will checkpoints saved by pytorch, for example `ucf101_i3d_resnet50_rgb_model_best.pth.tar`.
45 |
46 | Use the following command to test its performance:
47 |
48 | ```bash
49 | python test_models.py ucf101 RGB \
50 | ucf101_i3d_resnet50_rgb_model_best.pth.tar \
51 | --arch i3d_resnet50 --save_scores
52 | ```
53 |
54 | Or for flow models:
55 |
56 | ```bash
57 | python test_models.py ucf101 Flow \
58 | ucf101_i3d_resnet50_flow_model_best.pth.tar \
59 | --arch i3d_resnet50 --save_scores
60 | ```
61 |
62 | Please refer to [test_models.py](test_models.py) for more details.
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import numpy as np
5 | import os
6 | import os.path
7 | from numpy.random import randint
8 |
9 |
10 | class VideoRecord(object):
11 | def __init__(self, row):
12 | self._data = row
13 |
14 | @property
15 | def path(self):
16 | return self._data[0]
17 |
18 | @property
19 | def num_frames(self):
20 | return int(self._data[1])
21 |
22 | @property
23 | def label(self):
24 | return int(self._data[2])
25 |
26 |
27 | class I3DDataSet(data.Dataset):
28 | def __init__(self, root_path, list_file, clip_length=64, frame_size=(320, 240),
29 | modality='RGB', image_tmpl='img_{:05d}.jpg',
30 | transform=None, random_shift=True, test_mode=False):
31 | self.root_path = root_path
32 | self.list_file = list_file
33 | self.clip_length = clip_length
34 | self.frame_size = frame_size
35 | self.modality = modality
36 | self.image_tmpl = image_tmpl
37 | self.transform = transform
38 | self.random_shift = random_shift
39 | self.test_mode = test_mode
40 |
41 | self._parse_list()
42 |
43 | def _load_image(self, directory, idx):
44 | root_path = os.path.join(self.root_path, 'rawframes/') # ../data/ucf101/rawframes/
45 | directory = os.path.join(root_path, directory)
46 |
47 | if self.modality == 'RGB' or self.modality == 'RGBDiff':
48 | return [Image.open(os.path.join(directory, self.image_tmpl.format(idx))).convert('RGB')]
49 | elif self.modality == 'Flow':
50 | x_img = Image.open(os.path.join(directory, self.image_tmpl.format('x', idx))).convert('L')
51 | y_img = Image.open(os.path.join(directory, self.image_tmpl.format('y', idx))).convert('L')
52 |
53 | return [x_img, y_img]
54 |
55 | def _parse_list(self):
56 | self.video_list = [VideoRecord(x.strip().split(' ')) for x in
57 | open(os.path.join(self.root_path, self.list_file))]
58 |
59 | def _sample_indices(self, record):
60 | if not self.test_mode and self.random_shift:
61 | average_duration = record.num_frames // self.clip_length
62 | if average_duration > 0:
63 | offsets = np.sort(
64 | np.multiply(list(range(self.clip_length)), average_duration) + randint(average_duration,
65 | size=self.clip_length))
66 | else:
67 | offsets = np.sort(randint(record.num_frames, size=self.clip_length))
68 | else:
69 | tick = record.num_frames / float(self.clip_length)
70 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.clip_length)])
71 | return offsets + 1
72 |
73 | def __getitem__(self, index):
74 | record = self.video_list[index]
75 | indices = self._sample_indices(record)
76 | return self.get(record, indices)
77 |
78 | def get(self, record, indices):
79 | images = list()
80 | for index in indices:
81 | img = self._load_image(record.path, int(index))
82 | images.extend(img)
83 | process_data = self.transform(images)
84 | return process_data, record.label
85 |
86 | def __len__(self):
87 | return len(self.video_list)
88 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import shutil
4 | import torchvision.transforms
5 | import torch.utils.data.dataloader
6 | import torch.nn.parallel
7 | import torch.backends.cudnn as cudnn
8 | import torch.optim
9 | from torch.nn.utils import clip_grad_norm
10 |
11 | from models import i3d
12 | from dataset import I3DDataSet
13 | from transforms import *
14 | from opts import parser
15 |
16 | best_prec1 = 0
17 |
18 |
19 | def main():
20 | global args, best_prec1
21 | args = parser.parse_args()
22 |
23 | if args.dataset == 'ucf101':
24 | num_classes = 101
25 | elif args.dataset == 'hmdb51':
26 | num_classes = 51
27 | elif args.dataset == 'kinetics':
28 | num_classes = 400
29 | else:
30 | raise ValueError('Unknown dataset ' + args.dataset)
31 |
32 | model = getattr(i3d, args.arch)(modality=args.modality, num_classes=num_classes,
33 | dropout_ratio=args.dropout)
34 |
35 | crop_size = args.input_size
36 | scale_size = args.input_size * 256 // 224
37 | input_mean = [0.485, 0.456, 0.406]
38 | input_std = [0.229, 0.224, 0.225]
39 | if args.modality == 'Flow':
40 | input_mean = [0.5]
41 | input_std = [np.mean(input_std)]
42 |
43 | train_augmentation = get_augmentation(args.modality, args.input_size)
44 |
45 | model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()
46 |
47 | if args.resume:
48 | if os.path.isfile(args.resume):
49 | print("=> loading checkpoint '{}'".format(args.resume))
50 | checkpoint = torch.load(args.resume)
51 | args.start_epoch = checkpoint['epoch']
52 | best_prec1 = checkpoint['best_prec1']
53 | model.load_state_dict(checkpoint['state_dict'])
54 | print("=> loaded checkpoint '{}' (epoch {})"
55 | .format(args.evaluate, checkpoint['epoch']))
56 | else:
57 | print("=> no checkpoint found at '{}'".format(args.resume))
58 |
59 | cudnn.benchmark = True
60 |
61 | # Data loading code
62 | train_loader = torch.utils.data.DataLoader(
63 | I3DDataSet(args.root_path, args.train_list, clip_length=args.clip_length, modality=args.modality,
64 | image_tmpl="img_{:05d}.jpg" if args.modality == "RGB" else args.flow_prefix + "{}_{:05d}.jpg",
65 | transform=torchvision.transforms.Compose([
66 | train_augmentation,
67 | ToNumpyNDArray(),
68 | ToTorchFormatTensor(),
69 | GroupNormalize(input_mean, input_std),
70 | ])),
71 | batch_size=args.batch_size, shuffle=True,
72 | num_workers=args.workers, pin_memory=True)
73 |
74 | val_loader = torch.utils.data.DataLoader(
75 | I3DDataSet(args.root_path, args.val_list, clip_length=args.clip_length, modality=args.modality,
76 | image_tmpl="img_{:05d}.jpg" if args.modality == "RGB" else args.flow_prefix + "{}_{:05d}.jpg",
77 | random_shift=False,
78 | transform=torchvision.transforms.Compose([
79 | GroupScale(int(scale_size)),
80 | GroupCenterCrop(crop_size),
81 | ToNumpyNDArray(),
82 | ToTorchFormatTensor(),
83 | GroupNormalize(input_mean, input_std),
84 | ])),
85 | batch_size=args.batch_size, shuffle=False,
86 | num_workers=args.workers, pin_memory=True
87 | )
88 |
89 | # define loss function (criterion) and optimizer
90 | if args.loss_type == 'nll':
91 | criterion = torch.nn.CrossEntropyLoss().cuda()
92 | else:
93 | raise ValueError("Unknown loss type")
94 |
95 | optimizer = torch.optim.SGD(params=model.parameters(), lr=args.lr, momentum=args.momentum,
96 | weight_decay=args.weight_decay)
97 |
98 | if args.evaluate:
99 | validate(val_loader, model, criterion, 0)
100 | return
101 |
102 | for epoch in range(args.start_epoch, args.epochs):
103 | adjust_learning_rate(optimizer, epoch, args.lr_steps)
104 |
105 | # train for one epoch
106 | train(train_loader, model, criterion, optimizer, epoch)
107 |
108 | # evaluate on validation set
109 | if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
110 | prec1 = validate(val_loader, model, criterion, (epoch + 1) * len(train_loader))
111 |
112 | # remember best prec@1 and save checkpoint
113 | is_best = prec1 > best_prec1
114 | best_prec1 = max(prec1, best_prec1)
115 | save_checkpoint({
116 | 'epoch': epoch + 1,
117 | 'arch': args.arch,
118 | 'state_dict': model.state_dict(),
119 | 'best_prec1': best_prec1,
120 | }, is_best)
121 |
122 |
123 | def train(train_loader, model, criterion, optimizer, epoch):
124 | batch_time = AverageMeter()
125 | data_time = AverageMeter()
126 | losses = AverageMeter()
127 | top1 = AverageMeter()
128 | top5 = AverageMeter()
129 |
130 | # switch to train mode
131 | model.train()
132 |
133 | end = time.time()
134 | for i, (input, target) in enumerate(train_loader):
135 | # measure data loading time
136 | data_time.update(time.time() - end)
137 |
138 | target = target.cuda()
139 | output = model(input)
140 | loss = criterion(output, target)
141 |
142 | # measure accuracy and record loss
143 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
144 | losses.update(loss.item(), input.size(0))
145 | top1.update(prec1.item(), input.size(0))
146 | top5.update(prec5.item(), input.size(0))
147 |
148 | # compute gradient and do SGD step
149 | optimizer.zero_grad()
150 |
151 | loss.backward()
152 |
153 | if args.clip_gradient is not None:
154 | total_norm = clip_grad_norm(model.parameters(), args.clip_gradient)
155 | if total_norm > args.clip_gradient:
156 | print("clipping gradient: {} with coef {}".format(total_norm, args.clip_gradient / total_norm))
157 |
158 | optimizer.step()
159 |
160 | # measure elapsed time
161 | batch_time.update(time.time() - end)
162 | end = time.time()
163 |
164 | if i % args.print_freq == 0:
165 | print(('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
166 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
167 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
168 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
169 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
170 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
171 | epoch, i, len(train_loader), batch_time=batch_time,
172 | data_time=data_time, loss=losses, top1=top1, top5=top5, lr=optimizer.param_groups[-1]['lr'])))
173 |
174 |
175 | def validate(val_loader, model, criterion, iter, logger=None):
176 | batch_time = AverageMeter()
177 | losses = AverageMeter()
178 | top1 = AverageMeter()
179 | top5 = AverageMeter()
180 |
181 | # switch to evaluate mode
182 | model.eval()
183 | with torch.no_grad():
184 | end = time.time()
185 | for i, (input, target) in enumerate(val_loader):
186 | target = target.cuda()
187 | output = model(input)
188 | loss = criterion(output, target)
189 |
190 | # measure accuracy and record loss
191 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
192 |
193 | losses.update(loss.item(), input.size(0))
194 | top1.update(prec1.item(), input.size(0))
195 | top5.update(prec5.item(), input.size(0))
196 |
197 | # measure elapsed time
198 | batch_time.update(time.time() - end)
199 | end = time.time()
200 |
201 | if i % args.print_freq == 0:
202 | print(('Test: [{0}/{1}]\t'
203 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
204 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
205 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
206 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
207 | i, len(val_loader), batch_time=batch_time, loss=losses,
208 | top1=top1, top5=top5)))
209 |
210 | print(('Testing Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}'
211 | .format(top1=top1, top5=top5, loss=losses)))
212 |
213 | return top1.avg
214 |
215 |
216 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
217 | filename = '_'.join((args.snapshot_pref, args.modality.lower(), filename))
218 | torch.save(state, filename)
219 | if is_best:
220 | best_name = '_'.join((args.snapshot_pref, args.modality.lower(), 'model_best.pth.tar'))
221 | shutil.copyfile(filename, best_name)
222 |
223 |
224 | def get_augmentation(modality, input_size):
225 | if modality == 'RGB':
226 | return torchvision.transforms.Compose([GroupMultiScaleCrop(input_size, [1, .875, .75, .66]),
227 | GroupRandomHorizontalFlip(is_flow=False)])
228 | elif modality == 'Flow':
229 | return torchvision.transforms.Compose([GroupMultiScaleCrop(input_size, [1, .875, .75]),
230 | GroupRandomHorizontalFlip(is_flow=True)])
231 |
232 |
233 | class AverageMeter(object):
234 | """Computes and stores the average and current value"""
235 |
236 | def __init__(self):
237 | self.reset()
238 |
239 | def reset(self):
240 | self.val = 0
241 | self.avg = 0
242 | self.sum = 0
243 | self.count = 0
244 |
245 | def update(self, val, n=1):
246 | self.val = val
247 | self.sum += val * n
248 | self.count += n
249 | self.avg = self.sum / self.count
250 |
251 |
252 | def adjust_learning_rate(optimizer, epoch, lr_steps):
253 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
254 | decay = 0.1 ** (sum(epoch >= np.array(lr_steps)))
255 | lr = args.lr * decay
256 | decay = args.weight_decay
257 | for param_group in optimizer.param_groups:
258 | param_group['lr'] = lr
259 | param_group['weight_decay'] = decay
260 |
261 |
262 | def accuracy(output, target, topk=(1,)):
263 | """Computes the precision@k for the specified values of k"""
264 | maxk = max(topk)
265 | batch_size = target.size(0)
266 |
267 | _, pred = output.topk(maxk, 1, True, True)
268 | pred = pred.t()
269 | correct = pred.eq(target.view(1, -1).expand_as(pred))
270 |
271 | res = []
272 | for k in topk:
273 | correct_k = correct[:k].view(-1).float().sum(0)
274 | res.append(correct_k.mul_(100.0 / batch_size))
275 | return res
276 |
277 |
278 | if __name__ == '__main__':
279 | main()
280 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .backbones import *
2 | from .i3d import *
3 |
--------------------------------------------------------------------------------
/models/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet3d import *
2 |
--------------------------------------------------------------------------------
/models/backbones/resnet3d.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torchvision.models.utils import load_state_dict_from_url
3 |
4 | __all__ = ['resnet3d']
5 |
6 | model_urls = {
7 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
8 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
9 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
10 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
11 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
12 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
13 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
14 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
15 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
16 | }
17 |
18 |
19 | def conv3x3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
20 | """3x3x3 convolution with padding"""
21 | if isinstance(stride, int):
22 | stride = (1, stride, stride)
23 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride,
24 | padding=dilation, dilation=dilation, groups=groups, bias=False)
25 |
26 |
27 | def conv1x1x1(in_planes, out_planes, stride=1):
28 | """1x1x1 convolution"""
29 | if isinstance(stride, int):
30 | stride = (1, stride, stride)
31 | return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
32 |
33 |
34 | class BasicBlock3d(nn.Module):
35 | expansion = 1
36 |
37 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
38 | base_width=64, dilation=1, norm_layer=None):
39 | super(BasicBlock3d, self).__init__()
40 | if norm_layer is None:
41 | norm_layer = nn.BatchNorm3d
42 | if groups != 1 or base_width != 64:
43 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
44 | if dilation > 1:
45 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
46 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
47 | self.conv1 = conv3x3x3(inplanes, planes, stride)
48 | self.bn1 = norm_layer(planes)
49 | self.relu = nn.ReLU(inplace=True)
50 | self.conv2 = conv3x3x3(planes, planes)
51 | self.bn2 = norm_layer(planes)
52 | self.downsample = downsample
53 | self.stride = stride
54 |
55 | def forward(self, x):
56 | identity = x
57 |
58 | out = self.conv1(x)
59 | out = self.bn1(out)
60 | out = self.relu(out)
61 |
62 | out = self.conv2(out)
63 | out = self.bn2(out)
64 |
65 | if self.downsample is not None:
66 | identity = self.downsample(x)
67 |
68 | out += identity
69 | out = self.relu(out)
70 |
71 | return out
72 |
73 |
74 | class Bottleneck3d(nn.Module):
75 | expansion = 4
76 |
77 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
78 | base_width=64, dilation=1, norm_layer=None):
79 | super(Bottleneck3d, self).__init__()
80 | if norm_layer is None:
81 | norm_layer = nn.BatchNorm3d
82 | width = int(planes * (base_width / 64.)) * groups
83 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
84 | self.conv1 = conv1x1x1(inplanes, width)
85 | self.bn1 = norm_layer(width)
86 | self.conv2 = conv3x3x3(width, width, stride, groups, dilation)
87 | self.bn2 = norm_layer(width)
88 | self.conv3 = conv1x1x1(width, planes * self.expansion)
89 | self.bn3 = norm_layer(planes * self.expansion)
90 | self.relu = nn.ReLU(inplace=True)
91 | self.downsample = downsample
92 | self.stride = stride
93 |
94 | def forward(self, x):
95 | identity = x
96 |
97 | out = self.conv1(x)
98 | out = self.bn1(out)
99 | out = self.relu(out)
100 |
101 | out = self.conv2(out)
102 | out = self.bn2(out)
103 | out = self.relu(out)
104 |
105 | out = self.conv3(out)
106 | out = self.bn3(out)
107 |
108 | if self.downsample is not None:
109 | identity = self.downsample(x)
110 |
111 | out += identity
112 | out = self.relu(out)
113 |
114 | return out
115 |
116 |
117 | class ResNet3d(nn.Module):
118 |
119 | def __init__(self, block, layers, zero_init_residual=False,
120 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
121 | norm_layer=None, modality='RGB'):
122 | super(ResNet3d, self).__init__()
123 | if norm_layer is None:
124 | norm_layer = nn.BatchNorm3d
125 | self._norm_layer = norm_layer
126 |
127 | self.modality = modality
128 | self.inplanes = 64
129 | self.dilation = 1
130 | if replace_stride_with_dilation is None:
131 | # each element in the tuple indicates if we should replace
132 | # the 2x2x2 stride with a dilated convolution instead
133 | replace_stride_with_dilation = [False, False, False]
134 | if len(replace_stride_with_dilation) != 3:
135 | raise ValueError("replace_stride_with_dilation should be None "
136 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
137 | self.groups = groups
138 | self.base_width = width_per_group
139 |
140 | self._make_stem_layer()
141 |
142 | self.layer1 = self._make_layer(block, 64, layers[0])
143 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
144 | dilate=replace_stride_with_dilation[0])
145 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
146 | dilate=replace_stride_with_dilation[1])
147 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
148 | dilate=replace_stride_with_dilation[2])
149 |
150 | for m in self.modules(): # self.modules() --> Depth-First-Search the Net
151 | if isinstance(m, nn.Conv3d):
152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
153 | elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)):
154 | nn.init.constant_(m.weight, 1)
155 | nn.init.constant_(m.bias, 0)
156 |
157 | # Zero-initialize the last BN in each residual branch,
158 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
159 | if zero_init_residual:
160 | for m in self.modules():
161 | if isinstance(m, Bottleneck3d):
162 | nn.init.constant_(m.bn3.weight, 0)
163 | elif isinstance(m, BasicBlock3d):
164 | nn.init.constant_(m.bn2.weight, 0)
165 |
166 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
167 | norm_layer = self._norm_layer
168 | downsample = None
169 | previous_dilation = self.dilation
170 | if dilate:
171 | self.dilation *= stride
172 | stride = 1
173 | if stride != 1 or self.inplanes != planes * block.expansion:
174 | downsample = nn.Sequential(
175 | conv1x1x1(self.inplanes, planes * block.expansion, stride),
176 | norm_layer(planes * block.expansion),
177 | )
178 |
179 | layers = []
180 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
181 | self.base_width, previous_dilation, norm_layer))
182 | self.inplanes = planes * block.expansion
183 | for _ in range(1, blocks):
184 | layers.append(block(self.inplanes, planes, groups=self.groups,
185 | base_width=self.base_width, dilation=self.dilation,
186 | norm_layer=norm_layer))
187 |
188 | return nn.Sequential(*layers)
189 |
190 | def _make_stem_layer(self):
191 | """Construct the stem layers consists of a conv+norm+act module and a
192 | pooling layer."""
193 | if self.modality == 'RGB':
194 | inchannels = 3
195 | elif self.modality == 'Flow':
196 | inchannels = 2
197 | else:
198 | raise ValueError('Unknown modality: {}'.format(self.modality))
199 | self.conv1 = nn.Conv3d(inchannels, self.inplanes, kernel_size=(5, 7, 7),
200 | stride=2, padding=(2, 3, 3), bias=False)
201 | self.bn1 = self._norm_layer(self.inplanes)
202 | self.relu = nn.ReLU(inplace=True)
203 | self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=2,
204 | padding=(0, 1, 1)) # kernel_size=(2, 3, 3)
205 |
206 | def _forward_impl(self, x):
207 | # See note [TorchScript super()]
208 | x = self.conv1(x)
209 | x = self.bn1(x)
210 | x = self.relu(x)
211 | x = self.maxpool(x)
212 |
213 | x = self.layer1(x)
214 | x = self.layer2(x)
215 | x = self.layer3(x)
216 | x = self.layer4(x)
217 |
218 | return x
219 |
220 | def forward(self, x):
221 | return self._forward_impl(x)
222 |
223 | def _inflate_conv_params(self, conv3d, state_dict_2d, module_name_2d,
224 | inflated_param_names):
225 | """Inflate a conv module from 2d to 3d.
226 |
227 | Args:
228 | conv3d (nn.Module): The destination conv3d module.
229 | state_dict_2d (OrderedDict): The state dict of pretrained 2d models.
230 | module_name_2d (str): The name of corresponding conv module in the
231 | 2d models.
232 | inflated_param_names (list[str]): List of parameters that have been
233 | inflated.
234 | """
235 | weight_2d_name = module_name_2d + '.weight'
236 |
237 | conv2d_weight = state_dict_2d[weight_2d_name]
238 | kernel_t = conv3d.weight.data.shape[2]
239 |
240 | new_weight = conv2d_weight.data.unsqueeze(2).expand_as(conv3d.weight) / kernel_t
241 | conv3d.weight.data.copy_(new_weight)
242 | inflated_param_names.append(weight_2d_name)
243 |
244 | if getattr(conv3d, 'bias') is not None:
245 | bias_2d_name = module_name_2d + '.bias'
246 | conv3d.bias.data.copy_(state_dict_2d[bias_2d_name])
247 | inflated_param_names.append(bias_2d_name)
248 |
249 | def _inflate_bn_params(self, bn3d, state_dict_2d, module_name_2d,
250 | inflated_param_names):
251 | """Inflate a norm module from 2d to 3d.
252 |
253 | Args:
254 | bn3d (nn.Module): The destination bn3d module.
255 | state_dict_2d (OrderedDict): The state dict of pretrained 2d models.
256 | module_name_2d (str): The name of corresponding bn module in the
257 | 2d models.
258 | inflated_param_names (list[str]): List of parameters that have been
259 | inflated.
260 | """
261 | for param_name, param in bn3d.named_parameters():
262 | param_2d_name = f'{module_name_2d}.{param_name}'
263 | param_2d = state_dict_2d[param_2d_name]
264 | param.data.copy_(param_2d)
265 | inflated_param_names.append(param_2d_name)
266 |
267 | for param_name, param in bn3d.named_buffers():
268 | param_2d_name = f'{module_name_2d}.{param_name}'
269 | # some buffers like num_batches_tracked may not exist in old
270 | # checkpoints
271 | if param_2d_name in state_dict_2d:
272 | param_2d = state_dict_2d[param_2d_name]
273 | param.data.copy_(param_2d)
274 | inflated_param_names.append(param_2d_name)
275 |
276 | def inflate_weights(self, state_dict_r2d):
277 | """Inflate the resnet2d parameters to resnet3d.
278 |
279 | The differences between resnet3d and resnet2d mainly lie in an extra
280 | axis of conv kernel. To utilize the pretrained parameters in 2d models,
281 | the weight of conv2d models should be inflated to fit in the shapes of
282 | the 3d counterpart.
283 |
284 | """
285 |
286 | inflated_param_names = []
287 | for name, module in self.named_modules():
288 | if isinstance(module, nn.Conv3d) or isinstance(module, nn.BatchNorm3d):
289 | if name + '.weight' not in state_dict_r2d:
290 | print(f'Module not exist in the state_dict_r2d: {name}')
291 | else:
292 | shape_2d = state_dict_r2d[name + '.weight'].shape
293 | shape_3d = module.weight.data.shape
294 | if shape_2d != shape_3d[:2] + shape_3d[3:]:
295 | print(f'Weight shape mismatch for: {name}'
296 | f'3d weight shape: {shape_3d}; '
297 | f'2d weight shape: {shape_2d}. ')
298 | else:
299 | if isinstance(module, nn.Conv3d):
300 | self._inflate_conv_params(module, state_dict_r2d, name, inflated_param_names)
301 | else:
302 | self._inflate_bn_params(module, state_dict_r2d, name, inflated_param_names)
303 |
304 | # check if any parameters in the 2d checkpoint are not loaded
305 | remaining_names = set(state_dict_r2d.keys()) - set(inflated_param_names)
306 | if remaining_names:
307 | print(f'These parameters in the 2d checkpoint are not loaded: {remaining_names}')
308 |
309 |
310 | def resnet3d(arch, progress=True, modality='RGB', pretrained2d=True, **kwargs):
311 | r"""
312 | Args:
313 | arch (str): The architecture of resnet
314 | modality (str): The modality of input, 'RGB' or 'Flow'
315 | progress (bool): If True, displays a progress bar of the download to stderr
316 | pretrained2d (bool): If True, utilize the pretrained parameters in 2d models
317 | """
318 |
319 | arch_settings = {
320 | 'resnet18': (BasicBlock3d, (2, 2, 2, 2)),
321 | 'resnet34': (BasicBlock3d, (3, 4, 6, 3)),
322 | 'resnet50': (Bottleneck3d, (3, 4, 6, 3)),
323 | 'resnet101': (Bottleneck3d, (3, 4, 23, 3)),
324 | 'resnet152': (Bottleneck3d, (3, 8, 36, 3))
325 | }
326 |
327 | model = ResNet3d(*arch_settings[arch], modality=modality, **kwargs)
328 | if pretrained2d:
329 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
330 | model.inflate_weights(state_dict)
331 | return model
332 |
--------------------------------------------------------------------------------
/models/i3d.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .backbones import resnet3d
3 |
4 | __all__ = ['i3d_resnet18', 'i3d_resnet34', 'i3d_resnet50', 'i3d_resnet101', 'i3d_resnet152']
5 |
6 |
7 | class I3D(nn.Module):
8 | """
9 | Implements a I3D Network for action recognition.
10 |
11 | Arguments:
12 | backbone (nn.Module): the network used to compute the features for the model.
13 | classifier (nn.Module): module that takes the features returned from the
14 | backbone and returns classification scores.
15 | """
16 |
17 | def __init__(self, backbone, classifier):
18 | super(I3D, self).__init__()
19 | self.backbone = backbone
20 | self.classifier = classifier
21 |
22 | def forward(self, x):
23 | x = self.backbone(x)
24 | x = self.classifier(x)
25 | return x
26 |
27 |
28 | class I3DHead(nn.Module):
29 | """Classification head for I3D.
30 |
31 | Args:
32 | num_classes (int): Number of classes to be classified.
33 | in_channels (int): Number of channels in input feature.
34 | dropout_ratio (float): Probability of dropout layer. Default: 0.5.
35 | """
36 |
37 | def __init__(self, num_classes, in_channels, dropout_ratio=0.5):
38 | super(I3DHead, self).__init__()
39 | self.num_classes = num_classes
40 | self.in_channels = in_channels
41 | self.dropout_ratio = dropout_ratio
42 | # use `nn.AdaptiveAvgPool3d` to adaptively match the in_channels.
43 | self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
44 | if self.dropout_ratio != 0:
45 | self.dropout = nn.Dropout(p=self.dropout_ratio)
46 | else:
47 | self.dropout = None
48 | self.fc_cls = nn.Linear(self.in_channels, self.num_classes)
49 |
50 | def forward(self, x):
51 | """Defines the computation performed at every call.
52 |
53 | Args:
54 | x (torch.Tensor): The input data.
55 |
56 | Returns:
57 | torch.Tensor: The classification scores for input samples.
58 | """
59 | # [N, in_channels, 4, 7, 7]
60 | x = self.avg_pool(x)
61 | # [N, in_channels, 1, 1, 1]
62 | if self.dropout is not None:
63 | x = self.dropout(x)
64 | # [N, in_channels, 1, 1, 1]
65 | x = x.view(x.shape[0], -1)
66 | # [N, in_channels]
67 | cls_score = self.fc_cls(x)
68 | # [N, num_classes]
69 | return cls_score
70 |
71 |
72 | def _load_model(backbone_name, progress, modality, pretrained2d, num_classes, **kwargs):
73 | backbone = resnet3d(arch=backbone_name, progress=progress, modality=modality, pretrained2d=pretrained2d)
74 | classifier = I3DHead(num_classes=num_classes, in_channels=2048, **kwargs)
75 | model = I3D(backbone, classifier)
76 | return model
77 |
78 |
79 | def i3d_resnet18(modality='RGB', pretrained2d=True, progress=True, num_classes=21, **kwargs):
80 | """Constructs a I3D model with a ResNet3d-18 backbone.
81 |
82 | Args:
83 | modality (str): The modality of input data (RGB or Flow). If 'RGB', the first Conv
84 | accept a 3-channels input. (Default: RGB)
85 | pretrained2d (bool): If True, the backbone utilize the pretrained parameters in 2d
86 | models. (Default: True)
87 | progress (bool): If True, displays a progress bar of the download to stderr.
88 | (Default: True)
89 | num_classes (int): Number of dataset classes. (Default: 21)
90 | """
91 | return _load_model('resnet18', progress, modality, pretrained2d, num_classes, **kwargs)
92 |
93 |
94 | def i3d_resnet34(modality='RGB', pretrained2d=True, progress=True, num_classes=21, **kwargs):
95 | """Constructs a I3D model with a ResNet3d-34 backbone.
96 |
97 | Args:
98 | modality (str): The modality of input data (RGB or Flow). If 'RGB', the first Conv
99 | accept a 3-channels input. (Default: RGB)
100 | pretrained2d (bool): If True, the backbone utilize the pretrained parameters in 2d
101 | models. (Default: True)
102 | progress (bool): If True, displays a progress bar of the download to stderr.
103 | (Default: True)
104 | num_classes (int): Number of dataset classes. (Default: 21)
105 | """
106 | return _load_model('resnet34', progress, modality, pretrained2d, num_classes, **kwargs)
107 |
108 |
109 | def i3d_resnet50(modality='RGB', pretrained2d=True, progress=True, num_classes=21, **kwargs):
110 | """Constructs a I3D model with a ResNet3d-50 backbone.
111 |
112 | Args:
113 | modality (str): The modality of input data (RGB or Flow). If 'RGB', the first Conv
114 | accept a 3-channels input. (Default: RGB)
115 | pretrained2d (bool): If True, the backbone utilize the pretrained parameters in 2d
116 | models. (Default: True)
117 | progress (bool): If True, displays a progress bar of the download to stderr.
118 | (Default: True)
119 | num_classes (int): Number of dataset classes. (Default: 21)
120 | """
121 | return _load_model('resnet50', progress, modality, pretrained2d, num_classes, **kwargs)
122 |
123 |
124 | def i3d_resnet101(modality='RGB', pretrained2d=True, progress=True, num_classes=21, **kwargs):
125 | """Constructs a I3D model with a ResNet3d-101 backbone.
126 |
127 | Args:
128 | modality (str): The modality of input data (RGB or Flow). If 'RGB', the first Conv
129 | accept a 3-channels input. (Default: RGB)
130 | pretrained2d (bool): If True, the backbone utilize the pretrained parameters in 2d
131 | models. (Default: True)
132 | progress (bool): If True, displays a progress bar of the download to stderr.
133 | (Default: True)
134 | num_classes (int): Number of dataset classes. (Default: 21)
135 | """
136 | return _load_model('resnet101', progress, modality, pretrained2d, num_classes, **kwargs)
137 |
138 |
139 | def i3d_resnet152(modality='RGB', pretrained2d=True, progress=True, num_classes=21, **kwargs):
140 | """Constructs a I3D model with a ResNet3d-152 backbone.
141 |
142 | Args:
143 | modality (str): The modality of input data (RGB or Flow). If 'RGB', the first Conv
144 | accept a 3-channels input. (Default: RGB)
145 | pretrained2d (bool): If True, the backbone utilize the pretrained parameters in 2d
146 | models. (Default: True)
147 | progress (bool): If True, displays a progress bar of the download to stderr.
148 | (Default: True)
149 | num_classes (int): Number of dataset classes. (Default: 21)
150 | """
151 | return _load_model('resnet152', progress, modality, pretrained2d, num_classes, **kwargs)
152 |
--------------------------------------------------------------------------------
/opts.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | parser = argparse.ArgumentParser(description="PyTorch implementation of Inflated 3D ConvNets")
4 | parser.add_argument('dataset', type=str, choices=['ucf101', 'hmdb51', 'kinetics'])
5 | parser.add_argument('modality', type=str, choices=['RGB', 'Flow'])
6 | parser.add_argument('root_path', type=str)
7 | parser.add_argument('train_list', type=str)
8 | parser.add_argument('val_list', type=str)
9 |
10 | # ========================= Model Configs ==========================
11 | parser.add_argument('--arch', type=str, default="i3d_resnet50")
12 | parser.add_argument('--dropout', '--do', default=0.5, type=float,
13 | metavar='DO', help='dropout ratio (default: 0.5)')
14 | parser.add_argument('--clip_length', default=64, type=int, metavar='N',
15 | help='length of sequential frames (default: 64)')
16 | parser.add_argument('--input_size', default=224, type=int, metavar='N',
17 | help='size of input (default: 224)')
18 | parser.add_argument('--loss_type', type=str, default="nll",
19 | choices=['nll'])
20 |
21 | # ========================= Learning Configs ==========================
22 | parser.add_argument('--epochs', default=80, type=int, metavar='N',
23 | help='number of total epochs to run')
24 | parser.add_argument('-b', '--batch-size', default=16, type=int,
25 | metavar='N', help='mini-batch size (default: 16)')
26 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
27 | metavar='LR', help='initial learning rate')
28 | parser.add_argument('--lr_steps', default=[30, 60], type=float, nargs="+",
29 | metavar='LRSteps', help='epochs to decay learning rate by 10')
30 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
31 | help='momentum')
32 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
33 | metavar='W', help='weight decay (default: 5e-4)')
34 | parser.add_argument('--clip-gradient', '--gd', default=None, type=float,
35 | metavar='W', help='gradient norm clipping (default: disabled)')
36 |
37 | # ========================= Monitor Configs ==========================
38 | parser.add_argument('--print-freq', '-p', default=20, type=int,
39 | metavar='N', help='print frequency (default: 10)')
40 | parser.add_argument('--eval-freq', '-ef', default=5, type=int,
41 | metavar='N', help='evaluation frequency (default: 5)')
42 |
43 | # ========================= Runtime Configs ==========================
44 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
45 | help='number of data loading workers (default: 4)')
46 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
47 | help='path to latest checkpoint (default: none)')
48 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
49 | help='evaluate model on validation set')
50 | parser.add_argument('--snapshot_pref', type=str, default="")
51 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
52 | help='manual epoch number (useful on restarts)')
53 | parser.add_argument('--gpus', nargs='+', type=int, default=None)
54 | parser.add_argument('--flow_prefix', default="flow_", type=str)
55 |
--------------------------------------------------------------------------------
/test_models.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import os
4 |
5 | from sklearn.metrics import confusion_matrix
6 |
7 | from dataset import I3DDataSet
8 | from models import i3d
9 | from transforms import *
10 |
11 | # options
12 | parser = argparse.ArgumentParser(description="Standard video-level testing")
13 | parser.add_argument('dataset', type=str, choices=['ucf101', 'hmdb51', 'kinetics'])
14 | parser.add_argument('modality', type=str, choices=['RGB', 'Flow'])
15 | parser.add_argument('root_path', type=str)
16 | parser.add_argument('test_list', type=str)
17 | parser.add_argument('weights', type=str)
18 | parser.add_argument('--arch', type=str, default='i3d_resnet50')
19 | parser.add_argument('--save_scores', type=str, default=None)
20 | parser.add_argument('--max_num', type=int, default=-1)
21 | parser.add_argument('--input_size', type=int, default=224)
22 | parser.add_argument('--clip_length', default=250, type=int, metavar='N',
23 | help='length of sequential frames (default: 64)')
24 | parser.add_argument('--dropout', type=float, default=0.7)
25 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
26 | help='number of data loading workers (default: 4)')
27 | parser.add_argument('--gpus', nargs='+', type=int, default=None)
28 | parser.add_argument('--flow_prefix', type=str, default='flow_')
29 |
30 | args = parser.parse_args()
31 |
32 | if args.dataset == 'ucf101':
33 | num_classes = 101
34 | elif args.dataset == 'hmdb51':
35 | num_classes = 51
36 | elif args.dataset == 'kinetics':
37 | num_classes = 400
38 | else:
39 | raise ValueError('Unknown dataset ' + args.dataset)
40 |
41 | model = getattr(i3d, args.arch)(modality=args.modality, num_classes=num_classes,
42 | dropout_ratio=args.dropout)
43 |
44 | checkpoint = torch.load(args.weights)
45 | print("model epoch {} best prec@1: {}".format(checkpoint['epoch'], checkpoint['best_prec1']))
46 |
47 | base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint['state_dict'].items())}
48 | model.load_state_dict(base_dict)
49 |
50 | # Data loading code
51 | crop_size = args.input_size
52 | scale_size = args.input_size * 256 // 224
53 | input_mean = [0.485, 0.456, 0.406]
54 | input_std = [0.229, 0.224, 0.225]
55 | if args.modality == 'Flow':
56 | input_mean = [0.5]
57 | input_std = [np.mean(input_std)]
58 |
59 | data_loader = torch.utils.data.DataLoader(
60 | I3DDataSet(args.root_path, args.test_list, clip_length=args.clip_length, modality=args.modality,
61 | image_tmpl="img_{:05d}.jpg" if args.modality == "RGB" else args.flow_prefix + "{}_{:05d}.jpg",
62 | transform=torchvision.transforms.Compose([
63 | GroupScale(scale_size),
64 | GroupCenterCrop(crop_size),
65 | ToNumpyNDArray(),
66 | ToTorchFormatTensor(),
67 | GroupNormalize(input_mean, input_std),
68 | ]),
69 | test_mode=True),
70 | batch_size=1, shuffle=False,
71 | num_workers=args.workers * 2, pin_memory=True)
72 |
73 | if args.gpus is not None:
74 | devices = [args.gpus[i] for i in range(args.workers)]
75 | else:
76 | devices = list(range(args.workers))
77 |
78 | model = torch.nn.DataParallel(model.cuda())
79 | model.eval()
80 |
81 | data_gen = enumerate(data_loader)
82 |
83 | total_num = len(data_loader.dataset)
84 | output = []
85 |
86 |
87 | def eval_video(video_data):
88 | i, data, label = video_data
89 |
90 | rst = model(data).data.cpu().numpy().copy()
91 | return i, rst, label[0]
92 |
93 |
94 | proc_start_time = time.time()
95 | max_num = args.max_num if args.max_num > 0 else len(data_loader.dataset)
96 |
97 | for i, (data, label) in data_gen:
98 | if i >= max_num:
99 | break
100 | rst = eval_video((i, data, label))
101 | output.append(rst[1:])
102 | cnt_time = time.time() - proc_start_time
103 | if i % 10 == 0:
104 | print('video {} done, total {}/{}, average {} sec/video'.format(i, i + 1,
105 | total_num,
106 | float(cnt_time) / (i + 1)))
107 |
108 | video_pred = [np.argmax(x[0]) for x in output]
109 |
110 | video_labels = [x[1] for x in output]
111 |
112 | cf = confusion_matrix(video_labels, video_pred).astype(float)
113 |
114 | cls_cnt = cf.sum(axis=1)
115 | cls_hit = np.diag(cf)
116 |
117 | cls_acc = cls_hit / cls_cnt
118 |
119 | print(cls_acc)
120 |
121 | print('Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100))
122 |
123 | if args.save_scores is not None:
124 |
125 | # reorder before saving
126 | name_list = [x.strip().split()[0] for x in open(args.test_list)]
127 |
128 | order_dict = {e: i for i, e in enumerate(sorted(name_list))}
129 |
130 | reorder_output = [None] * len(output)
131 | reorder_label = [None] * len(output)
132 |
133 | for i in range(len(output)):
134 | idx = order_dict[name_list[i]]
135 | reorder_output[idx] = output[i]
136 | reorder_label[idx] = video_labels[i]
137 |
138 | np.savez(args.save_scores, scores=reorder_output, labels=reorder_label)
139 |
--------------------------------------------------------------------------------
/transforms.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | import random
3 | from PIL import Image, ImageOps
4 | import numpy as np
5 | import numbers
6 | import math
7 | import torch
8 |
9 |
10 | class GroupRandomCrop(object):
11 | def __init__(self, size):
12 | if isinstance(size, numbers.Number):
13 | self.size = (int(size), int(size))
14 | else:
15 | self.size = size
16 |
17 | def __call__(self, img_group):
18 |
19 | w, h = img_group[0].size
20 | th, tw = self.size
21 |
22 | out_images = list()
23 |
24 | x1 = random.randint(0, w - tw)
25 | y1 = random.randint(0, h - th)
26 |
27 | for img in img_group:
28 | assert (img.size[0] == w and img.size[1] == h)
29 | if w == tw and h == th:
30 | out_images.append(img)
31 | else:
32 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
33 |
34 | return out_images
35 |
36 |
37 | class GroupCenterCrop(object):
38 | def __init__(self, size):
39 | self.worker = torchvision.transforms.CenterCrop(size)
40 |
41 | def __call__(self, img_group):
42 | return [self.worker(img) for img in img_group]
43 |
44 |
45 | class GroupRandomHorizontalFlip(object):
46 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5
47 | """
48 |
49 | def __init__(self, is_flow=False):
50 | self.is_flow = is_flow
51 |
52 | def __call__(self, img_group, is_flow=False):
53 | v = random.random()
54 | if v < 0.5:
55 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
56 | if self.is_flow:
57 | for i in range(0, len(ret), 2):
58 | ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping
59 | return ret
60 | else:
61 | return img_group
62 |
63 |
64 | class GroupNormalize(object):
65 | def __init__(self, mean, std):
66 | self.mean = mean
67 | self.std = std
68 |
69 | def __call__(self, tensor):
70 | rep_mean = self.mean * (tensor.size()[0] // len(self.mean))
71 | rep_std = self.std * (tensor.size()[0] // len(self.std))
72 |
73 | # TODO: make efficient
74 | for t, m, s in zip(tensor, rep_mean, rep_std):
75 | t.sub_(m).div_(s)
76 |
77 | return tensor
78 |
79 |
80 | class GroupScale(object):
81 | """ Rescales the input PIL.Image to the given 'size'.
82 | 'size' will be the size of the smaller edge.
83 | For example, if height > width, then image will be
84 | rescaled to (size * height / width, size)
85 | size: size of the smaller edge
86 | interpolation: Default: PIL.Image.BILINEAR
87 | """
88 |
89 | def __init__(self, size, interpolation=Image.BILINEAR):
90 | self.worker = torchvision.transforms.Resize(size, interpolation)
91 |
92 | def __call__(self, img_group):
93 | return [self.worker(img) for img in img_group]
94 |
95 |
96 | class GroupOverSample(object):
97 | def __init__(self, crop_size, scale_size=None):
98 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size)
99 |
100 | if scale_size is not None:
101 | self.scale_worker = GroupScale(scale_size)
102 | else:
103 | self.scale_worker = None
104 |
105 | def __call__(self, img_group):
106 |
107 | if self.scale_worker is not None:
108 | img_group = self.scale_worker(img_group)
109 |
110 | image_w, image_h = img_group[0].size
111 | crop_w, crop_h = self.crop_size
112 |
113 | offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h)
114 | oversample_group = list()
115 | for o_w, o_h in offsets:
116 | normal_group = list()
117 | flip_group = list()
118 | for i, img in enumerate(img_group):
119 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
120 | normal_group.append(crop)
121 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
122 |
123 | if img.mode == 'L' and i % 2 == 0:
124 | flip_group.append(ImageOps.invert(flip_crop))
125 | else:
126 | flip_group.append(flip_crop)
127 |
128 | oversample_group.extend(normal_group)
129 | oversample_group.extend(flip_group)
130 | return oversample_group
131 |
132 |
133 | class GroupMultiScaleCrop(object):
134 |
135 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):
136 | self.scales = scales if scales is not None else [1, .875, .75, .66]
137 | self.max_distort = max_distort
138 | self.fix_crop = fix_crop
139 | self.more_fix_crop = more_fix_crop
140 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
141 | self.interpolation = Image.BILINEAR
142 |
143 | def __call__(self, img_group):
144 |
145 | im_size = img_group[0].size
146 |
147 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
148 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
149 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)
150 | for img in crop_img_group]
151 | return ret_img_group
152 |
153 | def _sample_crop_size(self, im_size):
154 | image_w, image_h = im_size[0], im_size[1]
155 |
156 | # find a crop size
157 | base_size = min(image_w, image_h)
158 | crop_sizes = [int(base_size * x) for x in self.scales]
159 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]
160 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]
161 |
162 | pairs = []
163 | for i, h in enumerate(crop_h):
164 | for j, w in enumerate(crop_w):
165 | if abs(i - j) <= self.max_distort:
166 | pairs.append((w, h))
167 |
168 | crop_pair = random.choice(pairs)
169 | if not self.fix_crop:
170 | w_offset = random.randint(0, image_w - crop_pair[0])
171 | h_offset = random.randint(0, image_h - crop_pair[1])
172 | else:
173 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
174 |
175 | return crop_pair[0], crop_pair[1], w_offset, h_offset
176 |
177 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
178 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)
179 | return random.choice(offsets)
180 |
181 | @staticmethod
182 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
183 | w_step = (image_w - crop_w) // 4
184 | h_step = (image_h - crop_h) // 4
185 |
186 | ret = list()
187 | ret.append((0, 0)) # upper left
188 | ret.append((4 * w_step, 0)) # upper right
189 | ret.append((0, 4 * h_step)) # lower left
190 | ret.append((4 * w_step, 4 * h_step)) # lower right
191 | ret.append((2 * w_step, 2 * h_step)) # center
192 |
193 | if more_fix_crop:
194 | ret.append((0, 2 * h_step)) # center left
195 | ret.append((4 * w_step, 2 * h_step)) # center right
196 | ret.append((2 * w_step, 4 * h_step)) # lower center
197 | ret.append((2 * w_step, 0 * h_step)) # upper center
198 |
199 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter
200 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter
201 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter
202 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
203 |
204 | return ret
205 |
206 |
207 | class GroupRandomSizedCrop(object):
208 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
209 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
210 | This is popularly used to train the Inception networks
211 | size: size of the smaller edge
212 | interpolation: Default: PIL.Image.BILINEAR
213 | """
214 |
215 | def __init__(self, size, interpolation=Image.BILINEAR):
216 | self.size = size
217 | self.interpolation = interpolation
218 |
219 | def __call__(self, img_group):
220 | for attempt in range(10):
221 | area = img_group[0].size[0] * img_group[0].size[1]
222 | target_area = random.uniform(0.08, 1.0) * area
223 | aspect_ratio = random.uniform(3. / 4, 4. / 3)
224 |
225 | w = int(round(math.sqrt(target_area * aspect_ratio)))
226 | h = int(round(math.sqrt(target_area / aspect_ratio)))
227 |
228 | if random.random() < 0.5:
229 | w, h = h, w
230 |
231 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]:
232 | x1 = random.randint(0, img_group[0].size[0] - w)
233 | y1 = random.randint(0, img_group[0].size[1] - h)
234 | found = True
235 | break
236 | else:
237 | found = False
238 | x1 = 0
239 | y1 = 0
240 |
241 | if found:
242 | out_group = list()
243 | for img in img_group:
244 | img = img.crop((x1, y1, x1 + w, y1 + h))
245 | assert (img.size == (w, h))
246 | out_group.append(img.resize((self.size, self.size), self.interpolation))
247 | return out_group
248 | else:
249 | # Fallback
250 | scale = GroupScale(self.size, interpolation=self.interpolation)
251 | crop = GroupRandomCrop(self.size)
252 | return crop(scale(img_group))
253 |
254 |
255 | class Stack(object):
256 |
257 | def __init__(self, roll=False):
258 | self.roll = roll
259 |
260 | def __call__(self, img_group):
261 | if img_group[0].mode == 'L':
262 | return np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2)
263 | elif img_group[0].mode == 'RGB':
264 | if self.roll:
265 | return np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2)
266 | else:
267 | return np.concatenate(img_group, axis=2)
268 |
269 |
270 | class ToNumpyNDArray(object):
271 |
272 | def __call__(self, img_group):
273 | if img_group[0].mode == 'L':
274 | return np.array([np.stack((np.array(img_group[x]), np.array(img_group[x + 1])), axis=-1)
275 | for x in range(0, len(img_group), 2)])
276 | if img_group[0].mode == 'RGB':
277 | return np.array([np.array(x) for x in img_group])
278 |
279 |
280 | class ToTorchFormatTensor(object):
281 | """ Converts a group of PIL.Image (RGB) or numpy.ndarray (T x H x W x C) in the range
282 | [0, 255] to a torch.FloatTensor of shape (C x T x H x W) in the range [0.0, 1.0] """
283 |
284 | def __init__(self, div=True):
285 | self.div = div
286 |
287 | def __call__(self, pic):
288 | if isinstance(pic, np.ndarray):
289 | # handle numpy array
290 | # put it from THWC to CTHW format
291 | imgs = torch.from_numpy(pic).permute(3, 0, 1, 2).contiguous()
292 | else:
293 | # handle PIL Image
294 | imgs = list()
295 | for p in pic:
296 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(p.tobytes()))
297 | img = img.view(p.size[1], p.size[0], len(p.mode))
298 | # put it from HWC to CHW format
299 | # yikes, this transpose takes 80% of the loading time/CPU
300 | img = img.transpose(0, 1).transpose(0, 2).contiguous()
301 | imgs.append(img)
302 | imgs = torch.stack(imgs, dim=0)
303 | return imgs.float().div(255) if self.div else imgs.float()
304 |
305 |
306 | class IdentityTransform(object):
307 |
308 | def __call__(self, data):
309 | return data
310 |
311 |
312 | if __name__ == "__main__":
313 | trans = torchvision.transforms.Compose([
314 | GroupScale(256),
315 | GroupRandomCrop(224),
316 | Stack(),
317 | ToTorchFormatTensor(),
318 | GroupNormalize(
319 | mean=[.485, .456, .406],
320 | std=[.229, .224, .225]
321 | )]
322 | )
323 |
324 | im = Image.open('../tensorflow-model-zoo.torch/lena_299.png')
325 |
326 | color_group = [im] * 3
327 | rst = trans(color_group)
328 |
329 | gray_group = [im.convert('L')] * 9
330 | gray_rst = trans(gray_group)
331 |
332 | trans2 = torchvision.transforms.Compose([
333 | GroupRandomSizedCrop(256),
334 | Stack(),
335 | ToTorchFormatTensor(),
336 | GroupNormalize(
337 | mean=[.485, .456, .406],
338 | std=[.229, .224, .225])
339 | ])
340 | print(trans2(color_group))
341 |
--------------------------------------------------------------------------------