├── .gitignore ├── QuoVadis.png ├── README.md ├── dataset.py ├── main.py ├── models ├── __init__.py ├── backbones │ ├── __init__.py │ └── resnet3d.py └── i3d.py ├── opts.py ├── test_models.py └── transforms.py /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea/ 2 | -------------------------------------------------------------------------------- /QuoVadis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PPPrior/i3d-pytorch/e3a022e0892be5553e82e0ee152dd3e73f9f6bac/QuoVadis.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # I3D-PyTorch 2 | This is a simple and crude implementation of Inflated 3D ConvNet Models (I3D) in PyTorch. Different from models reported in "[Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset](https://arxiv.org/abs/1705.07750)" by Joao Carreira and Andrew Zisserman, this implementation uses [ResNet](https://arxiv.org/pdf/1512.03385.pdf) as backbone. 3 | 4 |
5 | 6 |
7 | 8 | This implementation is based on OpenMMLab's [MMAction2](https://github.com/open-mmlab/mmaction2). 9 | 10 | ## Data Preparation 11 | 12 | For optical flow extraction and video list generation, please refer to [TSN](https://github.com/yjxiong/temporal-segment-networks#code--data-preparation) for details. 13 | 14 | ## Training 15 | 16 | To train a new model, use the `main.py` script. 17 | 18 | For example, command to train models with RGB modality on UCF101 can be 19 | 20 | ```bash 21 | python main.py ucf101 RGB \ 22 | \ 23 | --arch i3d_resnet50 --clip_length 64 \ 24 | --lr 0.001 --lr_steps 30 60 --epochs 80 \ 25 | -b 32 -j 8 --dropout 0.8 \ 26 | --snapshot_pref ucf101_i3d_resnet50 27 | ``` 28 | 29 | For flow models: 30 | 31 | ```bash 32 | python main.py ucf101 Flow \ 33 | \ 34 | --arch i3d_resnet50 --clip_length 64 \ 35 | --lr 0.001 --lr_steps 15 30 --epochs 40 \ 36 | -b 64 -j 8 --dropout 0.8 \ 37 | --snapshot_pref ucf101_i3d_resnet50 38 | ``` 39 | 40 | Please refer to [main.py](main.py) for more details. 41 | 42 | ## Testing 43 | 44 | After training, there will checkpoints saved by pytorch, for example `ucf101_i3d_resnet50_rgb_model_best.pth.tar`. 45 | 46 | Use the following command to test its performance: 47 | 48 | ```bash 49 | python test_models.py ucf101 RGB \ 50 | ucf101_i3d_resnet50_rgb_model_best.pth.tar \ 51 | --arch i3d_resnet50 --save_scores 52 | ``` 53 | 54 | Or for flow models: 55 | 56 | ```bash 57 | python test_models.py ucf101 Flow \ 58 | ucf101_i3d_resnet50_flow_model_best.pth.tar \ 59 | --arch i3d_resnet50 --save_scores 60 | ``` 61 | 62 | Please refer to [test_models.py](test_models.py) for more details. -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import numpy as np 5 | import os 6 | import os.path 7 | from numpy.random import randint 8 | 9 | 10 | class VideoRecord(object): 11 | def __init__(self, row): 12 | self._data = row 13 | 14 | @property 15 | def path(self): 16 | return self._data[0] 17 | 18 | @property 19 | def num_frames(self): 20 | return int(self._data[1]) 21 | 22 | @property 23 | def label(self): 24 | return int(self._data[2]) 25 | 26 | 27 | class I3DDataSet(data.Dataset): 28 | def __init__(self, root_path, list_file, clip_length=64, frame_size=(320, 240), 29 | modality='RGB', image_tmpl='img_{:05d}.jpg', 30 | transform=None, random_shift=True, test_mode=False): 31 | self.root_path = root_path 32 | self.list_file = list_file 33 | self.clip_length = clip_length 34 | self.frame_size = frame_size 35 | self.modality = modality 36 | self.image_tmpl = image_tmpl 37 | self.transform = transform 38 | self.random_shift = random_shift 39 | self.test_mode = test_mode 40 | 41 | self._parse_list() 42 | 43 | def _load_image(self, directory, idx): 44 | root_path = os.path.join(self.root_path, 'rawframes/') # ../data/ucf101/rawframes/ 45 | directory = os.path.join(root_path, directory) 46 | 47 | if self.modality == 'RGB' or self.modality == 'RGBDiff': 48 | return [Image.open(os.path.join(directory, self.image_tmpl.format(idx))).convert('RGB')] 49 | elif self.modality == 'Flow': 50 | x_img = Image.open(os.path.join(directory, self.image_tmpl.format('x', idx))).convert('L') 51 | y_img = Image.open(os.path.join(directory, self.image_tmpl.format('y', idx))).convert('L') 52 | 53 | return [x_img, y_img] 54 | 55 | def _parse_list(self): 56 | self.video_list = [VideoRecord(x.strip().split(' ')) for x in 57 | open(os.path.join(self.root_path, self.list_file))] 58 | 59 | def _sample_indices(self, record): 60 | if not self.test_mode and self.random_shift: 61 | average_duration = record.num_frames // self.clip_length 62 | if average_duration > 0: 63 | offsets = np.sort( 64 | np.multiply(list(range(self.clip_length)), average_duration) + randint(average_duration, 65 | size=self.clip_length)) 66 | else: 67 | offsets = np.sort(randint(record.num_frames, size=self.clip_length)) 68 | else: 69 | tick = record.num_frames / float(self.clip_length) 70 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.clip_length)]) 71 | return offsets + 1 72 | 73 | def __getitem__(self, index): 74 | record = self.video_list[index] 75 | indices = self._sample_indices(record) 76 | return self.get(record, indices) 77 | 78 | def get(self, record, indices): 79 | images = list() 80 | for index in indices: 81 | img = self._load_image(record.path, int(index)) 82 | images.extend(img) 83 | process_data = self.transform(images) 84 | return process_data, record.label 85 | 86 | def __len__(self): 87 | return len(self.video_list) 88 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import shutil 4 | import torchvision.transforms 5 | import torch.utils.data.dataloader 6 | import torch.nn.parallel 7 | import torch.backends.cudnn as cudnn 8 | import torch.optim 9 | from torch.nn.utils import clip_grad_norm 10 | 11 | from models import i3d 12 | from dataset import I3DDataSet 13 | from transforms import * 14 | from opts import parser 15 | 16 | best_prec1 = 0 17 | 18 | 19 | def main(): 20 | global args, best_prec1 21 | args = parser.parse_args() 22 | 23 | if args.dataset == 'ucf101': 24 | num_classes = 101 25 | elif args.dataset == 'hmdb51': 26 | num_classes = 51 27 | elif args.dataset == 'kinetics': 28 | num_classes = 400 29 | else: 30 | raise ValueError('Unknown dataset ' + args.dataset) 31 | 32 | model = getattr(i3d, args.arch)(modality=args.modality, num_classes=num_classes, 33 | dropout_ratio=args.dropout) 34 | 35 | crop_size = args.input_size 36 | scale_size = args.input_size * 256 // 224 37 | input_mean = [0.485, 0.456, 0.406] 38 | input_std = [0.229, 0.224, 0.225] 39 | if args.modality == 'Flow': 40 | input_mean = [0.5] 41 | input_std = [np.mean(input_std)] 42 | 43 | train_augmentation = get_augmentation(args.modality, args.input_size) 44 | 45 | model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() 46 | 47 | if args.resume: 48 | if os.path.isfile(args.resume): 49 | print("=> loading checkpoint '{}'".format(args.resume)) 50 | checkpoint = torch.load(args.resume) 51 | args.start_epoch = checkpoint['epoch'] 52 | best_prec1 = checkpoint['best_prec1'] 53 | model.load_state_dict(checkpoint['state_dict']) 54 | print("=> loaded checkpoint '{}' (epoch {})" 55 | .format(args.evaluate, checkpoint['epoch'])) 56 | else: 57 | print("=> no checkpoint found at '{}'".format(args.resume)) 58 | 59 | cudnn.benchmark = True 60 | 61 | # Data loading code 62 | train_loader = torch.utils.data.DataLoader( 63 | I3DDataSet(args.root_path, args.train_list, clip_length=args.clip_length, modality=args.modality, 64 | image_tmpl="img_{:05d}.jpg" if args.modality == "RGB" else args.flow_prefix + "{}_{:05d}.jpg", 65 | transform=torchvision.transforms.Compose([ 66 | train_augmentation, 67 | ToNumpyNDArray(), 68 | ToTorchFormatTensor(), 69 | GroupNormalize(input_mean, input_std), 70 | ])), 71 | batch_size=args.batch_size, shuffle=True, 72 | num_workers=args.workers, pin_memory=True) 73 | 74 | val_loader = torch.utils.data.DataLoader( 75 | I3DDataSet(args.root_path, args.val_list, clip_length=args.clip_length, modality=args.modality, 76 | image_tmpl="img_{:05d}.jpg" if args.modality == "RGB" else args.flow_prefix + "{}_{:05d}.jpg", 77 | random_shift=False, 78 | transform=torchvision.transforms.Compose([ 79 | GroupScale(int(scale_size)), 80 | GroupCenterCrop(crop_size), 81 | ToNumpyNDArray(), 82 | ToTorchFormatTensor(), 83 | GroupNormalize(input_mean, input_std), 84 | ])), 85 | batch_size=args.batch_size, shuffle=False, 86 | num_workers=args.workers, pin_memory=True 87 | ) 88 | 89 | # define loss function (criterion) and optimizer 90 | if args.loss_type == 'nll': 91 | criterion = torch.nn.CrossEntropyLoss().cuda() 92 | else: 93 | raise ValueError("Unknown loss type") 94 | 95 | optimizer = torch.optim.SGD(params=model.parameters(), lr=args.lr, momentum=args.momentum, 96 | weight_decay=args.weight_decay) 97 | 98 | if args.evaluate: 99 | validate(val_loader, model, criterion, 0) 100 | return 101 | 102 | for epoch in range(args.start_epoch, args.epochs): 103 | adjust_learning_rate(optimizer, epoch, args.lr_steps) 104 | 105 | # train for one epoch 106 | train(train_loader, model, criterion, optimizer, epoch) 107 | 108 | # evaluate on validation set 109 | if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: 110 | prec1 = validate(val_loader, model, criterion, (epoch + 1) * len(train_loader)) 111 | 112 | # remember best prec@1 and save checkpoint 113 | is_best = prec1 > best_prec1 114 | best_prec1 = max(prec1, best_prec1) 115 | save_checkpoint({ 116 | 'epoch': epoch + 1, 117 | 'arch': args.arch, 118 | 'state_dict': model.state_dict(), 119 | 'best_prec1': best_prec1, 120 | }, is_best) 121 | 122 | 123 | def train(train_loader, model, criterion, optimizer, epoch): 124 | batch_time = AverageMeter() 125 | data_time = AverageMeter() 126 | losses = AverageMeter() 127 | top1 = AverageMeter() 128 | top5 = AverageMeter() 129 | 130 | # switch to train mode 131 | model.train() 132 | 133 | end = time.time() 134 | for i, (input, target) in enumerate(train_loader): 135 | # measure data loading time 136 | data_time.update(time.time() - end) 137 | 138 | target = target.cuda() 139 | output = model(input) 140 | loss = criterion(output, target) 141 | 142 | # measure accuracy and record loss 143 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 144 | losses.update(loss.item(), input.size(0)) 145 | top1.update(prec1.item(), input.size(0)) 146 | top5.update(prec5.item(), input.size(0)) 147 | 148 | # compute gradient and do SGD step 149 | optimizer.zero_grad() 150 | 151 | loss.backward() 152 | 153 | if args.clip_gradient is not None: 154 | total_norm = clip_grad_norm(model.parameters(), args.clip_gradient) 155 | if total_norm > args.clip_gradient: 156 | print("clipping gradient: {} with coef {}".format(total_norm, args.clip_gradient / total_norm)) 157 | 158 | optimizer.step() 159 | 160 | # measure elapsed time 161 | batch_time.update(time.time() - end) 162 | end = time.time() 163 | 164 | if i % args.print_freq == 0: 165 | print(('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t' 166 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 167 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 168 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 169 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 170 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 171 | epoch, i, len(train_loader), batch_time=batch_time, 172 | data_time=data_time, loss=losses, top1=top1, top5=top5, lr=optimizer.param_groups[-1]['lr']))) 173 | 174 | 175 | def validate(val_loader, model, criterion, iter, logger=None): 176 | batch_time = AverageMeter() 177 | losses = AverageMeter() 178 | top1 = AverageMeter() 179 | top5 = AverageMeter() 180 | 181 | # switch to evaluate mode 182 | model.eval() 183 | with torch.no_grad(): 184 | end = time.time() 185 | for i, (input, target) in enumerate(val_loader): 186 | target = target.cuda() 187 | output = model(input) 188 | loss = criterion(output, target) 189 | 190 | # measure accuracy and record loss 191 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 192 | 193 | losses.update(loss.item(), input.size(0)) 194 | top1.update(prec1.item(), input.size(0)) 195 | top5.update(prec5.item(), input.size(0)) 196 | 197 | # measure elapsed time 198 | batch_time.update(time.time() - end) 199 | end = time.time() 200 | 201 | if i % args.print_freq == 0: 202 | print(('Test: [{0}/{1}]\t' 203 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 204 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 205 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 206 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 207 | i, len(val_loader), batch_time=batch_time, loss=losses, 208 | top1=top1, top5=top5))) 209 | 210 | print(('Testing Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}' 211 | .format(top1=top1, top5=top5, loss=losses))) 212 | 213 | return top1.avg 214 | 215 | 216 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 217 | filename = '_'.join((args.snapshot_pref, args.modality.lower(), filename)) 218 | torch.save(state, filename) 219 | if is_best: 220 | best_name = '_'.join((args.snapshot_pref, args.modality.lower(), 'model_best.pth.tar')) 221 | shutil.copyfile(filename, best_name) 222 | 223 | 224 | def get_augmentation(modality, input_size): 225 | if modality == 'RGB': 226 | return torchvision.transforms.Compose([GroupMultiScaleCrop(input_size, [1, .875, .75, .66]), 227 | GroupRandomHorizontalFlip(is_flow=False)]) 228 | elif modality == 'Flow': 229 | return torchvision.transforms.Compose([GroupMultiScaleCrop(input_size, [1, .875, .75]), 230 | GroupRandomHorizontalFlip(is_flow=True)]) 231 | 232 | 233 | class AverageMeter(object): 234 | """Computes and stores the average and current value""" 235 | 236 | def __init__(self): 237 | self.reset() 238 | 239 | def reset(self): 240 | self.val = 0 241 | self.avg = 0 242 | self.sum = 0 243 | self.count = 0 244 | 245 | def update(self, val, n=1): 246 | self.val = val 247 | self.sum += val * n 248 | self.count += n 249 | self.avg = self.sum / self.count 250 | 251 | 252 | def adjust_learning_rate(optimizer, epoch, lr_steps): 253 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 254 | decay = 0.1 ** (sum(epoch >= np.array(lr_steps))) 255 | lr = args.lr * decay 256 | decay = args.weight_decay 257 | for param_group in optimizer.param_groups: 258 | param_group['lr'] = lr 259 | param_group['weight_decay'] = decay 260 | 261 | 262 | def accuracy(output, target, topk=(1,)): 263 | """Computes the precision@k for the specified values of k""" 264 | maxk = max(topk) 265 | batch_size = target.size(0) 266 | 267 | _, pred = output.topk(maxk, 1, True, True) 268 | pred = pred.t() 269 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 270 | 271 | res = [] 272 | for k in topk: 273 | correct_k = correct[:k].view(-1).float().sum(0) 274 | res.append(correct_k.mul_(100.0 / batch_size)) 275 | return res 276 | 277 | 278 | if __name__ == '__main__': 279 | main() 280 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbones import * 2 | from .i3d import * 3 | -------------------------------------------------------------------------------- /models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet3d import * 2 | -------------------------------------------------------------------------------- /models/backbones/resnet3d.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision.models.utils import load_state_dict_from_url 3 | 4 | __all__ = ['resnet3d'] 5 | 6 | model_urls = { 7 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 8 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 9 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 10 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 11 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 12 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 13 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 14 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 15 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 16 | } 17 | 18 | 19 | def conv3x3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 20 | """3x3x3 convolution with padding""" 21 | if isinstance(stride, int): 22 | stride = (1, stride, stride) 23 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, 24 | padding=dilation, dilation=dilation, groups=groups, bias=False) 25 | 26 | 27 | def conv1x1x1(in_planes, out_planes, stride=1): 28 | """1x1x1 convolution""" 29 | if isinstance(stride, int): 30 | stride = (1, stride, stride) 31 | return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 32 | 33 | 34 | class BasicBlock3d(nn.Module): 35 | expansion = 1 36 | 37 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 38 | base_width=64, dilation=1, norm_layer=None): 39 | super(BasicBlock3d, self).__init__() 40 | if norm_layer is None: 41 | norm_layer = nn.BatchNorm3d 42 | if groups != 1 or base_width != 64: 43 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 44 | if dilation > 1: 45 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 46 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 47 | self.conv1 = conv3x3x3(inplanes, planes, stride) 48 | self.bn1 = norm_layer(planes) 49 | self.relu = nn.ReLU(inplace=True) 50 | self.conv2 = conv3x3x3(planes, planes) 51 | self.bn2 = norm_layer(planes) 52 | self.downsample = downsample 53 | self.stride = stride 54 | 55 | def forward(self, x): 56 | identity = x 57 | 58 | out = self.conv1(x) 59 | out = self.bn1(out) 60 | out = self.relu(out) 61 | 62 | out = self.conv2(out) 63 | out = self.bn2(out) 64 | 65 | if self.downsample is not None: 66 | identity = self.downsample(x) 67 | 68 | out += identity 69 | out = self.relu(out) 70 | 71 | return out 72 | 73 | 74 | class Bottleneck3d(nn.Module): 75 | expansion = 4 76 | 77 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 78 | base_width=64, dilation=1, norm_layer=None): 79 | super(Bottleneck3d, self).__init__() 80 | if norm_layer is None: 81 | norm_layer = nn.BatchNorm3d 82 | width = int(planes * (base_width / 64.)) * groups 83 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 84 | self.conv1 = conv1x1x1(inplanes, width) 85 | self.bn1 = norm_layer(width) 86 | self.conv2 = conv3x3x3(width, width, stride, groups, dilation) 87 | self.bn2 = norm_layer(width) 88 | self.conv3 = conv1x1x1(width, planes * self.expansion) 89 | self.bn3 = norm_layer(planes * self.expansion) 90 | self.relu = nn.ReLU(inplace=True) 91 | self.downsample = downsample 92 | self.stride = stride 93 | 94 | def forward(self, x): 95 | identity = x 96 | 97 | out = self.conv1(x) 98 | out = self.bn1(out) 99 | out = self.relu(out) 100 | 101 | out = self.conv2(out) 102 | out = self.bn2(out) 103 | out = self.relu(out) 104 | 105 | out = self.conv3(out) 106 | out = self.bn3(out) 107 | 108 | if self.downsample is not None: 109 | identity = self.downsample(x) 110 | 111 | out += identity 112 | out = self.relu(out) 113 | 114 | return out 115 | 116 | 117 | class ResNet3d(nn.Module): 118 | 119 | def __init__(self, block, layers, zero_init_residual=False, 120 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 121 | norm_layer=None, modality='RGB'): 122 | super(ResNet3d, self).__init__() 123 | if norm_layer is None: 124 | norm_layer = nn.BatchNorm3d 125 | self._norm_layer = norm_layer 126 | 127 | self.modality = modality 128 | self.inplanes = 64 129 | self.dilation = 1 130 | if replace_stride_with_dilation is None: 131 | # each element in the tuple indicates if we should replace 132 | # the 2x2x2 stride with a dilated convolution instead 133 | replace_stride_with_dilation = [False, False, False] 134 | if len(replace_stride_with_dilation) != 3: 135 | raise ValueError("replace_stride_with_dilation should be None " 136 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 137 | self.groups = groups 138 | self.base_width = width_per_group 139 | 140 | self._make_stem_layer() 141 | 142 | self.layer1 = self._make_layer(block, 64, layers[0]) 143 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 144 | dilate=replace_stride_with_dilation[0]) 145 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 146 | dilate=replace_stride_with_dilation[1]) 147 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 148 | dilate=replace_stride_with_dilation[2]) 149 | 150 | for m in self.modules(): # self.modules() --> Depth-First-Search the Net 151 | if isinstance(m, nn.Conv3d): 152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 153 | elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)): 154 | nn.init.constant_(m.weight, 1) 155 | nn.init.constant_(m.bias, 0) 156 | 157 | # Zero-initialize the last BN in each residual branch, 158 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 159 | if zero_init_residual: 160 | for m in self.modules(): 161 | if isinstance(m, Bottleneck3d): 162 | nn.init.constant_(m.bn3.weight, 0) 163 | elif isinstance(m, BasicBlock3d): 164 | nn.init.constant_(m.bn2.weight, 0) 165 | 166 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 167 | norm_layer = self._norm_layer 168 | downsample = None 169 | previous_dilation = self.dilation 170 | if dilate: 171 | self.dilation *= stride 172 | stride = 1 173 | if stride != 1 or self.inplanes != planes * block.expansion: 174 | downsample = nn.Sequential( 175 | conv1x1x1(self.inplanes, planes * block.expansion, stride), 176 | norm_layer(planes * block.expansion), 177 | ) 178 | 179 | layers = [] 180 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 181 | self.base_width, previous_dilation, norm_layer)) 182 | self.inplanes = planes * block.expansion 183 | for _ in range(1, blocks): 184 | layers.append(block(self.inplanes, planes, groups=self.groups, 185 | base_width=self.base_width, dilation=self.dilation, 186 | norm_layer=norm_layer)) 187 | 188 | return nn.Sequential(*layers) 189 | 190 | def _make_stem_layer(self): 191 | """Construct the stem layers consists of a conv+norm+act module and a 192 | pooling layer.""" 193 | if self.modality == 'RGB': 194 | inchannels = 3 195 | elif self.modality == 'Flow': 196 | inchannels = 2 197 | else: 198 | raise ValueError('Unknown modality: {}'.format(self.modality)) 199 | self.conv1 = nn.Conv3d(inchannels, self.inplanes, kernel_size=(5, 7, 7), 200 | stride=2, padding=(2, 3, 3), bias=False) 201 | self.bn1 = self._norm_layer(self.inplanes) 202 | self.relu = nn.ReLU(inplace=True) 203 | self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=2, 204 | padding=(0, 1, 1)) # kernel_size=(2, 3, 3) 205 | 206 | def _forward_impl(self, x): 207 | # See note [TorchScript super()] 208 | x = self.conv1(x) 209 | x = self.bn1(x) 210 | x = self.relu(x) 211 | x = self.maxpool(x) 212 | 213 | x = self.layer1(x) 214 | x = self.layer2(x) 215 | x = self.layer3(x) 216 | x = self.layer4(x) 217 | 218 | return x 219 | 220 | def forward(self, x): 221 | return self._forward_impl(x) 222 | 223 | def _inflate_conv_params(self, conv3d, state_dict_2d, module_name_2d, 224 | inflated_param_names): 225 | """Inflate a conv module from 2d to 3d. 226 | 227 | Args: 228 | conv3d (nn.Module): The destination conv3d module. 229 | state_dict_2d (OrderedDict): The state dict of pretrained 2d models. 230 | module_name_2d (str): The name of corresponding conv module in the 231 | 2d models. 232 | inflated_param_names (list[str]): List of parameters that have been 233 | inflated. 234 | """ 235 | weight_2d_name = module_name_2d + '.weight' 236 | 237 | conv2d_weight = state_dict_2d[weight_2d_name] 238 | kernel_t = conv3d.weight.data.shape[2] 239 | 240 | new_weight = conv2d_weight.data.unsqueeze(2).expand_as(conv3d.weight) / kernel_t 241 | conv3d.weight.data.copy_(new_weight) 242 | inflated_param_names.append(weight_2d_name) 243 | 244 | if getattr(conv3d, 'bias') is not None: 245 | bias_2d_name = module_name_2d + '.bias' 246 | conv3d.bias.data.copy_(state_dict_2d[bias_2d_name]) 247 | inflated_param_names.append(bias_2d_name) 248 | 249 | def _inflate_bn_params(self, bn3d, state_dict_2d, module_name_2d, 250 | inflated_param_names): 251 | """Inflate a norm module from 2d to 3d. 252 | 253 | Args: 254 | bn3d (nn.Module): The destination bn3d module. 255 | state_dict_2d (OrderedDict): The state dict of pretrained 2d models. 256 | module_name_2d (str): The name of corresponding bn module in the 257 | 2d models. 258 | inflated_param_names (list[str]): List of parameters that have been 259 | inflated. 260 | """ 261 | for param_name, param in bn3d.named_parameters(): 262 | param_2d_name = f'{module_name_2d}.{param_name}' 263 | param_2d = state_dict_2d[param_2d_name] 264 | param.data.copy_(param_2d) 265 | inflated_param_names.append(param_2d_name) 266 | 267 | for param_name, param in bn3d.named_buffers(): 268 | param_2d_name = f'{module_name_2d}.{param_name}' 269 | # some buffers like num_batches_tracked may not exist in old 270 | # checkpoints 271 | if param_2d_name in state_dict_2d: 272 | param_2d = state_dict_2d[param_2d_name] 273 | param.data.copy_(param_2d) 274 | inflated_param_names.append(param_2d_name) 275 | 276 | def inflate_weights(self, state_dict_r2d): 277 | """Inflate the resnet2d parameters to resnet3d. 278 | 279 | The differences between resnet3d and resnet2d mainly lie in an extra 280 | axis of conv kernel. To utilize the pretrained parameters in 2d models, 281 | the weight of conv2d models should be inflated to fit in the shapes of 282 | the 3d counterpart. 283 | 284 | """ 285 | 286 | inflated_param_names = [] 287 | for name, module in self.named_modules(): 288 | if isinstance(module, nn.Conv3d) or isinstance(module, nn.BatchNorm3d): 289 | if name + '.weight' not in state_dict_r2d: 290 | print(f'Module not exist in the state_dict_r2d: {name}') 291 | else: 292 | shape_2d = state_dict_r2d[name + '.weight'].shape 293 | shape_3d = module.weight.data.shape 294 | if shape_2d != shape_3d[:2] + shape_3d[3:]: 295 | print(f'Weight shape mismatch for: {name}' 296 | f'3d weight shape: {shape_3d}; ' 297 | f'2d weight shape: {shape_2d}. ') 298 | else: 299 | if isinstance(module, nn.Conv3d): 300 | self._inflate_conv_params(module, state_dict_r2d, name, inflated_param_names) 301 | else: 302 | self._inflate_bn_params(module, state_dict_r2d, name, inflated_param_names) 303 | 304 | # check if any parameters in the 2d checkpoint are not loaded 305 | remaining_names = set(state_dict_r2d.keys()) - set(inflated_param_names) 306 | if remaining_names: 307 | print(f'These parameters in the 2d checkpoint are not loaded: {remaining_names}') 308 | 309 | 310 | def resnet3d(arch, progress=True, modality='RGB', pretrained2d=True, **kwargs): 311 | r""" 312 | Args: 313 | arch (str): The architecture of resnet 314 | modality (str): The modality of input, 'RGB' or 'Flow' 315 | progress (bool): If True, displays a progress bar of the download to stderr 316 | pretrained2d (bool): If True, utilize the pretrained parameters in 2d models 317 | """ 318 | 319 | arch_settings = { 320 | 'resnet18': (BasicBlock3d, (2, 2, 2, 2)), 321 | 'resnet34': (BasicBlock3d, (3, 4, 6, 3)), 322 | 'resnet50': (Bottleneck3d, (3, 4, 6, 3)), 323 | 'resnet101': (Bottleneck3d, (3, 4, 23, 3)), 324 | 'resnet152': (Bottleneck3d, (3, 8, 36, 3)) 325 | } 326 | 327 | model = ResNet3d(*arch_settings[arch], modality=modality, **kwargs) 328 | if pretrained2d: 329 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 330 | model.inflate_weights(state_dict) 331 | return model 332 | -------------------------------------------------------------------------------- /models/i3d.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .backbones import resnet3d 3 | 4 | __all__ = ['i3d_resnet18', 'i3d_resnet34', 'i3d_resnet50', 'i3d_resnet101', 'i3d_resnet152'] 5 | 6 | 7 | class I3D(nn.Module): 8 | """ 9 | Implements a I3D Network for action recognition. 10 | 11 | Arguments: 12 | backbone (nn.Module): the network used to compute the features for the model. 13 | classifier (nn.Module): module that takes the features returned from the 14 | backbone and returns classification scores. 15 | """ 16 | 17 | def __init__(self, backbone, classifier): 18 | super(I3D, self).__init__() 19 | self.backbone = backbone 20 | self.classifier = classifier 21 | 22 | def forward(self, x): 23 | x = self.backbone(x) 24 | x = self.classifier(x) 25 | return x 26 | 27 | 28 | class I3DHead(nn.Module): 29 | """Classification head for I3D. 30 | 31 | Args: 32 | num_classes (int): Number of classes to be classified. 33 | in_channels (int): Number of channels in input feature. 34 | dropout_ratio (float): Probability of dropout layer. Default: 0.5. 35 | """ 36 | 37 | def __init__(self, num_classes, in_channels, dropout_ratio=0.5): 38 | super(I3DHead, self).__init__() 39 | self.num_classes = num_classes 40 | self.in_channels = in_channels 41 | self.dropout_ratio = dropout_ratio 42 | # use `nn.AdaptiveAvgPool3d` to adaptively match the in_channels. 43 | self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) 44 | if self.dropout_ratio != 0: 45 | self.dropout = nn.Dropout(p=self.dropout_ratio) 46 | else: 47 | self.dropout = None 48 | self.fc_cls = nn.Linear(self.in_channels, self.num_classes) 49 | 50 | def forward(self, x): 51 | """Defines the computation performed at every call. 52 | 53 | Args: 54 | x (torch.Tensor): The input data. 55 | 56 | Returns: 57 | torch.Tensor: The classification scores for input samples. 58 | """ 59 | # [N, in_channels, 4, 7, 7] 60 | x = self.avg_pool(x) 61 | # [N, in_channels, 1, 1, 1] 62 | if self.dropout is not None: 63 | x = self.dropout(x) 64 | # [N, in_channels, 1, 1, 1] 65 | x = x.view(x.shape[0], -1) 66 | # [N, in_channels] 67 | cls_score = self.fc_cls(x) 68 | # [N, num_classes] 69 | return cls_score 70 | 71 | 72 | def _load_model(backbone_name, progress, modality, pretrained2d, num_classes, **kwargs): 73 | backbone = resnet3d(arch=backbone_name, progress=progress, modality=modality, pretrained2d=pretrained2d) 74 | classifier = I3DHead(num_classes=num_classes, in_channels=2048, **kwargs) 75 | model = I3D(backbone, classifier) 76 | return model 77 | 78 | 79 | def i3d_resnet18(modality='RGB', pretrained2d=True, progress=True, num_classes=21, **kwargs): 80 | """Constructs a I3D model with a ResNet3d-18 backbone. 81 | 82 | Args: 83 | modality (str): The modality of input data (RGB or Flow). If 'RGB', the first Conv 84 | accept a 3-channels input. (Default: RGB) 85 | pretrained2d (bool): If True, the backbone utilize the pretrained parameters in 2d 86 | models. (Default: True) 87 | progress (bool): If True, displays a progress bar of the download to stderr. 88 | (Default: True) 89 | num_classes (int): Number of dataset classes. (Default: 21) 90 | """ 91 | return _load_model('resnet18', progress, modality, pretrained2d, num_classes, **kwargs) 92 | 93 | 94 | def i3d_resnet34(modality='RGB', pretrained2d=True, progress=True, num_classes=21, **kwargs): 95 | """Constructs a I3D model with a ResNet3d-34 backbone. 96 | 97 | Args: 98 | modality (str): The modality of input data (RGB or Flow). If 'RGB', the first Conv 99 | accept a 3-channels input. (Default: RGB) 100 | pretrained2d (bool): If True, the backbone utilize the pretrained parameters in 2d 101 | models. (Default: True) 102 | progress (bool): If True, displays a progress bar of the download to stderr. 103 | (Default: True) 104 | num_classes (int): Number of dataset classes. (Default: 21) 105 | """ 106 | return _load_model('resnet34', progress, modality, pretrained2d, num_classes, **kwargs) 107 | 108 | 109 | def i3d_resnet50(modality='RGB', pretrained2d=True, progress=True, num_classes=21, **kwargs): 110 | """Constructs a I3D model with a ResNet3d-50 backbone. 111 | 112 | Args: 113 | modality (str): The modality of input data (RGB or Flow). If 'RGB', the first Conv 114 | accept a 3-channels input. (Default: RGB) 115 | pretrained2d (bool): If True, the backbone utilize the pretrained parameters in 2d 116 | models. (Default: True) 117 | progress (bool): If True, displays a progress bar of the download to stderr. 118 | (Default: True) 119 | num_classes (int): Number of dataset classes. (Default: 21) 120 | """ 121 | return _load_model('resnet50', progress, modality, pretrained2d, num_classes, **kwargs) 122 | 123 | 124 | def i3d_resnet101(modality='RGB', pretrained2d=True, progress=True, num_classes=21, **kwargs): 125 | """Constructs a I3D model with a ResNet3d-101 backbone. 126 | 127 | Args: 128 | modality (str): The modality of input data (RGB or Flow). If 'RGB', the first Conv 129 | accept a 3-channels input. (Default: RGB) 130 | pretrained2d (bool): If True, the backbone utilize the pretrained parameters in 2d 131 | models. (Default: True) 132 | progress (bool): If True, displays a progress bar of the download to stderr. 133 | (Default: True) 134 | num_classes (int): Number of dataset classes. (Default: 21) 135 | """ 136 | return _load_model('resnet101', progress, modality, pretrained2d, num_classes, **kwargs) 137 | 138 | 139 | def i3d_resnet152(modality='RGB', pretrained2d=True, progress=True, num_classes=21, **kwargs): 140 | """Constructs a I3D model with a ResNet3d-152 backbone. 141 | 142 | Args: 143 | modality (str): The modality of input data (RGB or Flow). If 'RGB', the first Conv 144 | accept a 3-channels input. (Default: RGB) 145 | pretrained2d (bool): If True, the backbone utilize the pretrained parameters in 2d 146 | models. (Default: True) 147 | progress (bool): If True, displays a progress bar of the download to stderr. 148 | (Default: True) 149 | num_classes (int): Number of dataset classes. (Default: 21) 150 | """ 151 | return _load_model('resnet152', progress, modality, pretrained2d, num_classes, **kwargs) 152 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser(description="PyTorch implementation of Inflated 3D ConvNets") 4 | parser.add_argument('dataset', type=str, choices=['ucf101', 'hmdb51', 'kinetics']) 5 | parser.add_argument('modality', type=str, choices=['RGB', 'Flow']) 6 | parser.add_argument('root_path', type=str) 7 | parser.add_argument('train_list', type=str) 8 | parser.add_argument('val_list', type=str) 9 | 10 | # ========================= Model Configs ========================== 11 | parser.add_argument('--arch', type=str, default="i3d_resnet50") 12 | parser.add_argument('--dropout', '--do', default=0.5, type=float, 13 | metavar='DO', help='dropout ratio (default: 0.5)') 14 | parser.add_argument('--clip_length', default=64, type=int, metavar='N', 15 | help='length of sequential frames (default: 64)') 16 | parser.add_argument('--input_size', default=224, type=int, metavar='N', 17 | help='size of input (default: 224)') 18 | parser.add_argument('--loss_type', type=str, default="nll", 19 | choices=['nll']) 20 | 21 | # ========================= Learning Configs ========================== 22 | parser.add_argument('--epochs', default=80, type=int, metavar='N', 23 | help='number of total epochs to run') 24 | parser.add_argument('-b', '--batch-size', default=16, type=int, 25 | metavar='N', help='mini-batch size (default: 16)') 26 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, 27 | metavar='LR', help='initial learning rate') 28 | parser.add_argument('--lr_steps', default=[30, 60], type=float, nargs="+", 29 | metavar='LRSteps', help='epochs to decay learning rate by 10') 30 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 31 | help='momentum') 32 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 33 | metavar='W', help='weight decay (default: 5e-4)') 34 | parser.add_argument('--clip-gradient', '--gd', default=None, type=float, 35 | metavar='W', help='gradient norm clipping (default: disabled)') 36 | 37 | # ========================= Monitor Configs ========================== 38 | parser.add_argument('--print-freq', '-p', default=20, type=int, 39 | metavar='N', help='print frequency (default: 10)') 40 | parser.add_argument('--eval-freq', '-ef', default=5, type=int, 41 | metavar='N', help='evaluation frequency (default: 5)') 42 | 43 | # ========================= Runtime Configs ========================== 44 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 45 | help='number of data loading workers (default: 4)') 46 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 47 | help='path to latest checkpoint (default: none)') 48 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 49 | help='evaluate model on validation set') 50 | parser.add_argument('--snapshot_pref', type=str, default="") 51 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 52 | help='manual epoch number (useful on restarts)') 53 | parser.add_argument('--gpus', nargs='+', type=int, default=None) 54 | parser.add_argument('--flow_prefix', default="flow_", type=str) 55 | -------------------------------------------------------------------------------- /test_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import os 4 | 5 | from sklearn.metrics import confusion_matrix 6 | 7 | from dataset import I3DDataSet 8 | from models import i3d 9 | from transforms import * 10 | 11 | # options 12 | parser = argparse.ArgumentParser(description="Standard video-level testing") 13 | parser.add_argument('dataset', type=str, choices=['ucf101', 'hmdb51', 'kinetics']) 14 | parser.add_argument('modality', type=str, choices=['RGB', 'Flow']) 15 | parser.add_argument('root_path', type=str) 16 | parser.add_argument('test_list', type=str) 17 | parser.add_argument('weights', type=str) 18 | parser.add_argument('--arch', type=str, default='i3d_resnet50') 19 | parser.add_argument('--save_scores', type=str, default=None) 20 | parser.add_argument('--max_num', type=int, default=-1) 21 | parser.add_argument('--input_size', type=int, default=224) 22 | parser.add_argument('--clip_length', default=250, type=int, metavar='N', 23 | help='length of sequential frames (default: 64)') 24 | parser.add_argument('--dropout', type=float, default=0.7) 25 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 26 | help='number of data loading workers (default: 4)') 27 | parser.add_argument('--gpus', nargs='+', type=int, default=None) 28 | parser.add_argument('--flow_prefix', type=str, default='flow_') 29 | 30 | args = parser.parse_args() 31 | 32 | if args.dataset == 'ucf101': 33 | num_classes = 101 34 | elif args.dataset == 'hmdb51': 35 | num_classes = 51 36 | elif args.dataset == 'kinetics': 37 | num_classes = 400 38 | else: 39 | raise ValueError('Unknown dataset ' + args.dataset) 40 | 41 | model = getattr(i3d, args.arch)(modality=args.modality, num_classes=num_classes, 42 | dropout_ratio=args.dropout) 43 | 44 | checkpoint = torch.load(args.weights) 45 | print("model epoch {} best prec@1: {}".format(checkpoint['epoch'], checkpoint['best_prec1'])) 46 | 47 | base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint['state_dict'].items())} 48 | model.load_state_dict(base_dict) 49 | 50 | # Data loading code 51 | crop_size = args.input_size 52 | scale_size = args.input_size * 256 // 224 53 | input_mean = [0.485, 0.456, 0.406] 54 | input_std = [0.229, 0.224, 0.225] 55 | if args.modality == 'Flow': 56 | input_mean = [0.5] 57 | input_std = [np.mean(input_std)] 58 | 59 | data_loader = torch.utils.data.DataLoader( 60 | I3DDataSet(args.root_path, args.test_list, clip_length=args.clip_length, modality=args.modality, 61 | image_tmpl="img_{:05d}.jpg" if args.modality == "RGB" else args.flow_prefix + "{}_{:05d}.jpg", 62 | transform=torchvision.transforms.Compose([ 63 | GroupScale(scale_size), 64 | GroupCenterCrop(crop_size), 65 | ToNumpyNDArray(), 66 | ToTorchFormatTensor(), 67 | GroupNormalize(input_mean, input_std), 68 | ]), 69 | test_mode=True), 70 | batch_size=1, shuffle=False, 71 | num_workers=args.workers * 2, pin_memory=True) 72 | 73 | if args.gpus is not None: 74 | devices = [args.gpus[i] for i in range(args.workers)] 75 | else: 76 | devices = list(range(args.workers)) 77 | 78 | model = torch.nn.DataParallel(model.cuda()) 79 | model.eval() 80 | 81 | data_gen = enumerate(data_loader) 82 | 83 | total_num = len(data_loader.dataset) 84 | output = [] 85 | 86 | 87 | def eval_video(video_data): 88 | i, data, label = video_data 89 | 90 | rst = model(data).data.cpu().numpy().copy() 91 | return i, rst, label[0] 92 | 93 | 94 | proc_start_time = time.time() 95 | max_num = args.max_num if args.max_num > 0 else len(data_loader.dataset) 96 | 97 | for i, (data, label) in data_gen: 98 | if i >= max_num: 99 | break 100 | rst = eval_video((i, data, label)) 101 | output.append(rst[1:]) 102 | cnt_time = time.time() - proc_start_time 103 | if i % 10 == 0: 104 | print('video {} done, total {}/{}, average {} sec/video'.format(i, i + 1, 105 | total_num, 106 | float(cnt_time) / (i + 1))) 107 | 108 | video_pred = [np.argmax(x[0]) for x in output] 109 | 110 | video_labels = [x[1] for x in output] 111 | 112 | cf = confusion_matrix(video_labels, video_pred).astype(float) 113 | 114 | cls_cnt = cf.sum(axis=1) 115 | cls_hit = np.diag(cf) 116 | 117 | cls_acc = cls_hit / cls_cnt 118 | 119 | print(cls_acc) 120 | 121 | print('Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100)) 122 | 123 | if args.save_scores is not None: 124 | 125 | # reorder before saving 126 | name_list = [x.strip().split()[0] for x in open(args.test_list)] 127 | 128 | order_dict = {e: i for i, e in enumerate(sorted(name_list))} 129 | 130 | reorder_output = [None] * len(output) 131 | reorder_label = [None] * len(output) 132 | 133 | for i in range(len(output)): 134 | idx = order_dict[name_list[i]] 135 | reorder_output[idx] = output[i] 136 | reorder_label[idx] = video_labels[i] 137 | 138 | np.savez(args.save_scores, scores=reorder_output, labels=reorder_label) 139 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import random 3 | from PIL import Image, ImageOps 4 | import numpy as np 5 | import numbers 6 | import math 7 | import torch 8 | 9 | 10 | class GroupRandomCrop(object): 11 | def __init__(self, size): 12 | if isinstance(size, numbers.Number): 13 | self.size = (int(size), int(size)) 14 | else: 15 | self.size = size 16 | 17 | def __call__(self, img_group): 18 | 19 | w, h = img_group[0].size 20 | th, tw = self.size 21 | 22 | out_images = list() 23 | 24 | x1 = random.randint(0, w - tw) 25 | y1 = random.randint(0, h - th) 26 | 27 | for img in img_group: 28 | assert (img.size[0] == w and img.size[1] == h) 29 | if w == tw and h == th: 30 | out_images.append(img) 31 | else: 32 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 33 | 34 | return out_images 35 | 36 | 37 | class GroupCenterCrop(object): 38 | def __init__(self, size): 39 | self.worker = torchvision.transforms.CenterCrop(size) 40 | 41 | def __call__(self, img_group): 42 | return [self.worker(img) for img in img_group] 43 | 44 | 45 | class GroupRandomHorizontalFlip(object): 46 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 47 | """ 48 | 49 | def __init__(self, is_flow=False): 50 | self.is_flow = is_flow 51 | 52 | def __call__(self, img_group, is_flow=False): 53 | v = random.random() 54 | if v < 0.5: 55 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] 56 | if self.is_flow: 57 | for i in range(0, len(ret), 2): 58 | ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping 59 | return ret 60 | else: 61 | return img_group 62 | 63 | 64 | class GroupNormalize(object): 65 | def __init__(self, mean, std): 66 | self.mean = mean 67 | self.std = std 68 | 69 | def __call__(self, tensor): 70 | rep_mean = self.mean * (tensor.size()[0] // len(self.mean)) 71 | rep_std = self.std * (tensor.size()[0] // len(self.std)) 72 | 73 | # TODO: make efficient 74 | for t, m, s in zip(tensor, rep_mean, rep_std): 75 | t.sub_(m).div_(s) 76 | 77 | return tensor 78 | 79 | 80 | class GroupScale(object): 81 | """ Rescales the input PIL.Image to the given 'size'. 82 | 'size' will be the size of the smaller edge. 83 | For example, if height > width, then image will be 84 | rescaled to (size * height / width, size) 85 | size: size of the smaller edge 86 | interpolation: Default: PIL.Image.BILINEAR 87 | """ 88 | 89 | def __init__(self, size, interpolation=Image.BILINEAR): 90 | self.worker = torchvision.transforms.Resize(size, interpolation) 91 | 92 | def __call__(self, img_group): 93 | return [self.worker(img) for img in img_group] 94 | 95 | 96 | class GroupOverSample(object): 97 | def __init__(self, crop_size, scale_size=None): 98 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) 99 | 100 | if scale_size is not None: 101 | self.scale_worker = GroupScale(scale_size) 102 | else: 103 | self.scale_worker = None 104 | 105 | def __call__(self, img_group): 106 | 107 | if self.scale_worker is not None: 108 | img_group = self.scale_worker(img_group) 109 | 110 | image_w, image_h = img_group[0].size 111 | crop_w, crop_h = self.crop_size 112 | 113 | offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h) 114 | oversample_group = list() 115 | for o_w, o_h in offsets: 116 | normal_group = list() 117 | flip_group = list() 118 | for i, img in enumerate(img_group): 119 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 120 | normal_group.append(crop) 121 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 122 | 123 | if img.mode == 'L' and i % 2 == 0: 124 | flip_group.append(ImageOps.invert(flip_crop)) 125 | else: 126 | flip_group.append(flip_crop) 127 | 128 | oversample_group.extend(normal_group) 129 | oversample_group.extend(flip_group) 130 | return oversample_group 131 | 132 | 133 | class GroupMultiScaleCrop(object): 134 | 135 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 136 | self.scales = scales if scales is not None else [1, .875, .75, .66] 137 | self.max_distort = max_distort 138 | self.fix_crop = fix_crop 139 | self.more_fix_crop = more_fix_crop 140 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 141 | self.interpolation = Image.BILINEAR 142 | 143 | def __call__(self, img_group): 144 | 145 | im_size = img_group[0].size 146 | 147 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 148 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 149 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) 150 | for img in crop_img_group] 151 | return ret_img_group 152 | 153 | def _sample_crop_size(self, im_size): 154 | image_w, image_h = im_size[0], im_size[1] 155 | 156 | # find a crop size 157 | base_size = min(image_w, image_h) 158 | crop_sizes = [int(base_size * x) for x in self.scales] 159 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 160 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 161 | 162 | pairs = [] 163 | for i, h in enumerate(crop_h): 164 | for j, w in enumerate(crop_w): 165 | if abs(i - j) <= self.max_distort: 166 | pairs.append((w, h)) 167 | 168 | crop_pair = random.choice(pairs) 169 | if not self.fix_crop: 170 | w_offset = random.randint(0, image_w - crop_pair[0]) 171 | h_offset = random.randint(0, image_h - crop_pair[1]) 172 | else: 173 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 174 | 175 | return crop_pair[0], crop_pair[1], w_offset, h_offset 176 | 177 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 178 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 179 | return random.choice(offsets) 180 | 181 | @staticmethod 182 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 183 | w_step = (image_w - crop_w) // 4 184 | h_step = (image_h - crop_h) // 4 185 | 186 | ret = list() 187 | ret.append((0, 0)) # upper left 188 | ret.append((4 * w_step, 0)) # upper right 189 | ret.append((0, 4 * h_step)) # lower left 190 | ret.append((4 * w_step, 4 * h_step)) # lower right 191 | ret.append((2 * w_step, 2 * h_step)) # center 192 | 193 | if more_fix_crop: 194 | ret.append((0, 2 * h_step)) # center left 195 | ret.append((4 * w_step, 2 * h_step)) # center right 196 | ret.append((2 * w_step, 4 * h_step)) # lower center 197 | ret.append((2 * w_step, 0 * h_step)) # upper center 198 | 199 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 200 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 201 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 202 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 203 | 204 | return ret 205 | 206 | 207 | class GroupRandomSizedCrop(object): 208 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size 209 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 210 | This is popularly used to train the Inception networks 211 | size: size of the smaller edge 212 | interpolation: Default: PIL.Image.BILINEAR 213 | """ 214 | 215 | def __init__(self, size, interpolation=Image.BILINEAR): 216 | self.size = size 217 | self.interpolation = interpolation 218 | 219 | def __call__(self, img_group): 220 | for attempt in range(10): 221 | area = img_group[0].size[0] * img_group[0].size[1] 222 | target_area = random.uniform(0.08, 1.0) * area 223 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 224 | 225 | w = int(round(math.sqrt(target_area * aspect_ratio))) 226 | h = int(round(math.sqrt(target_area / aspect_ratio))) 227 | 228 | if random.random() < 0.5: 229 | w, h = h, w 230 | 231 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]: 232 | x1 = random.randint(0, img_group[0].size[0] - w) 233 | y1 = random.randint(0, img_group[0].size[1] - h) 234 | found = True 235 | break 236 | else: 237 | found = False 238 | x1 = 0 239 | y1 = 0 240 | 241 | if found: 242 | out_group = list() 243 | for img in img_group: 244 | img = img.crop((x1, y1, x1 + w, y1 + h)) 245 | assert (img.size == (w, h)) 246 | out_group.append(img.resize((self.size, self.size), self.interpolation)) 247 | return out_group 248 | else: 249 | # Fallback 250 | scale = GroupScale(self.size, interpolation=self.interpolation) 251 | crop = GroupRandomCrop(self.size) 252 | return crop(scale(img_group)) 253 | 254 | 255 | class Stack(object): 256 | 257 | def __init__(self, roll=False): 258 | self.roll = roll 259 | 260 | def __call__(self, img_group): 261 | if img_group[0].mode == 'L': 262 | return np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2) 263 | elif img_group[0].mode == 'RGB': 264 | if self.roll: 265 | return np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2) 266 | else: 267 | return np.concatenate(img_group, axis=2) 268 | 269 | 270 | class ToNumpyNDArray(object): 271 | 272 | def __call__(self, img_group): 273 | if img_group[0].mode == 'L': 274 | return np.array([np.stack((np.array(img_group[x]), np.array(img_group[x + 1])), axis=-1) 275 | for x in range(0, len(img_group), 2)]) 276 | if img_group[0].mode == 'RGB': 277 | return np.array([np.array(x) for x in img_group]) 278 | 279 | 280 | class ToTorchFormatTensor(object): 281 | """ Converts a group of PIL.Image (RGB) or numpy.ndarray (T x H x W x C) in the range 282 | [0, 255] to a torch.FloatTensor of shape (C x T x H x W) in the range [0.0, 1.0] """ 283 | 284 | def __init__(self, div=True): 285 | self.div = div 286 | 287 | def __call__(self, pic): 288 | if isinstance(pic, np.ndarray): 289 | # handle numpy array 290 | # put it from THWC to CTHW format 291 | imgs = torch.from_numpy(pic).permute(3, 0, 1, 2).contiguous() 292 | else: 293 | # handle PIL Image 294 | imgs = list() 295 | for p in pic: 296 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(p.tobytes())) 297 | img = img.view(p.size[1], p.size[0], len(p.mode)) 298 | # put it from HWC to CHW format 299 | # yikes, this transpose takes 80% of the loading time/CPU 300 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 301 | imgs.append(img) 302 | imgs = torch.stack(imgs, dim=0) 303 | return imgs.float().div(255) if self.div else imgs.float() 304 | 305 | 306 | class IdentityTransform(object): 307 | 308 | def __call__(self, data): 309 | return data 310 | 311 | 312 | if __name__ == "__main__": 313 | trans = torchvision.transforms.Compose([ 314 | GroupScale(256), 315 | GroupRandomCrop(224), 316 | Stack(), 317 | ToTorchFormatTensor(), 318 | GroupNormalize( 319 | mean=[.485, .456, .406], 320 | std=[.229, .224, .225] 321 | )] 322 | ) 323 | 324 | im = Image.open('../tensorflow-model-zoo.torch/lena_299.png') 325 | 326 | color_group = [im] * 3 327 | rst = trans(color_group) 328 | 329 | gray_group = [im.convert('L')] * 9 330 | gray_rst = trans(gray_group) 331 | 332 | trans2 = torchvision.transforms.Compose([ 333 | GroupRandomSizedCrop(256), 334 | Stack(), 335 | ToTorchFormatTensor(), 336 | GroupNormalize( 337 | mean=[.485, .456, .406], 338 | std=[.229, .224, .225]) 339 | ]) 340 | print(trans2(color_group)) 341 | --------------------------------------------------------------------------------