├── ops ├── __init__.py ├── utils.py └── basic_ops.py ├── models ├── __init__.py └── i3dnon.py ├── README.md ├── train.sh ├── .gitignore ├── opts.py ├── dataset.py ├── i3d.py ├── transforms.py ├── testmodel.py └── main.py /ops/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .i3dnon import * 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # i3d-nonlocal-pytorch 2 | pytorch for i3d_nonlocal 3 | 4 | Usage 5 | 6 | sh train.sh 7 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python main.py kinetics RGB kinetics_files/train_all.txt kinetics_files/val_all.txt --arch resnet101 --snapshot_pref kinetics_i3dresnet101_ \ 2 | --lr 0.001 --lr_steps 20 80 --epochs 120 \ 3 | -b 32 -j 8 --dropout 0.5 -p 20 --gd 20 4 | -------------------------------------------------------------------------------- /ops/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.metrics import confusion_matrix 4 | 5 | def get_grad_hook(name): 6 | def hook(m, grad_in, grad_out): 7 | print((name, grad_out[0].data.abs().mean(), grad_in[0].data.abs().mean())) 8 | print((grad_out[0].size())) 9 | print((grad_in[0].size())) 10 | 11 | print((grad_out[0])) 12 | print((grad_in[0])) 13 | 14 | return hook 15 | 16 | 17 | def softmax(scores): 18 | es = np.exp(scores - scores.max(axis=-1)[..., None]) 19 | return es / es.sum(axis=-1)[..., None] 20 | 21 | 22 | def log_add(log_a, log_b): 23 | return log_a + np.log(1 + np.exp(log_b - log_a)) 24 | 25 | 26 | def class_accuracy(prediction, label): 27 | cf = confusion_matrix(prediction, label) 28 | cls_cnt = cf.sum(axis=1) 29 | cls_hit = np.diag(cf) 30 | 31 | cls_acc = cls_hit / cls_cnt.astype(float) 32 | 33 | mean_cls_acc = cls_acc.mean() 34 | 35 | return cls_acc, mean_cls_acc 36 | -------------------------------------------------------------------------------- /ops/basic_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | class Identity(torch.nn.Module): 6 | def forward(self, input): 7 | return input 8 | 9 | 10 | class SegmentConsensus(torch.autograd.Function): 11 | 12 | def __init__(self, consensus_type, dim=1): 13 | self.consensus_type = consensus_type 14 | self.dim = dim 15 | self.shape = None 16 | 17 | def forward(self, input_tensor): 18 | self.shape = input_tensor.size() 19 | if self.consensus_type == 'avg': 20 | output = input_tensor.mean(dim=self.dim, keepdim=True) 21 | elif self.consensus_type == 'identity': 22 | output = input_tensor 23 | else: 24 | output = None 25 | 26 | return output 27 | 28 | def backward(self, grad_output): 29 | if self.consensus_type == 'avg': 30 | grad_in = grad_output.expand(self.shape) / float(self.shape[self.dim]) 31 | elif self.consensus_type == 'identity': 32 | grad_in = grad_output 33 | else: 34 | grad_in = None 35 | 36 | return grad_in 37 | 38 | 39 | class ConsensusModule(torch.nn.Module): 40 | 41 | def __init__(self, consensus_type, dim=1): 42 | super(ConsensusModule, self).__init__() 43 | self.consensus_type = consensus_type if consensus_type != 'rnn' else 'identity' 44 | self.dim = dim 45 | 46 | def forward(self, input): 47 | return SegmentConsensus(self.consensus_type, self.dim)(input) 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | parser = argparse.ArgumentParser(description="PyTorch implementation of Temporal Segment Networks") 3 | parser.add_argument('dataset', type=str, choices=['ucf101', 'hmdb51', 'kinetics']) 4 | parser.add_argument('modality', type=str, choices=['RGB', 'Flow', 'RGBDiff']) 5 | parser.add_argument('train_list', type=str) 6 | parser.add_argument('val_list', type=str) 7 | 8 | # ========================= Model Configs ========================== 9 | parser.add_argument('--arch', type=str, default="resnet101") 10 | parser.add_argument('--sample_frames', type=int, default=32) 11 | parser.add_argument('--k', type=int, default=3) 12 | 13 | parser.add_argument('--dropout', '--do', default=0.5, type=float, 14 | metavar='DO', help='dropout ratio (default: 0.5)') 15 | parser.add_argument('--loss_type', type=str, default="nll", 16 | choices=['nll']) 17 | 18 | # ========================= Learning Configs ========================== 19 | parser.add_argument('--epochs', default=120, type=int, metavar='N', 20 | help='number of total epochs to run') 21 | parser.add_argument('-b', '--batch-size', default=256, type=int, 22 | metavar='N', help='mini-batch size (default: 256)') 23 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 24 | metavar='LR', help='initial learning rate') 25 | parser.add_argument('--lr_steps', default=[45, 90], type=float, nargs="+", 26 | metavar='LRSteps', help='epochs to decay learning rate by 10') 27 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 28 | help='momentum') 29 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 30 | metavar='W', help='weight decay (default: 5e-4)') 31 | parser.add_argument('--clip-gradient', '--gd', default=None, type=float, 32 | metavar='W', help='gradient norm clipping (default: disabled)') 33 | 34 | # ========================= Monitor Configs ========================== 35 | parser.add_argument('--print-freq', '-p', default=20, type=int, 36 | metavar='N', help='print frequency (default: 10)') 37 | parser.add_argument('--eval-freq', '-ef', default=5, type=int, 38 | metavar='N', help='evaluation frequency (default: 5)') 39 | 40 | 41 | # ========================= Runtime Configs ========================== 42 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 43 | help='number of data loading workers (default: 4)') 44 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 45 | help='path to latest checkpoint (default: none)') 46 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 47 | help='evaluate model on validation set') 48 | parser.add_argument('--snapshot_pref', type=str, default="") 49 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 50 | help='manual epoch number (useful on restarts)') 51 | parser.add_argument('--gpus', nargs='+', type=int, default=None) 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | from numpy.random import randint 8 | 9 | class VideoRecord(object): 10 | def __init__(self, row): 11 | self._data = row 12 | 13 | @property 14 | def path(self): 15 | return self._data[0] 16 | 17 | @property 18 | def num_frames(self): 19 | return int(self._data[1]) 20 | 21 | @property 22 | def label(self): 23 | return int(self._data[2]) 24 | 25 | 26 | class I3DDataSet(data.Dataset): 27 | def __init__(self, root_path, list_file, 28 | sample_frames=32, modality='RGB', 29 | image_tmpl='img_{:05d}.jpg', transform=None, 30 | force_grayscale=False, train_mode=True, test_clips=10): 31 | 32 | self.root_path = root_path 33 | self.list_file = list_file 34 | self.sample_frames = sample_frames 35 | self.modality = modality 36 | self.image_tmpl = image_tmpl 37 | self.transform = transform 38 | self.train_mode = train_mode 39 | if not self.train_mode: 40 | self.num_clips = test_clips 41 | 42 | self._parse_list() 43 | 44 | def _load_image(self, directory, idx): 45 | if self.modality == 'RGB' or self.modality == 'RGBDiff': 46 | img_path = os.path.join(directory, self.image_tmpl.format(idx)) 47 | try: 48 | return [Image.open(img_path).convert('RGB')] 49 | except: 50 | print("Couldn't load image:{}".format(img_path)) 51 | return None 52 | elif self.modality == 'Flow': 53 | x_img = Image.open(os.path.join(directory, self.image_tmpl.format('x', idx))).convert('L') 54 | y_img = Image.open(os.path.join(directory, self.image_tmpl.format('y', idx))).convert('L') 55 | 56 | return [x_img, y_img] 57 | 58 | def _parse_list(self): 59 | self.video_list = [VideoRecord(x.strip().split(' ')) for x in open(self.list_file)] 60 | 61 | def _sample_indices(self, record): 62 | """ 63 | :param record: VideoRecord 64 | :return: list 65 | """ 66 | expanded_sample_length = self.sample_frames * 4 # in order to drop every other frame 67 | if record.num_frames >= expanded_sample_length: 68 | start_pos = randint(record.num_frames - expanded_sample_length + 1) 69 | offsets = range(start_pos, start_pos + expanded_sample_length, 4) 70 | elif record.num_frames > self.sample_frames*2: 71 | start_pos = randint(record.num_frames - self.sample_frames*2 + 1) 72 | offsets = range(start_pos, start_pos + self.sample_frames*2, 2) 73 | elif record.num_frames > self.sample_frames: 74 | start_pos = randint(record.num_frames - self.sample_frames + 1) 75 | offsets = range(start_pos, start_pos + self.sample_frames, 1) 76 | else: 77 | offsets = np.sort(randint(record.num_frames, size=self.sample_frames)) 78 | 79 | offsets =[int(v)+1 for v in offsets] # images are 1-indexed 80 | return offsets 81 | 82 | def _get_test_indices(self, record): 83 | tick = (record.num_frames - self.sample_frames*2 + 1) / float(self.num_clips) 84 | sample_start_pos = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_clips)]) 85 | offsets = [] 86 | for p in sample_start_pos: 87 | offsets.extend(range(p,p+self.sample_frames*2,2)) 88 | 89 | checked_offsets = [] 90 | for f in offsets: 91 | new_f = int(f) + 1 92 | if new_f < 1: 93 | new_f = 1 94 | elif new_f >= record.num_frames: 95 | new_f = record.num_frames - 1 96 | checked_offsets.append(new_f) 97 | 98 | return checked_offsets 99 | 100 | 101 | def __getitem__(self, index): 102 | record = self.video_list[index] 103 | 104 | if self.train_mode: 105 | segment_indices = self._sample_indices(record) 106 | process_data, label = self.get(record, segment_indices) 107 | while process_data is None: 108 | index = randint(0, len(self.video_list) - 1) 109 | process_data, label = self.__getitem__(index) 110 | else: 111 | segment_indices = self._get_test_indices(record) 112 | process_data,label = self.get(record, segment_indices) 113 | if process_data is None: 114 | raise ValueError('sample indices:', record.path, segment_indices) 115 | 116 | return process_data,label 117 | 118 | 119 | def get(self, record, indices): 120 | 121 | images = list() 122 | for ind in indices: 123 | seg_img = self._load_image(record.path, ind) 124 | if seg_img is None: 125 | return None,None 126 | images.extend(seg_img) 127 | 128 | process_data = self.transform(images) 129 | return process_data, record.label 130 | 131 | def __len__(self): 132 | return len(self.video_list) 133 | -------------------------------------------------------------------------------- /i3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch.nn.init import normal, constant 4 | from transforms import * 5 | import models.i3dnon 6 | class I3DModel(torch.nn.Module): 7 | def __init__(self, num_class, sample_frames, modality, 8 | base_model='resnet101', 9 | dropout=0.8): 10 | super(I3DModel, self).__init__() 11 | self.modality = modality 12 | self.sample_frames = sample_frames 13 | self.reshape = True 14 | self.dropout = dropout 15 | 16 | print((""" 17 | Initializing I3D with base model: {}. 18 | I3D Configurations: 19 | input_modality: {} 20 | sample_frames: {} 21 | dropout_ratio: {} 22 | """.format(base_model, self.modality, self.sample_frames, self.dropout))) 23 | 24 | self._prepare_base_model(base_model) 25 | self._prepare_i3d(num_class) 26 | 27 | 28 | def _prepare_i3d(self, num_class): 29 | feature_dim = getattr(self.base_model, self.base_model.last_layer_name).in_features 30 | if self.dropout == 0: 31 | setattr(self.base_model, self.base_model.last_layer_name, torch.nn.Linear(feature_dim, num_class)) 32 | self.new_fc = None 33 | else: 34 | setattr(self.base_model, self.base_model.last_layer_name, torch.nn.Dropout(p=self.dropout)) 35 | self.new_fc = torch.nn.Linear(feature_dim, num_class) 36 | 37 | std = 0.001 38 | if self.new_fc is None: 39 | normal(getattr(self.base_model, self.base_model.last_layer_name).weight, 0, std) 40 | constant(getattr(self.base_model, self.base_model.last_layer_name).bias, 0) 41 | else: 42 | normal(self.new_fc.weight, 0, std) 43 | constant(self.new_fc.bias, 0) 44 | 45 | def _prepare_base_model(self, base_model): 46 | if 'resnet101' in base_model or 'resnet152' in base_model: 47 | self.base_model = getattr(models, base_model)(pretrained=True) 48 | self.base_model.last_layer_name = 'fc' 49 | self.input_size = 224 50 | self.input_mean = [0.485, 0.456, 0.406] 51 | self.input_std = [0.229, 0.224, 0.225] 52 | else: 53 | raise ValueError('Unknown base model: {}'.format(base_model)) 54 | 55 | 56 | def get_optim_policies(self): 57 | first_conv_weight = [] 58 | first_conv_bias = [] 59 | normal_weight = [] 60 | normal_bias = [] 61 | bn = [] 62 | 63 | conv_cnt = 0 64 | for m in self.modules(): 65 | if isinstance(m, torch.nn.Conv3d): 66 | ps = list(m.parameters()) 67 | conv_cnt += 1 68 | if conv_cnt == 1: 69 | first_conv_weight.append(ps[0]) 70 | if len(ps) == 2: 71 | first_conv_bias.append(ps[1]) 72 | else: 73 | normal_weight.append(ps[0]) 74 | if len(ps) == 2: 75 | normal_bias.append(ps[1]) 76 | elif isinstance(m, torch.nn.Linear): 77 | ps = list(m.parameters()) 78 | normal_weight.append(ps[0]) 79 | if len(ps) == 2: 80 | normal_bias.append(ps[1]) 81 | elif isinstance(m, torch.nn.BatchNorm3d): # enable BN 82 | bn.extend(list(m.parameters())) 83 | elif len(m._modules) == 0: 84 | if len(list(m.parameters())) > 0: 85 | raise ValueError("New atomic module type: {}. Need to give it a learning policy".format(type(m))) 86 | 87 | return [ 88 | {'params': first_conv_weight, 'lr_mult': 1, 'decay_mult': 1, 89 | 'name': "first_conv_weight"}, 90 | {'params': first_conv_bias, 'lr_mult': 2, 'decay_mult': 0, 91 | 'name': "first_conv_bias"}, 92 | {'params': normal_weight, 'lr_mult': 1, 'decay_mult': 1, 93 | 'name': "normal_weight"}, 94 | {'params': normal_bias, 'lr_mult': 2, 'decay_mult': 0, 95 | 'name': "normal_bias"}, 96 | {'params': bn, 'lr_mult': 1, 'decay_mult': 0, 97 | 'name': "BN scale/shift"}, 98 | ] 99 | 100 | 101 | def forward(self, input): 102 | out = self.base_model(input) 103 | if self.dropout > 0: 104 | out = self.new_fc(out) 105 | 106 | return out 107 | 108 | 109 | @property 110 | def crop_size(self): 111 | return self.input_size 112 | 113 | @property 114 | def scale_size(self): 115 | return self.input_size * 256 // 224 116 | 117 | def get_augmentation(self,mode='train'): 118 | resize_range_min = self.scale_size 119 | if mode == 'train': 120 | resize_range_max = self.input_size * 320 // 224 121 | return torchvision.transforms.Compose( 122 | [GroupRandomResizeCrop([resize_range_min, resize_range_max], self.input_size), 123 | GroupRandomHorizontalFlip(is_flow=False), 124 | GroupColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05)]) 125 | elif mode == 'val': 126 | return torchvision.transforms.Compose([GroupScale(resize_range_min), 127 | GroupCenterCrop(self.input_size)]) 128 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import random 3 | from PIL import Image, ImageOps 4 | import numpy as np 5 | import numbers 6 | import math 7 | import torch 8 | 9 | 10 | class GroupRandomCrop(object): 11 | def __init__(self, size): 12 | if isinstance(size, numbers.Number): 13 | self.size = (int(size), int(size)) 14 | else: 15 | self.size = size 16 | 17 | def __call__(self, img_group): 18 | 19 | w, h = img_group[0].size 20 | th, tw = self.size 21 | 22 | out_images = list() 23 | 24 | x1 = random.randint(0, w - tw) 25 | y1 = random.randint(0, h - th) 26 | 27 | for img in img_group: 28 | assert(img.size[0] == w and img.size[1] == h) 29 | if w == tw and h == th: 30 | out_images.append(img) 31 | else: 32 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 33 | 34 | return out_images 35 | 36 | 37 | class GroupCenterCrop(object): 38 | def __init__(self, size): 39 | self.worker = torchvision.transforms.CenterCrop(size) 40 | 41 | def __call__(self, img_group): 42 | return [self.worker(img) for img in img_group] 43 | 44 | 45 | class GroupColorJitter(object): 46 | def __init__(self, brightness=0,contrast=0,saturation=0,hue=0): 47 | self.worker = torchvision.transforms.ColorJitter(brightness,contrast,saturation,hue) 48 | 49 | def __call__(self, img_group): 50 | return [self.worker(img) for img in img_group] 51 | 52 | 53 | class GroupRandomHorizontalFlip(object): 54 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 55 | """ 56 | def __init__(self, is_flow=False): 57 | self.is_flow = is_flow 58 | 59 | def __call__(self, img_group, is_flow=False): 60 | v = random.random() 61 | if v < 0.5: 62 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] 63 | if self.is_flow: 64 | for i in range(0, len(ret), 2): 65 | ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping 66 | return ret 67 | else: 68 | return img_group 69 | 70 | 71 | class GroupNormalize(object): 72 | def __init__(self, mean, std): 73 | self.mean = mean 74 | self.std = std 75 | 76 | def __call__(self, tensor): 77 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean)) 78 | rep_std = self.std * (tensor.size()[0]//len(self.std)) 79 | 80 | # TODO: make efficient 81 | for t, m, s in zip(tensor, rep_mean, rep_std): 82 | t.sub_(m).div_(s) 83 | 84 | return tensor 85 | 86 | 87 | class GroupScale(object): 88 | """ Rescales the input PIL.Image to the given 'size'. 89 | 'size' will be the size of the smaller edge. 90 | For example, if height > width, then image will be 91 | rescaled to (size * height / width, size) 92 | size: size of the smaller edge 93 | interpolation: Default: PIL.Image.BILINEAR 94 | """ 95 | 96 | def __init__(self, size, interpolation=Image.BILINEAR): 97 | self.worker = torchvision.transforms.Scale(size, interpolation) 98 | 99 | def __call__(self, img_group): 100 | return [self.worker(img) for img in img_group] 101 | 102 | 103 | 104 | class GroupRandomResizeCrop(object): 105 | """ 106 | random resize image to shorter size = [256,320] (e.g.), 107 | and random crop image to 224[e.g.] 108 | p.s.: if input size > 224, resize_range should be enlarged in equal proportion 109 | """ 110 | def __init__(self, resize_range, input_size, interpolation=Image.BILINEAR): 111 | self.resize_range = resize_range 112 | self.crop_worker = GroupRandomCrop(input_size) 113 | self.interpolation = interpolation 114 | 115 | def __call__(self, img_group): 116 | resize_size = random.randint(self.resize_range[0],self.resize_range[1]) 117 | resize_worker = GroupScale(resize_size) 118 | resized_img_group = resize_worker(img_group) 119 | crop_img_group = self.crop_worker(resized_img_group) 120 | 121 | return crop_img_group 122 | 123 | 124 | class Stack(object): 125 | 126 | def __call__(self, img_group): 127 | stacked_group = np.concatenate([np.expand_dims(x, 3) for x in img_group], axis=3) 128 | 129 | return stacked_group 130 | 131 | 132 | class ToTorchFormatTensor(object): 133 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C x D) in the range [0, 255] 134 | to a torch.FloatTensor of shape (C x D x H x W) in the range [0.0, 1.0] """ 135 | 136 | def __call__(self, pic): 137 | # handle numpy array 138 | img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous() 139 | 140 | return img.float().div(255) 141 | 142 | 143 | class IdentityTransform(object): 144 | 145 | def __call__(self, data): 146 | return data 147 | 148 | 149 | if __name__ == "__main__": 150 | trans = torchvision.transforms.Compose([ 151 | GroupScale(256), 152 | GroupRandomCrop(224), 153 | Stack(), 154 | ToTorchFormatTensor(), 155 | GroupNormalize( 156 | mean=[.485, .456, .406], 157 | std=[.229, .224, .225] 158 | )] 159 | ) 160 | 161 | im = Image.open('../tensorflow-model-zoo.torch/lena_299.png') 162 | 163 | color_group = [im] * 3 164 | rst = trans(color_group) 165 | -------------------------------------------------------------------------------- /testmodel.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import numpy as np 5 | import torch.nn.parallel 6 | import torch.optim 7 | import torchvision 8 | from sklearn.metrics import confusion_matrix 9 | 10 | from dataset import I3DDataSet 11 | from i3d import I3DModel 12 | from transforms import * 13 | 14 | # options 15 | parser = argparse.ArgumentParser( 16 | description="Standard video-level testing") 17 | parser.add_argument('dataset', type=str, choices=['ucf101', 'hmdb51', 'kinetics', 'ucf-crime']) 18 | parser.add_argument('modality', type=str, choices=['RGB', 'Flow', 'RGBDiff']) 19 | parser.add_argument('test_list', type=str) 20 | parser.add_argument('weights', type=str) 21 | parser.add_argument('--arch', type=str, default="resnet101") 22 | parser.add_argument('--save_scores', type=str, default=None) 23 | parser.add_argument('--test_clips', type=int, default=10) 24 | parser.add_argument('--sample_frames', type=int, default=32) 25 | parser.add_argument('--test_crops', type=int, default=1) 26 | parser.add_argument('--input_size', type=int, default=224) 27 | parser.add_argument('--k', type=int, default=3) 28 | parser.add_argument('--dropout', type=float, default=0.5) 29 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 30 | help='number of data loading workers (default: 4)') 31 | parser.add_argument('--gpus', nargs='+', type=int, default=None) 32 | parser.add_argument('--flow_prefix', type=str, default='') 33 | 34 | args = parser.parse_args() 35 | 36 | 37 | if args.dataset == 'ucf101': 38 | num_class = 101 39 | elif args.dataset == 'hmdb51': 40 | num_class = 51 41 | elif args.dataset == 'kinetics': 42 | num_class = 400 43 | else: 44 | raise ValueError('Unknown dataset '+args.dataset) 45 | 46 | i3d_model = I3DModel(num_class, args.sample_frames, args.modality, 47 | base_model=args.arch, dropout=args.dropout) 48 | 49 | checkpoint = torch.load(args.weights) 50 | print("model epoch {} best prec@1: {}".format(checkpoint['epoch'], checkpoint['best_prec1'])) 51 | 52 | base_dict = {'.'.join(k.split('.')[1:]): v for k,v in list(checkpoint['state_dict'].items())} 53 | i3d_model.load_state_dict(base_dict) 54 | 55 | if args.test_crops == 1: 56 | cropping = torchvision.transforms.Compose([ 57 | GroupScale(i3d_model.scale_size), 58 | GroupCenterCrop(i3d_model.input_size), 59 | ]) 60 | elif args.test_crops == 10: 61 | cropping = torchvision.transforms.Compose([ 62 | GroupOverSample(i3d_model.input_size, i3d_model.scale_size) 63 | ]) 64 | else: 65 | raise ValueError("Only 1 and 10 crops are supported while we got {}".format(args.test_crops)) 66 | 67 | data_loader = torch.utils.data.DataLoader( 68 | I3DDataSet("", args.test_list, sample_frames=args.sample_frames, 69 | modality=args.modality, 70 | image_tmpl="image_{:05d}.jpg" if args.modality in ['RGB', 'RGBDiff'] else args.flow_prefix+"{}_{:05d}.jpg", 71 | train_mode=False, test_clips=args.test_clips, 72 | transform=torchvision.transforms.Compose([ 73 | cropping, 74 | Stack(), 75 | ToTorchFormatTensor(), 76 | GroupNormalize(i3d_model.input_mean, i3d_model.input_std), 77 | ])), 78 | batch_size=1, shuffle=False, 79 | num_workers=args.workers * 2, pin_memory=True) 80 | 81 | if args.gpus is not None: 82 | devices = [args.gpus[i] for i in range(args.workers)] 83 | else: 84 | devices = list(range(args.workers)) 85 | 86 | i3d_model = torch.nn.DataParallel(i3d_model.cuda(devices[0]), device_ids=devices) 87 | i3d_model.eval() 88 | 89 | data_gen = enumerate(data_loader) 90 | 91 | total_num = len(data_loader.dataset) 92 | output = [] 93 | 94 | def eval_video(video_data): 95 | i, data, label = video_data 96 | 97 | if args.modality == 'RGB': 98 | num_channel = 3 99 | num_depth = 32 100 | else: 101 | raise ValueError("Unknown modality "+args.modality) 102 | data = data.squeeze(0) 103 | data = data.view(num_channel,-1,num_depth,data.size(2),data.size(3)).contiguous() 104 | data = data.permute(1,0,2,3,4).contiguous() 105 | #data = data.view(data.size(0),num_channel,-1,num_depth,data.size(3),data.size(4)) 106 | #data = data.squeeze(0).permute(1,0,2,3,4).contiguous() 107 | input_var = torch.autograd.Variable(data, volatile=True) 108 | rst = i3d_model(input_var).data.cpu().numpy().copy() 109 | return i, rst.reshape((args.test_clips*args.test_crops, num_class)).mean(axis=0).reshape((1, num_class)), \ 110 | label[0] 111 | 112 | 113 | proc_start_time = time.time() 114 | 115 | for i, (data, label) in data_gen: 116 | rst = eval_video((i, data, label)) 117 | output.append(rst[1:]) 118 | cnt_time = time.time() - proc_start_time 119 | print('video {} done, total {}/{}, average {} sec/video'.format(i, i+1, 120 | total_num, 121 | float(cnt_time) / (i+1))) 122 | 123 | video_pred = [np.argmax(x[0]) for x in output] 124 | 125 | video_labels = [x[1] for x in output] 126 | 127 | 128 | cf = confusion_matrix(video_labels, video_pred).astype(float) 129 | 130 | cls_cnt = cf.sum(axis=1) 131 | cls_hit = np.diag(cf) 132 | 133 | cls_acc = cls_hit / cls_cnt 134 | 135 | print(cls_acc) 136 | 137 | print('Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100)) 138 | 139 | if args.save_scores is not None: 140 | 141 | # reorder before saving 142 | name_list = [x.strip().split()[0] for x in open(args.test_list)] 143 | 144 | order_dict = {e:i for i, e in enumerate(sorted(name_list))} 145 | 146 | reorder_output = [None] * len(output) 147 | reorder_label = [None] * len(output) 148 | 149 | for i in range(len(output)): 150 | idx = order_dict[name_list[i]] 151 | reorder_output[idx] = output[i] 152 | reorder_label[idx] = video_labels[i] 153 | 154 | np.savez(args.save_scores, scores=reorder_output, labels=reorder_label) 155 | 156 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import shutil 4 | import torch 5 | import torchvision 6 | import torch.nn.parallel 7 | import torch.backends.cudnn as cudnn 8 | import torch.optim 9 | from torch.nn.utils import clip_grad_norm 10 | 11 | from dataset import I3DDataSet 12 | from i3d import I3DModel 13 | from transforms import * 14 | from opts import parser 15 | 16 | best_prec1 = 0 17 | 18 | def main(): 19 | global args, best_prec1 20 | args = parser.parse_args() 21 | 22 | if args.dataset == 'ucf101': 23 | num_class = 101 24 | elif args.dataset == 'hmdb51': 25 | num_class = 51 26 | elif args.dataset == 'kinetics': 27 | num_class = 400 28 | else: 29 | raise ValueError('Unknown dataset '+args.dataset) 30 | 31 | i3d_model = I3DModel(num_class, args.sample_frames, args.modality, 32 | base_model=args.arch, dropout=args.dropout) 33 | if args.resume: 34 | if os.path.isfile(args.resume): 35 | print(("=> loading checkpoint '{}'".format(args.resume))) 36 | checkpoint = torch.load(args.resume) 37 | args.start_epoch = checkpoint['epoch'] 38 | best_prec1 = checkpoint['best_prec1'] 39 | i3d_model.load_state_dict(checkpoint['state_dict']) 40 | print(("=> loaded checkpoint '{}' (epoch {})" 41 | .format(args.evaluate, checkpoint['epoch']))) 42 | else: 43 | print(("=> no checkpoint found at '{}'".format(args.resume))) 44 | 45 | cudnn.benchmark = False 46 | 47 | # Data loading code 48 | input_mean = i3d_model.input_mean 49 | input_std = i3d_model.input_std 50 | policies = i3d_model.get_optim_policies() 51 | train_augmentation = i3d_model.get_augmentation(mode='train') 52 | val_trans = i3d_model.get_augmentation(mode='val') 53 | normalize = GroupNormalize(input_mean, input_std) 54 | 55 | i3d_model = torch.nn.DataParallel(i3d_model, device_ids=args.gpus).cuda() 56 | train_loader = torch.utils.data.DataLoader( 57 | I3DDataSet("", args.train_list, 58 | sample_frames=args.sample_frames, 59 | modality=args.modality, 60 | image_tmpl="image_{:05d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+"{}_{:05d}.jpg", 61 | transform=torchvision.transforms.Compose([ 62 | train_augmentation, 63 | Stack(), 64 | ToTorchFormatTensor(), 65 | normalize, 66 | ])), 67 | batch_size=args.batch_size, shuffle=True, 68 | num_workers=args.workers, pin_memory=True) 69 | 70 | val_loader = torch.utils.data.DataLoader( 71 | I3DDataSet("", args.val_list, 72 | sample_frames=args.sample_frames, 73 | modality=args.modality, 74 | image_tmpl="image_{:05d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+"{}_{:05d}.jpg", 75 | transform=torchvision.transforms.Compose([ 76 | val_trans, 77 | Stack(), 78 | ToTorchFormatTensor(), 79 | normalize, 80 | ])), 81 | batch_size=args.batch_size, shuffle=False, 82 | num_workers=args.workers, pin_memory=True) 83 | 84 | # define loss function (criterion) and optimizer 85 | if args.loss_type == 'nll': 86 | criterion = torch.nn.CrossEntropyLoss().cuda() 87 | else: 88 | raise ValueError("Unknown loss type") 89 | 90 | for group in policies: 91 | print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format( 92 | group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) 93 | 94 | optimizer = torch.optim.SGD(policies, args.lr, 95 | momentum=args.momentum, 96 | weight_decay=args.weight_decay) 97 | 98 | if args.evaluate: 99 | validate(val_loader,i3d_model, criterion, 0) 100 | return 101 | 102 | for epoch in range(args.start_epoch, args.epochs): 103 | adjust_learning_rate(optimizer, epoch, args.lr_steps) 104 | 105 | # train for one epoch 106 | train(train_loader, i3d_model, criterion, optimizer, epoch) 107 | 108 | # evaluate on validation set 109 | if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: 110 | prec1 = validate(val_loader, i3d_model, criterion, (epoch + 1) * len(train_loader)) 111 | 112 | # remember best prec@1 and save checkpoint 113 | is_best = prec1 > best_prec1 114 | best_prec1 = max(prec1, best_prec1) 115 | save_checkpoint({ 116 | 'epoch': epoch + 1, 117 | 'arch': args.arch, 118 | 'state_dict': i3d_model.state_dict(), 119 | 'best_prec1': best_prec1, 120 | }, is_best) 121 | 122 | 123 | def train(train_loader, model, criterion, optimizer, epoch): 124 | batch_time = AverageMeter() 125 | data_time = AverageMeter() 126 | losses = AverageMeter() 127 | top1 = AverageMeter() 128 | top5 = AverageMeter() 129 | 130 | # switch to train mode 131 | model.train() 132 | 133 | end = time.time() 134 | for i, (input, target) in enumerate(train_loader): 135 | # measure data loading time 136 | data_time.update(time.time() - end) 137 | 138 | target = target.cuda(async=True) 139 | input_var = torch.autograd.Variable(input) 140 | target_var = torch.autograd.Variable(target) 141 | 142 | # compute output 143 | output = model(input_var) 144 | loss = criterion(output, target_var) 145 | 146 | # measure accuracy and record loss 147 | prec1, prec5 = accuracy(output.data, target, topk=(1,5)) 148 | losses.update(loss.data[0], input.size(0)) 149 | top1.update(prec1[0], input.size(0)) 150 | top5.update(prec5[0], input.size(0)) 151 | 152 | 153 | # compute gradient and do SGD step 154 | optimizer.zero_grad() 155 | 156 | loss.backward() 157 | 158 | if args.clip_gradient is not None: 159 | total_norm = clip_grad_norm(model.parameters(), args.clip_gradient) 160 | if total_norm > args.clip_gradient: 161 | print("clipping gradient: {} with coef {}".format(total_norm, args.clip_gradient / total_norm)) 162 | 163 | optimizer.step() 164 | 165 | # measure elapsed time 166 | batch_time.update(time.time() - end) 167 | end = time.time() 168 | 169 | if i % args.print_freq == 0: 170 | print(('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t' 171 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 172 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 173 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 174 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 175 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 176 | epoch, i, len(train_loader), batch_time=batch_time, 177 | data_time=data_time, loss=losses, top1=top1, top5=top5, lr=optimizer.param_groups[-1]['lr']))) 178 | 179 | 180 | def validate(val_loader, model, criterion, iter, logger=None): 181 | batch_time = AverageMeter() 182 | losses = AverageMeter() 183 | top1 = AverageMeter() 184 | top5 = AverageMeter() 185 | 186 | # switch to evaluate mode 187 | model.eval() 188 | 189 | end = time.time() 190 | for i, (input, target) in enumerate(val_loader): 191 | target = target.cuda(async=True) 192 | input_var = torch.autograd.Variable(input, volatile=True) 193 | target_var = torch.autograd.Variable(target, volatile=True) 194 | 195 | # compute output 196 | output = model(input_var) 197 | loss = criterion(output, target_var) 198 | 199 | # measure accuracy and record loss 200 | prec1, prec5 = accuracy(output.data, target, topk=(1,5)) 201 | 202 | losses.update(loss.data[0], input.size(0)) 203 | top1.update(prec1[0], input.size(0)) 204 | top5.update(prec5[0], input.size(0)) 205 | 206 | # measure elapsed time 207 | batch_time.update(time.time() - end) 208 | end = time.time() 209 | 210 | if i % args.print_freq == 0: 211 | print(('Test: [{0}/{1}]\t' 212 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 213 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 214 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 215 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 216 | i, len(val_loader), batch_time=batch_time, loss=losses, 217 | top1=top1, top5=top5))) 218 | 219 | print(('Testing Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}' 220 | .format(top1=top1, top5=top5, loss=losses))) 221 | 222 | return top1.avg 223 | 224 | 225 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 226 | filename = '_'.join((args.snapshot_pref, args.modality.lower(), filename)) 227 | torch.save(state, filename) 228 | if is_best: 229 | best_name = '_'.join((args.snapshot_pref, args.modality.lower(), 'model_best.pth.tar')) 230 | shutil.copyfile(filename, best_name) 231 | 232 | 233 | class AverageMeter(object): 234 | """Computes and stores the average and current value""" 235 | def __init__(self): 236 | self.reset() 237 | 238 | def reset(self): 239 | self.val = 0 240 | self.avg = 0 241 | self.sum = 0 242 | self.count = 0 243 | 244 | def update(self, val, n=1): 245 | self.val = val 246 | self.sum += val * n 247 | self.count += n 248 | self.avg = self.sum / self.count 249 | 250 | 251 | def adjust_learning_rate(optimizer, epoch, lr_steps): 252 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 253 | decay = 0.1 ** (sum(epoch >= np.array(lr_steps))) 254 | lr = args.lr * decay 255 | decay = args.weight_decay 256 | for param_group in optimizer.param_groups: 257 | param_group['lr'] = lr * param_group['lr_mult'] 258 | param_group['weight_decay'] = decay * param_group['decay_mult'] 259 | 260 | 261 | def accuracy(output, target, topk=(1,)): 262 | """Computes the precision@k for the specified values of k""" 263 | maxk = max(topk) 264 | batch_size = target.size(0) 265 | 266 | _, pred = output.topk(maxk, 1, True, True) 267 | pred = pred.t() 268 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 269 | 270 | res = [] 271 | for k in topk: 272 | correct_k = correct[:k].view(-1).float().sum(0) 273 | res.append(correct_k.mul_(100.0 / batch_size)) 274 | return res 275 | 276 | 277 | if __name__ == '__main__': 278 | main() 279 | -------------------------------------------------------------------------------- /models/i3dnon.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | __all__ = ['I3DResNet', 'resnet50', 'resnet101', 'resnet152'] 5 | 6 | class Bottleneck(nn.Module): 7 | expansion = 4 8 | 9 | def __init__(self, inplanes, planes, time_kernel=1, space_stride=1, downsample=None,addnon = None): 10 | super(Bottleneck, self).__init__() 11 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=(time_kernel,1,1), padding=(int((time_kernel-1)/2), 0,0),bias=False) # timepadding: make sure time-dim not reduce 12 | self.bn1 = nn.BatchNorm3d(planes) 13 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=(1,3,3), stride=(1,space_stride,space_stride), 14 | padding=(0,1,1), bias=False) 15 | self.bn2 = nn.BatchNorm3d(planes) 16 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=(1,1,1), bias=False) 17 | self.bn3 = nn.BatchNorm3d(planes * 4) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.addnon = addnon 20 | self.downsample = downsample 21 | 22 | def forward(self, x): 23 | residual = x 24 | 25 | out = self.conv1(x) 26 | out = self.bn1(out) 27 | out = self.relu(out) 28 | 29 | out = self.conv2(out) 30 | out = self.bn2(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv3(out) 34 | out = self.bn3(out) 35 | 36 | if self.downsample is not None: 37 | residual = self.downsample(x) 38 | 39 | out += residual 40 | out = self.relu(out) 41 | if self.addnon is not None: 42 | out = nonlocalnet(out,out.size(1)) 43 | return out 44 | 45 | 46 | 47 | class I3DResNet(nn.Module): 48 | 49 | def __init__(self, block, layers, frame_num=32, num_classes=400): 50 | self.inplanes = 64 51 | super(I3DResNet, self).__init__() 52 | self.conv1 = nn.Conv3d(3, 64, kernel_size=(5,7,7), stride=(2,2,2), padding=(2,3,3), 53 | bias=False) 54 | self.bn1 = nn.BatchNorm3d(64) 55 | self.relu = nn.ReLU(inplace=True) 56 | self.maxpool = nn.MaxPool3d(kernel_size=(3,3,3), stride=(2,2,2), padding=(1,1,1)) 57 | self.layer1 = self._make_layer_inflat(block, 64, layers[0]) 58 | self.temporalpool = nn.MaxPool3d(kernel_size=(3,1,1), stride=(2,1,1), padding=(1,0,0)) 59 | self.layer2 = self._make_layer_inflat(block, 128, layers[1], space_stride=2) 60 | self.layer3 = self._make_layer_inflat(block, 256, layers[2], space_stride=2) 61 | self.layer4 = self._make_layer_inflat(block, 512, layers[3], space_stride=2) 62 | self.avgpool = nn.AvgPool3d((int(frame_num/8),7,7)) 63 | self.avgdrop =nn.Dropout(0.5) 64 | self.fc = nn.Linear(512 * block.expansion, num_classes) 65 | 66 | 67 | def _make_layer_inflat(self, block, planes, blocks, space_stride=1): 68 | downsample = None 69 | if space_stride != 1 or self.inplanes != planes * block.expansion: 70 | downsample = nn.Sequential( 71 | nn.Conv3d(self.inplanes, planes * block.expansion, 72 | kernel_size=(1,1,1), stride=(1,space_stride,space_stride), bias=False), 73 | nn.BatchNorm3d(planes * block.expansion), 74 | ) 75 | 76 | layers = [] 77 | time_kernel = 3 #making I3D(3*1*1) 78 | 79 | 80 | layers.append(block(self.inplanes, planes, time_kernel, space_stride, downsample,addnon= None)) 81 | self.inplanes = planes * block.expansion 82 | if blocks == 3: 83 | for i in range(1, blocks): 84 | if i % 2 == 1: 85 | time_kernel = 3 86 | else: 87 | time_kernel = 1 88 | layers.append(block(self.inplanes, planes, time_kernel)) 89 | 90 | elif blocks == 4: 91 | for i in range(1, blocks): 92 | 93 | if i % 2 == 1: 94 | time_kernel = 3 95 | layers.append(block(self.inplanes, planes, time_kernel,addnon= True)) 96 | else: 97 | time_kernel = 1 98 | layers.append(block(self.inplanes, planes, time_kernel)) 99 | 100 | elif blocks == 23: 101 | for i in range(1, blocks): 102 | if i % 2 == 1 : 103 | #addnon = True 104 | time_kernel = 3 105 | else: 106 | time_kernel = 1 107 | if i % 7 == 6: 108 | addnon=True 109 | layers.append(block(self.inplanes, planes, time_kernel,addnon=True)) 110 | 111 | else: 112 | layers.append(block(self.inplanes, planes, time_kernel)) 113 | return nn.Sequential(*layers) 114 | 115 | def forward(self, x): 116 | x = self.conv1(x) 117 | x = self.bn1(x) 118 | x = self.relu(x) 119 | x = self.maxpool(x) 120 | x = self.layer1(x) 121 | x = self.temporalpool(x) 122 | x = self.layer2(x) 123 | x = self.layer3(x) 124 | x = self.layer4(x) 125 | x = self.avgpool(x) 126 | x = x.permute(0, 2, 1, 3, 4).contiguous() 127 | x = x.view(x.size(0), -1) 128 | x = self.avgdrop(x) 129 | x = self.fc(x) 130 | 131 | return x 132 | 133 | 134 | def resnet50(pretrained=False, **kwargs): 135 | """Constructs a ResNet-50 model. 136 | Args: 137 | pretrained (bool): If True, returns a model pre-trained on ImageNet 138 | """ 139 | model = I3DResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 140 | 141 | return model 142 | 143 | 144 | def resnet101(pretrained=False, **kwargs): 145 | """Constructs a ResNet-101 model. 146 | Args: 147 | pretrained (bool): If True, returns a model pre-trained on ImageNet 148 | """ 149 | model = I3DResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 150 | if pretrained: 151 | import torchvision 152 | if torch.cuda.is_available(): 153 | pretrained_model = torch.load('/workspace/fyzhang/projects/i3d-nonlocal-affine/kinetics_i3dresnet101__rgb_model_best.pth.tar') 154 | base_dict = {'.'.join(k.split('.')[2:]): v for k,v in list(pretrained_model['state_dict'].items())} 155 | model = affine_weights(base_dict, model) 156 | print("ok") 157 | else: 158 | #pretrained_model = torchvision.models.resnet101(pretrained=True) 159 | pretrained_model = torch.load('/Users/fyzhang/Downloads/vscode/nonlocal/i3d-pytorch/kinetics_i3dresnet101__rgb_model_best.pth.tar') 160 | model = affine_weights(pretrained_model, model) 161 | return model 162 | 163 | 164 | def resnet152(pretrained=False, **kwargs): 165 | """Constructs a ResNet-152 model. 166 | Args: 167 | pretrained (bool): If True, returns a model pre-trained on ImageNet 168 | """ 169 | model = I3DResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 170 | 171 | return model 172 | 173 | 174 | # def inflat_weights(model_2d, model_3d): 175 | # pretrained_dict_2d = model_2d.state_dict() 176 | # model_dict_3d = model_3d.state_dict() 177 | # for key,weight_2d in pretrained_dict_2d.items(): 178 | # if key in model_dict_3d: 179 | # if 'conv' in key: 180 | # time_kernel_size = model_dict_3d[key].shape[2] 181 | # if 'weight' in key: 182 | # weight_3d = weight_2d.unsqueeze(2).repeat(1,1,time_kernel_size,1,1) 183 | # weight_3d = weight_3d / time_kernel_size 184 | # model_dict_3d[key] = weight_3d 185 | # elif 'bias' in key: 186 | # model_dict_3d[key] = weight_2d 187 | # elif 'bn' in key: 188 | # model_dict_3d[key] = weight_2d 189 | # #elif 'fc' in key: 190 | # # if 'weight' in key: 191 | # # time_kernel_size = model_dict_3d[key].shape[1] / weight_2d.shape[1] 192 | # # weight_3d = weight_2d.repeat(1, int(time_kernel_size)) 193 | # # weight_3d = weight_3d / time_kernel_size 194 | # # model_dict_3d[key] = weight_3d 195 | # #elif 'bias' in key: 196 | # # model_dict_3d[key] = weight_2d 197 | # elif 'downsample' in key: 198 | # if '0.weight' in key: 199 | # time_kernel_size = model_dict_3d[key].shape[2] 200 | # weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_kernel_size, 1, 1) 201 | # weight_3d = weight_3d / time_kernel_size 202 | # model_dict_3d[key] = weight_3d 203 | # else: 204 | # model_dict_3d[key] = weight_2d 205 | 206 | # model_3d.load_state_dict(model_dict_3d) 207 | # return model_3d 208 | 209 | 210 | def affine_weights(model_2d, model_3d): 211 | #pretrained_dict_2d = model_2d.state_dict() 212 | pretrained_dict_2d = model_2d#.state_dict() 213 | pretrained_dict_2d['fc.weight']=pretrained_dict_2d.pop('weight') 214 | pretrained_dict_2d['fc.bias']=pretrained_dict_2d.pop('bias') 215 | model_dict_3d = model_3d.state_dict() 216 | for key,weight_2d in pretrained_dict_2d.items(): 217 | if key in model_dict_3d: 218 | if 'conv' in key: 219 | #time_kernel_size = model_dict_3d[key].shape[2] 220 | if 'weight' in key: 221 | weight_3d = weight_2d#.unsqueeze(2).repeat(1,1,time_kernel_size,1,1) 222 | #weight_3d = weight_3d / time_kernel_size 223 | model_dict_3d[key] = weight_3d 224 | elif 'bias' in key: 225 | model_dict_3d[key] = weight_2d 226 | elif 'bn' in key: 227 | model_dict_3d[key] = weight_2d 228 | elif 'fc' in key: 229 | if 'weight' in key: 230 | weight_3d = weight_2d#.unsqueeze(2).repeat(1,1,time_kernel_size,1,1) 231 | #weight_3d = weight_3d / time_kernel_size 232 | model_dict_3d[key] = weight_3d 233 | elif 'bias' in key: 234 | model_dict_3d[key] = weight_2d 235 | elif 'downsample' in key: 236 | if '0.weight' in key: 237 | #time_kernel_size = model_dict_3d[key].shape[2] 238 | weight_3d = weight_2d#.unsqueeze(2).repeat(1, 1, time_kernel_size, 1, 1) 239 | #weight_3d = weight_3d / time_kernel_size 240 | model_dict_3d[key] = weight_3d 241 | else: 242 | model_dict_3d[key] = weight_2d 243 | 244 | model_3d.load_state_dict(model_dict_3d) 245 | return model_3d.cuda() 246 | 247 | 248 | class _NonLocalBlockND(nn.Module): 249 | def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded_gaussian', 250 | sub_sample=True, bn_layer=True): 251 | super(_NonLocalBlockND, self).__init__() 252 | 253 | assert dimension in [1, 2, 3] 254 | assert mode in ['embedded_gaussian', 'gaussian', 'dot_product', 'concatenation'] 255 | 256 | # print('Dimension: %d, mode: %s' % (dimension, mode)) 257 | 258 | self.mode = mode 259 | self.dimension = dimension 260 | self.sub_sample = sub_sample 261 | 262 | self.in_channels = in_channels 263 | self.inter_channels = inter_channels 264 | 265 | if self.inter_channels is None: 266 | self.inter_channels = in_channels // 2 267 | if self.inter_channels == 0: 268 | self.inter_channels = 1 269 | 270 | if dimension == 3: 271 | conv_nd = nn.Conv3d 272 | max_pool = nn.MaxPool3d 273 | bn = nn.BatchNorm3d 274 | elif dimension == 2: 275 | conv_nd = nn.Conv2d 276 | max_pool = nn.MaxPool2d 277 | bn = nn.BatchNorm2d 278 | else: 279 | conv_nd = nn.Conv1d 280 | max_pool = nn.MaxPool1d 281 | bn = nn.BatchNorm1d 282 | 283 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 284 | kernel_size=1, stride=1, padding=0) 285 | nn.init.kaiming_normal(self.g.weight) 286 | nn.init.constant(self.g.bias,0) 287 | if bn_layer: 288 | self.W = nn.Sequential( 289 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 290 | kernel_size=1, stride=1, padding=0), 291 | bn(self.in_channels) 292 | ) 293 | nn.init.kaiming_normal(self.W[0].weight) 294 | nn.init.constant(self.W[0].bias, 0) 295 | nn.init.constant(self.W[1].weight, 0) 296 | nn.init.constant(self.W[1].bias, 0) 297 | 298 | 299 | else: 300 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 301 | kernel_size=1, stride=1, padding=0) 302 | nn.init.kaiming_normal(self.W.weight) 303 | nn.init.constant(self.W.bias, 0) 304 | 305 | self.theta = None 306 | self.phi = None 307 | 308 | if mode in ['embedded_gaussian', 'dot_product']: 309 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 310 | kernel_size=1, stride=1, padding=0) 311 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 312 | kernel_size=1, stride=1, padding=0) 313 | 314 | if mode == 'embedded_gaussian': 315 | self.operation_function = self._embedded_gaussian 316 | else: 317 | self.operation_function = self._dot_product 318 | 319 | elif mode == 'gaussian': 320 | self.operation_function = self._gaussian 321 | else: 322 | raise NotImplementedError('Mode concatenation has not been implemented.') 323 | 324 | if sub_sample: 325 | self.g = nn.Sequential(self.g, max_pool(kernel_size=2)) 326 | if self.phi is None: 327 | self.phi = max_pool(kernel_size=2) 328 | else: 329 | self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2)) 330 | 331 | def forward(self, x): 332 | ''' 333 | :param x: (b, c, t, h, w) 334 | :return: 335 | ''' 336 | 337 | output = self.operation_function(x) 338 | return output 339 | 340 | def _embedded_gaussian(self, x): 341 | batch_size = x.size(0) 342 | 343 | # g=>(b, c, t, h, w)->(b, 0.5c, t, h, w)->(b, thw, 0.5c) 344 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 345 | g_x = g_x.permute(0, 2, 1) 346 | 347 | # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c) 348 | # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw) 349 | # f=>(b, thw, 0.5c)dot(b, 0.5c, twh) = (b, thw, thw) 350 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 351 | theta_x = theta_x.permute(0, 2, 1) 352 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 353 | f = torch.matmul(theta_x, phi_x) 354 | f_div_C = F.softmax(f, dim=-1) 355 | 356 | # (b, thw, thw)dot(b, thw, 0.5c) = (b, thw, 0.5c)->(b, 0.5c, t, h, w)->(b, c, t, h, w) 357 | y = torch.matmul(f_div_C, g_x) 358 | y = y.permute(0, 2, 1).contiguous() 359 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 360 | W_y = self.W(y) 361 | z = W_y + x 362 | 363 | return z 364 | 365 | def _gaussian(self, x): 366 | batch_size = x.size(0) 367 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 368 | g_x = g_x.permute(0, 2, 1) 369 | 370 | theta_x = x.view(batch_size, self.in_channels, -1) 371 | theta_x = theta_x.permute(0, 2, 1) 372 | 373 | if self.sub_sample: 374 | phi_x = self.phi(x).view(batch_size, self.in_channels, -1) 375 | else: 376 | phi_x = x.view(batch_size, self.in_channels, -1) 377 | 378 | f = torch.matmul(theta_x, phi_x) 379 | f_div_C = F.softmax(f, dim=-1) 380 | 381 | y = torch.matmul(f_div_C, g_x) 382 | y = y.permute(0, 2, 1).contiguous() 383 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 384 | W_y = self.W(y) 385 | z = W_y + x 386 | 387 | return z 388 | 389 | def _dot_product(self, x): 390 | batch_size = x.size(0) 391 | 392 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 393 | g_x = g_x.permute(0, 2, 1) 394 | 395 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 396 | theta_x = theta_x.permute(0, 2, 1) 397 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 398 | f = torch.matmul(theta_x, phi_x) 399 | N = f.size(-1) 400 | f_div_C = f / N 401 | 402 | y = torch.matmul(f_div_C, g_x) 403 | y = y.permute(0, 2, 1).contiguous() 404 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 405 | W_y = self.W(y) 406 | z = W_y + x 407 | 408 | return z 409 | 410 | 411 | class NONLocalBlock1D(_NonLocalBlockND): 412 | def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample=True, bn_layer=True): 413 | super(NONLocalBlock1D, self).__init__(in_channels, 414 | inter_channels=inter_channels, 415 | dimension=1, mode=mode, 416 | sub_sample=sub_sample, 417 | bn_layer=bn_layer) 418 | 419 | 420 | class NONLocalBlock2D(_NonLocalBlockND): 421 | def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample=True, bn_layer=True): 422 | super(NONLocalBlock2D, self).__init__(in_channels, 423 | inter_channels=inter_channels, 424 | dimension=2, mode=mode, 425 | sub_sample=sub_sample, 426 | bn_layer=bn_layer) 427 | 428 | 429 | class NONLocalBlock3D(_NonLocalBlockND): 430 | def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample=True, bn_layer=True): 431 | super(NONLocalBlock3D, self).__init__(in_channels, 432 | inter_channels=inter_channels, 433 | dimension=3, mode=mode, 434 | sub_sample=sub_sample, 435 | bn_layer=bn_layer) 436 | 437 | def nonlocalnet(input_layer,input_channel): 438 | if torch.cuda.is_available(): 439 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 440 | net = NONLocalBlock3D(in_channels=input_channel,mode='embedded_gaussian') 441 | out = net(input_layer) 442 | else: 443 | net = NONLocalBlock3D(in_channels=input_channel,mode='embedded_gaussian') 444 | out = net(input_layer) 445 | return out 446 | 447 | 448 | 449 | if __name__ == '__main__': 450 | import torchvision 451 | import numpy as np 452 | import torch 453 | from torch.autograd import Variable 454 | 455 | resnet = torchvision.models.resnet101(pretrained=True) 456 | resnet_i3d = resnet101(pretrained=True) 457 | 458 | data = np.ones((1, 3, 224, 224), dtype=np.float32) 459 | tensor = torch.from_numpy(data) 460 | inputs = Variable(tensor) 461 | out1 = resnet(inputs) 462 | print(out1) 463 | 464 | data2 = np.ones((1, 3, 32, 224, 224), dtype=np.float32) 465 | tensor2 = torch.from_numpy(data2) 466 | inputs2 = Variable(tensor2) 467 | inputs2 = inputs2.cuda() 468 | out2 = resnet_i3d(inputs2) 469 | print(out2) 470 | --------------------------------------------------------------------------------