├── .idea
├── .gitignore
├── vcs.xml
├── misc.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
└── ExtractVideoFeature.iml
├── README.md
├── utils.py
├── .gitignore
├── demo_extract_i3d_feature.py
├── videotransforms.py
├── extract_features_from_video.py
└── i3d.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /workspace.xml
3 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ExtractVideoFeature
2 | Extract video features. Currently, the models includes I3D, will be continuously updated.
3 |
4 | This project is based on [piergiaj/pytorch-i3d](https://github.com/piergiaj/pytorch-i3d). Converted models can be downloaded from [GitHub](https://github.com/piergiaj/pytorch-i3d/tree/master/models) or [BaiDuYun](https://pan.baidu.com/s/1m1yG4JUUtSix0-MTQKpoDA) with code: s416.
5 |
--------------------------------------------------------------------------------
/.idea/ExtractVideoFeature.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | Some useful functions
3 | '''
4 |
5 | from PIL import Image
6 | import zipfile
7 | import io
8 | import cv2
9 |
10 |
11 | def load_zipframe(zipdata, name, resize=False):
12 |
13 | stream = zipdata.read(name)
14 | data = Image.open(io.BytesIO(stream))
15 |
16 | assert(data.size[1] == 256)
17 | assert(data.size[0] == 340)
18 |
19 | if resize:
20 | data = data.resize((224, 224), Image.ANTIALIAS)
21 |
22 | data = np.array(data)
23 | data = data.astype(float)
24 | data = (data * 2 / 255) - 1
25 |
26 | assert(data.max()<=1.0)
27 | assert(data.min()>=-1.0)
28 |
29 | return data
30 |
31 |
32 | def load_ziprgb_batch(rgb_zipdata, rgb_files,
33 | frame_indices, resize=False):
34 |
35 | if resize:
36 | batch_data = np.zeros(frame_indices.shape + (224,224,3))
37 | else:
38 | batch_data = np.zeros(frame_indices.shape + (256,340,3))
39 |
40 | for i in range(frame_indices.shape[0]):
41 | for j in range(frame_indices.shape[1]):
42 |
43 | batch_data[i,j,:,:,:] = load_zipframe(rgb_zipdata,
44 | rgb_files[frame_indices[i][j]], resize)
45 |
46 | return batch_data
47 |
48 |
49 | def load_zipflow_batch(flow_x_zipdata, flow_y_zipdata,
50 | flow_x_files, flow_y_files,
51 | frame_indices, resize=False):
52 |
53 | if resize:
54 | batch_data = np.zeros(frame_indices.shape + (224,224,2))
55 | else:
56 | batch_data = np.zeros(frame_indices.shape + (256,340,2))
57 |
58 | for i in range(frame_indices.shape[0]):
59 | for j in range(frame_indices.shape[1]):
60 |
61 | batch_data[i,j,:,:,0] = load_zipframe(flow_x_zipdata,
62 | flow_x_files[frame_indices[i][j]], resize)
63 |
64 | batch_data[i,j,:,:,1] = load_zipframe(flow_y_zipdata,
65 | flow_y_files[frame_indices[i][j]], resize)
66 |
67 | return batch_data
68 |
69 |
70 |
71 |
72 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | slowfast/
131 | models/
132 | features_1/
133 | videos_1/
134 |
--------------------------------------------------------------------------------
/demo_extract_i3d_feature.py:
--------------------------------------------------------------------------------
1 | '''
2 | We can use torchvision.io.read_video to read videos,
3 | However, it lies in the latest torchvision, in updating.
4 |
5 | vframes, _, info = torchvision.io.read_video(vid_file)
6 | print('video frames info', vframes.shape, info)
7 | '''
8 | import os
9 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
10 | os.environ["CUDA_VISIBLE_DEVICES"] = "2"
11 | os.environ['OMP_NUM_THREADS'] = '1'
12 |
13 | import argparse
14 | from i3d import network_init
15 | from extract_features_from_video import run
16 | import numpy as np
17 | import cv2
18 |
19 |
20 | def args_parser():
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument('-vid_dir', help='directory to video file, e.g., *.mp4', default='/disk2/yangle/ExtractVideoFeature/videos_1')
23 | parser.add_argument('-feature_dir', help='directory to save feature file', default='/disk2/yangle/ExtractVideoFeature/features_1')
24 | parser.add_argument('-stride', default=1, help='feature extraction stride')
25 | parser.add_argument('-chunk_size', help='snippet frame number for one feature extraction', default=16)
26 | parser.add_argument('-batch_size', default=8)
27 | parser.add_argument('-mode', default='rgb', help='rgb or flow')
28 | parser.add_argument('-weight_file', default='/disk2/yangle/ExtractVideoFeature/models/rgb_imagenet.pt')
29 |
30 | args = parser.parse_args()
31 | return args
32 |
33 |
34 | def video2array(video_file):
35 | cap = cv2.VideoCapture(video_file)
36 | frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
37 | # Notice: we should perform resize when load frames
38 | datas = np.zeros((frameCount, 224, 224, 3), np.dtype('uint8'))
39 |
40 | fc = 0
41 | ret = True
42 | while (fc < frameCount) and ret:
43 | ret, img = cap.read()
44 | if ret:
45 | # data = cv2.resize(img, (224, 224), interpolation=cv2.INTER_BITS)
46 | data = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
47 | datas[fc, :, :, :] = data
48 | fc += 1
49 | cap.release()
50 | return datas
51 |
52 |
53 | def extract_feature(args):
54 | model_i3d = network_init(num_devices=1, mode=args.mode, weight_file=args.weight_file)
55 | if not os.path.exists(args.feature_dir):
56 | os.makedirs(args.feature_dir)
57 |
58 | vid_name_set = os.listdir(args.vid_dir)
59 | vid_name_set.sort(reverse=False)
60 |
61 | for video_name in vid_name_set:
62 | vid_name = video_name[:-4]
63 |
64 | feature_file = os.path.join(args.feature_dir, vid_name + '.npz')
65 | if os.path.exists(feature_file):
66 | print('Video %s already calculated, skip' % video_name)
67 | continue
68 |
69 | vid_file = os.path.join(args.vid_dir, vid_name + '.mp4')
70 | vframes = video2array(vid_file)
71 | print('video frames info', vframes.shape)
72 | if vframes.shape[0] < args.chunk_size:
73 | print('video %s contains %d frames, too short, skip' % (video_name, vframes.shape[0]))
74 | continue
75 |
76 | run(model_i3d, vid_name, vframes, chunk_size=args.chunk_size,
77 | stride=args.stride,
78 | output_dir=args.feature_dir, batch_size=args.batch_size)
79 |
80 |
81 | if __name__ == '__main__':
82 | args = args_parser()
83 | extract_feature(args)
84 |
--------------------------------------------------------------------------------
/videotransforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import numbers
3 | import random
4 |
5 | class RandomCrop(object):
6 | """Crop the given video sequences (t x h x w) at a random location.
7 | Args:
8 | size (sequence or int): Desired output size of the crop. If size is an
9 | int instead of sequence like (h, w), a square crop (size, size) is
10 | made.
11 | """
12 |
13 | def __init__(self, size):
14 | if isinstance(size, numbers.Number):
15 | self.size = (int(size), int(size))
16 | else:
17 | self.size = size
18 |
19 | @staticmethod
20 | def get_params(img, output_size):
21 | """Get parameters for ``crop`` for a random crop.
22 | Args:
23 | img (PIL Image): Image to be cropped.
24 | output_size (tuple): Expected output size of the crop.
25 | Returns:
26 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
27 | """
28 | t, h, w, c = img.shape
29 | th, tw = output_size
30 | if w == tw and h == th:
31 | return 0, 0, h, w
32 |
33 | i = random.randint(0, h - th) if h!=th else 0
34 | j = random.randint(0, w - tw) if w!=tw else 0
35 | return i, j, th, tw
36 |
37 | def __call__(self, imgs):
38 |
39 | i, j, h, w = self.get_params(imgs, self.size)
40 |
41 | imgs = imgs[:, i:i+h, j:j+w, :]
42 | return imgs
43 |
44 | def __repr__(self):
45 | return self.__class__.__name__ + '(size={0})'.format(self.size)
46 |
47 | class CenterCrop(object):
48 | """Crops the given seq Images at the center.
49 | Args:
50 | size (sequence or int): Desired output size of the crop. If size is an
51 | int instead of sequence like (h, w), a square crop (size, size) is
52 | made.
53 | """
54 |
55 | def __init__(self, size):
56 | if isinstance(size, numbers.Number):
57 | self.size = (int(size), int(size))
58 | else:
59 | self.size = size
60 |
61 | def __call__(self, imgs):
62 | """
63 | Args:
64 | img (PIL Image): Image to be cropped.
65 | Returns:
66 | PIL Image: Cropped image.
67 | """
68 | t, h, w, c = imgs.shape
69 | th, tw = self.size
70 | i = int(np.round((h - th) / 2.))
71 | j = int(np.round((w - tw) / 2.))
72 |
73 | return imgs[:, i:i+th, j:j+tw, :]
74 |
75 |
76 | def __repr__(self):
77 | return self.__class__.__name__ + '(size={0})'.format(self.size)
78 |
79 |
80 | class RandomHorizontalFlip(object):
81 | """Horizontally flip the given seq Images randomly with a given probability.
82 | Args:
83 | p (float): probability of the image being flipped. Default value is 0.5
84 | """
85 |
86 | def __init__(self, p=0.5):
87 | self.p = p
88 |
89 | def __call__(self, imgs):
90 | """
91 | Args:
92 | img (seq Images): seq Images to be flipped.
93 | Returns:
94 | seq Images: Randomly flipped seq images.
95 | """
96 | if random.random() < self.p:
97 | # t x h x w
98 | return np.flip(imgs, axis=2).copy()
99 | return imgs
100 |
101 | def __repr__(self):
102 | return self.__class__.__name__ + '(p={})'.format(self.p)
103 |
--------------------------------------------------------------------------------
/extract_features_from_video.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import math
5 | from PIL import Image
6 | import numpy as np
7 |
8 |
9 | def load_frame(frame_file, resize=False):
10 | data = Image.open(frame_file)
11 |
12 | assert (data.size[1] == 256)
13 | assert (data.size[0] == 340)
14 |
15 | if resize:
16 | data = data.resize((224, 224), Image.ANTIALIAS)
17 |
18 | data = np.array(data)
19 | data = data.astype(float)
20 | data = (data * 2 / 255) - 1
21 |
22 | assert (data.max() <= 1.0)
23 | assert (data.min() >= -1.0)
24 |
25 | return data
26 |
27 |
28 | def vframes_pre_process(vframes):
29 | # Rescale pixel values to [-1, 1]
30 | vframes = vframes.astype(float)
31 | datas = (vframes * 2 / 255) - 1
32 | return datas
33 |
34 |
35 | def load_rgb_batch(vframes, frame_indices, chunk_size):
36 | # frame_indices: [batch, chunk_size]
37 | batch_data = np.zeros(frame_indices.shape + (224, 224, 3))
38 |
39 | # chunk_size = 8
40 | for i in range(frame_indices.shape[0]):
41 | indices_list = list(frame_indices[i, :])
42 | for j in indices_list:
43 | idx_set = j % chunk_size
44 | batch_data[i, idx_set, :, :, :] = vframes[j, :, :, :]
45 |
46 | return batch_data # [batch, chunk_size, 224, 224, 3]
47 |
48 |
49 | def load_flow_batch(frames_dir, flow_x_files, flow_y_files,
50 | frame_indices, resize=False):
51 | if resize:
52 | batch_data = np.zeros(frame_indices.shape + (224, 224, 2))
53 | else:
54 | batch_data = np.zeros(frame_indices.shape + (256, 340, 2))
55 |
56 | for i in range(frame_indices.shape[0]):
57 | for j in range(frame_indices.shape[1]):
58 | batch_data[i, j, :, :, 0] = load_frame(os.path.join(frames_dir,
59 | flow_x_files[frame_indices[i][j]]), resize)
60 |
61 | batch_data[i, j, :, :, 1] = load_frame(os.path.join(frames_dir,
62 | flow_y_files[frame_indices[i][j]]), resize)
63 |
64 | return batch_data
65 |
66 |
67 | def run(model, video_name, vframes, chunk_size=16, stride=16, output_dir='', batch_size=1):
68 |
69 | def forward_batch(b_data):
70 | b_data = b_data.transpose([0, 4, 1, 2, 3])
71 | b_data = torch.from_numpy(b_data) # b,c,t,h,w # 40x3x16x224x224
72 |
73 | b_data = b_data.cuda().float()
74 | with torch.no_grad():
75 | b_features = model(b_data)
76 |
77 | b_features = b_features.data.cpu().numpy()[:, :]
78 | return b_features
79 |
80 | save_file = '{}.npz'.format(video_name)
81 |
82 | frame_cnt = vframes.shape[0]
83 | vframes = vframes_pre_process(vframes)
84 |
85 | # Cut frames
86 | clipped_length = math.floor((frame_cnt - chunk_size) / stride) + 1
87 | frame_indices = list()
88 |
89 | for i in range(0, frame_cnt-chunk_size, stride):
90 | indices = [j for j in range(i, i+chunk_size)]
91 | frame_indices.append(indices)
92 | frame_indices = np.array(frame_indices)
93 |
94 | chunk_num = frame_indices.shape[0]
95 |
96 | batch_num = int(np.ceil(chunk_num / batch_size)) # Chunks to batches
97 | frame_indices = np.array_split(frame_indices, batch_num, axis=0)
98 |
99 | full_features = list()
100 | for batch_id in range(batch_num):
101 | batch_data = load_rgb_batch(vframes, frame_indices[batch_id], chunk_size)
102 | full_features.append(forward_batch(batch_data))
103 |
104 | full_features = np.concatenate(full_features, axis=0)
105 |
106 | np.savez(os.path.join(output_dir, save_file),
107 | feature=full_features,
108 | frame_cnt=frame_cnt,
109 | video_name=video_name)
110 |
111 | print('{} done: {} / {}, {}'.format(video_name, frame_cnt, clipped_length, full_features.shape))
112 |
--------------------------------------------------------------------------------
/i3d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 |
6 | import numpy as np
7 |
8 |
9 | class MaxPool3dSamePadding(nn.MaxPool3d):
10 |
11 | def compute_pad(self, dim, s):
12 | if s % self.stride[dim] == 0:
13 | return max(self.kernel_size[dim] - self.stride[dim], 0)
14 | else:
15 | return max(self.kernel_size[dim] - (s % self.stride[dim]), 0)
16 |
17 | def forward(self, x):
18 | # compute 'same' padding
19 | (batch, channel, t, h, w) = x.size()
20 | # print t,h,w
21 | out_t = np.ceil(float(t) / float(self.stride[0]))
22 | out_h = np.ceil(float(h) / float(self.stride[1]))
23 | out_w = np.ceil(float(w) / float(self.stride[2]))
24 | # print out_t, out_h, out_w
25 | pad_t = self.compute_pad(0, t)
26 | pad_h = self.compute_pad(1, h)
27 | pad_w = self.compute_pad(2, w)
28 | # print pad_t, pad_h, pad_w
29 |
30 | pad_t_f = pad_t // 2
31 | pad_t_b = pad_t - pad_t_f
32 | pad_h_f = pad_h // 2
33 | pad_h_b = pad_h - pad_h_f
34 | pad_w_f = pad_w // 2
35 | pad_w_b = pad_w - pad_w_f
36 |
37 | pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
38 | # print x.size()
39 | # print pad
40 | x = F.pad(x, pad)
41 | return super(MaxPool3dSamePadding, self).forward(x)
42 |
43 |
44 | class Unit3D(nn.Module):
45 |
46 | def __init__(self, in_channels,
47 | output_channels,
48 | kernel_shape=(1, 1, 1),
49 | stride=(1, 1, 1),
50 | padding=0,
51 | activation_fn=F.relu,
52 | use_batch_norm=True,
53 | use_bias=False,
54 | name='unit_3d'):
55 |
56 | """Initializes Unit3D module."""
57 | super(Unit3D, self).__init__()
58 |
59 | self._output_channels = output_channels
60 | self._kernel_shape = kernel_shape
61 | self._stride = stride
62 | self._use_batch_norm = use_batch_norm
63 | self._activation_fn = activation_fn
64 | self._use_bias = use_bias
65 | self.name = name
66 | self.padding = padding
67 |
68 | self.conv3d = nn.Conv3d(in_channels=in_channels,
69 | out_channels=self._output_channels,
70 | kernel_size=self._kernel_shape,
71 | stride=self._stride,
72 | padding=0,
73 | # we always want padding to be 0 here. We will dynamically pad based on input size in forward function
74 | bias=self._use_bias)
75 |
76 | if self._use_batch_norm:
77 | self.bn = nn.BatchNorm3d(self._output_channels, eps=0.001, momentum=0.01)
78 |
79 | def compute_pad(self, dim, s):
80 | if s % self._stride[dim] == 0:
81 | return max(self._kernel_shape[dim] - self._stride[dim], 0)
82 | else:
83 | return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0)
84 |
85 | def forward(self, x):
86 | # compute 'same' padding
87 | (batch, channel, t, h, w) = x.size()
88 | # print t,h,w
89 | out_t = np.ceil(float(t) / float(self._stride[0]))
90 | out_h = np.ceil(float(h) / float(self._stride[1]))
91 | out_w = np.ceil(float(w) / float(self._stride[2]))
92 | # print out_t, out_h, out_w
93 | pad_t = self.compute_pad(0, t)
94 | pad_h = self.compute_pad(1, h)
95 | pad_w = self.compute_pad(2, w)
96 | # print pad_t, pad_h, pad_w
97 |
98 | pad_t_f = pad_t // 2
99 | pad_t_b = pad_t - pad_t_f
100 | pad_h_f = pad_h // 2
101 | pad_h_b = pad_h - pad_h_f
102 | pad_w_f = pad_w // 2
103 | pad_w_b = pad_w - pad_w_f
104 |
105 | pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
106 | # print x.size()
107 | # print pad
108 | x = F.pad(x, pad)
109 | # print x.size()
110 |
111 | x = self.conv3d(x)
112 | if self._use_batch_norm:
113 | x = self.bn(x)
114 | if self._activation_fn is not None:
115 | x = self._activation_fn(x)
116 | return x
117 |
118 |
119 | class InceptionModule(nn.Module):
120 | def __init__(self, in_channels, out_channels, name):
121 | super(InceptionModule, self).__init__()
122 |
123 | self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0,
124 | name=name + '/Branch_0/Conv3d_0a_1x1')
125 | self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0,
126 | name=name + '/Branch_1/Conv3d_0a_1x1')
127 | self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 3, 3],
128 | name=name + '/Branch_1/Conv3d_0b_3x3')
129 | self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0,
130 | name=name + '/Branch_2/Conv3d_0a_1x1')
131 | self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 3, 3],
132 | name=name + '/Branch_2/Conv3d_0b_3x3')
133 | self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3],
134 | stride=(1, 1, 1), padding=0)
135 | self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0,
136 | name=name + '/Branch_3/Conv3d_0b_1x1')
137 | self.name = name
138 |
139 | def forward(self, x):
140 | b0 = self.b0(x)
141 | b1 = self.b1b(self.b1a(x))
142 | b2 = self.b2b(self.b2a(x))
143 | b3 = self.b3b(self.b3a(x))
144 | return torch.cat([b0, b1, b2, b3], dim=1)
145 |
146 |
147 | class InceptionI3d(nn.Module):
148 | """Inception-v1 I3D architecture.
149 | The model is introduced in:
150 | Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset
151 | Joao Carreira, Andrew Zisserman
152 | https://arxiv.org/pdf/1705.07750v1.pdf.
153 | See also the Inception architecture, introduced in:
154 | Going deeper with convolutions
155 | Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
156 | Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
157 | http://arxiv.org/pdf/1409.4842v1.pdf.
158 | """
159 |
160 | # Endpoints of the model in order. During construction, all the endpoints up
161 | # to a designated `final_endpoint` are returned in a dictionary as the
162 | # second return value.
163 | VALID_ENDPOINTS = (
164 | 'Conv3d_1a_7x7',
165 | 'MaxPool3d_2a_3x3',
166 | 'Conv3d_2b_1x1',
167 | 'Conv3d_2c_3x3',
168 | 'MaxPool3d_3a_3x3',
169 | 'Mixed_3b',
170 | 'Mixed_3c',
171 | 'MaxPool3d_4a_3x3',
172 | 'Mixed_4b',
173 | 'Mixed_4c',
174 | 'Mixed_4d',
175 | 'Mixed_4e',
176 | 'Mixed_4f',
177 | 'MaxPool3d_5a_2x2',
178 | 'Mixed_5b',
179 | 'Mixed_5c',
180 | 'Logits',
181 | 'Predictions',
182 | )
183 |
184 | def __init__(self, num_classes=400, spatial_squeeze=True,
185 | final_endpoint='Logits', name='inception_i3d', in_channels=3, dropout_keep_prob=0.5):
186 | """Initializes I3D model instance.
187 | Args:
188 | num_classes: The number of outputs in the logit layer (default 400, which
189 | matches the Kinetics dataset).
190 | spatial_squeeze: Whether to squeeze the spatial dimensions for the logits
191 | before returning (default True).
192 | final_endpoint: The model contains many possible endpoints.
193 | `final_endpoint` specifies the last endpoint for the model to be built
194 | up to. In addition to the output at `final_endpoint`, all the outputs
195 | at endpoints up to `final_endpoint` will also be returned, in a
196 | dictionary. `final_endpoint` must be one of
197 | InceptionI3d.VALID_ENDPOINTS (default 'Logits').
198 | name: A string (optional). The name of this module.
199 | Raises:
200 | ValueError: if `final_endpoint` is not recognized.
201 | """
202 |
203 | if final_endpoint not in self.VALID_ENDPOINTS:
204 | raise ValueError('Unknown final endpoint %s' % final_endpoint)
205 |
206 | super(InceptionI3d, self).__init__()
207 | self._num_classes = num_classes
208 | self._spatial_squeeze = spatial_squeeze
209 | self._final_endpoint = final_endpoint
210 | self.logits = None
211 |
212 | if self._final_endpoint not in self.VALID_ENDPOINTS:
213 | raise ValueError('Unknown final endpoint %s' % self._final_endpoint)
214 |
215 | self.end_points = {}
216 | end_point = 'Conv3d_1a_7x7'
217 | self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7],
218 | stride=(2, 2, 2), padding=(3, 3, 3), name=name + end_point)
219 | if self._final_endpoint == end_point: return
220 |
221 | end_point = 'MaxPool3d_2a_3x3'
222 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
223 | padding=0)
224 | if self._final_endpoint == end_point: return
225 |
226 | end_point = 'Conv3d_2b_1x1'
227 | self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0,
228 | name=name + end_point)
229 | if self._final_endpoint == end_point: return
230 |
231 | end_point = 'Conv3d_2c_3x3'
232 | self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1,
233 | name=name + end_point)
234 | if self._final_endpoint == end_point: return
235 |
236 | end_point = 'MaxPool3d_3a_3x3'
237 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
238 | padding=0)
239 | if self._final_endpoint == end_point: return
240 |
241 | end_point = 'Mixed_3b'
242 | self.end_points[end_point] = InceptionModule(192, [64, 96, 128, 16, 32, 32], name + end_point)
243 | if self._final_endpoint == end_point: return
244 |
245 | end_point = 'Mixed_3c'
246 | self.end_points[end_point] = InceptionModule(256, [128, 128, 192, 32, 96, 64], name + end_point)
247 | if self._final_endpoint == end_point: return
248 |
249 | end_point = 'MaxPool3d_4a_3x3'
250 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2),
251 | padding=0)
252 | if self._final_endpoint == end_point: return
253 |
254 | end_point = 'Mixed_4b'
255 | self.end_points[end_point] = InceptionModule(128 + 192 + 96 + 64, [192, 96, 208, 16, 48, 64], name + end_point)
256 | if self._final_endpoint == end_point: return
257 |
258 | end_point = 'Mixed_4c'
259 | self.end_points[end_point] = InceptionModule(192 + 208 + 48 + 64, [160, 112, 224, 24, 64, 64], name + end_point)
260 | if self._final_endpoint == end_point: return
261 |
262 | end_point = 'Mixed_4d'
263 | self.end_points[end_point] = InceptionModule(160 + 224 + 64 + 64, [128, 128, 256, 24, 64, 64], name + end_point)
264 | if self._final_endpoint == end_point: return
265 |
266 | end_point = 'Mixed_4e'
267 | self.end_points[end_point] = InceptionModule(128 + 256 + 64 + 64, [112, 144, 288, 32, 64, 64], name + end_point)
268 | if self._final_endpoint == end_point: return
269 |
270 | end_point = 'Mixed_4f'
271 | self.end_points[end_point] = InceptionModule(112 + 288 + 64 + 64, [256, 160, 320, 32, 128, 128],
272 | name + end_point)
273 | if self._final_endpoint == end_point: return
274 |
275 | end_point = 'MaxPool3d_5a_2x2'
276 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(2, 2, 2),
277 | padding=0)
278 | if self._final_endpoint == end_point: return
279 |
280 | end_point = 'Mixed_5b'
281 | self.end_points[end_point] = InceptionModule(256 + 320 + 128 + 128, [256, 160, 320, 32, 128, 128],
282 | name + end_point)
283 | if self._final_endpoint == end_point: return
284 |
285 | end_point = 'Mixed_5c'
286 | self.end_points[end_point] = InceptionModule(256 + 320 + 128 + 128, [384, 192, 384, 48, 128, 128],
287 | name + end_point)
288 | if self._final_endpoint == end_point: return
289 |
290 | end_point = 'Logits'
291 | self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7],
292 | stride=(1, 1, 1))
293 | self.dropout = nn.Dropout(dropout_keep_prob)
294 | self.logits = Unit3D(in_channels=384 + 384 + 128 + 128, output_channels=self._num_classes,
295 | kernel_shape=[1, 1, 1],
296 | padding=0,
297 | activation_fn=None,
298 | use_batch_norm=False,
299 | use_bias=True,
300 | name='logits')
301 |
302 | self.build()
303 |
304 | def replace_logits(self, num_classes):
305 | self._num_classes = num_classes
306 | self.logits = Unit3D(in_channels=384 + 384 + 128 + 128, output_channels=self._num_classes,
307 | kernel_shape=[1, 1, 1],
308 | padding=0,
309 | activation_fn=None,
310 | use_batch_norm=False,
311 | use_bias=True,
312 | name='logits')
313 |
314 | def build(self):
315 | for k in self.end_points.keys():
316 | self.add_module(k, self.end_points[k])
317 |
318 | def forward(self, x):
319 | for end_point in self.VALID_ENDPOINTS:
320 | if end_point in self.end_points:
321 | x = self._modules[end_point](x) # use _modules to work with dataparallel
322 |
323 | x = self.avg_pool(x)
324 | x = x.squeeze(4).squeeze(3).squeeze(2)
325 | return x
326 |
327 | def extract_features(self, x):
328 | for end_point in self.VALID_ENDPOINTS:
329 | if end_point in self.end_points:
330 | x = self._modules[end_point](x)
331 | return self.avg_pool(x)
332 |
333 |
334 | def network_init(num_devices, mode, weight_file):
335 | if mode == 'rgb':
336 | model = InceptionI3d(400, in_channels=3)
337 | else:
338 | model = InceptionI3d(400, in_channels=2)
339 | model.load_state_dict(torch.load(weight_file))
340 | model = model.cuda()
341 | model.eval()
342 | model = nn.DataParallel(model, device_ids=list(range(num_devices)))
343 | print('I3D model with mode %s initialized for evaluation mode' % mode)
344 |
345 | return model
346 |
347 |
348 | if __name__ == '__main__':
349 | i3d = InceptionI3d(400, in_channels=3)
350 | i3d = i3d.cuda()
351 | data = torch.randn(2, 3, 16, 224, 224).cuda()
352 | fea = i3d(data)
353 | # fea = i3d.extract_features(data)
354 | print(fea.size())
355 |
356 |
--------------------------------------------------------------------------------