├── README.md ├── feature_extract.ipynb ├── frame_extract.ipynb └── video_retrieval.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # CCF-BDCI 视频拷贝检测 2 | 3 | ## 当前思路 4 | 1. 提取视频关键帧; 5 | 2. 通过resnet18提取关键帧特征; 6 | 3. 对特征进行~~PCA降维~~(失败中)和L2正则化; 7 | 4. 所有视频两两计算得相似度矩阵(余弦相似度); 8 | 5. 对于相似度top-K视频对,进行帧级匹配(按相似度建图,跑最长路)。 9 | 10 | ## 一些经验 11 | 1. 特征不宜过细,采用resnet50提取特征的效果比resnet18差10~20个点; 12 | 2. 当前算法对参数比较敏感,目前取相似度前K=20视频进行帧级匹配,帧级匹配阶段,帧间相似度阈值0.85,最大跨度为10帧; 13 | 3. 主要瓶颈在于视频级匹配,只要目标视频落入Top-K视频,基本可以得到正确的帧匹配; 14 | 4. query与refer抽帧密度接近可能较好,也可能是抽帧不易过密。进行了query一秒五帧,refer一秒一帧与它们都一秒一帧两组测试,结果一秒一帧不仅运行速度快,而且得分大大高于另一组。 15 | 16 | ## TODO 17 | 18 | 1. 细粒度抽帧(当前1s抽一帧,感觉已经足够了); 19 | 2. 代码重构(还差video_retrieval); 20 | 3. 继续case analysis(不同视频,相同位置、角度与表情的大妈和男生的相似度竟然有85%,特征提取要继续研究)。 21 | 22 | ## 当前成绩 23 | 24 | * 初赛(最大误差5s)得分:F1-score = 0.60054720,排名:第10名 25 | * 复赛(最大误差3s)得分:F1-score = 0.47160494,排名:第8名 26 | 27 | ## 参考 28 | 29 | [1] Tan H K , Ngo C W , Hong R , et al. Scalable detection of partial near-duplicate videos by visual-temporal consistency[C]// Proceedings of the 17th International Conference on Multimedia 2009, Vancouver, British Columbia, Canada, October 19-24, 2009. ACM, 2009. 30 | 31 | [2] Jiang Y G , Jiang Y , Wang J . VCDB: A Large-Scale Database for Partial Copy Detection in Videos[M]// Computer Vision – ECCV 2014. Springer International Publishing, 2014. 32 | 33 | [3] 顾佳伟, 赵瑞玮, 姜育刚. 视频拷贝检测方法综述[J]. 计算机研究与发展, 2017(6). 34 | 35 | [4] 刘红. 一种基于图的近重复视频子序列匹配算法[J]. 计算机应用研究, 2013(12):343-348. 36 | 37 | [5] FFmpeg视频抽帧那些事 - 阿水的文章 - 知乎 https://zhuanlan.zhihu.com/p/85895180 -------------------------------------------------------------------------------- /feature_extract.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 特征提取" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 31, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import os, sys, codecs\n", 17 | "import glob\n", 18 | "import pandas as pd\n", 19 | "import numpy as np\n", 20 | "import pickle\n", 21 | "from PIL import Image\n", 22 | "from tqdm.notebook import tqdm\n", 23 | "\n", 24 | "import cv2\n", 25 | "\n", 26 | "from sklearn.preprocessing import normalize as sknormalize\n", 27 | "from sklearn.decomposition import PCA\n", 28 | "\n", 29 | "import torch\n", 30 | "torch.manual_seed(0)\n", 31 | "torch.backends.cudnn.deterministic = True\n", 32 | "torch.backends.cudnn.benchmark = False\n", 33 | "\n", 34 | "import torchvision.models as models\n", 35 | "import torchvision.transforms as transforms\n", 36 | "import torchvision.datasets as datasets\n", 37 | "import torch.nn as nn\n", 38 | "import torch.nn.functional as F\n", 39 | "import torch.optim as optim\n", 40 | "from torch.autograd import Variable\n", 41 | "from torch.utils.data.dataset import Dataset\n", 42 | "\n", 43 | "PATH = '/home/wx/work/video_copy_detection/'\n", 44 | "TRAIN_PATH = PATH + 'train/'\n", 45 | "TEST_PATH = PATH + 'test/'\n", 46 | "TRAIN_QUERY_PATH = TRAIN_PATH + 'query/'\n", 47 | "REFER_PATH = TRAIN_PATH + 'refer/'\n", 48 | "TRAIN_QUERY_FRAME_PATH = TRAIN_PATH + 'query_uniformframe/'\n", 49 | "REFER_FRAME_PATH = TRAIN_PATH + 'refer_uniformframe/'\n", 50 | "TEST_QUERY_PATH = TEST_PATH + 'query2/'\n", 51 | "TEST_QUERY_FRAME_PATH = TEST_PATH + 'query2_uniformframe/'\n", 52 | "CODE_DIR = PATH + 'code/'" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 40, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "class QRDataset(Dataset):\n", 62 | " def __init__(self, img_path, transform = None):\n", 63 | " self.img_path = img_path\n", 64 | "\n", 65 | " self.img_label = np.zeros(len(img_path))\n", 66 | " \n", 67 | " if transform is not None:\n", 68 | " self.transform = transform\n", 69 | " else:\n", 70 | " self.transform = None\n", 71 | " \n", 72 | " def __getitem__(self, index):\n", 73 | " img = Image.open(self.img_path[index])\n", 74 | " \n", 75 | " if self.transform is not None:\n", 76 | " img = self.transform(img)\n", 77 | " \n", 78 | " return img, self.img_path[index]\n", 79 | "\n", 80 | " def __len__(self):\n", 81 | " return len(self.img_path)\n", 82 | "\n", 83 | "class Img2Vec():\n", 84 | "\n", 85 | " def __init__(self, model='resnet-18', layer='default', layer_output_size=512):\n", 86 | " \"\"\" Img2Vec\n", 87 | " :param model: String name of requested model\n", 88 | " :param layer: String or Int depending on model.\n", 89 | " :param layer_output_size: Int depicting the output size of the requested layer\n", 90 | " \"\"\"\n", 91 | " self.device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 92 | " self.layer_output_size = layer_output_size\n", 93 | " self.model_name = model\n", 94 | " \n", 95 | " self.model, self.extraction_layer = self._get_model_and_layer(model, layer)\n", 96 | "\n", 97 | " self.model = self.model.to(self.device)\n", 98 | "\n", 99 | " self.model.eval()\n", 100 | "\n", 101 | " self.transformer = transforms.Compose([\n", 102 | " transforms.Resize((224, 224)), \n", 103 | " transforms.ToTensor(),\n", 104 | " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", 105 | " ])\n", 106 | "\n", 107 | " def get_vec(self, path):\n", 108 | " \"\"\" Get vector embedding from PIL image\n", 109 | " :param path: Path of image dataset\n", 110 | " :returns: Numpy ndarray\n", 111 | " \"\"\"\n", 112 | " if not isinstance(path, list):\n", 113 | " path = [path]\n", 114 | "\n", 115 | " data_loader = torch.utils.data.DataLoader(QRDataset(path, self.transformer), batch_size = 40, \n", 116 | " shuffle = False, num_workers = 16)\n", 117 | "\n", 118 | " my_embedding = []\n", 119 | "\n", 120 | " # hook function\n", 121 | " def append_data(module, input, output):\n", 122 | " my_embedding.append(output.clone().detach().cpu().numpy())\n", 123 | " \n", 124 | " with torch.no_grad():\n", 125 | " for batch_data in tqdm(data_loader):\n", 126 | " batch_x, batch_y = batch_data\n", 127 | " if torch.cuda.is_available():\n", 128 | " batch_x = Variable(batch_x, requires_grad = False).cuda()\n", 129 | " else:\n", 130 | " batch_x = Variable(batch_x, requires_grad = False)\n", 131 | "\n", 132 | " h = self.extraction_layer.register_forward_hook(append_data)\n", 133 | " h_x = self.model(batch_x)\n", 134 | " h.remove()\n", 135 | " del h_x\n", 136 | "\n", 137 | " my_embedding = np.vstack(my_embedding)\n", 138 | " if self.model_name == 'alexnet':\n", 139 | " return my_embedding[:, :]\n", 140 | " else:\n", 141 | " return my_embedding[:, :, 0, 0]\n", 142 | "\n", 143 | " def _get_model_and_layer(self, model_name, layer):\n", 144 | " \"\"\" Internal method for getting layer from model\n", 145 | " :param model_name: model name such as 'resnet-18'\n", 146 | " :param layer: layer as a string for resnet-18 or int for alexnet\n", 147 | " :returns: pytorch model, selected layer\n", 148 | " \"\"\"\n", 149 | " if model_name == 'resnet-18':\n", 150 | " model = models.resnet18(pretrained=True)\n", 151 | " if layer == 'default':\n", 152 | " layer = model._modules.get('avgpool')\n", 153 | " self.layer_output_size = 512\n", 154 | " else:\n", 155 | " layer = model._modules.get(layer)\n", 156 | "\n", 157 | " return model, layer\n", 158 | " \n", 159 | " elif model_name == 'resnet-50':\n", 160 | " model = models.resnet50(pretrained=True)\n", 161 | " if layer == 'default':\n", 162 | " layer = model._modules.get('avgpool')\n", 163 | " self.layer_output_size = 2048\n", 164 | " else:\n", 165 | " layer = model._modules.get(layer)\n", 166 | "\n", 167 | " return model, layer\n", 168 | " \n", 169 | " elif model_name == 'alexnet':\n", 170 | " model = models.alexnet(pretrained=True)\n", 171 | " if layer == 'default':\n", 172 | " layer = model.classifier[-2]\n", 173 | " self.layer_output_size = 4096\n", 174 | " else:\n", 175 | " layer = model.classifier[-layer]\n", 176 | "\n", 177 | " return model, layer\n", 178 | "\n", 179 | " else:\n", 180 | " raise KeyError('Model %s was not found' % model_name)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 33, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "# 读取 test_query 视频的帧,并按照视频和帧时间进行排序\n", 190 | "test_query_imgs_path = []\n", 191 | "for id in pd.read_csv(TEST_PATH + 'submit_example2.csv')['query_id']:\n", 192 | " test_query_imgs_path += glob.glob(TEST_QUERY_FRAME_PATH + id + '/*.jpg')\n", 193 | "\n", 194 | "test_query_imgs_path.sort(key = lambda x: x.lower())" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 34, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "# 读取 train_query 视频的帧,并按照视频和帧时间进行排序\n", 204 | "train_query_imgs_path = []\n", 205 | "for id in pd.read_csv(TRAIN_PATH + 'train.csv')['query_id']:\n", 206 | " train_query_imgs_path += glob.glob(TRAIN_QUERY_FRAME_PATH + id + '/*.jpg')\n", 207 | "\n", 208 | "train_query_imgs_path.sort(key = lambda x: x.lower())" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 35, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "# 读取 refer 视频的帧,并按照视频和帧时间进行排序\n", 218 | "\n", 219 | "refer_imgs_path = glob.glob(REFER_FRAME_PATH + '*/*.jpg')\n", 220 | "refer_imgs_path.sort(key = lambda x: x.lower())" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 41, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "# Initialize Img2Vec\n", 230 | "img2vec = Img2Vec(model='resnet-50')" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 42, 236 | "metadata": {}, 237 | "outputs": [ 238 | { 239 | "data": { 240 | "application/vnd.jupyter.widget-view+json": { 241 | "model_id": "fc58d2d3a68e4e1fbd55d5a2847ebf35", 242 | "version_major": 2, 243 | "version_minor": 0 244 | }, 245 | "text/plain": [ 246 | "HBox(children=(IntProgress(value=0, max=5045), HTML(value='')))" 247 | ] 248 | }, 249 | "metadata": {}, 250 | "output_type": "display_data" 251 | }, 252 | { 253 | "name": "stdout", 254 | "output_type": "stream", 255 | "text": [ 256 | "\n" 257 | ] 258 | } 259 | ], 260 | "source": [ 261 | "# 抽取 test_query 关键帧特征\n", 262 | "test_query_features = img2vec.get_vec(test_query_imgs_path[:])" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": 43, 268 | "metadata": {}, 269 | "outputs": [ 270 | { 271 | "data": { 272 | "application/vnd.jupyter.widget-view+json": { 273 | "model_id": "6ffab0465c67426ba3eba041bff05e8b", 274 | "version_major": 2, 275 | "version_minor": 0 276 | }, 277 | "text/plain": [ 278 | "HBox(children=(IntProgress(value=0, max=9393), HTML(value='')))" 279 | ] 280 | }, 281 | "metadata": {}, 282 | "output_type": "display_data" 283 | }, 284 | { 285 | "name": "stderr", 286 | "output_type": "stream", 287 | "text": [ 288 | "IOPub message rate exceeded.\n", 289 | "The notebook server will temporarily stop sending output\n", 290 | "to the client in order to avoid crashing it.\n", 291 | "To change this limit, set the config variable\n", 292 | "`--NotebookApp.iopub_msg_rate_limit`.\n", 293 | "\n", 294 | "Current values:\n", 295 | "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", 296 | "NotebookApp.rate_limit_window=3.0 (secs)\n", 297 | "\n" 298 | ] 299 | }, 300 | { 301 | "name": "stdout", 302 | "output_type": "stream", 303 | "text": [ 304 | "\n" 305 | ] 306 | } 307 | ], 308 | "source": [ 309 | "# 抽取 train_query 关键帧特征\n", 310 | "train_query_features = img2vec.get_vec(train_query_imgs_path[:])" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 44, 316 | "metadata": {}, 317 | "outputs": [ 318 | { 319 | "data": { 320 | "application/vnd.jupyter.widget-view+json": { 321 | "model_id": "853479c94710467284251203f760c30a", 322 | "version_major": 2, 323 | "version_minor": 0 324 | }, 325 | "text/plain": [ 326 | "HBox(children=(IntProgress(value=0, max=18738), HTML(value='')))" 327 | ] 328 | }, 329 | "metadata": {}, 330 | "output_type": "display_data" 331 | }, 332 | { 333 | "name": "stderr", 334 | "output_type": "stream", 335 | "text": [ 336 | "IOPub message rate exceeded.\n", 337 | "The notebook server will temporarily stop sending output\n", 338 | "to the client in order to avoid crashing it.\n", 339 | "To change this limit, set the config variable\n", 340 | "`--NotebookApp.iopub_msg_rate_limit`.\n", 341 | "\n", 342 | "Current values:\n", 343 | "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", 344 | "NotebookApp.rate_limit_window=3.0 (secs)\n", 345 | "\n", 346 | "IOPub message rate exceeded.\n", 347 | "The notebook server will temporarily stop sending output\n", 348 | "to the client in order to avoid crashing it.\n", 349 | "To change this limit, set the config variable\n", 350 | "`--NotebookApp.iopub_msg_rate_limit`.\n", 351 | "\n", 352 | "Current values:\n", 353 | "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", 354 | "NotebookApp.rate_limit_window=3.0 (secs)\n", 355 | "\n" 356 | ] 357 | } 358 | ], 359 | "source": [ 360 | "# 抽取 refer 关键帧特征\n", 361 | "refer_features = img2vec.get_vec(list(refer_imgs_path[:]))" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": 45, 367 | "metadata": {}, 368 | "outputs": [], 369 | "source": [ 370 | "def normalize(x, copy = False):\n", 371 | " \"\"\"\n", 372 | " A helper function that wraps the function of the same name in sklearn.\n", 373 | " This helper handles the case of a single column vector.\n", 374 | " \"\"\"\n", 375 | " if type(x) == np.ndarray and len(x.shape) == 1:\n", 376 | " return np.squeeze(sknormalize(x.reshape(1, -1), copy = copy))\n", 377 | " #return np.squeeze(x / np.sqrt((x ** 2).sum(-1))[..., np.newaxis])\n", 378 | " else:\n", 379 | " return sknormalize(x, copy = copy)\n", 380 | " #return x / np.sqrt((x ** 2).sum(-1))[..., np.newaxis]" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": 46, 386 | "metadata": {}, 387 | "outputs": [ 388 | { 389 | "data": { 390 | "text/plain": [ 391 | "'\\npca = PCA(n_components=512)\\n\\ntrain_query_features = pca.fit_transform(train_query_features)\\ntest_query_features = pca.fit_transform(test_query_features)\\nrefer_features = pca.fit_transform(refer_features)\\n'" 392 | ] 393 | }, 394 | "execution_count": 46, 395 | "metadata": {}, 396 | "output_type": "execute_result" 397 | } 398 | ], 399 | "source": [ 400 | "# PCA 降维\n", 401 | "'''\n", 402 | "pca = PCA(n_components=512)\n", 403 | "\n", 404 | "train_query_features = pca.fit_transform(train_query_features)\n", 405 | "test_query_features = pca.fit_transform(test_query_features)\n", 406 | "refer_features = pca.fit_transform(refer_features)\n", 407 | "'''\n" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": 47, 413 | "metadata": {}, 414 | "outputs": [], 415 | "source": [ 416 | "# L2正则化\n", 417 | "train_query_features = normalize(train_query_features)\n", 418 | "test_query_features = normalize(test_query_features)\n", 419 | "refer_features = normalize(refer_features)" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": 48, 425 | "metadata": {}, 426 | "outputs": [], 427 | "source": [ 428 | "# 保存 test_query 帧特征\n", 429 | "\n", 430 | "with open(PATH + 'var/test_query_features_res50_uni.pk', 'wb') as pk_file:\n", 431 | " pickle.dump(test_query_features, pk_file)" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": 49, 437 | "metadata": {}, 438 | "outputs": [], 439 | "source": [ 440 | "# 保存 train_query 帧特征\n", 441 | "\n", 442 | "with open(PATH + 'var/train_query_features_res50_uni.pk', 'wb') as pk_file:\n", 443 | " pickle.dump(train_query_features, pk_file)" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": 51, 449 | "metadata": {}, 450 | "outputs": [], 451 | "source": [ 452 | "# 保存 refer 帧特征\n", 453 | "\n", 454 | "with open(PATH + 'var/refer_features_res50_uni.pk', 'wb') as pk_file:\n", 455 | " pickle.dump(refer_features, pk_file, protocol = 4)" 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": null, 461 | "metadata": {}, 462 | "outputs": [], 463 | "source": [] 464 | } 465 | ], 466 | "metadata": { 467 | "file_extension": ".py", 468 | "kernelspec": { 469 | "display_name": "Python 3", 470 | "language": "python", 471 | "name": "python3" 472 | }, 473 | "language_info": { 474 | "codemirror_mode": { 475 | "name": "ipython", 476 | "version": 3 477 | }, 478 | "file_extension": ".py", 479 | "mimetype": "text/x-python", 480 | "name": "python", 481 | "nbconvert_exporter": "python", 482 | "pygments_lexer": "ipython3", 483 | "version": "3.7.4" 484 | }, 485 | "mimetype": "text/x-python", 486 | "name": "python", 487 | "npconvert_exporter": "python", 488 | "pygments_lexer": "ipython3", 489 | "version": 3 490 | }, 491 | "nbformat": 4, 492 | "nbformat_minor": 4 493 | } 494 | -------------------------------------------------------------------------------- /frame_extract.ipynb: -------------------------------------------------------------------------------- 1 | {"cells":[{"cell_type":"markdown","execution_count":null,"metadata":{},"outputs":[],"source":["## 关键帧提取"]},{"cell_type":"code","execution_count":2,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":"Populating the interactive namespace from numpy and matplotlib\n"}],"source":["import os\n","import sys\n","import glob\n","import shutil\n","import codecs\n","from tqdm import tqdm_notebook as tqdm\n","\n","import pandas as pd\n","import numpy as np\n","import time\n","from multiprocessing import Pool\n","\n","%pylab inline\n","from PIL import Image"]},{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[],"source":["PATH = '/home/wx/work/video_copy_detection/'"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[],"source":["# 抽取关键帧\n","class FrameExtractor():\n"," # key uniform scene\n"," def __init__(self, path=PATH):\n"," self.train_path = path + 'train/'\n"," self.test_path = path + 'test/'\n"," self.train_query_path = self.train_path + 'query/'\n"," self.refer_path = self.train_path + 'refer/'\n"," #self.test_query_path = self.test_path + 'query/'\n"," self.test_query_path = self.test_path + 'query2/'\n"," self.train_df = pd.read_csv(self.train_path + 'train.csv')\n"," self.train_query_paths = self._get_videos(self.train_query_path)\n"," self.test_query_paths = self._get_videos(self.test_query_path)\n"," self.refer_paths = self._get_videos(self.refer_path)\n","\n"," def _get_videos(self, path):\n"," video_paths = glob.glob(path + '*.mp4')\n"," return video_paths\n"," \n"," def extract_keyframe(self, video_path, frame_path):\n"," video_id = video_path.split('/')[-1][:-4]\n"," if not os.path.exists(frame_path + video_id):\n"," os.mkdir(frame_path + video_id)\n","\n"," # 抽取关键帧(I帧)\n"," command = ['ffmpeg', '-i', video_path,\n"," '-vf', '\"select=eq(pict_type\\,I)\"',\n"," '-vsync', 'vfr', '-qscale:v', '2',\n"," '-f', 'image2',\n"," frame_path + '{0}/{0}_%05d.jpg'.format(video_id)]\n"," os.system(' '.join(command))\n","\n"," # 抽取视频关键帧时间\n"," command = ['ffprobe', '-i', video_path,\n"," '-v', 'quiet', '-select_streams',\n"," 'v', '-show_entries', 'frame=pkt_pts_time,pict_type|grep',\n"," '-B', '1', 'pict_type=I|grep pkt_pts_time', '>',\n"," frame_path + '{0}/{0}.log'.format(video_id)]\n"," os.system(' '.join(command))\n"," \n"," def _extract_keyframe(self, param):\n"," self.extract_keyframe(param[0], param[1])\n","\n"," def extract_uniformframe(self, video_path, frame_path, frame_per_sec=1):\n"," video_id = video_path.split('/')[-1][:-4]\n"," if not os.path.exists(frame_path + video_id):\n"," os.mkdir(frame_path + video_id)\n"," \n"," # -r 指定抽取的帧率,即从视频中每秒钟抽取图片的数量。1代表每秒抽取一帧。\n"," command = ['ffmpeg', '-i', video_path,\n"," '-r', str(frame_per_sec),\n"," '-q:v', '2', '-f', 'image2',\n"," frame_path + '{0}/{0}_%08d.000000.jpg'.format(video_id)]\n"," os.system(' '.join(command))\n"," \n"," def _extract_uniformframe(self, param):\n"," self.extract_uniformframe(param[0], param[1], param[2])\n"," \n"," # 关键帧用时间戳重命名\n"," def _rename(self, video_paths, frame_path, mode='key', frame_per_sec=1):\n"," for path in video_paths[:]:\n"," video_id = path.split('/')[-1][:-4]\n"," id_files = glob.glob(frame_path + video_id + '/*.jpg')\n"," # IMPORTANT!!!\n"," id_files.sort()\n"," if mode == 'key':\n"," id_times = codecs.open(frame_path + '{0}/{0}.log'.format(video_id)).readlines()\n"," id_times = [x.strip().split('=')[1] for x in id_times]\n","\n"," for id_file, id_time in zip(id_files, id_times):\n"," shutil.move(id_file, id_file[:-9] + id_time.zfill(15) + '.jpg')\n"," else:\n"," id_time = 0.0\n"," for id_file in id_files:\n"," shutil.move(id_file, id_file[:-19] + '{:0>15.4f}'.format(id_time) + '.jpg')\n"," id_time += 1.0 / frame_per_sec\n","\n"," def extract(self, mode='key', num_worker=5, frame_per_sec_q=1, frame_per_sec_r=1):\n"," if mode == 'key':\n"," pool = Pool(processes=num_worker)\n"," for path in self.train_query_paths:\n"," pool.apply_async(self._extract_keyframe, ((path, self.train_path + 'query_keyframe/'),))\n","\n"," for path in self.test_query_paths:\n"," # pool.apply_async(self._extract_keyframe, ((path, self.test_path + 'query_keyframe/'),))\n"," pool.apply_async(self._extract_keyframe, ((path, self.test_path + 'query2_keyframe/'),))\n","\n"," for path in self.refer_paths:\n"," pool.apply_async(self._extract_keyframe, ((path, self.train_path + 'refer_keyframe/'),))\n","\n"," pool.close()\n"," pool.join()\n"," \n"," self._rename(self.train_query_paths, self.train_path + 'query_keyframe/')\n"," # self._rename(self.test_query_paths, self.test_path + 'query_keyframe/')\n"," self._rename(self.test_query_paths, self.test_path + 'query2_keyframe/')\n"," self._rename(self.refer_paths, self.train_path + 'refer_keyframe/')\n","\n"," elif mode == 'uniform':\n"," \n"," pool = Pool(processes=num_worker)\n"," for path in self.train_query_paths:\n"," pool.apply_async(self._extract_uniformframe, ((path, self.train_path + 'query_uniformframe/', frame_per_sec_q),))\n","\n"," for path in self.test_query_paths:\n"," # pool.apply_async(self._extract_uniformframe, ((path, self.test_path + 'query_uniformframe/', frame_per_sec_q),))\n"," pool.apply_async(self._extract_uniformframe, ((path, self.test_path + 'query2_uniformframe/', frame_per_sec_q),))\n","\n"," for path in self.refer_paths:\n"," pool.apply_async(self._extract_uniformframe, ((path, self.train_path + 'refer_uniformframe/', frame_per_sec_r),))\n","\n"," pool.close()\n"," pool.join()\n"," \n"," self._rename(self.train_query_paths, self.train_path + 'query_uniformframe/', \n"," mode='uniform', frame_per_sec=frame_per_sec_q)\n","# self._rename(self.test_query_paths, self.test_path + 'query_uniformframe/',\n","# mode='uniform', frame_per_sec=frame_per_sec_q)\n"," self._rename(self.test_query_paths, self.test_path + 'query2_uniformframe/',\n"," mode='uniform', frame_per_sec=frame_per_sec_q)\n"," self._rename(self.refer_paths, self.train_path + 'refer_uniformframe/',\n"," mode='uniform', frame_per_sec=frame_per_sec_r)\n"," else:\n"," None"]},{"cell_type":"code","execution_count":8,"metadata":{},"outputs":[],"source":["frame_extractor = FrameExtractor(PATH)"]},{"cell_type":"code","execution_count":9,"metadata":{},"outputs":[],"source":["frame_extractor.extract(mode='uniform', num_worker=16, frame_per_sec_q=1, frame_per_sec_r=1)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[]}],"nbformat":4,"nbformat_minor":2,"metadata":{"language_info":{"name":"python","codemirror_mode":{"name":"ipython","version":3}},"orig_nbformat":2,"file_extension":".py","mimetype":"text/x-python","name":"python","npconvert_exporter":"python","pygments_lexer":"ipython3","version":3}} -------------------------------------------------------------------------------- /video_retrieval.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 相似视频检索\n", 8 | "\n", 9 | "视频级相似匹配 -> 帧级匹配" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 25, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import glob\n", 19 | "import pandas as pd\n", 20 | "import pickle\n", 21 | "import time\n", 22 | "\n", 23 | "import cv2\n", 24 | "import imagehash\n", 25 | "import numpy as np\n", 26 | "import networkx as nx\n", 27 | "from tqdm.notebook import tqdm\n", 28 | "from PIL import Image\n", 29 | "from scipy.spatial.distance import cdist\n", 30 | "from scipy.spatial.distance import cosine\n", 31 | "from networkx.algorithms.dag import dag_longest_path\n", 32 | "\n", 33 | "PATH = '/home/wx/work/video_copy_detection/'\n", 34 | "TRAIN_PATH = PATH + 'train/'\n", 35 | "TEST_PATH = PATH + 'test/'\n", 36 | "TRAIN_QUERY_PATH = TRAIN_PATH + 'query/'\n", 37 | "REFER_PATH = TRAIN_PATH + 'refer/'\n", 38 | "TRAIN_QUERY_FRAME_PATH = TRAIN_PATH + 'query_uniformframe/'\n", 39 | "REFER_FRAME_PATH = TRAIN_PATH + 'refer_uniformframe/'\n", 40 | "TEST_QUERY_PATH = TEST_PATH + 'query2/'\n", 41 | "TEST_QUERY_FRAME_PATH = TEST_PATH + 'query2_uniformframe/'\n", 42 | "CODE_DIR = PATH + 'code/'" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 26, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "# 读取特征文件\n", 52 | "with open(PATH + 'var/train_query_features_uni.pk', 'rb') as pk_file:\n", 53 | " train_query_features = pickle.load(pk_file)\n", 54 | "\n", 55 | "with open(PATH + 'var/test_query_features_uni.pk', 'rb') as pk_file:\n", 56 | " test_query_features = pickle.load(pk_file)\n", 57 | "\n", 58 | "with open(PATH + 'var/refer_features_uni.pk', 'rb') as pk_file:\n", 59 | " refer_features = pickle.load(pk_file)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 68, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "# 读取特征文件\n", 69 | "with open(PATH + 'var/train_query_features_res50_uni.pk', 'rb') as pk_file:\n", 70 | " train_query_features = pickle.load(pk_file)\n", 71 | " \n", 72 | "with open(PATH + 'var/test_query_features_res50_uni.pk', 'rb') as pk_file:\n", 73 | " test_query_features = pickle.load(pk_file)\n", 74 | "\n", 75 | "with open(PATH + 'var/refer_features_res50_uni.pk', 'rb') as pk_file:\n", 76 | " refer_features = pickle.load(pk_file)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 69, 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "data": { 86 | "text/plain": [ 87 | "(375702, 2048)" 88 | ] 89 | }, 90 | "execution_count": 69, 91 | "metadata": {}, 92 | "output_type": "execute_result" 93 | } 94 | ], 95 | "source": [ 96 | "train_query_features.shape" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 70, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "# 读取 train_query 视频的关键帧\n", 106 | "# 按照视频和关键帧时间进行排序\n", 107 | "# 预处理工具 dict\n", 108 | "train_query_imgs_path = []\n", 109 | "train_query_vids = []\n", 110 | "train_query_vid2idx = {}\n", 111 | "train_query_idx2vid = {}\n", 112 | "train_query_vid2baseaddr = {}\n", 113 | "train_query_fid2path = {}\n", 114 | "train_query_fid2vid = {}\n", 115 | "train_query_fid2time = {}\n", 116 | "\n", 117 | "for id in pd.read_csv(TRAIN_PATH + 'train.csv')['query_id']:\n", 118 | " train_query_imgs_path += glob.glob(TRAIN_QUERY_FRAME_PATH + id + '/*.jpg')\n", 119 | " train_query_vids += [id]\n", 120 | "\n", 121 | "train_query_imgs_path.sort(key = lambda x: x.lower())\n", 122 | "train_query_vids.sort(key = lambda x: x.lower())\n", 123 | "\n", 124 | "\n", 125 | "idx = 0\n", 126 | "for vid in train_query_vids:\n", 127 | " train_query_vid2idx[vid] = idx\n", 128 | " train_query_idx2vid[idx] = vid\n", 129 | " idx += 1\n", 130 | "fid = 0\n", 131 | "pre_vid = \"\"\n", 132 | "cur_base = 0\n", 133 | "for idx, path in enumerate(train_query_imgs_path):\n", 134 | " cur_vid = path.split('/')[-1][:-20]\n", 135 | " train_query_fid2vid[fid] = cur_vid\n", 136 | " train_query_fid2path[fid] = path\n", 137 | " train_query_fid2time[fid] = float(path.split('/')[-1].split('_')[-1][:-4])\n", 138 | " if pre_vid != cur_vid:\n", 139 | " cur_base = idx\n", 140 | " pre_vid = cur_vid\n", 141 | " train_query_vid2baseaddr[cur_vid] = cur_base\n", 142 | " fid += 1" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 71, 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "# path.split('/')[-1][:-20]\n", 152 | "# float(path.split('/')[-1].split('_')[-1][:-4])\n" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 72, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "# 读取 test_query 视频的关键帧\n", 162 | "# 按照视频和关键帧时间进行排序\n", 163 | "# 预处理工具 dict\n", 164 | "test_query_imgs_path = []\n", 165 | "test_query_vids = []\n", 166 | "test_query_vid2idx = {}\n", 167 | "test_query_idx2vid = {}\n", 168 | "test_query_vid2baseaddr = {}\n", 169 | "test_query_fid2path = {}\n", 170 | "test_query_fid2vid = {}\n", 171 | "test_query_fid2time = {}\n", 172 | "\n", 173 | "for id in pd.read_csv(TEST_PATH + 'submit_example2.csv')['query_id']:\n", 174 | " test_query_imgs_path += glob.glob(TEST_QUERY_FRAME_PATH + id + '/*.jpg')\n", 175 | " test_query_vids += [id]\n", 176 | "\n", 177 | "test_query_imgs_path.sort(key = lambda x: x.lower())\n", 178 | "test_query_vids.sort(key = lambda x: x.lower())\n", 179 | "\n", 180 | "idx = 0\n", 181 | "for vid in test_query_vids:\n", 182 | " test_query_vid2idx[vid] = idx\n", 183 | " test_query_idx2vid[idx] = vid\n", 184 | " idx += 1\n", 185 | "fid = 0\n", 186 | "pre_vid = \"\"\n", 187 | "cur_base = 0\n", 188 | "for idx, path in enumerate(test_query_imgs_path):\n", 189 | " cur_vid = path.split('/')[-1][:-20]\n", 190 | " test_query_fid2vid[fid] = cur_vid\n", 191 | " test_query_fid2path[fid] = path\n", 192 | " test_query_fid2time[fid] = float(path.split('/')[-1].split('_')[-1][:-4])\n", 193 | " if pre_vid != cur_vid:\n", 194 | " cur_base = idx\n", 195 | " pre_vid = cur_vid\n", 196 | " test_query_vid2baseaddr[cur_vid] = cur_base\n", 197 | " fid += 1" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 73, 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "# 读取 refer_query 视频的关键帧\n", 207 | "# 按照视频和关键帧时间进行排序\n", 208 | "# 预处理工具 dict\n", 209 | "\n", 210 | "refer_imgs_path = glob.glob(REFER_FRAME_PATH + '*/*.jpg')\n", 211 | "refer_imgs_path.sort(key = lambda x: x.lower())\n", 212 | "\n", 213 | "refer_vids = []\n", 214 | "refer_vid2idx = {}\n", 215 | "refer_idx2vid = {}\n", 216 | "refer_vid2baseaddr = {}\n", 217 | "refer_fid2path = {}\n", 218 | "refer_fid2vid = {}\n", 219 | "refer_fid2time = {}\n", 220 | "\n", 221 | "for path in refer_imgs_path:\n", 222 | " vid = path.split('/')[-2]\n", 223 | " refer_vids += [vid]\n", 224 | "\n", 225 | "refer_vids = list(set(refer_vids))\n", 226 | "refer_vids.sort(key = lambda x: x.lower())\n", 227 | "\n", 228 | "idx = 0\n", 229 | "for vid in refer_vids:\n", 230 | " refer_vid2idx[vid] = idx\n", 231 | " refer_idx2vid[idx] = vid\n", 232 | " idx += 1\n", 233 | "fid = 0\n", 234 | "pre_vid = \"\"\n", 235 | "cur_base = 0\n", 236 | "for idx, path in enumerate(refer_imgs_path):\n", 237 | " cur_vid = path.split('/')[-1][:-20]\n", 238 | " refer_fid2vid[fid] = cur_vid\n", 239 | " refer_fid2path[fid] = path\n", 240 | " refer_fid2time[fid] = float(path.split('/')[-1].split('_')[-1][:-4])\n", 241 | " if pre_vid != cur_vid:\n", 242 | " cur_base = idx\n", 243 | " pre_vid = cur_vid\n", 244 | " refer_vid2baseaddr[cur_vid] = cur_base\n", 245 | " fid += 1" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 74, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "vids = np.concatenate((train_query_vids, test_query_vids, refer_vids), axis=0)" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 82, 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "# 特征按视频归类\n", 264 | "if False:\n", 265 | " vid2features = {}\n", 266 | " for (path, cur_feat) in tqdm(zip(train_query_imgs_path, train_query_features)):\n", 267 | " vid = path.split('/')[-2]\n", 268 | " if(not vid in vid2features):\n", 269 | " vid2features[vid] = [cur_feat]\n", 270 | " else:\n", 271 | " vid2features[vid] = np.concatenate((vid2features[vid], [cur_feat]), axis=0)\n", 272 | "\n", 273 | " for (path, cur_feat) in tqdm(zip(test_query_imgs_path, test_query_features)):\n", 274 | " vid = path.split('/')[-2]\n", 275 | " if(not vid in vid2features):\n", 276 | " vid2features[vid] = [cur_feat]\n", 277 | " else:\n", 278 | " vid2features[vid] = np.concatenate((vid2features[vid], [cur_feat]), axis=0)\n", 279 | "\n", 280 | " for (path, cur_feat) in tqdm(zip(refer_imgs_path, refer_features)):\n", 281 | " vid = path.split('/')[-2]\n", 282 | " if(not vid in vid2features):\n", 283 | " vid2features[vid] = [cur_feat]\n", 284 | " else:\n", 285 | " vid2features[vid] = np.concatenate((vid2features[vid], [cur_feat]), axis=0)\n", 286 | " \n", 287 | " # with open(PATH + 'var/vid2features_uni.pk', 'wb') as pk_file:\n", 288 | " with open(PATH + 'var/vid2features_res50_uni.pk', 'wb') as pk_file:\n", 289 | " pickle.dump(vid2features, pk_file)\n", 290 | "else:\n", 291 | " with open(PATH + 'var/vid2features_uni.pk', 'rb') as pk_file:\n", 292 | " # with open(PATH + 'var/vid2features_res50_uni.pk', 'rb') as pk_file:\n", 293 | " vid2features = pickle.load(pk_file)\n" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 83, 299 | "metadata": {}, 300 | "outputs": [ 301 | { 302 | "data": { 303 | "text/plain": [ 304 | "(1008, 512)" 305 | ] 306 | }, 307 | "execution_count": 83, 308 | "metadata": {}, 309 | "output_type": "execute_result" 310 | } 311 | ], 312 | "source": [ 313 | "vid2features[refer_vids[0]].shape" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": 84, 319 | "metadata": {}, 320 | "outputs": [], 321 | "source": [ 322 | "def compute_similarities(query_features, refer_features):\n", 323 | " \"\"\"\n", 324 | " 用于计算两组特征(已经做过l2-norm)之间的相似度\n", 325 | " Args:\n", 326 | " query_features: shape: [N, D]\n", 327 | " refer_features: shape: [M, D]\n", 328 | " Returns:\n", 329 | " sorted_sims: shape: [N, M]\n", 330 | " unsorted_sims: shape: [N, M]\n", 331 | " \"\"\"\n", 332 | " sorted_sims = []\n", 333 | " unsorted_sims = []\n", 334 | " # 计算待查询视频和所有视频的距离\n", 335 | " dist = np.nan_to_num(cdist(query_features, refer_features, metric='cosine'))\n", 336 | " for i, v in enumerate(query_features):\n", 337 | " # 归一化,将距离转化成相似度\n", 338 | " # sim = np.round(1 - dist[i] / dist[i].max(), decimals=6)\n", 339 | " sim = 1 - dist[i]\n", 340 | " # 按照相似度的从大到小排列,输出index\n", 341 | " unsorted_sims += [sim]\n", 342 | " sorted_sims += [[(s, sim[s]) for s in sim.argsort()[::-1] if not np.isnan(sim[s])]]\n", 343 | " return sorted_sims, unsorted_sims\n", 344 | "\n", 345 | "def compute_dists(query_features, refer_features):\n", 346 | " \"\"\"\n", 347 | " 用于计算两组特征(已经做过l2-norm)之间的余弦距离\n", 348 | " Args:\n", 349 | " query_features: shape: [N, D]\n", 350 | " refer_features: shape: [M, D]\n", 351 | " Returns:\n", 352 | " idxs: shape [N, M]\n", 353 | " unsorted_dists: shape: [N, M]\n", 354 | " sorted_dists: shape: [N, M]\n", 355 | " \"\"\"\n", 356 | " sims = np.dot(query_features, refer_features.T)\n", 357 | " unsorted_dists = 1 - sims # sort 不好改降序\n", 358 | " # unsorted_dist = np.nan_to_num(cdist(query_features, refer_features, metric='cosine'))\n", 359 | " idxs = np.argsort(unsorted_dists)\n", 360 | " rows = np.dot(np.arange(idxs.shape[0]).reshape((idxs.shape[0], 1)), np.ones((1, idxs.shape[1]))).astype(int)\n", 361 | " sorted_dists = unsorted_dists[rows, idxs]\n", 362 | " # sorted_dists = np.sort(unsorted_dists)\n", 363 | " return idxs, unsorted_dists, sorted_dists" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": 85, 369 | "metadata": {}, 370 | "outputs": [], 371 | "source": [ 372 | "def get_frame_alignment(query_features, refer_features, top_K=5, min_sim=0.80, max_step=10):\n", 373 | " \"\"\"\n", 374 | " 用于计算两组特征(已经做过l2-norm)之间的帧匹配结果\n", 375 | " Args:\n", 376 | " query_features: shape: [N, D]\n", 377 | " refer_features: shape: [M, D]\n", 378 | " top_K: 取前K个refer_frame\n", 379 | " min_sim: 要求query_frame与refer_frame的最小相似度\n", 380 | " max_step: 有边相连的结点间的最大步长\n", 381 | " Returns:\n", 382 | " path_query: shape: [1, L]\n", 383 | " path_refer: shape: [1, L]\n", 384 | " \"\"\"\n", 385 | " node_pair2id = {}\n", 386 | " node_id2pair = {}\n", 387 | " node_id2pair[0] = (-1, -1) # source\n", 388 | " node_pair2id[(-1, -1)] = 0\n", 389 | " node_num = 1\n", 390 | "\n", 391 | " DG = nx.DiGraph()\n", 392 | " DG.add_node(0)\n", 393 | "\n", 394 | " idxs, unsorted_dists, sorted_dists = compute_dists(query_features, refer_features)\n", 395 | "\n", 396 | " # add nodes\n", 397 | " for qf_idx in range(query_features.shape[0]):\n", 398 | " for k in range(top_K):\n", 399 | " rf_idx = idxs[qf_idx][k]\n", 400 | " sim = 1 - sorted_dists[qf_idx][k]\n", 401 | " if sim < min_sim:\n", 402 | " break\n", 403 | " node_id2pair[node_num] = (qf_idx, rf_idx)\n", 404 | " node_pair2id[(qf_idx, rf_idx)] = node_num\n", 405 | " DG.add_node(node_num)\n", 406 | " node_num += 1\n", 407 | " \n", 408 | " node_id2pair[node_num] = (query_features.shape[0], refer_features.shape[0]) # sink\n", 409 | " node_pair2id[(query_features.shape[0], refer_features.shape[0])] = node_num\n", 410 | " DG.add_node(node_num)\n", 411 | " node_num += 1\n", 412 | "\n", 413 | " # link nodes\n", 414 | "\n", 415 | " for i in range(0, node_num - 1):\n", 416 | " for j in range(i + 1, node_num - 1):\n", 417 | " \n", 418 | " pair_i = node_id2pair[i]\n", 419 | " pair_j = node_id2pair[j]\n", 420 | "\n", 421 | " if(pair_j[0] > pair_i[0] and pair_j[1] > pair_i[1] and\n", 422 | " pair_j[0] - pair_i[0] <= max_step and pair_j[1] - pair_i[1] <= max_step):\n", 423 | " qf_idx = pair_j[0]\n", 424 | " rf_idx = pair_j[1]\n", 425 | " DG.add_edge(i, j, weight=1 - unsorted_dists[qf_idx][rf_idx])\n", 426 | "\n", 427 | " for i in range(0, node_num - 1):\n", 428 | " j = node_num - 1\n", 429 | "\n", 430 | " pair_i = node_id2pair[i]\n", 431 | " pair_j = node_id2pair[j]\n", 432 | "\n", 433 | " if(pair_j[0] > pair_i[0] and pair_j[1] > pair_i[1] and\n", 434 | " pair_j[0] - pair_i[0] <= max_step and pair_j[1] - pair_i[1] <= max_step):\n", 435 | " qf_idx = pair_j[0]\n", 436 | " rf_idx = pair_j[1]\n", 437 | " DG.add_edge(i, j, weight=0)\n", 438 | "\n", 439 | " longest_path = dag_longest_path(DG)\n", 440 | " if 0 in longest_path:\n", 441 | " longest_path.remove(0) # remove source node\n", 442 | " if node_num - 1 in longest_path:\n", 443 | " longest_path.remove(node_num - 1) # remove sink node\n", 444 | " path_query = [node_id2pair[node_id][0] for node_id in longest_path]\n", 445 | " path_refer = [node_id2pair[node_id][1] for node_id in longest_path]\n", 446 | "\n", 447 | " score = 0.0\n", 448 | " for (qf_idx, rf_idx) in zip(path_query, path_refer):\n", 449 | " score += 1 - unsorted_dists[qf_idx][rf_idx]\n", 450 | "\n", 451 | " return path_query, path_refer, score" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": 86, 457 | "metadata": {}, 458 | "outputs": [ 459 | { 460 | "name": "stdout", 461 | "output_type": "stream", 462 | "text": [ 463 | "totally cost 0.12647390365600586\n" 464 | ] 465 | } 466 | ], 467 | "source": [ 468 | "time_start=time.time()\n", 469 | "qf = vid2features[train_query_vids[0]]\n", 470 | "rf = vid2features['1226686400']\n", 471 | "idxs, unsorted_dists, sorted_dists = compute_dists(qf, rf)\n", 472 | "time_end=time.time()\n", 473 | "print('totally cost',time_end-time_start)" 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": 87, 479 | "metadata": {}, 480 | "outputs": [ 481 | { 482 | "name": "stdout", 483 | "output_type": "stream", 484 | "text": [ 485 | "query_time_range(ms): 57000|79000\n", 486 | "refer_time_range(ms): 3227000|3248000\n", 487 | "score: 7.64414119720459\n", 488 | " query_id query_time_range(ms) refer_id \\\n", 489 | "1308 00530630-b8c8-11e9-930e-fa163ee49799 48290|116410 3184886800 \n", 490 | "\n", 491 | " refer_time_range(ms) \n", 492 | "1308 3217530|3285650 \n", 493 | "totally cost 0.2443249225616455\n" 494 | ] 495 | } 496 | ], 497 | "source": [ 498 | "time_start=time.time()\n", 499 | "q_vid = '00530630-b8c8-11e9-930e-fa163ee49799'\n", 500 | "r_vid = '3184886800'\n", 501 | "query = vid2features[q_vid]\n", 502 | "refer = vid2features[r_vid]\n", 503 | "q_baseaddr = train_query_vid2baseaddr[q_vid]\n", 504 | "r_baseaddr = refer_vid2baseaddr[r_vid]\n", 505 | "path_query, path_refer, score = get_frame_alignment(query, refer) # local address\n", 506 | "\n", 507 | "time_query = [int(train_query_fid2time[q_baseaddr + qf_id] * 1000) for qf_id in path_query]\n", 508 | "time_refer = [int(refer_fid2time[r_baseaddr + rf_id] * 1000) for rf_id in path_refer]\n", 509 | "print(\"query_time_range(ms): {}|{}\".format(time_query[0], time_query[-1]))\n", 510 | "print(\"refer_time_range(ms): {}|{}\".format(time_refer[0], time_refer[-1]))\n", 511 | "print(\"score: {}\".format(score))\n", 512 | "#print(time_query)\n", 513 | "#print(time_refer)\n", 514 | "train_df = pd.read_csv(TRAIN_PATH + 'train.csv')\n", 515 | "print(train_df.loc[train_df['query_id'] == q_vid])\n", 516 | "time_end=time.time()\n", 517 | "print('totally cost',time_end-time_start)" 518 | ] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "execution_count": 88, 523 | "metadata": {}, 524 | "outputs": [], 525 | "source": [ 526 | "idxs, unsorted_dists, sorted_dists = compute_dists(vid2features[q_vid], vid2features[r_vid])" 527 | ] 528 | }, 529 | { 530 | "cell_type": "code", 531 | "execution_count": 89, 532 | "metadata": {}, 533 | "outputs": [ 534 | { 535 | "data": { 536 | "text/plain": [ 537 | "'\\nfor i in range(len(sorted_dists)):\\n print(i)\\n for j in range(5):\\n print(idxs[i][j], 1-sorted_dists[i][j])\\n'" 538 | ] 539 | }, 540 | "execution_count": 89, 541 | "metadata": {}, 542 | "output_type": "execute_result" 543 | } 544 | ], 545 | "source": [ 546 | "# debug\n", 547 | "'''\n", 548 | "for i in range(len(sorted_dists)):\n", 549 | " print(i)\n", 550 | " for j in range(5):\n", 551 | " print(idxs[i][j], 1-sorted_dists[i][j])\n", 552 | "'''" 553 | ] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "execution_count": 98, 558 | "metadata": {}, 559 | "outputs": [ 560 | { 561 | "name": "stdout", 562 | "output_type": "stream", 563 | "text": [ 564 | "(51.53578978776932, '1224903000', 92000, 175000, 486000, 575000)\n", 565 | "(42.96090793609619, '2274916400', 18000, 86000, 1284000, 1352000)\n", 566 | "(63.607283651828766, '1356122300', 0, 69000, 343000, 411000)\n", 567 | "(49.72458457946777, '1398481500', 0, 60000, 809000, 868000)\n", 568 | "(14.793390035629272, '2509505900', 82000, 111000, 1041000, 1070000)\n", 569 | "(3.442767322063446, '1500872700', 91000, 95000, 249000, 262000)\n", 570 | "(26.741500854492188, '2666192100', 28000, 64000, 4542000, 4593000)\n", 571 | "(70.90073662996292, '1176745900', 0, 79000, 1289000, 1368000)\n", 572 | "(49.04690796136856, '1659203000', 32000, 92000, 688000, 760000)\n", 573 | "(6.891523063182831, '1419310000', 9000, 32000, 3931000, 3954000)\n", 574 | "(24.807477474212646, '1332713900', 34000, 101000, 1995000, 2063000)\n", 575 | "(65.31166815757751, '2436435900', 7000, 86000, 1166000, 1245000)\n", 576 | "(33.87100577354431, '2342638000', 0, 50000, 6824000, 6875000)\n", 577 | "(20.107609510421753, '1887729500', 47000, 93000, 464000, 524000)\n", 578 | "(31.705184996128082, '1596058300', 0, 67000, 429000, 495000)\n", 579 | "(67.98666608333588, '1723849300', 52000, 120000, 2480000, 2548000)\n", 580 | "(48.84355956315994, '3043930400', 0, 49000, 1155000, 1216000)\n", 581 | "(30.62875211238861, '2845332600', 0, 50000, 3461000, 3509000)\n", 582 | "(6.086414098739624, '1534060400', 17000, 45000, 3671000, 3688000)\n", 583 | "(56.07742738723755, '3203967000', 0, 61000, 1272000, 1344000)\n", 584 | "(70.84944242238998, '2342638000', 0, 105000, 1508000, 1611000)\n", 585 | "(6.202283024787903, '1901179600', 50000, 75000, 96000, 120000)\n", 586 | "(10.616593658924103, '3166859000', 63000, 105000, 1405000, 1447000)\n", 587 | "(4.430018126964569, '1534060400', 3000, 24000, 3668000, 3683000)\n", 588 | "(73.7065578699112, '1629260900', 93000, 179000, 5352000, 5438000)\n", 589 | "(51.001447439193726, '1601278800', 25000, 81000, 6757000, 6813000)\n", 590 | "(20.778279781341553, '2666192100', 34000, 82000, 1537000, 1601000)\n", 591 | "(26.051497995853424, '2367850000', 28000, 79000, 3571000, 3623000)\n", 592 | "(43.826499819755554, '1887729500', 0, 55000, 2495000, 2551000)\n", 593 | "(47.443311750888824, '1500872700', 0, 99000, 3441000, 3550000)\n", 594 | "(65.99086838960648, '1402364300', 36000, 104000, 2235000, 2303000)\n", 595 | "(10.400666773319244, '1598981800', 57000, 92000, 2683000, 2722000)\n", 596 | "(5.219767153263092, '1534060400', 74000, 97000, 3671000, 3687000)\n", 597 | "(8.667783915996552, '2333805400', 63000, 103000, 4858000, 4890000)\n", 598 | "(46.378190100193024, '3009055500', 2000, 53000, 1413000, 1474000)\n", 599 | "(8.748582065105438, '2303359200', 92000, 120000, 3114000, 3141000)\n", 600 | "(4.342910885810852, '1627286600', 17000, 31000, 138000, 156000)\n", 601 | "(2.6282519698143005, '1398481500', 31000, 34000, 47000, 50000)\n", 602 | "(53.984080612659454, '2620315400', 32000, 98000, 3945000, 4011000)\n", 603 | "(12.621381044387817, '1600623900', 38000, 80000, 468000, 512000)\n", 604 | "(33.809822618961334, '1234417600', 17000, 79000, 2853000, 2937000)\n", 605 | "(50.233526170253754, '2274699000', 0, 52000, 2139000, 2192000)\n", 606 | "(37.673711478710175, '2929626300', 45000, 83000, 3347000, 3385000)\n", 607 | "(90.90127158164978, '1684068900', 56000, 147000, 1163000, 1254000)\n", 608 | "(70.06904345750809, '1804154000', 9000, 86000, 4394000, 4471000)\n", 609 | "(34.03557014465332, '1260706600', 54000, 109000, 2264000, 2319000)\n", 610 | "(6.189080715179443, '1782169900', 33000, 51000, 106000, 117000)\n", 611 | "(68.13213807344437, '1443620200', 0, 80000, 3680000, 3774000)\n", 612 | "(15.171736061573029, '1928274600', 107000, 132000, 3127000, 3152000)\n", 613 | "(34.015252470970154, '2400411900', 73000, 133000, 550000, 611000)\n", 614 | "(6.113804221153259, '1804132300', 105000, 124000, 1391000, 1408000)\n", 615 | "(4.361481070518494, '3166859000', 81000, 109000, 59000, 81000)\n", 616 | "(39.306980311870575, '1443620200', 1000, 87000, 1152000, 1254000)\n" 617 | ] 618 | }, 619 | { 620 | "ename": "KeyboardInterrupt", 621 | "evalue": "", 622 | "output_type": "error", 623 | "traceback": [ 624 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 625 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 626 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mr_vid\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrefer_vids\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mr_feat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvid2features\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mr_vid\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0midxs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0munsorted_dists\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msorted_dists\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_dists\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mq_feat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mr_feat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0mscore\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msorted_dists\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0mr_scores\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscore\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mr_vid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 627 | "\u001b[0;32m\u001b[0m in \u001b[0;36mcompute_dists\u001b[0;34m(query_features, refer_features)\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0midxs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margsort\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0munsorted_dists\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0mrows\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0midxs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0midxs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midxs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mastype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 40\u001b[0;31m \u001b[0msorted_dists\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0munsorted_dists\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mrows\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midxs\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 41\u001b[0m \u001b[0;31m# sorted_dists = np.sort(unsorted_dists)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0midxs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0munsorted_dists\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msorted_dists\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 628 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 629 | ] 630 | } 631 | ], 632 | "source": [ 633 | "train_query_ans = {}\n", 634 | "for q_vid in train_query_vids:\n", 635 | " q_feat = vid2features[q_vid]\n", 636 | " q_baseaddr = train_query_vid2baseaddr[q_vid]\n", 637 | " q_ans = []\n", 638 | " # 初筛\n", 639 | " r_scores = []\n", 640 | " for r_vid in refer_vids:\n", 641 | " r_feat = vid2features[r_vid]\n", 642 | " idxs, unsorted_dists, sorted_dists = compute_dists(q_feat, r_feat)\n", 643 | " score = np.sum(sorted_dists[:, 0])\n", 644 | " r_scores.append((score, r_vid))\n", 645 | " r_scores.sort(key = lambda x: x[0], reverse=False)\n", 646 | " # 细筛\n", 647 | " top_K = 20\n", 648 | " for k, (_, r_vid) in enumerate(r_scores):\n", 649 | " if(k >= top_K):\n", 650 | " break\n", 651 | " r_feat = vid2features[r_vid]\n", 652 | " r_baseaddr = refer_vid2baseaddr[r_vid]\n", 653 | " path_q, path_r, score = get_frame_alignment(q_feat, r_feat, top_K=3, min_sim=0.85, max_step=10)\n", 654 | " if len(path_q) > 0:\n", 655 | " time_q = [int(train_query_fid2time[q_baseaddr + qf_id] * 1000) for qf_id in path_q]\n", 656 | " time_r = [int(refer_fid2time[r_baseaddr + rf_id] * 1000) for rf_id in path_r]\n", 657 | " q_ans.append((score, r_vid, time_q[0], time_q[-1], time_r[0], time_r[-1]))\n", 658 | " \n", 659 | " q_ans.sort(key = lambda x: x[0], reverse=True)\n", 660 | " train_query_ans[q_vid] = q_ans[0][1:]\n", 661 | " print(q_ans[0])\n" 662 | ] 663 | }, 664 | { 665 | "cell_type": "code", 666 | "execution_count": 91, 667 | "metadata": {}, 668 | "outputs": [], 669 | "source": [ 670 | "# 读取 train.csv\n", 671 | "train_df = pd.read_csv(TRAIN_PATH + 'train.csv')\n", 672 | "train_query_label = {}\n", 673 | "for vid in train_query_vids:\n", 674 | " row = train_df.loc[train_df['query_id'] == vid]\n", 675 | " time_q = (int(row.iloc[0, 1].split('|')[0]), int(row.iloc[0, 1].split('|')[1]))\n", 676 | " time_r = (int(row.iloc[0, 3].split('|')[0]), int(row.iloc[0, 3].split('|')[1]))\n", 677 | " train_query_label[vid] = (str(row.iloc[0, 2]), time_q[0], time_q[1], time_r[0], time_r[1])" 678 | ] 679 | }, 680 | { 681 | "cell_type": "code", 682 | "execution_count": 92, 683 | "metadata": {}, 684 | "outputs": [], 685 | "source": [ 686 | "# 计算分数\n", 687 | "def compute_precision_recall(y_true, y_pred, pr=False):\n", 688 | " \"\"\"\n", 689 | " 用于计算测试结果的P-R值\n", 690 | " Args:\n", 691 | " y_true: dict shape: [N, 5]\n", 692 | " y_pred: dict shape: [M, 5]\n", 693 | " pr: need precision and recall\n", 694 | " Returns:\n", 695 | " f1_score\n", 696 | " precision\n", 697 | " recall\n", 698 | " \"\"\"\n", 699 | " tp = fp = fn = 0\n", 700 | " threshold = 3000\n", 701 | "\n", 702 | "# for q_vid in y_true:\n", 703 | " for q_vid in y_pred:\n", 704 | " q_ans = y_pred[q_vid]\n", 705 | " q_label = y_true[q_vid]\n", 706 | "\n", 707 | " if(len(q_ans) == 5):\n", 708 | " if(q_ans[0] == q_label[0] and abs(q_ans[1] - q_label[1]) <= threshold and abs(q_ans[2] - q_label[2]) <= threshold \n", 709 | " and abs(q_ans[3] - q_label[3]) <= threshold and abs(q_ans[4] - q_label[4]) <= threshold):\n", 710 | " tp += 1\n", 711 | " else:\n", 712 | " fp += 1\n", 713 | " else:\n", 714 | " fn += 1\n", 715 | " precision = tp / (tp + fp)\n", 716 | " recall = tp / (tp + fn)\n", 717 | " f1_score = 2 * precision * recall / (precision + recall)\n", 718 | " if(pr):\n", 719 | " return f1_score, precision, recall\n", 720 | " else:\n", 721 | " return f1_score" 722 | ] 723 | }, 724 | { 725 | "cell_type": "code", 726 | "execution_count": 99, 727 | "metadata": {}, 728 | "outputs": [ 729 | { 730 | "data": { 731 | "text/plain": [ 732 | "0.6233766233766235" 733 | ] 734 | }, 735 | "execution_count": 99, 736 | "metadata": {}, 737 | "output_type": "execute_result" 738 | } 739 | ], 740 | "source": [ 741 | "compute_precision_recall(train_query_label, train_query_ans)" 742 | ] 743 | }, 744 | { 745 | "cell_type": "code", 746 | "execution_count": null, 747 | "metadata": {}, 748 | "outputs": [ 749 | { 750 | "name": "stdout", 751 | "output_type": "stream", 752 | "text": [ 753 | "0 (24.256401896476746, '1568027000', 70000, 137000, 455000, 527000)\n" 754 | ] 755 | } 756 | ], 757 | "source": [ 758 | "# 准备提交\n", 759 | "test_query_ans = {}\n", 760 | "for i, q_vid in enumerate(test_query_vids):\n", 761 | " q_feat = vid2features[q_vid]\n", 762 | " q_baseaddr = test_query_vid2baseaddr[q_vid]\n", 763 | " q_ans = []\n", 764 | " # 初筛\n", 765 | " r_scores = []\n", 766 | " for r_vid in refer_vids:\n", 767 | " r_feat = vid2features[r_vid]\n", 768 | " idxs, unsorted_dists, sorted_dists = compute_dists(q_feat, r_feat)\n", 769 | " score = np.sum(sorted_dists[:, 0])\n", 770 | " r_scores.append((score, r_vid))\n", 771 | " r_scores.sort(key = lambda x: x[0], reverse=False)\n", 772 | " # 细筛\n", 773 | " top_K = 20\n", 774 | " for k, (_, r_vid) in enumerate(r_scores):\n", 775 | " if(k >= top_K):\n", 776 | " break\n", 777 | " r_feat = vid2features[r_vid]\n", 778 | " r_baseaddr = refer_vid2baseaddr[r_vid]\n", 779 | " path_q, path_r, score = get_frame_alignment(q_feat, r_feat, top_K=3, min_sim=0.85, max_step=10)\n", 780 | " if len(path_q) > 0:\n", 781 | " time_q = [int(test_query_fid2time[q_baseaddr + qf_id] * 1000) for qf_id in path_q]\n", 782 | " time_r = [int(refer_fid2time[r_baseaddr + rf_id] * 1000) for rf_id in path_r]\n", 783 | " q_ans.append((score, r_vid, time_q[0], time_q[-1], time_r[0], time_r[-1]))\n", 784 | " \n", 785 | " q_ans.sort(key = lambda x: x[0], reverse=True)\n", 786 | " test_query_ans[q_vid] = q_ans[0][1:]\n", 787 | " print(i, q_ans[0])\n", 788 | " if i % 10 == 0:\n", 789 | " with open(PATH + 'var/test_query_ans_uni.pk', 'wb') as pk_file:\n", 790 | " pickle.dump(test_query_ans, pk_file)\n" 791 | ] 792 | }, 793 | { 794 | "cell_type": "code", 795 | "execution_count": null, 796 | "metadata": {}, 797 | "outputs": [], 798 | "source": [ 799 | "# 提交一个最简单的结果\n", 800 | "submit_df = pd.read_csv(TEST_PATH + 'submit_example2.csv')\n", 801 | "for vid in test_query_vids:\n", 802 | " q_pred = test_query_ans[vid]\n", 803 | " time_q = str(q_pred[1]) + '|' + str(q_pred[2])\n", 804 | " time_r = str(q_pred[3]) + '|' + str(q_pred[4])\n", 805 | " submit_df.loc[submit_df['query_id'] == vid, ['query_time_range(ms)', 'refer_id', 'refer_time_range(ms)']] = [time_q, q_pred[0], time_r]\n", 806 | "\n", 807 | "submit_df.to_csv(TEST_PATH + 'result2.csv', index = None, sep=',')" 808 | ] 809 | }, 810 | { 811 | "cell_type": "code", 812 | "execution_count": null, 813 | "metadata": {}, 814 | "outputs": [], 815 | "source": [] 816 | } 817 | ], 818 | "metadata": { 819 | "file_extension": ".py", 820 | "kernelspec": { 821 | "display_name": "Python 3", 822 | "language": "python", 823 | "name": "python3" 824 | }, 825 | "language_info": { 826 | "codemirror_mode": { 827 | "name": "ipython", 828 | "version": 3 829 | }, 830 | "file_extension": ".py", 831 | "mimetype": "text/x-python", 832 | "name": "python", 833 | "nbconvert_exporter": "python", 834 | "pygments_lexer": "ipython3", 835 | "version": "3.7.4" 836 | }, 837 | "mimetype": "text/x-python", 838 | "name": "python", 839 | "npconvert_exporter": "python", 840 | "pygments_lexer": "ipython3", 841 | "version": 3 842 | }, 843 | "nbformat": 4, 844 | "nbformat_minor": 4 845 | } 846 | --------------------------------------------------------------------------------