├── .idea ├── .gitignore ├── vcs.xml ├── misc.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml └── ExtractVideoFeature.iml ├── README.md ├── utils.py ├── .gitignore ├── demo_extract_i3d_feature.py ├── videotransforms.py ├── extract_features_from_video.py └── i3d.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /workspace.xml 3 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ExtractVideoFeature 2 | Extract video features. Currently, the models includes I3D, will be continuously updated. 3 | 4 | This project is based on [piergiaj/pytorch-i3d](https://github.com/piergiaj/pytorch-i3d). Converted models can be downloaded from [GitHub](https://github.com/piergiaj/pytorch-i3d/tree/master/models) or [BaiDuYun](https://pan.baidu.com/s/1m1yG4JUUtSix0-MTQKpoDA) with code: s416. 5 | -------------------------------------------------------------------------------- /.idea/ExtractVideoFeature.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Some useful functions 3 | ''' 4 | 5 | from PIL import Image 6 | import zipfile 7 | import io 8 | import cv2 9 | 10 | 11 | def load_zipframe(zipdata, name, resize=False): 12 | 13 | stream = zipdata.read(name) 14 | data = Image.open(io.BytesIO(stream)) 15 | 16 | assert(data.size[1] == 256) 17 | assert(data.size[0] == 340) 18 | 19 | if resize: 20 | data = data.resize((224, 224), Image.ANTIALIAS) 21 | 22 | data = np.array(data) 23 | data = data.astype(float) 24 | data = (data * 2 / 255) - 1 25 | 26 | assert(data.max()<=1.0) 27 | assert(data.min()>=-1.0) 28 | 29 | return data 30 | 31 | 32 | def load_ziprgb_batch(rgb_zipdata, rgb_files, 33 | frame_indices, resize=False): 34 | 35 | if resize: 36 | batch_data = np.zeros(frame_indices.shape + (224,224,3)) 37 | else: 38 | batch_data = np.zeros(frame_indices.shape + (256,340,3)) 39 | 40 | for i in range(frame_indices.shape[0]): 41 | for j in range(frame_indices.shape[1]): 42 | 43 | batch_data[i,j,:,:,:] = load_zipframe(rgb_zipdata, 44 | rgb_files[frame_indices[i][j]], resize) 45 | 46 | return batch_data 47 | 48 | 49 | def load_zipflow_batch(flow_x_zipdata, flow_y_zipdata, 50 | flow_x_files, flow_y_files, 51 | frame_indices, resize=False): 52 | 53 | if resize: 54 | batch_data = np.zeros(frame_indices.shape + (224,224,2)) 55 | else: 56 | batch_data = np.zeros(frame_indices.shape + (256,340,2)) 57 | 58 | for i in range(frame_indices.shape[0]): 59 | for j in range(frame_indices.shape[1]): 60 | 61 | batch_data[i,j,:,:,0] = load_zipframe(flow_x_zipdata, 62 | flow_x_files[frame_indices[i][j]], resize) 63 | 64 | batch_data[i,j,:,:,1] = load_zipframe(flow_y_zipdata, 65 | flow_y_files[frame_indices[i][j]], resize) 66 | 67 | return batch_data 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | slowfast/ 131 | models/ 132 | features_1/ 133 | videos_1/ 134 | -------------------------------------------------------------------------------- /demo_extract_i3d_feature.py: -------------------------------------------------------------------------------- 1 | ''' 2 | We can use torchvision.io.read_video to read videos, 3 | However, it lies in the latest torchvision, in updating. 4 | 5 | vframes, _, info = torchvision.io.read_video(vid_file) 6 | print('video frames info', vframes.shape, info) 7 | ''' 8 | import os 9 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 10 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 11 | os.environ['OMP_NUM_THREADS'] = '1' 12 | 13 | import argparse 14 | from i3d import network_init 15 | from extract_features_from_video import run 16 | import numpy as np 17 | import cv2 18 | 19 | 20 | def args_parser(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('-vid_dir', help='directory to video file, e.g., *.mp4', default='/disk2/yangle/ExtractVideoFeature/videos_1') 23 | parser.add_argument('-feature_dir', help='directory to save feature file', default='/disk2/yangle/ExtractVideoFeature/features_1') 24 | parser.add_argument('-stride', default=1, help='feature extraction stride') 25 | parser.add_argument('-chunk_size', help='snippet frame number for one feature extraction', default=16) 26 | parser.add_argument('-batch_size', default=8) 27 | parser.add_argument('-mode', default='rgb', help='rgb or flow') 28 | parser.add_argument('-weight_file', default='/disk2/yangle/ExtractVideoFeature/models/rgb_imagenet.pt') 29 | 30 | args = parser.parse_args() 31 | return args 32 | 33 | 34 | def video2array(video_file): 35 | cap = cv2.VideoCapture(video_file) 36 | frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 37 | # Notice: we should perform resize when load frames 38 | datas = np.zeros((frameCount, 224, 224, 3), np.dtype('uint8')) 39 | 40 | fc = 0 41 | ret = True 42 | while (fc < frameCount) and ret: 43 | ret, img = cap.read() 44 | if ret: 45 | # data = cv2.resize(img, (224, 224), interpolation=cv2.INTER_BITS) 46 | data = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR) 47 | datas[fc, :, :, :] = data 48 | fc += 1 49 | cap.release() 50 | return datas 51 | 52 | 53 | def extract_feature(args): 54 | model_i3d = network_init(num_devices=1, mode=args.mode, weight_file=args.weight_file) 55 | if not os.path.exists(args.feature_dir): 56 | os.makedirs(args.feature_dir) 57 | 58 | vid_name_set = os.listdir(args.vid_dir) 59 | vid_name_set.sort(reverse=False) 60 | 61 | for video_name in vid_name_set: 62 | vid_name = video_name[:-4] 63 | 64 | feature_file = os.path.join(args.feature_dir, vid_name + '.npz') 65 | if os.path.exists(feature_file): 66 | print('Video %s already calculated, skip' % video_name) 67 | continue 68 | 69 | vid_file = os.path.join(args.vid_dir, vid_name + '.mp4') 70 | vframes = video2array(vid_file) 71 | print('video frames info', vframes.shape) 72 | if vframes.shape[0] < args.chunk_size: 73 | print('video %s contains %d frames, too short, skip' % (video_name, vframes.shape[0])) 74 | continue 75 | 76 | run(model_i3d, vid_name, vframes, chunk_size=args.chunk_size, 77 | stride=args.stride, 78 | output_dir=args.feature_dir, batch_size=args.batch_size) 79 | 80 | 81 | if __name__ == '__main__': 82 | args = args_parser() 83 | extract_feature(args) 84 | -------------------------------------------------------------------------------- /videotransforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numbers 3 | import random 4 | 5 | class RandomCrop(object): 6 | """Crop the given video sequences (t x h x w) at a random location. 7 | Args: 8 | size (sequence or int): Desired output size of the crop. If size is an 9 | int instead of sequence like (h, w), a square crop (size, size) is 10 | made. 11 | """ 12 | 13 | def __init__(self, size): 14 | if isinstance(size, numbers.Number): 15 | self.size = (int(size), int(size)) 16 | else: 17 | self.size = size 18 | 19 | @staticmethod 20 | def get_params(img, output_size): 21 | """Get parameters for ``crop`` for a random crop. 22 | Args: 23 | img (PIL Image): Image to be cropped. 24 | output_size (tuple): Expected output size of the crop. 25 | Returns: 26 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 27 | """ 28 | t, h, w, c = img.shape 29 | th, tw = output_size 30 | if w == tw and h == th: 31 | return 0, 0, h, w 32 | 33 | i = random.randint(0, h - th) if h!=th else 0 34 | j = random.randint(0, w - tw) if w!=tw else 0 35 | return i, j, th, tw 36 | 37 | def __call__(self, imgs): 38 | 39 | i, j, h, w = self.get_params(imgs, self.size) 40 | 41 | imgs = imgs[:, i:i+h, j:j+w, :] 42 | return imgs 43 | 44 | def __repr__(self): 45 | return self.__class__.__name__ + '(size={0})'.format(self.size) 46 | 47 | class CenterCrop(object): 48 | """Crops the given seq Images at the center. 49 | Args: 50 | size (sequence or int): Desired output size of the crop. If size is an 51 | int instead of sequence like (h, w), a square crop (size, size) is 52 | made. 53 | """ 54 | 55 | def __init__(self, size): 56 | if isinstance(size, numbers.Number): 57 | self.size = (int(size), int(size)) 58 | else: 59 | self.size = size 60 | 61 | def __call__(self, imgs): 62 | """ 63 | Args: 64 | img (PIL Image): Image to be cropped. 65 | Returns: 66 | PIL Image: Cropped image. 67 | """ 68 | t, h, w, c = imgs.shape 69 | th, tw = self.size 70 | i = int(np.round((h - th) / 2.)) 71 | j = int(np.round((w - tw) / 2.)) 72 | 73 | return imgs[:, i:i+th, j:j+tw, :] 74 | 75 | 76 | def __repr__(self): 77 | return self.__class__.__name__ + '(size={0})'.format(self.size) 78 | 79 | 80 | class RandomHorizontalFlip(object): 81 | """Horizontally flip the given seq Images randomly with a given probability. 82 | Args: 83 | p (float): probability of the image being flipped. Default value is 0.5 84 | """ 85 | 86 | def __init__(self, p=0.5): 87 | self.p = p 88 | 89 | def __call__(self, imgs): 90 | """ 91 | Args: 92 | img (seq Images): seq Images to be flipped. 93 | Returns: 94 | seq Images: Randomly flipped seq images. 95 | """ 96 | if random.random() < self.p: 97 | # t x h x w 98 | return np.flip(imgs, axis=2).copy() 99 | return imgs 100 | 101 | def __repr__(self): 102 | return self.__class__.__name__ + '(p={})'.format(self.p) 103 | -------------------------------------------------------------------------------- /extract_features_from_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import math 5 | from PIL import Image 6 | import numpy as np 7 | 8 | 9 | def load_frame(frame_file, resize=False): 10 | data = Image.open(frame_file) 11 | 12 | assert (data.size[1] == 256) 13 | assert (data.size[0] == 340) 14 | 15 | if resize: 16 | data = data.resize((224, 224), Image.ANTIALIAS) 17 | 18 | data = np.array(data) 19 | data = data.astype(float) 20 | data = (data * 2 / 255) - 1 21 | 22 | assert (data.max() <= 1.0) 23 | assert (data.min() >= -1.0) 24 | 25 | return data 26 | 27 | 28 | def vframes_pre_process(vframes): 29 | # Rescale pixel values to [-1, 1] 30 | vframes = vframes.astype(float) 31 | datas = (vframes * 2 / 255) - 1 32 | return datas 33 | 34 | 35 | def load_rgb_batch(vframes, frame_indices, chunk_size): 36 | # frame_indices: [batch, chunk_size] 37 | batch_data = np.zeros(frame_indices.shape + (224, 224, 3)) 38 | 39 | # chunk_size = 8 40 | for i in range(frame_indices.shape[0]): 41 | indices_list = list(frame_indices[i, :]) 42 | for j in indices_list: 43 | idx_set = j % chunk_size 44 | batch_data[i, idx_set, :, :, :] = vframes[j, :, :, :] 45 | 46 | return batch_data # [batch, chunk_size, 224, 224, 3] 47 | 48 | 49 | def load_flow_batch(frames_dir, flow_x_files, flow_y_files, 50 | frame_indices, resize=False): 51 | if resize: 52 | batch_data = np.zeros(frame_indices.shape + (224, 224, 2)) 53 | else: 54 | batch_data = np.zeros(frame_indices.shape + (256, 340, 2)) 55 | 56 | for i in range(frame_indices.shape[0]): 57 | for j in range(frame_indices.shape[1]): 58 | batch_data[i, j, :, :, 0] = load_frame(os.path.join(frames_dir, 59 | flow_x_files[frame_indices[i][j]]), resize) 60 | 61 | batch_data[i, j, :, :, 1] = load_frame(os.path.join(frames_dir, 62 | flow_y_files[frame_indices[i][j]]), resize) 63 | 64 | return batch_data 65 | 66 | 67 | def run(model, video_name, vframes, chunk_size=16, stride=16, output_dir='', batch_size=1): 68 | 69 | def forward_batch(b_data): 70 | b_data = b_data.transpose([0, 4, 1, 2, 3]) 71 | b_data = torch.from_numpy(b_data) # b,c,t,h,w # 40x3x16x224x224 72 | 73 | b_data = b_data.cuda().float() 74 | with torch.no_grad(): 75 | b_features = model(b_data) 76 | 77 | b_features = b_features.data.cpu().numpy()[:, :] 78 | return b_features 79 | 80 | save_file = '{}.npz'.format(video_name) 81 | 82 | frame_cnt = vframes.shape[0] 83 | vframes = vframes_pre_process(vframes) 84 | 85 | # Cut frames 86 | clipped_length = math.floor((frame_cnt - chunk_size) / stride) + 1 87 | frame_indices = list() 88 | 89 | for i in range(0, frame_cnt-chunk_size, stride): 90 | indices = [j for j in range(i, i+chunk_size)] 91 | frame_indices.append(indices) 92 | frame_indices = np.array(frame_indices) 93 | 94 | chunk_num = frame_indices.shape[0] 95 | 96 | batch_num = int(np.ceil(chunk_num / batch_size)) # Chunks to batches 97 | frame_indices = np.array_split(frame_indices, batch_num, axis=0) 98 | 99 | full_features = list() 100 | for batch_id in range(batch_num): 101 | batch_data = load_rgb_batch(vframes, frame_indices[batch_id], chunk_size) 102 | full_features.append(forward_batch(batch_data)) 103 | 104 | full_features = np.concatenate(full_features, axis=0) 105 | 106 | np.savez(os.path.join(output_dir, save_file), 107 | feature=full_features, 108 | frame_cnt=frame_cnt, 109 | video_name=video_name) 110 | 111 | print('{} done: {} / {}, {}'.format(video_name, frame_cnt, clipped_length, full_features.shape)) 112 | -------------------------------------------------------------------------------- /i3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | import numpy as np 7 | 8 | 9 | class MaxPool3dSamePadding(nn.MaxPool3d): 10 | 11 | def compute_pad(self, dim, s): 12 | if s % self.stride[dim] == 0: 13 | return max(self.kernel_size[dim] - self.stride[dim], 0) 14 | else: 15 | return max(self.kernel_size[dim] - (s % self.stride[dim]), 0) 16 | 17 | def forward(self, x): 18 | # compute 'same' padding 19 | (batch, channel, t, h, w) = x.size() 20 | # print t,h,w 21 | out_t = np.ceil(float(t) / float(self.stride[0])) 22 | out_h = np.ceil(float(h) / float(self.stride[1])) 23 | out_w = np.ceil(float(w) / float(self.stride[2])) 24 | # print out_t, out_h, out_w 25 | pad_t = self.compute_pad(0, t) 26 | pad_h = self.compute_pad(1, h) 27 | pad_w = self.compute_pad(2, w) 28 | # print pad_t, pad_h, pad_w 29 | 30 | pad_t_f = pad_t // 2 31 | pad_t_b = pad_t - pad_t_f 32 | pad_h_f = pad_h // 2 33 | pad_h_b = pad_h - pad_h_f 34 | pad_w_f = pad_w // 2 35 | pad_w_b = pad_w - pad_w_f 36 | 37 | pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) 38 | # print x.size() 39 | # print pad 40 | x = F.pad(x, pad) 41 | return super(MaxPool3dSamePadding, self).forward(x) 42 | 43 | 44 | class Unit3D(nn.Module): 45 | 46 | def __init__(self, in_channels, 47 | output_channels, 48 | kernel_shape=(1, 1, 1), 49 | stride=(1, 1, 1), 50 | padding=0, 51 | activation_fn=F.relu, 52 | use_batch_norm=True, 53 | use_bias=False, 54 | name='unit_3d'): 55 | 56 | """Initializes Unit3D module.""" 57 | super(Unit3D, self).__init__() 58 | 59 | self._output_channels = output_channels 60 | self._kernel_shape = kernel_shape 61 | self._stride = stride 62 | self._use_batch_norm = use_batch_norm 63 | self._activation_fn = activation_fn 64 | self._use_bias = use_bias 65 | self.name = name 66 | self.padding = padding 67 | 68 | self.conv3d = nn.Conv3d(in_channels=in_channels, 69 | out_channels=self._output_channels, 70 | kernel_size=self._kernel_shape, 71 | stride=self._stride, 72 | padding=0, 73 | # we always want padding to be 0 here. We will dynamically pad based on input size in forward function 74 | bias=self._use_bias) 75 | 76 | if self._use_batch_norm: 77 | self.bn = nn.BatchNorm3d(self._output_channels, eps=0.001, momentum=0.01) 78 | 79 | def compute_pad(self, dim, s): 80 | if s % self._stride[dim] == 0: 81 | return max(self._kernel_shape[dim] - self._stride[dim], 0) 82 | else: 83 | return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0) 84 | 85 | def forward(self, x): 86 | # compute 'same' padding 87 | (batch, channel, t, h, w) = x.size() 88 | # print t,h,w 89 | out_t = np.ceil(float(t) / float(self._stride[0])) 90 | out_h = np.ceil(float(h) / float(self._stride[1])) 91 | out_w = np.ceil(float(w) / float(self._stride[2])) 92 | # print out_t, out_h, out_w 93 | pad_t = self.compute_pad(0, t) 94 | pad_h = self.compute_pad(1, h) 95 | pad_w = self.compute_pad(2, w) 96 | # print pad_t, pad_h, pad_w 97 | 98 | pad_t_f = pad_t // 2 99 | pad_t_b = pad_t - pad_t_f 100 | pad_h_f = pad_h // 2 101 | pad_h_b = pad_h - pad_h_f 102 | pad_w_f = pad_w // 2 103 | pad_w_b = pad_w - pad_w_f 104 | 105 | pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) 106 | # print x.size() 107 | # print pad 108 | x = F.pad(x, pad) 109 | # print x.size() 110 | 111 | x = self.conv3d(x) 112 | if self._use_batch_norm: 113 | x = self.bn(x) 114 | if self._activation_fn is not None: 115 | x = self._activation_fn(x) 116 | return x 117 | 118 | 119 | class InceptionModule(nn.Module): 120 | def __init__(self, in_channels, out_channels, name): 121 | super(InceptionModule, self).__init__() 122 | 123 | self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0, 124 | name=name + '/Branch_0/Conv3d_0a_1x1') 125 | self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0, 126 | name=name + '/Branch_1/Conv3d_0a_1x1') 127 | self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 3, 3], 128 | name=name + '/Branch_1/Conv3d_0b_3x3') 129 | self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0, 130 | name=name + '/Branch_2/Conv3d_0a_1x1') 131 | self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 3, 3], 132 | name=name + '/Branch_2/Conv3d_0b_3x3') 133 | self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3], 134 | stride=(1, 1, 1), padding=0) 135 | self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0, 136 | name=name + '/Branch_3/Conv3d_0b_1x1') 137 | self.name = name 138 | 139 | def forward(self, x): 140 | b0 = self.b0(x) 141 | b1 = self.b1b(self.b1a(x)) 142 | b2 = self.b2b(self.b2a(x)) 143 | b3 = self.b3b(self.b3a(x)) 144 | return torch.cat([b0, b1, b2, b3], dim=1) 145 | 146 | 147 | class InceptionI3d(nn.Module): 148 | """Inception-v1 I3D architecture. 149 | The model is introduced in: 150 | Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset 151 | Joao Carreira, Andrew Zisserman 152 | https://arxiv.org/pdf/1705.07750v1.pdf. 153 | See also the Inception architecture, introduced in: 154 | Going deeper with convolutions 155 | Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, 156 | Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich. 157 | http://arxiv.org/pdf/1409.4842v1.pdf. 158 | """ 159 | 160 | # Endpoints of the model in order. During construction, all the endpoints up 161 | # to a designated `final_endpoint` are returned in a dictionary as the 162 | # second return value. 163 | VALID_ENDPOINTS = ( 164 | 'Conv3d_1a_7x7', 165 | 'MaxPool3d_2a_3x3', 166 | 'Conv3d_2b_1x1', 167 | 'Conv3d_2c_3x3', 168 | 'MaxPool3d_3a_3x3', 169 | 'Mixed_3b', 170 | 'Mixed_3c', 171 | 'MaxPool3d_4a_3x3', 172 | 'Mixed_4b', 173 | 'Mixed_4c', 174 | 'Mixed_4d', 175 | 'Mixed_4e', 176 | 'Mixed_4f', 177 | 'MaxPool3d_5a_2x2', 178 | 'Mixed_5b', 179 | 'Mixed_5c', 180 | 'Logits', 181 | 'Predictions', 182 | ) 183 | 184 | def __init__(self, num_classes=400, spatial_squeeze=True, 185 | final_endpoint='Logits', name='inception_i3d', in_channels=3, dropout_keep_prob=0.5): 186 | """Initializes I3D model instance. 187 | Args: 188 | num_classes: The number of outputs in the logit layer (default 400, which 189 | matches the Kinetics dataset). 190 | spatial_squeeze: Whether to squeeze the spatial dimensions for the logits 191 | before returning (default True). 192 | final_endpoint: The model contains many possible endpoints. 193 | `final_endpoint` specifies the last endpoint for the model to be built 194 | up to. In addition to the output at `final_endpoint`, all the outputs 195 | at endpoints up to `final_endpoint` will also be returned, in a 196 | dictionary. `final_endpoint` must be one of 197 | InceptionI3d.VALID_ENDPOINTS (default 'Logits'). 198 | name: A string (optional). The name of this module. 199 | Raises: 200 | ValueError: if `final_endpoint` is not recognized. 201 | """ 202 | 203 | if final_endpoint not in self.VALID_ENDPOINTS: 204 | raise ValueError('Unknown final endpoint %s' % final_endpoint) 205 | 206 | super(InceptionI3d, self).__init__() 207 | self._num_classes = num_classes 208 | self._spatial_squeeze = spatial_squeeze 209 | self._final_endpoint = final_endpoint 210 | self.logits = None 211 | 212 | if self._final_endpoint not in self.VALID_ENDPOINTS: 213 | raise ValueError('Unknown final endpoint %s' % self._final_endpoint) 214 | 215 | self.end_points = {} 216 | end_point = 'Conv3d_1a_7x7' 217 | self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7], 218 | stride=(2, 2, 2), padding=(3, 3, 3), name=name + end_point) 219 | if self._final_endpoint == end_point: return 220 | 221 | end_point = 'MaxPool3d_2a_3x3' 222 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), 223 | padding=0) 224 | if self._final_endpoint == end_point: return 225 | 226 | end_point = 'Conv3d_2b_1x1' 227 | self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0, 228 | name=name + end_point) 229 | if self._final_endpoint == end_point: return 230 | 231 | end_point = 'Conv3d_2c_3x3' 232 | self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1, 233 | name=name + end_point) 234 | if self._final_endpoint == end_point: return 235 | 236 | end_point = 'MaxPool3d_3a_3x3' 237 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), 238 | padding=0) 239 | if self._final_endpoint == end_point: return 240 | 241 | end_point = 'Mixed_3b' 242 | self.end_points[end_point] = InceptionModule(192, [64, 96, 128, 16, 32, 32], name + end_point) 243 | if self._final_endpoint == end_point: return 244 | 245 | end_point = 'Mixed_3c' 246 | self.end_points[end_point] = InceptionModule(256, [128, 128, 192, 32, 96, 64], name + end_point) 247 | if self._final_endpoint == end_point: return 248 | 249 | end_point = 'MaxPool3d_4a_3x3' 250 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2), 251 | padding=0) 252 | if self._final_endpoint == end_point: return 253 | 254 | end_point = 'Mixed_4b' 255 | self.end_points[end_point] = InceptionModule(128 + 192 + 96 + 64, [192, 96, 208, 16, 48, 64], name + end_point) 256 | if self._final_endpoint == end_point: return 257 | 258 | end_point = 'Mixed_4c' 259 | self.end_points[end_point] = InceptionModule(192 + 208 + 48 + 64, [160, 112, 224, 24, 64, 64], name + end_point) 260 | if self._final_endpoint == end_point: return 261 | 262 | end_point = 'Mixed_4d' 263 | self.end_points[end_point] = InceptionModule(160 + 224 + 64 + 64, [128, 128, 256, 24, 64, 64], name + end_point) 264 | if self._final_endpoint == end_point: return 265 | 266 | end_point = 'Mixed_4e' 267 | self.end_points[end_point] = InceptionModule(128 + 256 + 64 + 64, [112, 144, 288, 32, 64, 64], name + end_point) 268 | if self._final_endpoint == end_point: return 269 | 270 | end_point = 'Mixed_4f' 271 | self.end_points[end_point] = InceptionModule(112 + 288 + 64 + 64, [256, 160, 320, 32, 128, 128], 272 | name + end_point) 273 | if self._final_endpoint == end_point: return 274 | 275 | end_point = 'MaxPool3d_5a_2x2' 276 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(2, 2, 2), 277 | padding=0) 278 | if self._final_endpoint == end_point: return 279 | 280 | end_point = 'Mixed_5b' 281 | self.end_points[end_point] = InceptionModule(256 + 320 + 128 + 128, [256, 160, 320, 32, 128, 128], 282 | name + end_point) 283 | if self._final_endpoint == end_point: return 284 | 285 | end_point = 'Mixed_5c' 286 | self.end_points[end_point] = InceptionModule(256 + 320 + 128 + 128, [384, 192, 384, 48, 128, 128], 287 | name + end_point) 288 | if self._final_endpoint == end_point: return 289 | 290 | end_point = 'Logits' 291 | self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7], 292 | stride=(1, 1, 1)) 293 | self.dropout = nn.Dropout(dropout_keep_prob) 294 | self.logits = Unit3D(in_channels=384 + 384 + 128 + 128, output_channels=self._num_classes, 295 | kernel_shape=[1, 1, 1], 296 | padding=0, 297 | activation_fn=None, 298 | use_batch_norm=False, 299 | use_bias=True, 300 | name='logits') 301 | 302 | self.build() 303 | 304 | def replace_logits(self, num_classes): 305 | self._num_classes = num_classes 306 | self.logits = Unit3D(in_channels=384 + 384 + 128 + 128, output_channels=self._num_classes, 307 | kernel_shape=[1, 1, 1], 308 | padding=0, 309 | activation_fn=None, 310 | use_batch_norm=False, 311 | use_bias=True, 312 | name='logits') 313 | 314 | def build(self): 315 | for k in self.end_points.keys(): 316 | self.add_module(k, self.end_points[k]) 317 | 318 | def forward(self, x): 319 | for end_point in self.VALID_ENDPOINTS: 320 | if end_point in self.end_points: 321 | x = self._modules[end_point](x) # use _modules to work with dataparallel 322 | 323 | x = self.avg_pool(x) 324 | x = x.squeeze(4).squeeze(3).squeeze(2) 325 | return x 326 | 327 | def extract_features(self, x): 328 | for end_point in self.VALID_ENDPOINTS: 329 | if end_point in self.end_points: 330 | x = self._modules[end_point](x) 331 | return self.avg_pool(x) 332 | 333 | 334 | def network_init(num_devices, mode, weight_file): 335 | if mode == 'rgb': 336 | model = InceptionI3d(400, in_channels=3) 337 | else: 338 | model = InceptionI3d(400, in_channels=2) 339 | model.load_state_dict(torch.load(weight_file)) 340 | model = model.cuda() 341 | model.eval() 342 | model = nn.DataParallel(model, device_ids=list(range(num_devices))) 343 | print('I3D model with mode %s initialized for evaluation mode' % mode) 344 | 345 | return model 346 | 347 | 348 | if __name__ == '__main__': 349 | i3d = InceptionI3d(400, in_channels=3) 350 | i3d = i3d.cuda() 351 | data = torch.randn(2, 3, 16, 224, 224).cuda() 352 | fea = i3d(data) 353 | # fea = i3d.extract_features(data) 354 | print(fea.size()) 355 | 356 | --------------------------------------------------------------------------------