├── .github └── nn.png ├── CMakeLists.txt ├── LICENSE ├── README.md ├── src ├── contextual_ms.py ├── contextual_tf.py └── parser.cpp └── test ├── concurrent_decode.py ├── contextual_ms.ipynb └── contextual_tf.ipynb /.github/nn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanmu97/PacketGame/99781a289bd4cb0639a1957d7ae178fefd80e751/.github/nn.png -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.25) 2 | project(parser) 3 | add_executable(parser src/parser.cpp) 4 | 5 | set(FFMPEG_LIBRARY_PATH /usr/local/lib) 6 | include_directories(/usr/local/include/) 7 | link_directories(${FFMPEG_LIBRARY_PATH}) 8 | target_link_libraries(parser 9 | ${FFMPEG_LIBRARY_PATH}/libavcodec.so 10 | ${FFMPEG_LIBRARY_PATH}/libavutil.so 11 | ) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Mu Yuan (袁牧) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PacketGame 2 | 3 | PacketGame is a pre-decoding packet filter for concurrent video inference at scale. 4 | 5 | Our paper "PacketGame: Multi-Stream Packet Gating for Concurrent Video Inference at Scale" is going to appear at *ACM SIGCOMM 2023*. 6 | 7 | ## Installation 8 | 9 | OS: Ubuntu 20.04 10 | 11 | ### Tensorflow/Mindspore 12 | 13 | We implement the neural network-based contextual predictor in PacketGame by Tensorflow 2.4.1 and Mindspore 2.0.0. 14 | 15 | Please refer to the installation docs: [Tensorflow](https://www.tensorflow.org/install) / [Mindspore](https://www.mindspore.cn/install) 16 | 17 | ### FFmpeg with nv-codec 18 | 19 | To use FFmpeg with NVIDIA GPU, we need to compile in from source (refers to [NVIDIA doc](https://docs.nvidia.com/video-technologies/video-codec-sdk/11.1/ffmpeg-with-nvidia-gpu/index.html)). 20 | 21 | Install ffnvcodec: 22 | ```bash 23 | git clone https://git.videolan.org/git/ffmpeg/nv-codec-headers.git 24 | cd nv-codec-headers 25 | make install 26 | ``` 27 | 28 | Install necessary packages: 29 | ```bash 30 | apt-get install yasm cmake 31 | # codecs: h.264, h.265, vp9, jp2k 32 | apt-get install libx264-dev libx265-dev libvpx-dev libopenjp2-7-dev 33 | ``` 34 | 35 | Download ([v5.1](https://github.com/FFmpeg/FFmpeg/tree/release/5.1)) and install FFmpeg: 36 | ```bash 37 | cd FFmpeg-release-5.1/ 38 | ./configure --enable-nonfree --enable-cuda-nvcc --enable-libnpp --extra-cflags=-I/usr/local/cuda/include --extra-ldflags=-L/usr/local/cuda/lib64 --disable-static --enable-shared --enable-gpl --enable-libx264 --enable-libx265 --enable-libvpx --enable-libopenjpeg 39 | make -j 8 40 | make install 41 | # test 42 | ffmpeg 43 | ------------------------------------------------------------------- 44 | ffmpeg version 5.1.2 Copyright (c) 2000-2022 the FFmpeg developers 45 | built with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1) 46 | configuration: --enable-nonfree --enable-cuda-nvcc --enable-libnpp --extra-cflags=-I/usr/local/cuda/include --extra-ldflags=-L/usr/local/cuda/lib64 --disable-static --enable-shared --enable-gpl --enable-libx264 --enable-libx265 --enable-libvpx --enable-libopenjpeg 47 | libavutil 57. 28.100 / 57. 28.100 48 | libavcodec 59. 37.100 / 59. 37.100 49 | libavformat 59. 27.100 / 59. 27.100 50 | libavdevice 59. 7.100 / 59. 7.100 51 | libavfilter 8. 44.100 / 8. 44.100 52 | libswscale 6. 7.100 / 6. 7.100 53 | libswresample 4. 7.100 / 4. 7.100 54 | libpostproc 56. 6.100 / 56. 6.100 55 | Hyper fast Audio and Video encoder 56 | usage: ffmpeg [options] [[infile options] -i infile]... {[outfile options] outfile}... 57 | ``` 58 | 59 | ## Packet Parser 60 | 61 | The first step is to parse the video and save its metadata (packet size and picture type). 62 | 63 | ```bash 64 | mkdir build 65 | cd build 66 | cmake .. 67 | make 68 | # test 69 | ./parser ../test/sample_video.h265 ../test/sampel_video_meta.txt 70 | ----------------------------------------------------------------- 71 | 251 packets parsed in 0.007604 seconds. 72 | ``` 73 | 74 | ## Concurrent Decoding 75 | 76 | Platform: 12 Intel Core i7-5930K CPUs / NVIDIA TITAN X GPU 77 | 78 | ```bash 79 | cd test 80 | # set USEGPU=1 81 | python concurrent_decode.py 82 | ---------------------------------------- 83 | concurrency time cost (s) fps 84 | ------------- --------------- ------- 85 | 1 1.51226 165.316 86 | 5 3.70409 337.465 87 | 10 6.72792 371.586 88 | 20 13.6818 365.449 89 | 30 19.5902 382.845 90 | 35 18.7045 467.803 91 | 40 23.5301 424.988 92 | 45 25.1256 447.75 93 | 50 27.077 461.646 94 | ``` 95 | 96 | ## Contextual Predictor 97 | 98 | 99 | 100 | Implementation using TensorFlow (`src/contextual_tf.py`): 101 | ```python 102 | def build_ensemble_threeview(inp_len1=5, inp_len2=5, inp_len3=1, conv_units=[32, 32], dense_unit=128): 103 | """ 104 | build three-view neural network 105 | 106 | Args: 107 | inp_len1, inp_len2, inp_len3 (int): input length of three views, 5 for inp1 & inp2 and 1 for inp3 by default 108 | conv_units (list of int): number of conv1d units 109 | by default, two conv1d layers with 32 units 110 | dense_unit (int): number of dense units, 128 by default 111 | Return: 112 | tf.keras.Model instance 113 | """ 114 | inp1 = layers.Input(shape=(None, inp_len1), name="View1-Indepdendent") 115 | inp2 = layers.Input(shape=(None, inp_len2), name="View2-Predicted") 116 | inp3 = layers.Input(shape=(inp_len3), name="View3-Temporal") 117 | x1 = inp1 118 | x2 = inp2 119 | x3 = inp3 120 | for u in conv_units: 121 | x1 = layers.Conv1D(u, 1, activation="relu")(x1) 122 | x2 = layers.Conv1D(u, 1, activation="relu")(x2) 123 | x1 = layers.GlobalMaxPooling1D()(x1) 124 | x2 = layers.GlobalMaxPooling1D()(x2) 125 | x = layers.Concatenate()([x1, x2]) 126 | out = layers.Dense(1, activation="sigmoid")(x) 127 | x4 = layers.Concatenate()([out, x3]) 128 | x4 = layers.Dense(dense_unit, activation="relu")(x4) 129 | out2 = layers.Dense(1, activation='sigmoid')(x4) 130 | return tf.keras.Model(inputs=[inp1, inp2, inp3], outputs=out2) 131 | 132 | m = build_ensemble_threeview(inp_len1=5, inp_len2=5, inp_len3=1, conv_units=[32, 32], dense_unit=128) 133 | print(m.summary()) 134 | -------------------------------------------------------------------------------------------------- 135 | Model: "model_1" 136 | __________________________________________________________________________________________________ 137 | Layer (type) Output Shape Param # Connected to 138 | ================================================================================================== 139 | View1-Indepdendent (InputLayer) [(None, None, 5)] 0 140 | __________________________________________________________________________________________________ 141 | View2-Predicted (InputLayer) [(None, None, 5)] 0 142 | __________________________________________________________________________________________________ 143 | conv1d_4 (Conv1D) (None, None, 32) 192 View1-Indepdendent[0][0] 144 | __________________________________________________________________________________________________ 145 | conv1d_5 (Conv1D) (None, None, 32) 192 View2-Predicted[0][0] 146 | __________________________________________________________________________________________________ 147 | conv1d_6 (Conv1D) (None, None, 32) 1056 conv1d_4[0][0] 148 | __________________________________________________________________________________________________ 149 | conv1d_7 (Conv1D) (None, None, 32) 1056 conv1d_5[0][0] 150 | __________________________________________________________________________________________________ 151 | global_max_pooling1d_2 (GlobalM (None, 32) 0 conv1d_6[0][0] 152 | __________________________________________________________________________________________________ 153 | global_max_pooling1d_3 (GlobalM (None, 32) 0 conv1d_7[0][0] 154 | __________________________________________________________________________________________________ 155 | concatenate_1 (Concatenate) (None, 64) 0 global_max_pooling1d_2[0][0] 156 | global_max_pooling1d_3[0][0] 157 | __________________________________________________________________________________________________ 158 | dense_1 (Dense) (None, 1) 65 concatenate_1[0][0] 159 | __________________________________________________________________________________________________ 160 | View3-Temporal (InputLayer) [(None, 1)] 0 161 | __________________________________________________________________________________________________ 162 | concatenate_2 (Concatenate) (None, 2) 0 dense_1[0][0] 163 | View3-Temporal[0][0] 164 | __________________________________________________________________________________________________ 165 | dense_2 (Dense) (None, 128) 384 concatenate_2[0][0] 166 | __________________________________________________________________________________________________ 167 | dense_3 (Dense) (None, 1) 129 dense_2[0][0] 168 | ================================================================================================== 169 | Total params: 3,074 170 | Trainable params: 3,074 171 | Non-trainable params: 0 172 | __________________________________________________________________________________________________ 173 | ``` 174 | 175 | Implementation using MindSpore (`src/contextual_ms.py`): 176 | ```python 177 | class EnsembleThreeview(nn.Cell): 178 | def __init__(self, inp_len1=5, inp_len2=5, inp_len3=1, conv_units=[32, 32], dense_unit=128): 179 | super(EnsembleThreeview, self).__init__() 180 | 181 | self.view1_layers = nn.CellList([nn.Conv1d(inp_len1, conv_units[0], 1, has_bias=True, pad_mode='valid'), 182 | nn.ReLU()]) 183 | self.view2_layers = nn.CellList([nn.Conv1d(inp_len2, conv_units[0], 1, has_bias=True, pad_mode='valid'), 184 | nn.ReLU()]) 185 | 186 | for i in range(len(conv_units)-1): 187 | self.view1_layers.append(nn.Conv1d(conv_units[i], conv_units[i+1], 1, has_bias=True, pad_mode='valid')) 188 | self.view1_layers.append(nn.ReLU()) 189 | 190 | self.view2_layers.append(nn.Conv1d(conv_units[i], conv_units[i+1], 1, has_bias=True, pad_mode='valid')) 191 | self.view2_layers.append(nn.ReLU()) 192 | 193 | self.view1_layers.append(nn.AdaptiveMaxPool1d(1)) 194 | self.view2_layers.append(nn.AdaptiveMaxPool1d(1)) 195 | 196 | self.dense = nn.Dense(conv_units[-1]*2, 1, has_bias=True) 197 | self.sigmoid = nn.Sigmoid() 198 | 199 | self.dense2 = nn.Dense(1+inp_len3, dense_unit, has_bias=True) 200 | self.relu = nn.ReLU() 201 | self.dense3 = nn.Dense(dense_unit, 1, has_bias=True) 202 | 203 | def construct(self, x1, x2, x3): 204 | for v1_layer, v2_layer in zip(self.view1_layers, self.view2_layers): 205 | x1 = v1_layer(x1) 206 | x2 = v2_layer(x2) 207 | x = ops.cat((x1, x2), axis=1) 208 | x = ops.squeeze(x, axis=-1) 209 | x = self.dense(x) 210 | x = self.sigmoid(x) 211 | 212 | x = ops.cat((x, x3), axis=1) 213 | x = self.dense2(x) 214 | x = self.relu(x) 215 | x = self.dense3(x) 216 | x = self.sigmoid(x) 217 | 218 | return x 219 | 220 | net = EnsembleThreeview(inp_len1=5, inp_len2=5, inp_len3=1, conv_units=[32, 32], dense_unit=128) 221 | print(net) 222 | ------------------------------------------------------------------------------------ 223 | EnsembleThreeview< 224 | (view1_layers): CellList< 225 | (0): Conv1d 226 | (1): ReLU<> 227 | (2): Conv1d 228 | (3): ReLU<> 229 | (4): AdaptiveMaxPool1d<> 230 | > 231 | (view2_layers): CellList< 232 | (0): Conv1d 233 | (1): ReLU<> 234 | (2): Conv1d 235 | (3): ReLU<> 236 | (4): AdaptiveMaxPool1d<> 237 | > 238 | (dense): Dense 239 | (sigmoid): Sigmoid<> 240 | (dense2): Dense 241 | (relu): ReLU<> 242 | (dense3): Dense 243 | > 244 | 245 | x1 = mindspore.Tensor(np.ones([1, 5, 1]), mindspore.float32) 246 | x2 = mindspore.Tensor(np.ones([1, 5, 1]), mindspore.float32) 247 | x3 = mindspore.Tensor(np.ones([1, 1]), mindspore.float32) 248 | print(net(x1, x2, x3).shape) 249 | ------------------------------ 250 | (1, 1) 251 | 252 | total_params = 0 253 | for v in net.parameters_dict().values(): 254 | print(v, v.size) 255 | total_params += v.size 256 | print(total_params) 257 | ----------------------- 258 | Parameter (name=view1_layers.0.weight, shape=(32, 5, 1, 1), dtype=Float32, requires_grad=True) 160 259 | Parameter (name=view1_layers.0.bias, shape=(32,), dtype=Float32, requires_grad=True) 32 260 | Parameter (name=view1_layers.2.weight, shape=(32, 32, 1, 1), dtype=Float32, requires_grad=True) 1024 261 | Parameter (name=view1_layers.2.bias, shape=(32,), dtype=Float32, requires_grad=True) 32 262 | Parameter (name=view2_layers.0.weight, shape=(32, 5, 1, 1), dtype=Float32, requires_grad=True) 160 263 | Parameter (name=view2_layers.0.bias, shape=(32,), dtype=Float32, requires_grad=True) 32 264 | Parameter (name=view2_layers.2.weight, shape=(32, 32, 1, 1), dtype=Float32, requires_grad=True) 1024 265 | Parameter (name=view2_layers.2.bias, shape=(32,), dtype=Float32, requires_grad=True) 32 266 | Parameter (name=dense.weight, shape=(1, 64), dtype=Float32, requires_grad=True) 64 267 | Parameter (name=dense.bias, shape=(1,), dtype=Float32, requires_grad=True) 1 268 | Parameter (name=dense2.weight, shape=(128, 2), dtype=Float32, requires_grad=True) 256 269 | Parameter (name=dense2.bias, shape=(128,), dtype=Float32, requires_grad=True) 128 270 | Parameter (name=dense3.weight, shape=(1, 128), dtype=Float32, requires_grad=True) 128 271 | Parameter (name=dense3.bias, shape=(1,), dtype=Float32, requires_grad=True) 1 272 | 3074 273 | ``` 274 | 275 | ## Citation 276 | 277 | If you find this repository helpful, please consider citing the following paper: 278 | ``` 279 | Mu Yuan, Lan Zhang, Xuanke You, and Xiang-Yang Li. 2023. PacketGame: Multi-Stream Packet Gating for Concurrent Video Inference at Scale. In ACM SIGCOMM 2023 Conference (ACM SIGCOMM ’23), September 10–14, 2023, New York, NY, USA. ACM, New York, NY, USA, 14 pages. https://doi.org/10.1145/3603269.3604825 280 | ``` 281 | 282 | ## License 283 | 284 | PacketGame is licensed under the [MIT License](./LICENSE). -------------------------------------------------------------------------------- /src/contextual_ms.py: -------------------------------------------------------------------------------- 1 | from mindspore import nn, ops 2 | 3 | 4 | class Conv1dTwoview(nn.Cell): 5 | """ 6 | build two-view neural network 7 | 8 | Args: 9 | inp_len1, inp_len2 (int): input length of two views, 5 by default 10 | conv_units (list of int): number of conv1d units 11 | by default, two conv1d layers with 32 units 12 | """ 13 | def __init__(self, inp_len1=5, inp_len2=5, conv_units=[32, 32]): 14 | super(Conv1dTwoview, self).__init__() 15 | 16 | self.view1_layers = nn.CellList([nn.Conv1d(inp_len1, conv_units[0], 1, has_bias=True, pad_mode='valid'), 17 | nn.ReLU()]) 18 | self.view2_layers = nn.CellList([nn.Conv1d(inp_len2, conv_units[0], 1, has_bias=True, pad_mode='valid'), 19 | nn.ReLU()]) 20 | 21 | for i in range(len(conv_units)-1): 22 | self.view1_layers.append(nn.Conv1d(conv_units[i], conv_units[i+1], 1, has_bias=True, pad_mode='valid')) 23 | self.view1_layers.append(nn.ReLU()) 24 | 25 | self.view2_layers.append(nn.Conv1d(conv_units[i], conv_units[i+1], 1, has_bias=True, pad_mode='valid')) 26 | self.view2_layers.append(nn.ReLU()) 27 | 28 | self.view1_layers.append(nn.AdaptiveMaxPool1d(1)) 29 | self.view2_layers.append(nn.AdaptiveMaxPool1d(1)) 30 | 31 | self.dense = nn.Dense(conv_units[-1]*2, 1, has_bias=True) 32 | self.sigmoid = nn.Sigmoid() 33 | 34 | def construct(self, x1, x2): 35 | for v1_layer, v2_layer in zip(self.view1_layers, self.view2_layers): 36 | x1 = v1_layer(x1) 37 | x2 = v2_layer(x2) 38 | x = ops.cat((x1, x2), axis=1) 39 | x = ops.squeeze(x, axis=-1) 40 | x = self.dense(x) 41 | x = self.sigmoid(x) 42 | return x 43 | 44 | 45 | class EnsembleThreeview(nn.Cell): 46 | """ 47 | build three-view neural network 48 | 49 | Args: 50 | inp_len1, inp_len2, inp_len3 (int): input length of three views, 5 for inp1 & inp2 and 1 for inp3 by default 51 | conv_units (list of int): number of conv1d units 52 | by default, two conv1d layers with 32 units 53 | dense_unit (int): number of dense units, 128 by default 54 | """ 55 | def __init__(self, inp_len1=5, inp_len2=5, inp_len3=1, conv_units=[32, 32], dense_unit=128): 56 | super(EnsembleThreeview, self).__init__() 57 | 58 | self.view1_layers = nn.CellList([nn.Conv1d(inp_len1, conv_units[0], 1, has_bias=True, pad_mode='valid'), 59 | nn.ReLU()]) 60 | self.view2_layers = nn.CellList([nn.Conv1d(inp_len2, conv_units[0], 1, has_bias=True, pad_mode='valid'), 61 | nn.ReLU()]) 62 | 63 | for i in range(len(conv_units)-1): 64 | self.view1_layers.append(nn.Conv1d(conv_units[i], conv_units[i+1], 1, has_bias=True, pad_mode='valid')) 65 | self.view1_layers.append(nn.ReLU()) 66 | 67 | self.view2_layers.append(nn.Conv1d(conv_units[i], conv_units[i+1], 1, has_bias=True, pad_mode='valid')) 68 | self.view2_layers.append(nn.ReLU()) 69 | 70 | self.view1_layers.append(nn.AdaptiveMaxPool1d(1)) 71 | self.view2_layers.append(nn.AdaptiveMaxPool1d(1)) 72 | 73 | self.dense = nn.Dense(conv_units[-1]*2, 1, has_bias=True) 74 | self.sigmoid = nn.Sigmoid() 75 | 76 | self.dense2 = nn.Dense(1+inp_len3, dense_unit, has_bias=True) 77 | self.relu = nn.ReLU() 78 | self.dense3 = nn.Dense(dense_unit, 1, has_bias=True) 79 | 80 | def construct(self, x1, x2, x3): 81 | for v1_layer, v2_layer in zip(self.view1_layers, self.view2_layers): 82 | x1 = v1_layer(x1) 83 | x2 = v2_layer(x2) 84 | x = ops.cat((x1, x2), axis=1) 85 | x = ops.squeeze(x, axis=-1) 86 | x = self.dense(x) 87 | x = self.sigmoid(x) 88 | 89 | x = ops.cat((x, x3), axis=1) 90 | x = self.dense2(x) 91 | x = self.relu(x) 92 | x = self.dense3(x) 93 | x = self.sigmoid(x) 94 | 95 | return x -------------------------------------------------------------------------------- /src/contextual_tf.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import layers 3 | 4 | 5 | def build_conv1d_twoview(inp_len1=5, inp_len2=5, conv_units=[32, 32]): 6 | """ 7 | build two-view neural network 8 | 9 | Args: 10 | inp_len1, inp_len2 (int): input length of two views, 5 by default 11 | conv_units (list of int): number of conv1d units 12 | by default, two conv1d layers with 32 units 13 | Return: 14 | tf.keras.Model instance 15 | """ 16 | inp1 = layers.Input(shape=(None, inp_len1), name="View1-Indepdendent") 17 | inp2 = layers.Input(shape=(None, inp_len2), name="View2-Predicted") 18 | x1 = inp1 19 | x2 = inp2 20 | for u in conv_units: 21 | x1 = layers.Conv1D(u, 1, activation="relu")(x1) 22 | x2 = layers.Conv1D(u, 1, activation="relu")(x2) 23 | x1 = layers.GlobalMaxPooling1D()(x1) 24 | x2 = layers.GlobalMaxPooling1D()(x2) 25 | x = layers.Concatenate()([x1, x2]) 26 | out = layers.Dense(1, activation="sigmoid")(x) 27 | return tf.keras.Model(inputs=[inp1, inp2], outputs=out) 28 | 29 | 30 | def build_ensemble_threeview(inp_len1=5, inp_len2=5, inp_len3=1, conv_units=[32, 32], dense_unit=128): 31 | """ 32 | build three-view neural network 33 | 34 | Args: 35 | inp_len1, inp_len2, inp_len3 (int): input length of three views, 5 for inp1 & inp2 and 1 for inp3 by default 36 | conv_units (list of int): number of conv1d units 37 | by default, two conv1d layers with 32 units 38 | dense_unit (int): number of dense units, 128 by default 39 | Return: 40 | tf.keras.Model instance 41 | """ 42 | inp1 = layers.Input(shape=(None, inp_len1), name="View1-Indepdendent") 43 | inp2 = layers.Input(shape=(None, inp_len2), name="View2-Predicted") 44 | inp3 = layers.Input(shape=(inp_len3), name="View3-Temporal") 45 | x1 = inp1 46 | x2 = inp2 47 | x3 = inp3 48 | for u in conv_units: 49 | x1 = layers.Conv1D(u, 1, activation="relu")(x1) 50 | x2 = layers.Conv1D(u, 1, activation="relu")(x2) 51 | x1 = layers.GlobalMaxPooling1D()(x1) 52 | x2 = layers.GlobalMaxPooling1D()(x2) 53 | x = layers.Concatenate()([x1, x2]) 54 | out = layers.Dense(1, activation="sigmoid")(x) 55 | x4 = layers.Concatenate()([out, x3]) 56 | x4 = layers.Dense(dense_unit, activation="relu")(x4) 57 | out2 = layers.Dense(1, activation='sigmoid')(x4) 58 | return tf.keras.Model(inputs=[inp1, inp2, inp3], outputs=out2) -------------------------------------------------------------------------------- /src/parser.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | parser.cpp 3 | parse packets from a video 4 | */ 5 | #include 6 | #include 7 | extern "C"{ 8 | #include 9 | } 10 | #include 11 | #define INBUF_SIZE 4096 12 | 13 | int main(int argc, char **argv){ 14 | const char *filename, *outfilename; 15 | const AVCodec *codec; 16 | AVCodecParserContext *parser; 17 | AVCodecContext *c = NULL; 18 | FILE *f; 19 | uint8_t inbuf[INBUF_SIZE + AV_INPUT_BUFFER_PADDING_SIZE]; 20 | uint8_t *data; 21 | size_t data_size; 22 | int ret; 23 | int eof; 24 | AVPacket *pkt; 25 | FILE *out_file; 26 | int pkt_count = 0; 27 | clock_t start_t, end_t; 28 | double time_cost = 0.; 29 | 30 | filename = argv[1]; 31 | outfilename = argv[2]; 32 | 33 | start_t = clock(); 34 | 35 | out_file = fopen(outfilename, "w"); 36 | fprintf(out_file, "pkt_size,pic_type\n"); 37 | 38 | pkt = av_packet_alloc(); 39 | if(!pkt){ 40 | exit(1); 41 | } 42 | memset(inbuf+INBUF_SIZE, 0, AV_INPUT_BUFFER_PADDING_SIZE); 43 | 44 | codec = avcodec_find_decoder(AV_CODEC_ID_H265); 45 | if(!codec){ 46 | fprintf(stderr, "Codec not found\n"); 47 | exit(1); 48 | } 49 | 50 | parser = av_parser_init(codec->id); 51 | if (!parser) { 52 | fprintf(stderr, "parser not found\n"); 53 | exit(1); 54 | } 55 | 56 | c = avcodec_alloc_context3(codec); 57 | if (!c) { 58 | fprintf(stderr, "Could not allocate video codec context\n"); 59 | exit(1); 60 | } 61 | if (avcodec_open2(c, codec, NULL) < 0) { 62 | fprintf(stderr, "Could not open codec\n"); 63 | exit(1); 64 | } 65 | 66 | f = fopen(filename, "rb"); 67 | if (!f) { 68 | fprintf(stderr, "Could not open %s\n", filename); 69 | exit(1); 70 | } 71 | 72 | do{ 73 | data_size = fread(inbuf, 1, INBUF_SIZE, f); 74 | if(ferror(f)){ 75 | break; 76 | } 77 | eof = !data_size; 78 | data = inbuf; 79 | while(data_size > 0 || eof){ 80 | ret = av_parser_parse2(parser, c, &pkt->data, &pkt->size, data, data_size, AV_NOPTS_VALUE, AV_NOPTS_VALUE, 0); 81 | if(ret < 0){ 82 | fprintf(stderr, "Error while parsing\n"); 83 | exit(1); 84 | } 85 | data += ret; 86 | data_size -= ret; 87 | if(pkt->size){ 88 | fprintf(out_file, "%d,%d\n", pkt->size, parser->pict_type); 89 | pkt_count += 1; 90 | } 91 | else if(eof){ 92 | break; 93 | } 94 | } 95 | } while(!eof); 96 | fclose(f); 97 | av_parser_close(parser); 98 | avcodec_free_context(&c); 99 | av_packet_free(&pkt); 100 | 101 | end_t = clock(); 102 | time_cost = (double)(end_t - start_t) / CLOCKS_PER_SEC; 103 | printf("%d packets parsed in %f seconds.\n", pkt_count, time_cost); 104 | 105 | return 0; 106 | } -------------------------------------------------------------------------------- /test/concurrent_decode.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 4 | import time 5 | from tabulate import tabulate 6 | 7 | USEGPU = 0 8 | 9 | cmd_list = [] 10 | 11 | vid_dir = "video_h1/" 12 | for vid in os.listdir(vid_dir): 13 | if vid.endswith(".h265"): 14 | vid_path = os.path.join(vid_dir, vid) 15 | # reference: https://trac.ffmpeg.org/wiki/Null#:~:text=ffmpeg%20%2Di%20input%20%2Df%20null%20%2D 16 | if USEGPU: 17 | cmd_list.append(f"ffmpeg -hwaccel cuda -i {vid_path} -f null -".split()) 18 | else: 19 | cmd_list.append(f"ffmpeg -i {vid_path} -f null -".split()) 20 | 21 | concur_levels = [10, 50] 22 | outputs = [] 23 | for concur in concur_levels: 24 | 25 | processes = [] 26 | for cmd in cmd_list[:concur]: 27 | p = subprocess.Popen(cmd) 28 | processes.append(p) 29 | 30 | start_t = time.time() 31 | for p in processes: 32 | p.wait() 33 | t = time.time() - start_t 34 | fps = 250*concur / t 35 | 36 | outputs.append([concur, t, fps]) 37 | 38 | print(tabulate(outputs, headers=["concurrency", "time cost (s)", "fps"])) -------------------------------------------------------------------------------- /test/contextual_ms.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "sys.path.append(\"../src\")\n", 11 | "import mindspore\n", 12 | "import numpy as np\n", 13 | "from contextual_ms import Conv1dTwoview, EnsembleThreeview" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [ 21 | { 22 | "name": "stdout", 23 | "output_type": "stream", 24 | "text": [ 25 | "Conv1dTwoview<\n", 26 | " (view1_layers): CellList<\n", 27 | " (0): Conv1d\n", 28 | " (1): ReLU<>\n", 29 | " (2): Conv1d\n", 30 | " (3): ReLU<>\n", 31 | " (4): AdaptiveMaxPool1d<>\n", 32 | " >\n", 33 | " (view2_layers): CellList<\n", 34 | " (0): Conv1d\n", 35 | " (1): ReLU<>\n", 36 | " (2): Conv1d\n", 37 | " (3): ReLU<>\n", 38 | " (4): AdaptiveMaxPool1d<>\n", 39 | " >\n", 40 | " (dense): Dense\n", 41 | " (sigmoid): Sigmoid<>\n", 42 | " >\n" 43 | ] 44 | }, 45 | { 46 | "data": { 47 | "text/plain": [ 48 | "(1, 1)" 49 | ] 50 | }, 51 | "execution_count": 2, 52 | "metadata": {}, 53 | "output_type": "execute_result" 54 | } 55 | ], 56 | "source": [ 57 | "net = Conv1dTwoview(inp_len1=5, inp_len2=5, conv_units=[32, 32])\n", 58 | "print(net)\n", 59 | "x1 = mindspore.Tensor(np.ones([1, 5, 1]), mindspore.float32)\n", 60 | "x2 = mindspore.Tensor(np.ones([1, 5, 1]), mindspore.float32)\n", 61 | "net(x1, x2).shape" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 3, 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "name": "stdout", 71 | "output_type": "stream", 72 | "text": [ 73 | "EnsembleThreeview<\n", 74 | " (view1_layers): CellList<\n", 75 | " (0): Conv1d\n", 76 | " (1): ReLU<>\n", 77 | " (2): Conv1d\n", 78 | " (3): ReLU<>\n", 79 | " (4): AdaptiveMaxPool1d<>\n", 80 | " >\n", 81 | " (view2_layers): CellList<\n", 82 | " (0): Conv1d\n", 83 | " (1): ReLU<>\n", 84 | " (2): Conv1d\n", 85 | " (3): ReLU<>\n", 86 | " (4): AdaptiveMaxPool1d<>\n", 87 | " >\n", 88 | " (dense): Dense\n", 89 | " (sigmoid): Sigmoid<>\n", 90 | " (dense2): Dense\n", 91 | " (relu): ReLU<>\n", 92 | " (dense3): Dense\n", 93 | " >\n" 94 | ] 95 | }, 96 | { 97 | "data": { 98 | "text/plain": [ 99 | "(1, 1)" 100 | ] 101 | }, 102 | "execution_count": 3, 103 | "metadata": {}, 104 | "output_type": "execute_result" 105 | } 106 | ], 107 | "source": [ 108 | "net = EnsembleThreeview(inp_len1=5, inp_len2=5, inp_len3=1, conv_units=[32, 32], dense_unit=128)\n", 109 | "print(net)\n", 110 | "x1 = mindspore.Tensor(np.ones([1, 5, 1]), mindspore.float32)\n", 111 | "x2 = mindspore.Tensor(np.ones([1, 5, 1]), mindspore.float32)\n", 112 | "x3 = mindspore.Tensor(np.ones([1, 1]), mindspore.float32)\n", 113 | "net(x1, x2, x3).shape" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 12, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "Parameter (name=view1_layers.0.weight, shape=(32, 5, 1, 1), dtype=Float32, requires_grad=True) 160\n", 126 | "Parameter (name=view1_layers.0.bias, shape=(32,), dtype=Float32, requires_grad=True) 32\n", 127 | "Parameter (name=view1_layers.2.weight, shape=(32, 32, 1, 1), dtype=Float32, requires_grad=True) 1024\n", 128 | "Parameter (name=view1_layers.2.bias, shape=(32,), dtype=Float32, requires_grad=True) 32\n", 129 | "Parameter (name=view2_layers.0.weight, shape=(32, 5, 1, 1), dtype=Float32, requires_grad=True) 160\n", 130 | "Parameter (name=view2_layers.0.bias, shape=(32,), dtype=Float32, requires_grad=True) 32\n", 131 | "Parameter (name=view2_layers.2.weight, shape=(32, 32, 1, 1), dtype=Float32, requires_grad=True) 1024\n", 132 | "Parameter (name=view2_layers.2.bias, shape=(32,), dtype=Float32, requires_grad=True) 32\n", 133 | "Parameter (name=dense.weight, shape=(1, 64), dtype=Float32, requires_grad=True) 64\n", 134 | "Parameter (name=dense.bias, shape=(1,), dtype=Float32, requires_grad=True) 1\n", 135 | "Parameter (name=dense2.weight, shape=(128, 2), dtype=Float32, requires_grad=True) 256\n", 136 | "Parameter (name=dense2.bias, shape=(128,), dtype=Float32, requires_grad=True) 128\n", 137 | "Parameter (name=dense3.weight, shape=(1, 128), dtype=Float32, requires_grad=True) 128\n", 138 | "Parameter (name=dense3.bias, shape=(1,), dtype=Float32, requires_grad=True) 1\n" 139 | ] 140 | }, 141 | { 142 | "data": { 143 | "text/plain": [ 144 | "3074" 145 | ] 146 | }, 147 | "execution_count": 12, 148 | "metadata": {}, 149 | "output_type": "execute_result" 150 | } 151 | ], 152 | "source": [ 153 | "total_params = 0\n", 154 | "for v in net.parameters_dict().values():\n", 155 | " print(v, v.size)\n", 156 | " total_params += v.size\n", 157 | "total_params" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [] 166 | } 167 | ], 168 | "metadata": { 169 | "kernelspec": { 170 | "display_name": "wfilter", 171 | "language": "python", 172 | "name": "wfilter" 173 | }, 174 | "language_info": { 175 | "codemirror_mode": { 176 | "name": "ipython", 177 | "version": 3 178 | }, 179 | "file_extension": ".py", 180 | "mimetype": "text/x-python", 181 | "name": "python", 182 | "nbconvert_exporter": "python", 183 | "pygments_lexer": "ipython3", 184 | "version": "3.7.10" 185 | } 186 | }, 187 | "nbformat": 4, 188 | "nbformat_minor": 4 189 | } 190 | -------------------------------------------------------------------------------- /test/contextual_tf.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "sys.path.append(\"../src\")\n", 11 | "from contextual_tf import build_conv1d_twoview, build_ensemble_threeview" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stdout", 21 | "output_type": "stream", 22 | "text": [ 23 | "Model: \"model\"\n", 24 | "__________________________________________________________________________________________________\n", 25 | "Layer (type) Output Shape Param # Connected to \n", 26 | "==================================================================================================\n", 27 | "View1-Indepdendent (InputLayer) [(None, None, 5)] 0 \n", 28 | "__________________________________________________________________________________________________\n", 29 | "View2-Predicted (InputLayer) [(None, None, 5)] 0 \n", 30 | "__________________________________________________________________________________________________\n", 31 | "conv1d (Conv1D) (None, None, 32) 192 View1-Indepdendent[0][0] \n", 32 | "__________________________________________________________________________________________________\n", 33 | "conv1d_1 (Conv1D) (None, None, 32) 192 View2-Predicted[0][0] \n", 34 | "__________________________________________________________________________________________________\n", 35 | "conv1d_2 (Conv1D) (None, None, 32) 1056 conv1d[0][0] \n", 36 | "__________________________________________________________________________________________________\n", 37 | "conv1d_3 (Conv1D) (None, None, 32) 1056 conv1d_1[0][0] \n", 38 | "__________________________________________________________________________________________________\n", 39 | "global_max_pooling1d (GlobalMax (None, 32) 0 conv1d_2[0][0] \n", 40 | "__________________________________________________________________________________________________\n", 41 | "global_max_pooling1d_1 (GlobalM (None, 32) 0 conv1d_3[0][0] \n", 42 | "__________________________________________________________________________________________________\n", 43 | "concatenate (Concatenate) (None, 64) 0 global_max_pooling1d[0][0] \n", 44 | " global_max_pooling1d_1[0][0] \n", 45 | "__________________________________________________________________________________________________\n", 46 | "dense (Dense) (None, 1) 65 concatenate[0][0] \n", 47 | "==================================================================================================\n", 48 | "Total params: 2,561\n", 49 | "Trainable params: 2,561\n", 50 | "Non-trainable params: 0\n", 51 | "__________________________________________________________________________________________________\n" 52 | ] 53 | } 54 | ], 55 | "source": [ 56 | "m = build_conv1d_twoview(inp_len1=5, inp_len2=5, conv_units=[32, 32])\n", 57 | "m.summary()" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 3, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "name": "stdout", 67 | "output_type": "stream", 68 | "text": [ 69 | "Model: \"model_1\"\n", 70 | "__________________________________________________________________________________________________\n", 71 | "Layer (type) Output Shape Param # Connected to \n", 72 | "==================================================================================================\n", 73 | "View1-Indepdendent (InputLayer) [(None, None, 5)] 0 \n", 74 | "__________________________________________________________________________________________________\n", 75 | "View2-Predicted (InputLayer) [(None, None, 5)] 0 \n", 76 | "__________________________________________________________________________________________________\n", 77 | "conv1d_4 (Conv1D) (None, None, 32) 192 View1-Indepdendent[0][0] \n", 78 | "__________________________________________________________________________________________________\n", 79 | "conv1d_5 (Conv1D) (None, None, 32) 192 View2-Predicted[0][0] \n", 80 | "__________________________________________________________________________________________________\n", 81 | "conv1d_6 (Conv1D) (None, None, 32) 1056 conv1d_4[0][0] \n", 82 | "__________________________________________________________________________________________________\n", 83 | "conv1d_7 (Conv1D) (None, None, 32) 1056 conv1d_5[0][0] \n", 84 | "__________________________________________________________________________________________________\n", 85 | "global_max_pooling1d_2 (GlobalM (None, 32) 0 conv1d_6[0][0] \n", 86 | "__________________________________________________________________________________________________\n", 87 | "global_max_pooling1d_3 (GlobalM (None, 32) 0 conv1d_7[0][0] \n", 88 | "__________________________________________________________________________________________________\n", 89 | "concatenate_1 (Concatenate) (None, 64) 0 global_max_pooling1d_2[0][0] \n", 90 | " global_max_pooling1d_3[0][0] \n", 91 | "__________________________________________________________________________________________________\n", 92 | "dense_1 (Dense) (None, 1) 65 concatenate_1[0][0] \n", 93 | "__________________________________________________________________________________________________\n", 94 | "View3-Temporal (InputLayer) [(None, 1)] 0 \n", 95 | "__________________________________________________________________________________________________\n", 96 | "concatenate_2 (Concatenate) (None, 2) 0 dense_1[0][0] \n", 97 | " View3-Temporal[0][0] \n", 98 | "__________________________________________________________________________________________________\n", 99 | "dense_2 (Dense) (None, 128) 384 concatenate_2[0][0] \n", 100 | "__________________________________________________________________________________________________\n", 101 | "dense_3 (Dense) (None, 1) 129 dense_2[0][0] \n", 102 | "==================================================================================================\n", 103 | "Total params: 3,074\n", 104 | "Trainable params: 3,074\n", 105 | "Non-trainable params: 0\n", 106 | "__________________________________________________________________________________________________\n" 107 | ] 108 | } 109 | ], 110 | "source": [ 111 | "m = build_ensemble_threeview(inp_len1=5, inp_len2=5, inp_len3=1, conv_units=[32, 32], dense_unit=128)\n", 112 | "m.summary()" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [] 121 | } 122 | ], 123 | "metadata": { 124 | "kernelspec": { 125 | "display_name": "wfilter", 126 | "language": "python", 127 | "name": "wfilter" 128 | }, 129 | "language_info": { 130 | "codemirror_mode": { 131 | "name": "ipython", 132 | "version": 3 133 | }, 134 | "file_extension": ".py", 135 | "mimetype": "text/x-python", 136 | "name": "python", 137 | "nbconvert_exporter": "python", 138 | "pygments_lexer": "ipython3", 139 | "version": "3.7.10" 140 | } 141 | }, 142 | "nbformat": 4, 143 | "nbformat_minor": 4 144 | } 145 | --------------------------------------------------------------------------------