├── .gitignore ├── LICENSE ├── README.md ├── TEMPLATE_eval.sh ├── augmentation.ipynb ├── data.py ├── data_augmenter.py ├── do_all_frame_level.sh ├── do_all_video_level.sh ├── do_feature_convert.sh ├── evaluation.py ├── feature_convert.py ├── fig ├── frame_aug.jpg └── video_aug.jpg ├── gene_aug_feat.py ├── model.py ├── requirements.txt ├── simpleknn ├── LICENSE ├── README.md ├── __init__.py ├── bigfile.py ├── build.sh ├── cpp │ ├── Makefile │ ├── Makefile.win64 │ ├── build.win64.bat │ ├── search.cpp │ ├── search.cpp.bak │ ├── search.def │ └── search.h ├── demo.py ├── do_norm_feat.sh ├── im2fea.py ├── lib │ ├── linux │ │ └── libsearch.so │ ├── mac │ │ └── libsearch.so │ └── win64 │ │ └── libsearch.dll ├── merge_feat.py ├── norm_feat.py ├── simpleknn.py ├── test.bat ├── test.sh ├── testbigfile.py ├── toydata │ └── FeatureData │ │ └── f1 │ │ ├── feature.bin │ │ ├── id.feature.txt │ │ ├── id.txt │ │ └── shape.txt └── txt2bin.py ├── test.py ├── train.py └── utils ├── __init__.py ├── cbvrp_eval.py ├── common.py ├── generic_utils.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *~ 3 | *.swp 4 | *.swo 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Feature Re-Learning with Data Augmentation for Video Relevance Prediction 2 | 3 | 4 | The source code of our TKDE paper [Feature Re-Learning with Data Augmentation for Content-based Video Recommendation](https://dl.acm.org/doi/abs/10.1145/3240508.3266441). We proposed a feature re-learning model enhanced by data augmentation that works for both frame-level and video-level features and negative-enhanced triplet ranking loss. It is also our winning entry for the [Hulu Content-based Video Relevance Prediction Challenge](https://github.com/cbvrp-acmmm-2018/cbvrp-acmmm-2018) at the ACM Multimedia 2018 conference. 5 | 6 | 7 | ## Requirements 8 | #### Required Packages 9 | * **python** 2.7 10 | * **PyTorch** 0.3.1 11 | * **tensorboard_logger** for tensorboard visualization 12 | 13 | We used virtualenv to setup a deep learning workspace that supports PyTorch. 14 | Run the following script to install the required packages. 15 | ```shell 16 | virtualenv --system-site-packages ~/cbvr 17 | source ~/cbvr/bin/activate 18 | pip install -r requirements.txt 19 | deactivate 20 | ``` 21 | 22 | #### Required Data 23 | 1. Download track_1_shows(6G) and track_2_movies(9.0G) datasets from [Google Drive](https://drive.google.com/open?id=1V9eZbbVEV6AQlTYqqjfrz0Lcpeqhk6Xn) or [Baidu Pan](https://pan.baidu.com/s/1v86WP7u-tcuO2qzh0CVAqQ#list/path=%2Fcbvr_data) or [here](http://39.104.114.128/cbvr_mm_2018/). If you have already downloaded the datasets provided by Hulu organizers, use the script [do_feature_convert.sh](do_feature_convert.sh) to convert the dataset to fit for our code. 24 | 2. Run the following script to extract the downloaded data. The extracted data is placed in `$HOME/VisualSearch/`. 25 | ```shell 26 | ROOTPATH=$HOME/VisualSearch 27 | mkdir -p $ROOTPATH 28 | # extract track_1_shows and track_2_movies datasets 29 | tar zxf track_1_shows.tar.gz -C $ROOTPATH 30 | tar zxf track_2_movies.tar.gz -C $ROOTPATH 31 | ``` 32 | 33 | 34 | ## Getting started 35 | #### Augmentation for frame-level features 36 | ![image](fig/frame_aug.jpg) 37 | 38 | Run the following script to train and evaluate the model with augmentation for frame-level features and the negative-enhanced triplet ranking loss. 39 | ```shell 40 | source ~/cbvr/bin/activate 41 | # on track_1_shows and track_2_movies with stride=12 42 | stride=12 43 | loss=netrl # use trl if you would like to use common triplet ranking loss 44 | ./do_all_frame_level.sh track_1_shows inception-pool3 $stride $loss 45 | ./do_all_frame_level.sh track_2_movies inception-pool3 $stride $loss 46 | deactive 47 | ``` 48 | Running the script will do the following things: 49 | 1. Generate augmented frame-level features and operate mean pooling to obtain video-level features in advance. 50 | 2. Train the feature re-learning model with augmentation for frame-level features and select a checkpoint that performs best on the validation set as the final model. 51 | 3. Evaluate the final model on the validate set and generate predicted results on the test set. Both two relevance prediction strategies are performed. Note that we as participants have no access to the ground-truth of the test set. Please contact the [task organizers](https://github.com/cbvrp-acmmm-2018/cbvrp-acmmm-2018) in case you may want to evaluate our model or your own model on the test set. 52 | 53 | 54 | #### Augmentation for video-level features 55 | ![image](fig/video_aug.jpg) 56 | 57 | Run the following script to train and evaluate the model with augmentation for video-level features. 58 | ```shell 59 | source ~/cbvr/bin/activate 60 | # on track_1_shows 61 | ./do_all_video_level.sh track_1_shows c3d-pool5 netrl 62 | # on track_2_movies 63 | ./do_all_video_level.sh track_2_movies c3d-pool5 netrl 64 | deactive 65 | ``` 66 | Running the script will do the following things: 67 | 1. Train the feature re-learning model with augmentation for video-level features and select a checkpoint that performs best on the validation set as the final model. (The augmented video-level features are generated on the fly.) 68 | 2. Evaluate the final model on the validate set and generate predicted results on the test set. 69 | 70 | 71 | ## How to perform the proposed augmentation for other video-related tasks? 72 | The proposed augmentation essentially can be used for other video-related tasks. 73 | [This note](augmentation.ipynb) shows 74 | * How to perform data augmentation over frame-level features? 75 | * How to perform data augmentation over video-level features? 76 | 77 | 78 | ## Citation 79 | If you find the package useful, please consider citing our following papers: 80 | ``` 81 | @inproceedings{mm2018-cbvrp-dong, 82 | title = {Feature Re-Learning with Data Augmentation for Content-based Video Recommendation}, 83 | author = {Jianfeng Dong and Xirong Li and Chaoxi Xu and Gang Yang and Xun Wang}, 84 | doi = {10.1145/3240508.3266441}, 85 | year = {2018}, 86 | booktitle = {ACM Multimedia}, 87 | } 88 | 89 | @article{dong2019feature, 90 | title={Feature Re-Learning with Data Augmentation for Video Relevance Prediction}, 91 | author={Dong, Jianfeng and Wang, Xun and Zhang, Leimin and Xu, Chaoxi and Yang, Gang and Li, Xirong}, 92 | journal={IEEE Transactions on Knowledge and Data Engineering}, 93 | doi={10.1109/TKDE.2019.2947442} 94 | year={2019}, 95 | publisher={IEEE} 96 | } 97 | 98 | ``` 99 | 100 | 101 | 102 | 103 | ## Acknowledgements 104 | We are grateful to HULU organizers for the challenge organization effort. 105 | ``` 106 | @article{liu2018content, 107 | title={Content-based Video Relevance Prediction Challenge: Data, Protocol, and Baseline}, 108 | author={Liu, Mengyi and Xie, Xiaohui and Zhou, Hanning}, 109 | journal={arXiv preprint arXiv:1806.00737}, 110 | year={2018} 111 | } 112 | ``` 113 | -------------------------------------------------------------------------------- /TEMPLATE_eval.sh: -------------------------------------------------------------------------------- 1 | rootpath=@@@rootpath@@@ 2 | collection=@@@collection@@@ 3 | overwrite=@@@overwrite@@@ 4 | 5 | checkpoint_path=@@@model_path@@@/model_best.pth.tar 6 | 7 | gpu=0 8 | for test_set in val 9 | do 10 | for strategy in 1 2 11 | do 12 | CUDA_VISIBLE_DEVICES=$gpu python test.py --rootpath $rootpath --collection $collection --checkpoint_path $checkpoint_path --test_set $test_set --overwrite $overwrite --strategy $strategy 13 | done 14 | done 15 | -------------------------------------------------------------------------------- /augmentation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Augmentation for frame-level features and video-level features" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "This note answers the following two questions:\n", 15 | "\n", 16 | "1. How to perfome data augmention over frame-level features?\n", 17 | "2. How to perfome data augmention over video-level features?" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "### 0. Setup" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 1, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "import numpy as np\n", 34 | "from data_augmenter import Frame_Level_Augmenter, Video_Level_Augmenter" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "### 1. Augmentation for frame-level features" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 2, 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "name": "stdout", 51 | "output_type": "stream", 52 | "text": [ 53 | "[[-0.6235536 -0.90535705 -0.44822978 0.50550119 -0.02782648]\n", 54 | " [-1.02859073 0.072869 -0.50732775 -0.9271956 1.05620006]\n", 55 | " [ 1.07593828 1.03311453 0.34317861 0.16404785 -1.0236284 ]\n", 56 | " [-1.25972658 0.38853909 0.98516284 -0.09839638 -1.29890457]\n", 57 | " [ 1.29644845 -1.27334989 0.13306375 -0.39547421 -0.49235768]\n", 58 | " [ 0.40335207 -0.72771485 -0.83376385 -0.26617051 1.025612 ]\n", 59 | " [ 1.91081159 1.62004108 0.25886422 0.45212089 -1.13554757]\n", 60 | " [ 1.0761965 0.56507028 -0.79111638 -0.77533702 -0.65562929]\n", 61 | " [ 1.41041613 0.59244662 -0.71805043 0.12487997 0.99022644]]\n" 62 | ] 63 | } 64 | ], 65 | "source": [ 66 | "# randomly generate a frame-level feature\n", 67 | "n_frms = 9 # number of frames\n", 68 | "frm_feat_dim = 5 # dimensionality of frame feature vector\n", 69 | "frm_feat = np.random.randn(n_frms, frm_feat_dim)\n", 70 | "print frm_feat" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 3, 76 | "metadata": {}, 77 | "outputs": [ 78 | { 79 | "data": { 80 | "text/plain": [ 81 | "[[array([-0.6235536 , -0.90535705, -0.44822978, 0.50550119, -0.02782648]),\n", 82 | " array([ 1.07593828, 1.03311453, 0.34317861, 0.16404785, -1.0236284 ]),\n", 83 | " array([ 1.29644845, -1.27334989, 0.13306375, -0.39547421, -0.49235768]),\n", 84 | " array([ 1.91081159, 1.62004108, 0.25886422, 0.45212089, -1.13554757]),\n", 85 | " array([ 1.41041613, 0.59244662, -0.71805043, 0.12487997, 0.99022644])],\n", 86 | " [array([-1.02859073, 0.072869 , -0.50732775, -0.9271956 , 1.05620006]),\n", 87 | " array([-1.25972658, 0.38853909, 0.98516284, -0.09839638, -1.29890457]),\n", 88 | " array([ 0.40335207, -0.72771485, -0.83376385, -0.26617051, 1.025612 ]),\n", 89 | " array([ 1.0761965 , 0.56507028, -0.79111638, -0.77533702, -0.65562929])]]" 90 | ] 91 | }, 92 | "execution_count": 3, 93 | "metadata": {}, 94 | "output_type": "execute_result" 95 | } 96 | ], 97 | "source": [ 98 | "# generate augmented feature\n", 99 | "f_auger = Frame_Level_Augmenter(stride=2, n_frame_threshold=5)\n", 100 | "f_auger.get_aug_feat(frm_feat)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": {}, 106 | "source": [ 107 | "### 2. Augmentation for video-level features" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 4, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "# randomly generate a set of frame-level feature for mean and std calculation\n", 117 | "n_vid_feats = 1000 # number of frame-level features\n", 118 | "vid_feat_dim = 5 # dimensionality of frame feature vector\n", 119 | "vid_feats = np.random.randn(n_vid_feats, vid_feat_dim)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 5, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "# calculate mean and std\n", 129 | "mean = np.mean(vid_feats, axis=0)\n", 130 | "std = np.std(vid_feats, axis=0)" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 6, 136 | "metadata": {}, 137 | "outputs": [ 138 | { 139 | "name": "stdout", 140 | "output_type": "stream", 141 | "text": [ 142 | "[ 0.11204779 -3.61119416 -0.09824153 0.80992366 -0.1196311 ]\n", 143 | "-->\n", 144 | "[ 0.26974022 -3.61119416 -0.09824153 0.3636637 -0.1196311 ]\n" 145 | ] 146 | } 147 | ], 148 | "source": [ 149 | "v_auger = Video_Level_Augmenter(perturb_intensity=1, perturb_prob=0.5, mean=mean, std=std)\n", 150 | "vid_feat = np.random.randn(vid_feat_dim)\n", 151 | "aug_feat = v_auger.get_aug_feat(vid_feat)\n", 152 | "print vid_feat\n", 153 | "print \"-->\"\n", 154 | "print aug_feat" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [] 163 | } 164 | ], 165 | "metadata": { 166 | "kernelspec": { 167 | "display_name": "Python 2", 168 | "language": "python", 169 | "name": "python2" 170 | }, 171 | "language_info": { 172 | "codemirror_mode": { 173 | "name": "ipython", 174 | "version": 2 175 | }, 176 | "file_extension": ".py", 177 | "mimetype": "text/x-python", 178 | "name": "python", 179 | "nbconvert_exporter": "python", 180 | "pygments_lexer": "ipython2", 181 | "version": "2.7.15" 182 | } 183 | }, 184 | "nbformat": 4, 185 | "nbformat_minor": 2 186 | } 187 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import random 4 | import numpy as np 5 | from data_augmenter import Frame_Level_Augmenter, Video_Level_Augmenter 6 | 7 | import torch 8 | import torch.utils.data as data 9 | 10 | 11 | def read_videopair(input_file): 12 | print 'reading data from:', input_file 13 | videopairlist = [] 14 | reader = csv.reader(open(input_file, 'r')) 15 | for data in reader: 16 | if data[1] == "": 17 | continue 18 | video = data[0] 19 | for video2 in data[1:]: 20 | videopairlist.append((video, video2)) 21 | return videopairlist 22 | 23 | 24 | def collate_fn(data): 25 | videos_1, videos_2, inds = zip(*data) 26 | 27 | # Merge videos (convert tuple of 2D tensor to 3D tensor) 28 | videos_1 = torch.stack(videos_1, 0) 29 | videos_2 = torch.stack(videos_2, 0) 30 | 31 | return videos_1, videos_2, inds 32 | 33 | 34 | # using data argumentation on the fly (training is too slow, so we discard it) 35 | # class Dataset_frame_da(data.Dataset): 36 | 37 | # def __init__(self, data_path, frame_feats, video2frames, stride=2): 38 | # self.videopairlist = read_videopair(data_path) 39 | # self.frame_feats = frame_feats 40 | # self.video2frames = video2frames 41 | # self.sub_length = len(self.videopairlist) 42 | 43 | # if type(stride) is int: 44 | # self.length = self.sub_length * stride 45 | # else: 46 | # self.length = self.sub_length * sum(stride) 47 | 48 | # self.f_auger = Frame_Level_Augmenter(stride) 49 | 50 | 51 | # def get_aug_pool_feat(self, vidoe_id): 52 | # frm_feat = [self.frame_feats.read_one(fid) for fid in self.video2frames[vidoe_id]] 53 | # frm_feat = self.f_auger.aug_feat_choice(frm_feat) 54 | # return np.array(frm_feat).mean(axis=0) 55 | 56 | 57 | # def __getitem__(self, index): 58 | # vidoe_id_1, video_id_2 = self.videopairlist[index%self.sub_length] 59 | 60 | # video_1 = torch.Tensor(self.get_aug_pool_feat(vidoe_id_1)) 61 | # video_2 = torch.Tensor(self.get_aug_pool_feat(video_id_2)) 62 | 63 | # return video_1, video_2, index 64 | 65 | # def __len__(self): 66 | # return self.length 67 | 68 | 69 | # def get_frame_da_loader(data_path, frame_feats, opt, batch_size=100, shuffle=True, num_workers=2, video2frames=None, stride=5): 70 | # """Returns torch.utils.data.DataLoader for custom coco dataset.""" 71 | # dset = Dataset_frame_da(data_path, frame_feats, video2frames, stride) 72 | 73 | # data_loader = torch.utils.data.DataLoader(dataset=dset, 74 | # batch_size=batch_size, 75 | # shuffle=shuffle, 76 | # pin_memory=True, 77 | # collate_fn=collate_fn) 78 | # return data_loader 79 | 80 | 81 | 82 | 83 | 84 | class PrecompDataset_video_da(data.Dataset): 85 | 86 | def __init__(self, data_path, video_feats, video2subvideo, n_subs, aug_prob=0, perturb_intensity=0.01, perturb_prob=0.5, feat_path=None): 87 | self.videopairlist = read_videopair(data_path) 88 | self.video_feats = video_feats 89 | self.sub_length = len(self.videopairlist) 90 | self.video2subvideo = video2subvideo 91 | self.length = self.sub_length * n_subs 92 | self.n_subs = n_subs 93 | 94 | self.aug_prob = aug_prob 95 | self.perturb_intensity = perturb_intensity 96 | self.perturb_prob = perturb_prob 97 | if self.aug_prob > 0: 98 | self.length = int(self.length / aug_prob) 99 | self.v_auger = Video_Level_Augmenter(feat_path, video_feats, perturb_intensity=perturb_intensity, perturb_prob=perturb_prob) 100 | 101 | def __getitem__(self, index): 102 | vidoe_id_1, video_id_2 = self.videopairlist[index%self.sub_length] 103 | 104 | if self.n_subs > 1: 105 | vidoe_id_1 = random.choice(self.video2subvideo[vidoe_id_1]) 106 | video_id_2 = random.choice(self.video2subvideo[video_id_2]) 107 | 108 | video_1 = self.video_feats.read_one(vidoe_id_1) 109 | video_2 = self.video_feats.read_one(video_id_2) 110 | 111 | if self.aug_prob > 0: # Adding tiny perturbations for data argumentation 112 | if random.random() < self.aug_prob: 113 | video_1 = self.v_auger.get_aug_feat(video_1) 114 | video_2 = self.v_auger.get_aug_feat(video_2) 115 | 116 | video_1 = torch.Tensor(video_1) 117 | video_2 = torch.Tensor(video_2) 118 | 119 | return video_1, video_2, index 120 | 121 | def __len__(self): 122 | return self.length 123 | 124 | 125 | 126 | 127 | def get_video_da_loader(data_path, video_feats, opt, batch_size=100, shuffle=True, num_workers=2, video2subvideo=None, n_subs=1, feat_path=""): 128 | dset = PrecompDataset_video_da(data_path, video_feats, video2subvideo, n_subs, 129 | aug_prob=opt.aug_prob, perturb_intensity=opt.perturb_intensity, perturb_prob=opt.perturb_prob, feat_path=feat_path) 130 | 131 | data_loader = torch.utils.data.DataLoader(dataset=dset, 132 | batch_size=batch_size, 133 | shuffle=shuffle, 134 | pin_memory=True, 135 | collate_fn=collate_fn) 136 | return data_loader 137 | 138 | 139 | 140 | 141 | 142 | # for validation and test 143 | class FeatDataset(data.Dataset): 144 | """ 145 | Load precomputed video features 146 | """ 147 | 148 | def __init__(self, videolist, video_feats): 149 | self.video_feats = video_feats 150 | self.videolist = videolist 151 | self.length = len(videolist) 152 | 153 | def __getitem__(self, index): 154 | vidoe_id = self.videolist[index] 155 | video = torch.Tensor(self.video_feats.read_one(vidoe_id)) 156 | return video, vidoe_id, index 157 | 158 | def __len__(self): 159 | return self.length 160 | 161 | 162 | def collate_fn_feat(data): 163 | 164 | videos, ids, idxs = zip(*data) 165 | 166 | # Merge videos (convert tuple of 2D tensor to 3D tensor) 167 | videos = torch.stack(videos, 0) 168 | 169 | return videos, ids, idxs 170 | 171 | 172 | def get_feat_loader(videolist, video_feats, batch_size=100, shuffle=False, num_workers=2): 173 | 174 | dset = FeatDataset(videolist, video_feats) 175 | data_loader = torch.utils.data.DataLoader(dataset=dset, 176 | batch_size=batch_size, 177 | shuffle=shuffle, 178 | pin_memory=True, 179 | collate_fn=collate_fn_feat) 180 | return data_loader -------------------------------------------------------------------------------- /data_augmenter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | from simpleknn.bigfile import BigFile 5 | 6 | # Augmentation for frame-level features 7 | class Frame_Level_Augmenter(object): 8 | 9 | def __init__(self, stride=2, n_frame_threshold=20): 10 | self.stride = stride 11 | self.n_frame_threshold = n_frame_threshold 12 | 13 | def get_aug_index(self, n_vecs): 14 | if type(self.stride) is int: 15 | self.stride = [self.stride] 16 | 17 | aug_index = [] 18 | # aug_index.append(range(n_vecs)) # keep original frame level features 19 | if n_vecs < self.n_frame_threshold: 20 | return aug_index 21 | for stride in self.stride: 22 | for i in range(stride): 23 | sub_index = range(n_vecs)[i::stride] 24 | aug_index.append(sub_index) 25 | return aug_index 26 | 27 | def get_aug_feat(self, frm_feat): 28 | n_vecs = len(frm_feat) 29 | aug_index = self.get_aug_index(n_vecs) 30 | 31 | aug_feats = [] 32 | for index in aug_index: 33 | org_feat = [frm_feat[x] for x in index] 34 | aug_feats.append(org_feat) 35 | return aug_feats 36 | 37 | 38 | def aug_index_choice(self, n_vecs): 39 | return random.choice(self.get_aug_index(n_vecs)) 40 | 41 | 42 | def aug_feat_choice(self, frm_feat): 43 | n_vecs = len(frm_feat) 44 | aug_index_choice = self.aug_index_choice(n_vecs) 45 | aug_feat = [frm_feat[x] for x in aug_index_choice] 46 | return aug_feat 47 | 48 | 49 | 50 | 51 | # Augmentation for video-level features 52 | class Video_Level_Augmenter(object): 53 | 54 | def __init__(self, feat_path=None, feat_reader=None, perturb_intensity=1, perturb_prob=0.5, n_sample=10000, step_size=500, mean=None, std=None): 55 | self.feat_reader = feat_reader 56 | self.perturb_intensity = perturb_intensity 57 | self.perturb_prob = perturb_prob 58 | self.step_size = step_size 59 | 60 | if mean is None or std is None: 61 | self.n_dims = feat_reader.ndims 62 | mean_std_file = os.path.join(feat_path, "mean_std.txt") 63 | if not os.path.exists(mean_std_file): 64 | # calculate the mean and std 65 | print "calculating the mean and std ..." 66 | if len(feat_reader.names) <= n_sample: 67 | self.sampled_videos = feat_reader.names 68 | else: 69 | self.sampled_videos = random.sample(feat_reader.names, n_sample) 70 | self.mean, self.std = self.__get_mean_std() 71 | if not os.path.exists(mean_std_file): 72 | with open(mean_std_file, 'w') as fout: 73 | fout.write(" ".join(map(str, self.mean)) + "\n") 74 | fout.write(" ".join(map(str, self.std)) + "\n") 75 | else: 76 | with open(mean_std_file) as fin: 77 | self.mean = map(float, fin.readline().strip().split(" ")) 78 | self.std = map(float, fin.readline().strip().split(" ")) 79 | else: 80 | self.n_dims = len(mean) 81 | self.mean = mean 82 | self.std = std 83 | 84 | # initialize mask 85 | self.__init_mask() 86 | 87 | 88 | def __get_mean_std(self): 89 | mean = [] 90 | std = [] 91 | for i in range(0, self.n_dims, self.step_size): 92 | vec_list = [] 93 | for video in self.sampled_videos: 94 | feat_vec = self.feat_reader.read_one(video) 95 | # using the subvec to accelerate calculation 96 | vec_list.append(feat_vec[i:min(self.step_size+i, self.n_dims)]) 97 | mean.extend(np.mean(vec_list, 0)) 98 | std.extend(np.std(vec_list, 0)) 99 | return np.array(mean), np.array(std) 100 | 101 | def __init_mask(self): 102 | self.mask = np.zeros(self.n_dims) 103 | self.mask[:int(self.n_dims*self.perturb_prob)] = 1 104 | 105 | def __shuffle_mask(self): 106 | random.shuffle(self.mask) 107 | 108 | def get_aug_feat(self, vid_feat): 109 | self.__shuffle_mask() 110 | perturbation = (np.random.randn(self.n_dims)*self.std + self.mean) * self.perturb_intensity * self.mask 111 | aug_feat = vid_feat + perturbation 112 | return aug_feat 113 | 114 | 115 | 116 | 117 | if __name__ == "__main__": 118 | 119 | # test frame level augmentation 120 | feats = np.random.randn(11, 4) 121 | n_vecs = feats.shape[0] 122 | for stride in [2, [2,3]]: 123 | f_auger = Frame_Level_Augmenter(stride) 124 | print f_auger.get_aug_index(n_vecs) 125 | # print f_auger.get_aug_feat(feats) 126 | print [len(a) for a in f_auger.get_aug_feat(feats)] 127 | 128 | # test video level augmentation 129 | rootpath = '/home/daniel/VisualSearch/hulu' 130 | collection = 'track_1_shows' 131 | feature = 'c3d-pool5' 132 | feat_path = os.path.join(rootpath, collection, "FeatureData", feature) 133 | feat_reader = BigFile(feat_path) 134 | 135 | v_auger = Video_Level_Augmenter(feat_path, feat_reader, perturb_intensity=1, perturb_prob=0.5) 136 | vid_feat = feat_reader.read_one(random.choice(feat_reader.names)) 137 | aug_feat = v_auger.get_aug_feat(vid_feat) -------------------------------------------------------------------------------- /do_all_frame_level.sh: -------------------------------------------------------------------------------- 1 | collection=$1 2 | feature=$2 3 | overwrite=0 4 | 5 | # Augmentation for frame-level features 6 | stride=$3 7 | 8 | # Loss 9 | loss=$4 10 | 11 | # Generate augmented frame-level features and operate mean pooling to obtain video-level features in advance 12 | python gene_aug_feat.py --collection $collection --feature $feature --stride $stride --overwrite $overwrite 13 | 14 | gpu=0 15 | CUDA_VISIBLE_DEVICES=$gpu python train.py --collection $collection --feature $feature --stride $stride --loss $loss --overwrite $overwrite -------------------------------------------------------------------------------- /do_all_video_level.sh: -------------------------------------------------------------------------------- 1 | collection=$1 2 | feature=$2 3 | overwrite=0 4 | 5 | # Augmentation for video-level features 6 | aug_prob=0.5 7 | 8 | # Loss 9 | loss=$3 10 | 11 | gpu=0 12 | CUDA_VISIBLE_DEVICES=$gpu python train.py --collection $collection --feature $feature --aug_prob $aug_prob --loss $loss --overwrite $overwrite -------------------------------------------------------------------------------- /do_feature_convert.sh: -------------------------------------------------------------------------------- 1 | for collection in track_1_shows track_2_movies 2 | do 3 | for feat in inception-pool3 c3d-pool5 4 | do 5 | python feature_convert.py $collection $feat 6 | done 7 | done -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import csv 5 | import pickle 6 | 7 | import numpy 8 | import time 9 | import numpy as np 10 | import torch 11 | from collections import OrderedDict 12 | 13 | from utils.common import makedirsforfile, checkToSkip 14 | from simpleknn.bigfile import BigFile 15 | 16 | class AverageMeter(object): 17 | """Computes and stores the average and current value""" 18 | 19 | def __init__(self): 20 | self.reset() 21 | 22 | def reset(self): 23 | self.val = 0 24 | self.avg = 0 25 | self.sum = 0 26 | self.count = 0 27 | 28 | def update(self, val, n=0): 29 | self.val = val 30 | self.sum += val * n 31 | self.count += n 32 | self.avg = self.sum / (.0001 + self.count) 33 | 34 | def __str__(self): 35 | """String representation for logging 36 | """ 37 | # for values that should be recorded exactly e.g. iteration number 38 | if self.count == 0: 39 | return str(self.val) 40 | # for stats 41 | return '%.4f (%.4f)' % (self.val, self.avg) 42 | 43 | 44 | class LogCollector(object): 45 | """A collection of logging objects that can change from train to val""" 46 | 47 | def __init__(self): 48 | # to keep the order of logged variables deterministic 49 | self.meters = OrderedDict() 50 | 51 | def update(self, k, v, n=0): 52 | # create a new meter if previously not recorded 53 | if k not in self.meters: 54 | self.meters[k] = AverageMeter() 55 | self.meters[k].update(v, n) 56 | 57 | def __str__(self): 58 | """Concatenate the meters in one log line 59 | """ 60 | s = '' 61 | for i, (k, v) in enumerate(self.meters.iteritems()): 62 | if i > 0: 63 | s += ' ' 64 | s += k + ' ' + str(v) 65 | return s 66 | 67 | def tb_log(self, tb_logger, prefix='', step=None): 68 | """Log using tensorboard 69 | """ 70 | for k, v in self.meters.iteritems(): 71 | tb_logger.log_value(prefix + k, v.val, step=step) 72 | 73 | 74 | def encode_data(model, data_loader, log_step=10, logging=print): 75 | """Encode all images and captions loadable by `data_loader` 76 | """ 77 | batch_time = AverageMeter() 78 | 79 | # switch to evaluate mode 80 | model.val_start() 81 | 82 | end = time.time() 83 | 84 | # numpy array to keep all the embeddings 85 | video_embs = None 86 | video_ids_list = [] 87 | for i, (videos, ids, idxs) in enumerate(data_loader): 88 | 89 | # compute the embeddings 90 | video_emb = model.forward_emb(videos, volatile=True) 91 | 92 | # initialize the numpy arrays given the size of the embeddings 93 | if video_embs is None: 94 | video_embs = np.zeros((len(data_loader.dataset), video_emb.size(1))) 95 | 96 | # preserve the embeddings by copying from gpu and converting to numpy 97 | video_embs[list(idxs)] = video_emb.data.cpu().numpy().copy() 98 | video_ids_list.extend(ids) 99 | 100 | # measure elapsed time 101 | batch_time.update(time.time() - end) 102 | end = time.time() 103 | 104 | if i % log_step == 0: 105 | logging('Test: [{0}/{1}]\t' 106 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 107 | .format( 108 | i, len(data_loader), batch_time=batch_time)) 109 | del videos 110 | 111 | return video_embs, video_ids_list 112 | 113 | 114 | def cal_rel_index(rele_file): 115 | with open(rele_file,'r') as csvfile: 116 | data=[] 117 | csv_reader=csv.reader(csvfile) 118 | for line in csv_reader: 119 | data.append(line) 120 | return data 121 | 122 | 123 | def re_cal_scores(scores, rel_index, n, sumofdata): 124 | 125 | scores_list = [] 126 | for i, index in enumerate(rel_index): 127 | data = [] 128 | data.append(scores[i]) 129 | if i < sumofdata: 130 | for j in index[:n]: 131 | data.append(scores[int(j)]) 132 | scores_list.append(sum(data) / len(data)) 133 | else: 134 | scores_list.append(data) 135 | 136 | # test video has no available relations 137 | scores_list.append(scores[len(rel_index): len(scores)]) 138 | 139 | return scores_list 140 | 141 | 142 | 143 | def score2result(scores, test_video_list, cand_video_list, rel_index, n): 144 | video2predrank = {} 145 | n_rows, n_column = scores.shape 146 | assert n_rows == len(test_video_list) 147 | assert n_column == len(cand_video_list) 148 | for i, test_video in enumerate(test_video_list): 149 | score_list = scores[i] 150 | if rel_index is not None: 151 | sumofdata = n_column - n_rows 152 | score_list = re_cal_scores(score_list, rel_index, n, sumofdata) 153 | cand_video_score_list = zip(cand_video_list, score_list) 154 | sorted_cand_video_score_list = sorted(cand_video_score_list, key=lambda v:v[1], reverse=True) 155 | #video2predrank[test_video] = [x[0] for x in sorted_cand_video_score_list] 156 | predrank = [x[0] for x in sorted_cand_video_score_list] 157 | predrank.remove(test_video) 158 | video2predrank[test_video] = predrank 159 | return video2predrank 160 | 161 | 162 | def score2result_fusion(scores, test_video_list, cand_video_list): 163 | video2predrank = {} 164 | n_rows, n_column = scores.shape 165 | assert n_rows == len(test_video_list) 166 | assert n_column == len(cand_video_list) 167 | for i, test_video in enumerate(test_video_list): 168 | score_list = scores[i] 169 | cand_video_score_list = zip(cand_video_list, score_list) 170 | sorted_cand_video_score_list = sorted(cand_video_score_list, key=lambda v:v[1], reverse=True) 171 | predrank = [x[0] for x in sorted_cand_video_score_list] 172 | predrank.remove(test_video) 173 | video2predrank[test_video] = predrank 174 | return video2predrank 175 | 176 | 177 | def do_predict(test_video_emd, test_video_list, cand_video_emd, cand_video_list, rel_index=None, n=5, output_dir=None, overwrite=0, no_imgnorm=False): 178 | 179 | if no_imgnorm: 180 | scores = cal_score(test_video_emd, cand_video_emd, measure='cosine') 181 | else: 182 | scores = cal_score(test_video_emd, cand_video_emd, measure='dot') 183 | 184 | video2predrank = score2result(scores, test_video_list, cand_video_list, rel_index, n) 185 | 186 | if output_dir is not None: 187 | output_file = os.path.join(output_dir, 'pred_scores_matrix.pth.tar') 188 | if checkToSkip(output_file, overwrite): 189 | sys.exit(0) 190 | makedirsforfile(output_file) 191 | torch.save({'scores': scores, 'test_videos': test_video_list, 'cand_videos': cand_video_list}, output_file) 192 | print("write score matrix into: " + output_file) 193 | 194 | return video2predrank 195 | 196 | 197 | # def cal_error(images, captions, measure='cosine', n_caption=2): 198 | # """ 199 | # Images->Text (Image Annotation) 200 | # Images: (5N, K) matrix of images 201 | # Captions: (5N, K) matrix of captions 202 | # """ 203 | # idx = range(0, images.shape[0], n_caption) 204 | # im = images[idx, :] 205 | # if measure == 'cosine': 206 | # errors = -1*numpy.dot(captions, im.T) 207 | 208 | # return errors 209 | 210 | 211 | def cal_score(video_1, video_2, measure='cosine'): 212 | if measure == 'cosine': 213 | # l2 normalization 214 | import sklearn.preprocessing as preprocessing 215 | video_1 = preprocessing.normalize(video_1, norm='l2') 216 | video_2 = preprocessing.normalize(video_2, norm='l2') 217 | scores = numpy.dot(video_1, video_2.T) 218 | elif measure == 'dot': 219 | scores = numpy.dot(video_1, video_2.T) 220 | 221 | return scores 222 | -------------------------------------------------------------------------------- /feature_convert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | 5 | from utils.generic_utils import Progbar 6 | from utils.common import ROOT_PATH, checkToSkip, makedirsforfile 7 | from simpleknn.txt2bin import process as text2bin 8 | 9 | 10 | def process(options, collection, feat_name): 11 | overwrite = options.overwrite 12 | rootpath = options.rootpath 13 | 14 | feature_dir = os.path.join(rootpath, collection, 'feature') 15 | resdir = os.path.join(rootpath, collection, 'FeatureData', feat_name) 16 | 17 | train_csv = os.path.join(rootpath, collection, 'split', 'train.csv') 18 | val_csv = os.path.join(rootpath, collection, 'split', 'val.csv') 19 | test_csv = os.path.join(rootpath, collection, 'split', 'test.csv') 20 | 21 | train_val_test_set = [] 22 | train_val_test_set.extend(map(str.strip, open(train_csv).readlines())) 23 | train_val_test_set.extend(map(str.strip, open(val_csv).readlines())) 24 | train_val_test_set.extend(map(str.strip, open(test_csv).readlines())) 25 | 26 | target_feat_file = os.path.join(resdir, 'id.feature.txt') 27 | if checkToSkip(os.path.join(resdir,'feature.bin'), overwrite): 28 | sys.exit(0) 29 | makedirsforfile(target_feat_file) 30 | 31 | frame_count = [] 32 | print 'Processing %s - %s' % (collection, feat_name) 33 | with open(target_feat_file, 'w') as fw_feat: 34 | progbar = Progbar(len(train_val_test_set)) 35 | for d in train_val_test_set: 36 | feat_file = os.path.join(feature_dir, d, '%s-%s.npy'%(d,feat_name)) 37 | feats = np.load(feat_file) 38 | if len(feats.shape) == 1: # video level feature 39 | dim = feats.shape[0] 40 | fw_feat.write('%s %s\n' % (d, ' '.join(['%.6f'%x for x in feats]))) 41 | elif len(feats.shape) == 2: # frame level feature 42 | frames, dim = feats.shape 43 | frame_count.append(frames) 44 | for i in range(frames): 45 | frame_id = d+'_'+str(i) 46 | fw_feat.write('%s %s\n' % (frame_id, ' '.join(['%.6f'%x for x in feats[i]]))) 47 | progbar.add(1) 48 | 49 | text2bin(dim, [target_feat_file], resdir, 1) 50 | os.system('rm %s' % target_feat_file) 51 | 52 | 53 | def main(argv=None): 54 | if argv is None: 55 | argv = sys.argv[1:] 56 | 57 | from optparse import OptionParser 58 | parser = OptionParser(usage="""usage: %prog [options] collection featname""") 59 | parser.add_option("--overwrite", default=0, type="int", help="overwrite existing file (default=0)") 60 | parser.add_option("--rootpath", default=ROOT_PATH, type="string", help="rootpath (default: %s)" % ROOT_PATH) 61 | 62 | (options, args) = parser.parse_args(argv) 63 | if len(args) < 2: 64 | parser.print_help() 65 | return 1 66 | 67 | return process(options, args[0], args[1]) 68 | 69 | if __name__ == '__main__': 70 | sys.exit(main()) 71 | 72 | -------------------------------------------------------------------------------- /fig/frame_aug.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danieljf24/cbvr/f604588a4d25523cc9a667f620d364793924e877/fig/frame_aug.jpg -------------------------------------------------------------------------------- /fig/video_aug.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danieljf24/cbvr/f604588a4d25523cc9a667f620d364793924e877/fig/video_aug.jpg -------------------------------------------------------------------------------- /gene_aug_feat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | 5 | from utils.generic_utils import Progbar 6 | from utils.common import ROOT_PATH, checkToSkip, makedirsforfile 7 | from simpleknn.bigfile import BigFile 8 | from simpleknn.txt2bin import process as text2bin 9 | from data_augmenter import Frame_Level_Augmenter 10 | 11 | 12 | def process(opt): 13 | 14 | rootpath = opt.rootpath 15 | collection = opt.collection 16 | feature = opt.feature 17 | stride = opt.stride 18 | overwrite = opt.overwrite 19 | pooling_style = opt.pooling_style 20 | 21 | 22 | feat_path = os.path.join(rootpath, collection, "FeatureData", feature) 23 | 24 | output_dir = os.path.join(rootpath, collection, "FeatureData", '%s-' % pooling_style + feature + "-stride%s" % stride) 25 | feat_combined_file = os.path.join(output_dir, "id_feat.txt") 26 | if checkToSkip(os.path.join(output_dir, "feature.bin"), overwrite): 27 | sys.exit(0) 28 | makedirsforfile(feat_combined_file) 29 | 30 | print "Generate augmented frame-level features and operate mean pooling..." 31 | 32 | feat_data = BigFile(feat_path) 33 | video2fmnos = {} 34 | for frame_id in feat_data.names: 35 | data = frame_id.strip().split("_") 36 | video_id = '_'.join(data[:-1]) 37 | fm_no = data[-1] 38 | video2fmnos.setdefault(video_id, []).append(int(fm_no)) 39 | 40 | video2frames = {} 41 | for video_id, fmnos in video2fmnos.iteritems(): 42 | for fm_no in sorted(fmnos): 43 | video2frames.setdefault(video_id, []).append(video_id + "_" + str(fm_no)) 44 | 45 | 46 | stride = map(int, stride.strip().split('-')) 47 | f_auger = Frame_Level_Augmenter(stride) 48 | 49 | video2subvideo = {} 50 | fout = open(feat_combined_file, 'w') 51 | progbar = Progbar(len(video2frames)) 52 | for video in video2frames: 53 | frame_ids = video2frames[video] 54 | 55 | # output the while video level feature 56 | video2subvideo.setdefault(video, []).append(video) 57 | reanme, feats = feat_data.read(frame_ids) 58 | if pooling_style == 'avg': 59 | feat_vec = np.array(feats).mean(axis=0) 60 | elif pooling_style == 'max': 61 | feat_vec = np.array(feats).max(axis=0) 62 | fout.write(video + " " + " ".join(map(str,feat_vec)) + '\n') 63 | 64 | 65 | # output the sub video level feature 66 | counter = 0 67 | aug_index = f_auger.get_aug_index(len(frame_ids)) # get augmented frame list 68 | for sub_index in aug_index: 69 | sub_frames = [frame_ids[idx] for idx in sub_index] 70 | reanme, sub_feats = feat_data.read(sub_frames) 71 | 72 | if pooling_style == 'avg': 73 | feat_vec = np.array(sub_feats).mean(axis=0) 74 | elif pooling_style == 'max': 75 | feat_vec = np.array(sub_feats).max(axis=0) 76 | 77 | video2subvideo.setdefault(video, []).append(video + "_sub%d" % counter) 78 | fout.write(video + "_sub%d" % counter + " " + " ".join(map(str,feat_vec)) + '\n') 79 | counter += 1 80 | progbar.add(1) 81 | 82 | fout.close() 83 | 84 | f = open(os.path.join(output_dir, "video2subvideo.txt"),'w') 85 | f.write(str(video2subvideo)) 86 | f.close() 87 | 88 | text2bin(len(feat_vec), [feat_combined_file], output_dir, 1) 89 | os.system('rm %s' % feat_combined_file) 90 | 91 | 92 | 93 | def main(argv=None): 94 | if argv is None: 95 | argv = sys.argv[1:] 96 | 97 | from optparse import OptionParser 98 | parser = OptionParser(usage="""usage: %prog [options]""") 99 | parser.add_option("--rootpath", default=ROOT_PATH, type="string", help="rootpath (default: %s)" % ROOT_PATH) 100 | parser.add_option("--collection", default="", type="string", help="collection name") 101 | parser.add_option("--feature", default="", type="string", help="feature name") 102 | parser.add_option("--stride", default="2", type="str", help="stride for frame-level data augmentation") 103 | parser.add_option("--overwrite", default=0, type="int", help="overwrite existing file (default=0)") 104 | parser.add_option("--pooling_style", default='avg', type="str", help="pooling style: avg(average pooling), max(max pooling)") 105 | 106 | (options, args) = parser.parse_args(argv) 107 | return process(options) 108 | 109 | if __name__ == "__main__": 110 | sys.exit(main()) 111 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init 4 | import torchvision.models as models 5 | from torch.autograd import Variable 6 | import torch.backends.cudnn as cudnn 7 | from torch.nn.utils.clip_grad import clip_grad_norm 8 | import numpy as np 9 | from collections import OrderedDict 10 | 11 | 12 | def l2norm(X): 13 | """L2-normalize columns of X 14 | """ 15 | norm = torch.pow(X, 2).sum(dim=1, keepdim=True).sqrt() 16 | X = torch.div(X, norm) 17 | return X 18 | 19 | ################################################### 20 | ########## Model Structure ########## 21 | ################################################### 22 | class BaseModel(nn.Module): 23 | def __init__(self): 24 | super(BaseModel, self).__init__() 25 | 26 | def load_state_dict(self, state_dict): 27 | """Copies parameters. overwritting the default one to 28 | accept state_dict from Full model 29 | """ 30 | own_state = self.state_dict() 31 | new_state = OrderedDict() 32 | for name, param in state_dict.items(): 33 | if name in own_state: 34 | new_state[name] = param 35 | 36 | super(BaseModel, self).load_state_dict(new_state) 37 | 38 | 39 | 40 | class EncoderVideo(BaseModel): 41 | 42 | def __init__(self, opt): 43 | super(EncoderVideo, self).__init__() 44 | self.embed_size = opt.embed_size 45 | self.no_imgnorm = opt.no_imgnorm 46 | 47 | self.fc = nn.Linear(opt.feature_dim, opt.embed_size) 48 | 49 | 50 | self.init_weights() 51 | 52 | def init_weights(self): 53 | """Xavier initialization for the fully connected layer 54 | """ 55 | r = np.sqrt(6.) / np.sqrt(self.fc.in_features + 56 | self.fc.out_features) 57 | self.fc.weight.data.uniform_(-r, r) 58 | self.fc.bias.data.fill_(0) 59 | 60 | 61 | def forward(self, images): 62 | """Extract image feature vectors.""" 63 | # assuming that the precomputed features are already l2-normalized 64 | 65 | features = self.fc(images) 66 | if not self.no_imgnorm: 67 | features = l2norm(features) 68 | 69 | return features 70 | 71 | 72 | 73 | def cosine_sim(im, s): 74 | """Cosine similarity between all the image and sentence pairs 75 | """ 76 | return im.mm(s.t()) 77 | 78 | 79 | 80 | class ContrastiveLoss(nn.Module): 81 | """ 82 | Compute contrastive loss 83 | """ 84 | 85 | def __init__(self, margin=0, max_violation=False, cost_style='sum'): 86 | super(ContrastiveLoss, self).__init__() 87 | self.margin = margin 88 | self.cost_style = cost_style 89 | self.max_violation = max_violation 90 | 91 | def forward(self, scores): 92 | # compute image-sentence score matrix 93 | # scores = self.sim(im, s) 94 | diagonal = scores.diag().view(scores.size(0), 1) 95 | d1 = diagonal.expand_as(scores) 96 | d2 = diagonal.t().expand_as(scores) 97 | 98 | # compare every diagonal score to scores in its column 99 | # caption retrieval 100 | cost_s = (self.margin + scores - d1).clamp(min=0) 101 | # compare every diagonal score to scores in its row 102 | # image retrieval 103 | cost_im = (self.margin + scores - d2).clamp(min=0) 104 | 105 | # clear diagonals 106 | mask = torch.eye(scores.size(0)) > .5 107 | I = Variable(mask) 108 | if torch.cuda.is_available(): 109 | I = I.cuda() 110 | cost_s = cost_s.masked_fill_(I, 0) 111 | cost_im = cost_im.masked_fill_(I, 0) 112 | 113 | # keep the maximum violating negative for each query 114 | if self.max_violation: 115 | cost_s = cost_s.max(1)[0] 116 | cost_im = cost_im.max(0)[0] 117 | 118 | if self.cost_style == 'sum': 119 | cost = cost_s.sum() + cost_im.sum() 120 | elif self.cost_style == 'mean': 121 | cost = cost_s.mean() + cost_im.mean() 122 | return cost 123 | 124 | 125 | class IrrelevantLoss(nn.Module): 126 | """ 127 | Compute contrastive loss 128 | """ 129 | 130 | def __init__(self, margin, cost_style='sum'): 131 | super(IrrelevantLoss, self).__init__() 132 | self.margin = margin 133 | self.cost_style = cost_style 134 | 135 | def forward(self, scores): 136 | 137 | # clear diagonals 138 | mask = torch.eye(scores.size(0)) > .5 139 | I = Variable(mask) 140 | if torch.cuda.is_available(): 141 | I = I.cuda() 142 | scores = scores.masked_fill_(I, 0) 143 | 144 | cost = (scores - self.margin).clamp(min=0) 145 | 146 | if self.cost_style == 'sum': 147 | cost = cost.sum() 148 | elif self.cost_style == 'mean': 149 | cost = cost.mean() 150 | return cost 151 | 152 | 153 | 154 | class ReLearning(object): 155 | 156 | def __init__(self, opt): 157 | self.grad_clip = opt.grad_clip 158 | self.img_enc = EncoderVideo(opt) 159 | self.loss = opt.loss 160 | 161 | 162 | if opt.measure == 'cosine': 163 | self.sim = cosine_sim 164 | else: 165 | print('measure %s is not supported') 166 | 167 | print(self.img_enc) 168 | if torch.cuda.is_available(): 169 | self.img_enc.cuda() 170 | cudnn.benchmark = True 171 | 172 | # Loss and Optimizer 173 | if opt.loss == 'trl': 174 | self.criterion = ContrastiveLoss(margin=opt.margin, 175 | max_violation=opt.max_violation, 176 | cost_style=opt.cost_style) 177 | 178 | 179 | elif opt.loss == 'netrl': 180 | self.criterion_1 = ContrastiveLoss(margin=opt.margin, 181 | max_violation=opt.max_violation, 182 | cost_style=opt.cost_style) 183 | self.criterion_2 = IrrelevantLoss(margin=opt.margin_irel, 184 | cost_style=opt.cost_style) 185 | self.alpha = opt.alpha 186 | 187 | 188 | params = list(self.img_enc.parameters()) 189 | self.params = params 190 | 191 | if opt.optimizer == 'adam': 192 | self.optimizer = torch.optim.Adam(params, lr=opt.learning_rate) 193 | elif opt.optimizer == 'rmsprop': 194 | self.optimizer = torch.optim.RMSprop(params, lr=opt.learning_rate) 195 | else: 196 | print('optimizer %s is not supported' % self.optimizer) 197 | 198 | self.Eiters = 0 199 | 200 | def state_dict(self): 201 | state_dict = [self.img_enc.state_dict()] 202 | return state_dict 203 | 204 | def load_state_dict(self, state_dict): 205 | self.img_enc.load_state_dict(state_dict[0]) 206 | 207 | def train_start(self): 208 | """switch to train mode 209 | """ 210 | self.img_enc.train() 211 | 212 | def val_start(self): 213 | """switch to evaluate mode 214 | """ 215 | self.img_enc.eval() 216 | 217 | 218 | def forward_emb(self, videos, volatile=False): 219 | """Compute the image and caption embeddings 220 | """ 221 | # Set mini-batch dataset 222 | videos = Variable(videos, volatile=volatile) 223 | if torch.cuda.is_available(): 224 | videos = videos.cuda() 225 | 226 | # Forward 227 | videos_emb = self.img_enc(videos) 228 | return videos_emb 229 | 230 | 231 | def forward_loss(self, videos_emb_1, videos_emb_2, **kwargs): 232 | """Compute the loss given pairs of image and caption embeddings 233 | """ 234 | scores = self.sim(videos_emb_1, videos_emb_2) 235 | 236 | if self.loss=='trl': 237 | loss = self.criterion(scores) 238 | 239 | elif self.loss == 'netrl': 240 | loss_1 = self.criterion_1(scores) 241 | loss_2 = self.criterion_2(scores) 242 | # print loss_1, loss_2 243 | loss = loss_1 + self.alpha * loss_2 244 | 245 | # loss = self.criterion(videos_emb_1, videos_emb_2) 246 | # self.logger.update('Le', loss.data[0], videos_emb_1.size(0)) 247 | return loss 248 | 249 | def train_emb(self, videos_1, videos_2, ids=None, *args): 250 | """One training step given images and captions. 251 | """ 252 | self.Eiters += 1 253 | 254 | # zero the gradient buffers 255 | self.optimizer.zero_grad() 256 | 257 | # compute the embeddings 258 | videos_emb_1 = self.forward_emb(videos_1) 259 | videos_emb_2 = self.forward_emb(videos_2) 260 | 261 | # measure accuracy and record loss 262 | # self.optimizer.zero_grad() 263 | loss = self.forward_loss(videos_emb_1, videos_emb_2) 264 | # loss_value = loss.item() 265 | if torch.__version__ in ['1.0.0', '1.1.0','1.0.1'] : 266 | loss_value = loss.item() 267 | else: 268 | loss_value = loss.data[0] 269 | 270 | # compute gradient and do SGD step 271 | loss.backward() 272 | if self.grad_clip > 0: 273 | clip_grad_norm(self.params, self.grad_clip) 274 | self.optimizer.step() 275 | 276 | return videos_emb_1.size(0), loss_value -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.22 2 | scipy==1.1.0 3 | six==1.11.0 4 | tensorboard==1.8.0 5 | tensorboard-logger==0.1.0 6 | torch==0.3.1 7 | torchvision==0.2.1 8 | 9 | -------------------------------------------------------------------------------- /simpleknn/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 li-xirong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /simpleknn/README.md: -------------------------------------------------------------------------------- 1 | simpleknn 2 | ========= 3 | 4 | Find k nearest neighbors by an exhaustive search, used for content-based image retrieval 5 | 6 | 7 | Given an image collection say ``toydata`` with ``n`` image, we presume that a specific visual feature, named as ``f1``, has been extracted and stored as ``toydata/FeatureData/f1/id.feature.txt``, where each line starts with a unique image id followed by its feature vector. Given a test image and its feature vector, simpleknn finds the ``k`` nearest neighbors from ``toydata`` by computing a given distance, namely ``l1`` or ``l2``, between the feature vectors. 8 | 9 | See ``test.sh`` for usage. 10 | -------------------------------------------------------------------------------- /simpleknn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danieljf24/cbvr/f604588a4d25523cc9a667f620d364793924e877/simpleknn/__init__.py -------------------------------------------------------------------------------- /simpleknn/bigfile.py: -------------------------------------------------------------------------------- 1 | import os, sys, array 2 | import numpy as np 3 | 4 | class BigFile: 5 | 6 | def __init__(self, datadir): 7 | self.nr_of_images, self.ndims = map(int, open(os.path.join(datadir,'shape.txt')).readline().split()) 8 | id_file = os.path.join(datadir, "id.txt") 9 | self.names = open(id_file).read().strip().split() 10 | assert(len(self.names) == self.nr_of_images) 11 | self.name2index = dict(zip(self.names, range(self.nr_of_images))) 12 | self.binary_file = os.path.join(datadir, "feature.bin") 13 | print ("[%s] %dx%d instances loaded from %s" % (self.__class__.__name__, self.nr_of_images, self.ndims, datadir)) 14 | 15 | 16 | def read(self, requested, isname=True): 17 | requested = set(requested) 18 | if isname: 19 | index_name_array = [(self.name2index[x], x) for x in requested if x in self.name2index] 20 | else: 21 | assert(min(requested)>=0) 22 | assert(max(requested)= self.nr_of_images: 89 | self.close() 90 | raise StopIteration 91 | else: 92 | res = array.array('f') 93 | res.fromfile(self.fr, self.ndims) 94 | _id = self.names[self.current] 95 | self.current += 1 96 | return _id, res.tolist() 97 | 98 | 99 | if __name__ == '__main__': 100 | bigfile = BigFile('toydata/FeatureData/f1') 101 | 102 | imset = str.split('b z a a b c') 103 | renamed, vectors = bigfile.read(imset) 104 | 105 | 106 | for name,vec in zip(renamed, vectors): 107 | print name, vec 108 | 109 | -------------------------------------------------------------------------------- /simpleknn/build.sh: -------------------------------------------------------------------------------- 1 | cd cpp 2 | make clean 3 | make 4 | cd .. 5 | 6 | -------------------------------------------------------------------------------- /simpleknn/cpp/Makefile: -------------------------------------------------------------------------------- 1 | CXX ?= g++ 2 | CFLAGS = -Wall -Wconversion -O3 -fPIC -ftree-vectorize -msse3 -ffast-math -fassociative-math -I/usr/local/Cellar/boost/1.58.0/include/ -static 3 | 4 | SEARCHLIB_NAME = libsearch 5 | 6 | 7 | all: libsearch 8 | 9 | 10 | libsearch: search.o 11 | $(CXX) -shared -dynamiclib search.o -o $(SEARCHLIB_NAME).so 12 | rm -f *.o 13 | 14 | search.o: search.cpp search.h 15 | $(CXX) $(CFLAGS) -c search.cpp 16 | 17 | clean: 18 | rm -f *~ *.o *.so 19 | -------------------------------------------------------------------------------- /simpleknn/cpp/Makefile.win64: -------------------------------------------------------------------------------- 1 | #You must ensure nmake.exe, cl.exe, link.exe are in system path. 2 | #VCVARS32.bat 3 | #Under dosbox prompt 4 | #nmake -f Makefile.win 5 | 6 | ########################################## 7 | CXX = cl.exe 8 | CFLAGS = -nologo -O2 -EHsc -IC:\local\boost_1_55_0 -D __WIN32__ -D _CRT_SECURE_NO_DEPRECATE 9 | TARGET = win64 10 | 11 | all: libsearch 12 | 13 | libsearch: search.cpp search.h search.def 14 | $(CXX) $(CFLAGS) -LD search.cpp -Fe..\lib\$(TARGET)\libsearch -link /LIBPATH:C:\local\boost_1_55_0\lib64-msvc-11.0 -DEF:search.def 15 | -erase ..\lib\$(TARGET)\*.exp ..\lib\$(TARGET)\*.lib 16 | 17 | clean: 18 | -erase /Q *.obj ..\lib\$(TARGET)\. 19 | 20 | -------------------------------------------------------------------------------- /simpleknn/cpp/build.win64.bat: -------------------------------------------------------------------------------- 1 | call "C:\Program Files (x86)\Microsoft Visual Studio 11.0\VC\vcvarsall.bat" x86_amd64 2 | 3 | nmake -f Makefile.win64 clean 4 | nmake -f Makefile.win64 all 5 | @pause 6 | 7 | -------------------------------------------------------------------------------- /simpleknn/cpp/search.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | using namespace std; 15 | 16 | #include 17 | #include 18 | 19 | #include "search.h" 20 | 21 | #ifndef __S_IFMT 22 | #define __S_IFMT 0170000 23 | #endif 24 | 25 | #ifndef __S_IFDIR 26 | #define __S_IFDIR 0040000 27 | #endif 28 | 29 | inline bool operator<(const search_result &a, const search_result &b) 30 | { 31 | return a.value < b.value; 32 | } 33 | 34 | /* 35 | search_model *load_model(const char *model_file_name, const UInt64 dim, const UInt64 nimages) 36 | { 37 | using namespace boost::interprocess; 38 | 39 | FILE *fp = fopen(model_file_name, "rb"); 40 | if (0 == fp) { 41 | fprintf(stderr, "[search.load_model] failed to open model_file %s\n", model_file_name); 42 | return 0; 43 | } 44 | 45 | struct stat info; 46 | stat(model_file_name, &info); 47 | if( (info.st_mode & __S_IFMT ) == __S_IFDIR) { 48 | fprintf(stderr, "[search.load_model] %s is a directory\n", model_file_name); 49 | fclose(fp); 50 | return 0; 51 | } 52 | 53 | //printf("%d %d %d %d\n", sizeof(dim), sizeof(nimages), dim, nimages); 54 | 55 | UInt64 count = dim * nimages; 56 | search_model *model = new search_model; 57 | 58 | //printf("%d %d\n", sizeof(count), count); 59 | 60 | //fprintf(stdout, "[search.load model] requesting %llu bytes memory ...\n", count * sizeof(DataType)); 61 | model->feature_ptr = new DataType[count]; 62 | 63 | if (0 == model->feature_ptr) 64 | { 65 | fprintf(stderr, "[search.load_model] Memory error!\n"); 66 | fclose(fp); 67 | free_model(&model); 68 | return 0; 69 | } 70 | 71 | fread((char *)(model->feature_ptr), sizeof(DataType), count, fp); 72 | fclose(fp); 73 | model->dim = dim; 74 | model->nimages = nimages; 75 | 76 | //print_model(model); 77 | 78 | return model; 79 | } 80 | */ 81 | 82 | search_model *load_model(const char *model_file_name, const UInt64 dim, const UInt64 nimages) 83 | { 84 | using namespace boost::interprocess; 85 | 86 | FILE *fp = fopen(model_file_name, "rb"); 87 | if (0 == fp) { 88 | fprintf(stderr, "[search.load_model] failed to open model_file %s\n", model_file_name); 89 | return 0; 90 | } 91 | fclose(fp); 92 | 93 | struct stat info; 94 | stat(model_file_name, &info); 95 | if( (info.st_mode & __S_IFMT ) == __S_IFDIR) { 96 | fprintf(stderr, "[search.load_model] %s is a directory\n", model_file_name); 97 | fclose(fp); 98 | return 0; 99 | } 100 | 101 | file_mapping *m_file = new file_mapping(model_file_name, read_only); 102 | mapped_region *region = new mapped_region(*m_file, read_only); 103 | 104 | UInt64 count = dim * nimages * sizeof(DataType); 105 | 106 | search_model *model = new search_model; 107 | model->m_file = m_file; 108 | model->region = region; 109 | //Get the address of the mapped region 110 | model->feature_ptr = (DataType*)region->get_address(); 111 | UInt64 region_size = region->get_size(); 112 | 113 | if (0 == model->feature_ptr) 114 | { 115 | fprintf(stderr, "[search.load_model] Memory error!\n"); 116 | free_model(&model); 117 | return 0; 118 | } 119 | 120 | /* if (count != region_size) { 121 | fprintf(stderr, "[search.load_model] File size mis-match number of images!\n"); 122 | printf("nimages: %llu\n", nimages); 123 | printf("dim: %llu\n", dim); 124 | printf("region_size: %llu\n", region_size); 125 | printf("count: %llu\n", count); 126 | free_model(&model); 127 | return 0; 128 | } 129 | */ 130 | model->dim = dim; 131 | model->nimages = nimages; 132 | 133 | return model; 134 | } 135 | 136 | void free_model_contents(search_model* model_ptr) 137 | { 138 | /* 139 | if (0 != model_ptr->feature_ptr) 140 | { 141 | //fprintf(stdout, "[search.free_model_contents]\n"); 142 | delete [] model_ptr->feature_ptr; 143 | model_ptr->feature_ptr = 0; 144 | } 145 | */ 146 | if (0 != model_ptr->feature_ptr) 147 | { 148 | //fprintf(stdout, "[search.free_model_contents]\n"); 149 | delete model_ptr->region; 150 | delete model_ptr->m_file; 151 | model_ptr->feature_ptr = 0; 152 | } 153 | } 154 | 155 | 156 | void free_model(search_model** model_ptr_ptr) 157 | { 158 | search_model* model_ptr = *model_ptr_ptr; 159 | 160 | if(0 != model_ptr) 161 | { 162 | free_model_contents(model_ptr); 163 | delete model_ptr; 164 | model_ptr = 0; 165 | //fprintf(stdout, "[search.free_model]\n"); 166 | } 167 | } 168 | 169 | void print_model(const search_model* model_ptr) 170 | { 171 | fprintf(stdout, "[search.print_model] %llu images, %llu dims\n", model_ptr->nimages, model_ptr->dim); 172 | } 173 | 174 | 175 | UInt64 get_dim(const search_model* model_ptr) 176 | { 177 | return model_ptr->dim; 178 | } 179 | 180 | UInt64 get_nr_images(const search_model* model_ptr) 181 | { 182 | return model_ptr->nimages; 183 | } 184 | 185 | // 1 - ((xi * yi) / (norm(x) * norm(y))) 186 | 187 | void compute_cosine_distance(const search_model *model, const DataType* query_ptr, double* dist_values) 188 | { 189 | const DataType *ptr = model->feature_ptr; 190 | 191 | for (UInt64 i=0; inimages; i++) 192 | { 193 | double norm_query = 0; 194 | double norm_ptr = 0; 195 | double dist = 0; 196 | for (UInt64 j=0; jdim; j++) 197 | { 198 | norm_query += query_ptr[j] * query_ptr[j]; 199 | norm_ptr += ptr[j] * ptr[j]; 200 | dist += query_ptr[j] * ptr[j]; 201 | } 202 | ptr += model->dim; 203 | //fprintf(stdout, "%d %f %f\n", i, dist, sqrt(dist)); 204 | if (norm_query < 1e-8 || norm_ptr < 1e-8) { 205 | dist_values[i] = (norm_query < 1e-8 && norm_ptr < 1e-8) ? 0 : 1; // define a value for zero input 206 | } else { 207 | dist_values[i] = 1. - (dist / (sqrt(norm_query) * sqrt(norm_ptr))); 208 | } 209 | //dist_values[i] = 1. - (dist / (sqrt(norm_query) * sqrt(norm_ptr))); 210 | } 211 | } 212 | 213 | void compute_l2_distance(const search_model *model, const DataType* query_ptr, double* dist_values) 214 | { 215 | const DataType *ptr = model->feature_ptr; 216 | 217 | for (UInt64 i=0; inimages; i++) 218 | { 219 | double dist = 0; 220 | for (UInt64 j=0; jdim; j++) 221 | { 222 | double d = query_ptr[j] - ptr[j]; 223 | dist += (d * d); 224 | //if (0 == i) fprintf(stdout, "%d %f %f %f\n", j, query_ptr[j], ptr[j], d); 225 | } 226 | ptr += model->dim; 227 | //fprintf(stdout, "%d %f %f\n", i, dist, sqrt(dist)); 228 | dist_values[i] = sqrt(dist); 229 | } 230 | } 231 | 232 | void compute_l1_distance(const search_model *model, const DataType* query_ptr, double* dist_values) 233 | { 234 | const DataType *ptr = model->feature_ptr; 235 | 236 | for (UInt64 i=0; inimages; i++) 237 | { 238 | double dist = 0; 239 | for (UInt64 j=0; jdim; j++) 240 | { 241 | double d = query_ptr[j] - ptr[j]; 242 | dist += fabs(d); 243 | } 244 | ptr += model->dim; 245 | //fprintf(stdout, "%d %f %f\n", i, dist, sqrt(dist)); 246 | dist_values[i] = dist; 247 | } 248 | } 249 | 250 | /* 251 | * chi2(x,y) = sum( (xi-yi)^2 / (xi+yi) ) / 2 252 | */ 253 | void compute_chi2_distance(const search_model *model, const DataType* query_ptr, double* dist_values) 254 | { 255 | const DataType *ptr = model->feature_ptr; 256 | double dist = 0; 257 | double d = 0; 258 | double s = 0; 259 | 260 | for (UInt64 i=0; inimages; i++) 261 | { 262 | dist = 0.0; 263 | for (UInt64 j=0; jdim; j++) 264 | { 265 | d = query_ptr[j] - ptr[j]; 266 | s = query_ptr[j] + ptr[j]; 267 | if (s > 1e-8) { 268 | dist += (d*d)/s; 269 | } 270 | } 271 | dist /= 2; 272 | ptr += model->dim; 273 | //fprintf(stdout, "%d %f %f\n", i, dist, sqrt(dist)); 274 | dist_values[i] = dist; 275 | } 276 | } 277 | 278 | 279 | void search_knn(const struct search_model *model, const DataType* query_ptr, const UInt64 k, const int dfunc, struct search_result *results) 280 | { 281 | double * dist_values = new double[model->nimages]; 282 | 283 | switch (dfunc) { 284 | case 0: 285 | compute_l1_distance(model, query_ptr, dist_values); 286 | break; 287 | case 2: 288 | compute_chi2_distance(model, query_ptr, dist_values); 289 | break; 290 | case 4: 291 | compute_cosine_distance(model, query_ptr, dist_values); 292 | break; 293 | default: 294 | compute_l2_distance(model, query_ptr, dist_values); 295 | break; 296 | } 297 | /*if (1 == l2) { 298 | compute_distance(model, query_ptr, dist_values); 299 | } else { 300 | compute_l1_distance(model, query_ptr, dist_values); 301 | } */ 302 | struct search_result *tosort = new search_result[model->nimages]; 303 | for (UInt64 i=0; inimages; i++) 304 | { 305 | tosort[i].index = i; 306 | tosort[i].value = dist_values[i]; 307 | } 308 | delete [] dist_values; 309 | 310 | if (k <= (model->nimages >> 1)) { 311 | partial_sort(tosort, tosort + k, tosort + model->nimages); 312 | } 313 | else { 314 | sort(tosort, tosort + model->nimages); 315 | } 316 | 317 | for (UInt64 i=0; i 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | using namespace std; 15 | 16 | #include "search.h" 17 | 18 | #ifndef __S_IFMT 19 | #define __S_IFMT 0170000 20 | #endif 21 | 22 | #ifndef __S_IFDIR 23 | #define __S_IFDIR 0040000 24 | #endif 25 | 26 | inline bool operator<(const search_result &a, const search_result &b) 27 | { 28 | return a.value < b.value; 29 | } 30 | 31 | search_model *load_model(const char *model_file_name, const UInt64 dim, const UInt64 nimages) 32 | { 33 | FILE *fp = fopen(model_file_name, "rb"); 34 | if (0 == fp) { 35 | fprintf(stderr, "[search.load_model] failed to open model_file %s\n", model_file_name); 36 | return 0; 37 | } 38 | 39 | struct stat info; 40 | stat(model_file_name, &info); 41 | if( (info.st_mode & __S_IFMT ) == __S_IFDIR) { 42 | fprintf(stderr, "[search.load_model] %s is a directory\n", model_file_name); 43 | fclose(fp); 44 | return 0; 45 | } 46 | 47 | //printf("%d %d %d %d\n", sizeof(dim), sizeof(nimages), dim, nimages); 48 | 49 | UInt64 count = dim * nimages; 50 | search_model *model = new search_model; 51 | 52 | //printf("%d %d\n", sizeof(count), count); 53 | 54 | fprintf(stdout, "[search.load model] requesting %llu bytes memory ...\n", count * sizeof(DataType)); 55 | model->feature_ptr = new DataType[count]; 56 | 57 | if (0 == model->feature_ptr) 58 | { 59 | fprintf(stderr, "[search.load_model] Memory error!\n"); 60 | fclose(fp); 61 | free_model(&model); 62 | return 0; 63 | } 64 | 65 | fread((char *)(model->feature_ptr), sizeof(DataType), count, fp); 66 | fclose(fp); 67 | model->dim = dim; 68 | model->nimages = nimages; 69 | 70 | print_model(model); 71 | 72 | return model; 73 | } 74 | 75 | void free_model_contents(search_model* model_ptr) 76 | { 77 | if (0 != model_ptr->feature_ptr) 78 | { 79 | //fprintf(stdout, "[search.free_model_contents]\n"); 80 | delete [] model_ptr->feature_ptr; 81 | model_ptr->feature_ptr = 0; 82 | } 83 | } 84 | 85 | 86 | void free_model(search_model** model_ptr_ptr) 87 | { 88 | search_model* model_ptr = *model_ptr_ptr; 89 | if(0 != model_ptr) 90 | { 91 | free_model_contents(model_ptr); 92 | delete model_ptr; 93 | model_ptr = 0; 94 | fprintf(stdout, "[search.free_model]\n"); 95 | } 96 | } 97 | 98 | void print_model(const search_model* model_ptr) 99 | { 100 | fprintf(stdout, "[search.print_model] %llu images, %llu dims\n", model_ptr->nimages, model_ptr->dim); 101 | } 102 | 103 | 104 | UInt64 get_dim(const search_model* model_ptr) 105 | { 106 | return model_ptr->dim; 107 | } 108 | 109 | UInt64 get_nr_images(const search_model* model_ptr) 110 | { 111 | return model_ptr->nimages; 112 | } 113 | 114 | 115 | void compute_distance(const search_model *model, const DataType* query_ptr, double* dist_values) 116 | { 117 | const DataType *ptr = model->feature_ptr; 118 | 119 | for (UInt64 i=0; inimages; i++) 120 | { 121 | double dist = 0; 122 | for (UInt64 j=0; jdim; j++) 123 | { 124 | double d = query_ptr[j] - ptr[j]; 125 | dist += (d * d); 126 | //if (0 == i) fprintf(stdout, "%d %f %f %f\n", j, query_ptr[j], ptr[j], d); 127 | } 128 | ptr += model->dim; 129 | //fprintf(stdout, "%d %f %f\n", i, dist, sqrt(dist)); 130 | dist_values[i] = sqrt(dist); 131 | } 132 | } 133 | 134 | void compute_l1_distance(const search_model *model, const DataType* query_ptr, double* dist_values) 135 | { 136 | const DataType *ptr = model->feature_ptr; 137 | 138 | for (UInt64 i=0; inimages; i++) 139 | { 140 | double dist = 0; 141 | for (UInt64 j=0; jdim; j++) 142 | { 143 | double d = query_ptr[j] - ptr[j]; 144 | dist += fabs(d); 145 | } 146 | ptr += model->dim; 147 | //fprintf(stdout, "%d %f %f\n", i, dist, sqrt(dist)); 148 | dist_values[i] = dist; 149 | } 150 | } 151 | 152 | 153 | 154 | void search_knn(const struct search_model *model, const DataType* query_ptr, const UInt64 k, const int l2, struct search_result *results) 155 | { 156 | double * dist_values = new double[model->nimages]; 157 | if (1 == l2) { 158 | compute_distance(model, query_ptr, dist_values); 159 | } else { 160 | compute_l1_distance(model, query_ptr, dist_values); 161 | } 162 | 163 | struct search_result *tosort = new search_result[model->nimages]; 164 | for (UInt64 i=0; inimages; i++) 165 | { 166 | tosort[i].index = i; 167 | tosort[i].value = dist_values[i]; 168 | } 169 | delete [] dist_values; 170 | 171 | if (k <= (model->nimages >> 1)) { 172 | partial_sort(tosort, tosort + k, tosort + model->nimages); 173 | } 174 | else { 175 | sort(tosort, tosort + model->nimages); 176 | } 177 | 178 | for (UInt64 i=0; i', " ".join(["%s %.3f" % (v[0],v[1]) for v in visualNeighbors[:3]]) 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /simpleknn/do_norm_feat.sh: -------------------------------------------------------------------------------- 1 | for collection in tgif-msrvtt10k-msvd tv2016test tv2016train tv2017test tv2018test 2 | do 3 | for feat in pyresnet-152_imagenet11k,flatten0_output,os pyresnext-101_rbps13k,flatten0_output,os 4 | do 5 | python norm_feat.py /home/daniel/VisualSearch/trecvid2018/${collection}/FeatureData/$feat 6 | done 7 | done 8 | -------------------------------------------------------------------------------- /simpleknn/im2fea.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import numpy as np 4 | from basic.common import makedirsforfile, checkToSkip, printStatus 5 | from bigfile import BigFile 6 | 7 | INFO = __file__ 8 | 9 | def process(options, feat_dir, imsetfile, result_dir): 10 | 11 | resultfile = os.path.join(result_dir, 'feature.bin') 12 | if checkToSkip(resultfile, options.overwrite): 13 | sys.exit(0) 14 | 15 | imset = map(str.strip, open(imsetfile).readlines()) 16 | print "requested", len(imset) 17 | 18 | feat_file = BigFile(feat_dir) 19 | 20 | makedirsforfile(resultfile) 21 | fw = open(resultfile, 'wb') 22 | 23 | done = [] 24 | start = 0 25 | 26 | while start < len(imset): 27 | end = min(len(imset), start + options.blocksize) 28 | printStatus(INFO, 'processing images from %d to %d' % (start, end-1)) 29 | toread = imset[start:end] 30 | if len(toread) == 0: 31 | break 32 | renamed, vectors = feat_file.read(toread) 33 | for vec in vectors: 34 | vec = np.array(vec, dtype=np.float32) 35 | vec.tofile(fw) 36 | done += renamed 37 | start = end 38 | fw.close() 39 | 40 | assert(len(done) == len(set(done))) 41 | with open(os.path.join(result_dir, 'id.txt'), 'w') as fw: 42 | fw.write(' '.join(done)) 43 | fw.close() 44 | 45 | with open(os.path.join(result_dir,'shape.txt'), 'w') as fw: 46 | fw.write('%d %d' % (len(done), feat_file.ndims)) 47 | fw.close() 48 | print '%d requested, %d obtained' % (len(imset), len(done)) 49 | 50 | 51 | def main(argv=None): 52 | if argv is None: 53 | argv = sys.argv[1:] 54 | 55 | from optparse import OptionParser 56 | parser = OptionParser(usage="""usage: %prog [options] feat_dir imsetfile result_dir""") 57 | parser.add_option("--overwrite", default=0, type="int", help="overwrite existing file (default=0)") 58 | parser.add_option("--blocksize", default=1000, type="int", help="nr of feature vectors loaded per time (default: 1000)") 59 | 60 | 61 | (options, args) = parser.parse_args(argv) 62 | if len(args) < 3: 63 | parser.print_help() 64 | return 1 65 | 66 | return process(options, args[0], args[1], args[2]) 67 | 68 | if __name__ == "__main__": 69 | sys.exit(main()) 70 | 71 | -------------------------------------------------------------------------------- /simpleknn/lib/linux/libsearch.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danieljf24/cbvr/f604588a4d25523cc9a667f620d364793924e877/simpleknn/lib/linux/libsearch.so -------------------------------------------------------------------------------- /simpleknn/lib/mac/libsearch.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danieljf24/cbvr/f604588a4d25523cc9a667f620d364793924e877/simpleknn/lib/mac/libsearch.so -------------------------------------------------------------------------------- /simpleknn/lib/win64/libsearch.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danieljf24/cbvr/f604588a4d25523cc9a667f620d364793924e877/simpleknn/lib/win64/libsearch.dll -------------------------------------------------------------------------------- /simpleknn/merge_feat.py: -------------------------------------------------------------------------------- 1 | import os, sys, array 2 | import numpy as np 3 | 4 | from basic.constant import ROOT_PATH 5 | from basic.common import checkToSkip, makedirsforfile, printStatus 6 | from basic.util import readImageSet 7 | 8 | INFO = __file__ 9 | 10 | def process(options, feature, srcCollections, newCollection): 11 | assert(type(srcCollections) == list) 12 | 13 | temp = [] 14 | [x for x in srcCollections if x not in temp and temp.append(x)] # unique source collections 15 | srcCollections = temp 16 | 17 | rootpath = options.rootpath 18 | 19 | resfile = os.path.join(rootpath, newCollection, 'FeatureData', feature, 'feature.bin') 20 | if checkToSkip(resfile, options.overwrite): 21 | return 0 22 | 23 | querysetfile = os.path.join(rootpath, newCollection, 'ImageSets', '%s.txt' % newCollection) 24 | try: 25 | query_set = set(map(str.strip, open(querysetfile).readlines())) 26 | printStatus(INFO, '%d images wanted' % len(query_set)) 27 | except IOError: 28 | printStatus(INFO, 'failed to load %s, will merge all features in %s' % (querysetfile, ';'.join(srcCollections))) 29 | query_set = None 30 | 31 | makedirsforfile(resfile) 32 | fw = open(resfile, 'wb') 33 | printStatus(INFO, 'writing results to %s' % resfile) 34 | seen = set() 35 | newimset = [] 36 | 37 | for collection in srcCollections: 38 | feat_dir = os.path.join(rootpath, collection, 'FeatureData', feature) 39 | with open(os.path.join(feat_dir, 'shape.txt')) as fr: 40 | nr_of_images, feat_dim = map(int, fr.readline().strip().split()) 41 | fr.close() 42 | 43 | srcimset = open(os.path.join(feat_dir,'id.txt')).readline().strip().split() 44 | res = array.array('f') 45 | fr = open(os.path.join(feat_dir,'feature.bin'), 'rb') 46 | 47 | for i,im in enumerate(srcimset): 48 | res.fromfile(fr, feat_dim) 49 | if im not in seen: 50 | seen.add(im) 51 | if not query_set or im in query_set: 52 | vec = res 53 | vec = np.array(vec, dtype=np.float32) 54 | vec.tofile(fw) 55 | newimset.append(im) 56 | del res[:] 57 | if i%1e5 == 0: 58 | printStatus(INFO, '%d parsed, %d obtained' % (len(seen), len(newimset))) 59 | fr.close() 60 | printStatus(INFO, '%d parsed, %d obtained' % (len(seen), len(newimset))) 61 | 62 | fw.close() 63 | printStatus(INFO, '%d parsed, %d obtained' % (len(seen), len(newimset))) 64 | 65 | idfile = os.path.join(os.path.split(resfile)[0], 'id.txt') 66 | with open(idfile, 'w') as fw: 67 | fw.write(' '.join(newimset)) 68 | fw.close() 69 | 70 | shapefile = os.path.join(os.path.split(resfile)[0], 'shape.txt') 71 | with open(shapefile, 'w') as fw: 72 | fw.write('%d %d' % (len(newimset), feat_dim)) 73 | fw.close() 74 | 75 | 76 | 77 | def main(argv=None): 78 | if argv is None: 79 | argv = sys.argv[1:] 80 | 81 | from optparse import OptionParser 82 | parser = OptionParser(usage="""usage: %prog [options] feature srcCollections newCollection""") 83 | parser.add_option("--overwrite", default=0, type="int", help="overwrite existing file (default=0)") 84 | parser.add_option("--rootpath", default=ROOT_PATH, type="string", help="rootpath where data are stored (default: %s)" % ROOT_PATH) 85 | 86 | (options, args) = parser.parse_args(argv) 87 | if len(args) < 3: 88 | parser.print_help() 89 | return 1 90 | return process(options, args[0], args[1].split(','), args[2]) 91 | 92 | 93 | if __name__ == "__main__": 94 | sys.exit(main()) 95 | -------------------------------------------------------------------------------- /simpleknn/norm_feat.py: -------------------------------------------------------------------------------- 1 | import os, sys, array, shutil 2 | import numpy as np 3 | from basic.common import checkToSkip, makedirsforfile 4 | 5 | def process(options, feat_dir): 6 | newname = '' 7 | if options.ssr: 8 | newname = 'ssr' 9 | newname += 'l%d' % options.p 10 | resfile = os.path.join(feat_dir.rstrip('/\\') + newname, 'feature.bin') 11 | if checkToSkip(resfile, options.overwrite): 12 | return 0 13 | 14 | with open(os.path.join(feat_dir, 'shape.txt')) as fr: 15 | nr_of_images, feat_dim = map(int, fr.readline().strip().split()) 16 | fr.close() 17 | 18 | offset = np.float32(1).nbytes * feat_dim 19 | res = array.array('f') 20 | 21 | fr = open(os.path.join(feat_dir,'feature.bin'), 'rb') 22 | makedirsforfile(resfile) 23 | fw = open(resfile, 'wb') 24 | print ('>>> writing results to %s' % resfile) 25 | 26 | 27 | for i in xrange(nr_of_images): 28 | res.fromfile(fr, feat_dim) 29 | vec = res 30 | if options.ssr: 31 | vec = [np.sign(x) * np.sqrt(abs(x)) for x in vec] 32 | if options.p == 1: 33 | Z = sum(abs(x) for x in vec) + 1e-9 34 | else: 35 | Z = np.sqrt(sum([x**2 for x in vec])) + 1e-9 36 | if i % 1e4 == 0: 37 | print ('image_%d, norm_%d=%g' % (i, options.p, Z)) 38 | vec = [x/Z for x in vec] 39 | del res[:] 40 | vec = np.array(vec, dtype=np.float32) 41 | vec.tofile(fw) 42 | fr.close() 43 | fw.close() 44 | print ('>>> %d lines parsed' % nr_of_images) 45 | shutil.copyfile(os.path.join(feat_dir,'id.txt'), os.path.join(os.path.split(resfile)[0], 'id.txt')) 46 | 47 | shapefile = os.path.join(os.path.split(resfile)[0], 'shape.txt') 48 | with open(shapefile, 'w') as fw: 49 | fw.write('%d %d' % (nr_of_images, feat_dim)) 50 | fw.close() 51 | 52 | 53 | 54 | def main(argv=None): 55 | if argv is None: 56 | argv = sys.argv[1:] 57 | 58 | from optparse import OptionParser 59 | parser = OptionParser(usage="""usage: %prog [options] feat_dir""") 60 | parser.add_option("--overwrite", default=0, type="int", help="overwrite existing file (default=0)") 61 | parser.add_option("--ssr", default=0, type="int", help="do signed square root per dim (default=0)") 62 | parser.add_option("--p", default=2, type="int", help="L_p normalization (default p=2)") 63 | 64 | 65 | (options, args) = parser.parse_args(argv) 66 | if len(args) < 1: 67 | parser.print_help() 68 | return 1 69 | assert(options.p in [1, 2]) 70 | return process(options, args[0]) 71 | 72 | 73 | if __name__ == "__main__": 74 | sys.exit(main()) -------------------------------------------------------------------------------- /simpleknn/simpleknn.py: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/env python 3 | 4 | import time 5 | from ctypes import * 6 | from ctypes.util import find_library 7 | 8 | import sys 9 | import os 10 | import platform 11 | 12 | LIB_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'lib') 13 | 14 | if sys.platform.startswith('linux'): 15 | filename = os.path.join(LIB_PATH, 'linux', 'libsearch.so') 16 | libsearch = CDLL(filename) 17 | elif sys.platform.startswith('darwin'): 18 | filename = os.path.join(LIB_PATH, 'mac', 'libsearch.so') 19 | libsearch = CDLL(filename) 20 | else: 21 | libdir = 'win32' if platform.architecture()[0] == '32bit' else 'win64' 22 | filename = os.path.join(LIB_PATH, libdir,'libsearch.dll') 23 | libsearch = cdll.LoadLibrary(filename) 24 | 25 | 26 | DFUNC_MAPPING = {'l1':0, 'l2':1, 'chi2':2, 'cosine':4} 27 | 28 | def fillprototype(f, restype, argtypes): 29 | f.restype = restype 30 | f.argtypes = argtypes 31 | 32 | def genFields(names, types): 33 | return list(zip(names, types)) 34 | 35 | 36 | class search_result(Structure): 37 | _names = ["index", "value"] 38 | _types = [c_uint64, c_double] 39 | _fields_ = genFields(_names, _types) 40 | 41 | 42 | class search_model(Structure): 43 | def __init__(self): 44 | self.__createfrom__ = 'python' 45 | 46 | def __del__(self): 47 | # free memory created by C to avoid memory leak 48 | if hasattr(self, '__createfrom__') and self.__createfrom__ == 'C': 49 | if pointer(self) is not None: 50 | libsearch.free_model(pointer(self)) 51 | 52 | def load_ids(self, idfile): 53 | self.ids = str.split(open(idfile).readline().strip()) 54 | assert(len(self.ids) == self.get_nr_images()) 55 | 56 | def get_dim(self): 57 | return libsearch.get_dim(self) 58 | 59 | def get_nr_images(self): 60 | return libsearch.get_nr_images(self) 61 | 62 | def set_distance(self, dfunc): 63 | self.dfunc = DFUNC_MAPPING[dfunc] 64 | #print ('[%s] use %s distance' % (self.__class__.__name__, self.get_distance_name())) 65 | 66 | ''' 67 | def useL2Distance(self): 68 | print ('[%s] use L2 distance' % self.__class__.__name__) 69 | self.l2 = 1 70 | ''' 71 | 72 | def get_distance_name(self): 73 | NAMES = ['l1', 'l2', 'chi2'] 74 | return NAMES[self.dfunc] 75 | 76 | def search_knn(self, query, max_hits): 77 | assert(len(query) == self.get_dim()) 78 | topn = min(self.get_nr_images(), max_hits) 79 | 80 | query_ptr = (c_float * len(query))() 81 | for i in range(len(query)): 82 | query_ptr[i] = query[i] 83 | 84 | results = (search_result * topn)() 85 | s_time = time.time() 86 | libsearch.search_knn(self, query_ptr, topn, self.dfunc, results) 87 | knn_time = time.time() - s_time 88 | 89 | #print "search %d-nn, %.4f seconds" % (topn, knn_time) 90 | return [(self.ids[x.index], x.value) for x in results] 91 | 92 | 93 | def toPyModel(model_ptr): 94 | """ 95 | toPyModel(model_ptr) -> search_model 96 | 97 | Convert a ctypes POINTER(search_model) to a Python search_model 98 | """ 99 | if bool(model_ptr) == False: 100 | raise ValueError("Null pointer") 101 | m = model_ptr.contents 102 | m.__createfrom__ = 'C' 103 | return m 104 | 105 | 106 | #def load_model(model_file_name, dim, nimages, id_file_name): 107 | # model = libsearch.load_model(model_file_name, dim, nimages) 108 | # if not model: 109 | # print("failed to load a search model from %s" % model_file_name) 110 | # return None 111 | # 112 | # model = toPyModel(model) 113 | # model.load_ids(id_file_name) 114 | # model.set_distance('l2') 115 | # return model 116 | 117 | def load_model(feat_dir): 118 | shapefile = os.path.join(feat_dir, 'shape.txt') 119 | nr_of_images, feat_dim = map(int, open(shapefile).readline().split()) 120 | model_file_name = os.path.join(feat_dir, 'feature.bin') 121 | model = libsearch.load_model(model_file_name, feat_dim, nr_of_images) 122 | if not model: 123 | print("failed to load a search model from %s" % model_file_name) 124 | return None 125 | 126 | model = toPyModel(model) 127 | model.load_ids(os.path.join(feat_dir, 'id.txt')) 128 | model.set_distance('l2') 129 | print ('[%s] %dx%d' % (__file__, nr_of_images, feat_dim)) 130 | return model 131 | 132 | 133 | 134 | fillprototype(libsearch.load_model, POINTER(search_model), [c_char_p, c_uint64, c_uint64]) 135 | fillprototype(libsearch.get_dim, c_uint64, [POINTER(search_model)]) 136 | fillprototype(libsearch.get_nr_images, c_uint64, [POINTER(search_model)]) 137 | fillprototype(libsearch.free_model, None, [POINTER(POINTER(search_model))]) 138 | fillprototype(libsearch.search_knn, None, [POINTER(search_model), POINTER(c_float), c_uint64, c_int, POINTER(search_result)]) 139 | -------------------------------------------------------------------------------- /simpleknn/test.bat: -------------------------------------------------------------------------------- 1 | 2 | ::# Step 1. prepare data 3 | SET dim=3 4 | SET featurefile=toydata/FeatureData/f1/id.feature.txt 5 | SET resultdir=toydata/FeatureData/f1 6 | 7 | python txt2bin.py %dim% %featurefile% 0 %resultdir% 8 | 9 | ::# Step 2. search 10 | python demo.py 11 | 12 | @pause 13 | 14 | -------------------------------------------------------------------------------- /simpleknn/test.sh: -------------------------------------------------------------------------------- 1 | 2 | # Step 1. prepare data 3 | dim=3 4 | featurefile=toydata/FeatureData/f1/id.feature.txt 5 | resultdir=toydata/FeatureData/f1 6 | python txt2bin.py $dim $featurefile 0 $resultdir 7 | 8 | # Step 2. search 9 | python demo.py 10 | 11 | -------------------------------------------------------------------------------- /simpleknn/testbigfile.py: -------------------------------------------------------------------------------- 1 | import os, random 2 | 3 | import simpleknn 4 | from bigfile import BigFile 5 | 6 | rootpath = '/Users/xirong/VisualSearch' 7 | collection = 'train10k' 8 | nr_of_images = 10000 9 | feature = 'color64' 10 | dim = 64 11 | 12 | feature_dir = os.path.join(rootpath,collection,'FeatureData',feature) 13 | feature_file = BigFile(feature_dir, dim) 14 | imset = map(str.strip, open(os.path.join(rootpath,collection,'ImageSets','%s.txt'%collection)).readlines()) 15 | imset = random.sample(imset, 10) 16 | 17 | searcher = simpleknn.load_model(os.path.join(feature_dir, "feature.bin"), dim, nr_of_images, os.path.join(feature_dir, "id.txt")) 18 | searcher.set_distance('l1') 19 | renamed,vectors = feature_file.read(imset) 20 | 21 | for name,vec in zip(renamed,vectors): 22 | visualNeighbors = searcher.search_knn(vec, max_hits=100) 23 | print name, visualNeighbors[:3] 24 | -------------------------------------------------------------------------------- /simpleknn/toydata/FeatureData/f1/feature.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danieljf24/cbvr/f604588a4d25523cc9a667f620d364793924e877/simpleknn/toydata/FeatureData/f1/feature.bin -------------------------------------------------------------------------------- /simpleknn/toydata/FeatureData/f1/id.feature.txt: -------------------------------------------------------------------------------- 1 | a 1.00000001 2.33 3.89935 2 | b 0.5 3 4.35394 3 | -------------------------------------------------------------------------------- /simpleknn/toydata/FeatureData/f1/id.txt: -------------------------------------------------------------------------------- 1 | a b -------------------------------------------------------------------------------- /simpleknn/toydata/FeatureData/f1/shape.txt: -------------------------------------------------------------------------------- 1 | 2 3 2 | -------------------------------------------------------------------------------- /simpleknn/txt2bin.py: -------------------------------------------------------------------------------- 1 | ''' 2 | # convert one or multiple feature files from txt format to binary (float32) format 3 | ''' 4 | 5 | import os, sys, math 6 | import numpy as np 7 | from optparse import OptionParser 8 | 9 | 10 | def checkToSkip(filename, overwrite): 11 | if os.path.exists(filename): 12 | print ("%s exists." % filename), 13 | if overwrite: 14 | print ("overwrite") 15 | return 0 16 | else: 17 | print ("skip") 18 | return 1 19 | return 0 20 | 21 | 22 | def process(feat_dim, inputTextFiles, resultdir, overwrite): 23 | res_binary_file = os.path.join(resultdir, 'feature.bin') 24 | res_id_file = os.path.join(resultdir, 'id.txt') 25 | 26 | if checkToSkip(res_binary_file, overwrite): 27 | return 0 28 | 29 | if os.path.isdir(resultdir) is False: 30 | os.makedirs(resultdir) 31 | 32 | fw = open(res_binary_file, 'wb') 33 | processed = set() 34 | imset = [] 35 | count_line = 0 36 | failed = 0 37 | 38 | for filename in inputTextFiles: 39 | print ('>>> Processing %s' % filename) 40 | for line in open(filename): 41 | count_line += 1 42 | elems = line.strip().split() 43 | if not elems: 44 | continue 45 | name = elems[0] 46 | if name in processed: 47 | continue 48 | processed.add(name) 49 | 50 | del elems[0] 51 | vec = np.array(map(float, elems), dtype=np.float32) 52 | okay = True 53 | for x in vec: 54 | if math.isnan(x): 55 | okay = False 56 | break 57 | if not okay: 58 | failed += 1 59 | continue 60 | 61 | assert(len(vec) == feat_dim), "dimensionality mismatch: required %d, input %d, id=%s, inputfile=%s" % (feat_dim, len(vec), name, filename) 62 | vec.tofile(fw) 63 | #print name, vec 64 | imset.append(name) 65 | fw.close() 66 | 67 | fw = open(res_id_file, 'w') 68 | fw.write(' '.join(imset)) 69 | fw.close() 70 | fw = open(os.path.join(resultdir,'shape.txt'), 'w') 71 | fw.write('%d %d' % (len(imset), feat_dim)) 72 | fw.close() 73 | print ('%d lines parsed, %d ids, %d failed -> %d unique ids' % (count_line, len(processed), failed, len(imset))) 74 | 75 | 76 | 77 | def main(argv=None): 78 | if argv is None: 79 | argv = sys.argv[1:] 80 | 81 | parser = OptionParser(usage="""usage: %prog [options] nDims inputTextFile isFileList resultDir""") 82 | parser.add_option("--overwrite", default=0, type="int", help="overwrite existing file (default=0)") 83 | 84 | (options, args) = parser.parse_args(argv) 85 | if len(args) < 4: 86 | parser.print_help() 87 | return 1 88 | 89 | fea_dim = int(args[0]) 90 | inputTextFile = args[1] 91 | if int(args[2]) == 1: 92 | inputTextFiles = [x.strip() for x in open(inputTextFile).readlines() if x.strip() and not x.strip().startswith('#')] 93 | else: 94 | inputTextFiles = [inputTextFile] 95 | return process(fea_dim, inputTextFiles, args[3], options.overwrite) 96 | 97 | if __name__ == "__main__": 98 | sys.exit(main()) 99 | 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import json 5 | 6 | import torch 7 | 8 | import data 9 | from model import ReLearning 10 | from evaluation import encode_data, do_predict, cal_rel_index 11 | 12 | import argparse 13 | import logging 14 | import tensorboard_logger as tb_logger 15 | 16 | from simpleknn.bigfile import BigFile 17 | from utils.generic_utils import Progbar 18 | from utils.common import makedirsforfile, checkToSkip, ROOT_PATH 19 | from utils.util import read_video_set, write_csv, read_dict, write_csv_video2rank, get_count 20 | from utils.cbvrp_eval import read_csv_to_dict, hit_k_own, recall_k_own 21 | 22 | 23 | def main(): 24 | # Hyper Parameters 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--rootpath", default=ROOT_PATH, type=str, help="rootpath (default: %s)" % ROOT_PATH) 27 | parser.add_argument('--collection', default='track_1_shows', type=str, help='collection') 28 | parser.add_argument('--checkpoint_path', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') 29 | parser.add_argument("--test_set", default="val", type=str, help="val or test") 30 | parser.add_argument('--batch_size', default=128, type=int, help='Size of a training mini-batch.') 31 | parser.add_argument("--overwrite", default=0, type=int, help="overwrite existing file (default: 0)") 32 | parser.add_argument('--strategy', default=1, type=int, help='1: use Strategy 1, 2: use Strategy 2') 33 | parser.add_argument('--n', default=5, type=int, help='top n relevant videos of a candidate video') 34 | 35 | opt = parser.parse_args() 36 | print(json.dumps(vars(opt), indent = 2)) 37 | 38 | 39 | assert opt.test_set in ['val', 'test'] 40 | output_dir = os.path.dirname(opt.checkpoint_path.replace('/cv/', '/results/%s/' % opt.test_set )) 41 | if opt.strategy == 2: 42 | output_dir = os.path.join(output_dir, 'strategy_%d_n_%d' % (opt.strategy, opt.n)) 43 | output_file = os.path.join(output_dir,'pred_video2rank.csv') 44 | if checkToSkip(output_file, opt.overwrite): 45 | sys.exit(0) 46 | makedirsforfile(output_file) 47 | 48 | 49 | if opt.strategy == 2: 50 | rele_index_path = os.path.join(opt.rootpath, opt.collection, 'rel_index.csv') 51 | if not os.path.exists(rele_index_path): 52 | get_count(os.path.join(opt.rootpath, opt.collection)) 53 | rel_index = cal_rel_index(rele_index_path) 54 | else: 55 | rel_index = None 56 | 57 | # reading data 58 | train_video_set_file = os.path.join(opt.rootpath, opt.collection, 'split', 'train.csv') 59 | val_video_set_file = os.path.join(opt.rootpath, opt.collection, 'split', 'val.csv') 60 | train_video_list = read_video_set(train_video_set_file) 61 | val_video_list = read_video_set(val_video_set_file) 62 | if opt.test_set == 'test': 63 | test_video_set_file = os.path.join(opt.rootpath, opt.collection, 'split', 'test.csv' ) 64 | test_video_list = read_video_set(test_video_set_file) 65 | 66 | 67 | # optionally resume from a checkpoint 68 | print("=> loading checkpoint '{}'".format(opt.checkpoint_path)) 69 | checkpoint = torch.load(opt.checkpoint_path) 70 | options = checkpoint['opt'] 71 | 72 | # set feature reader 73 | video_feat_path = os.path.join(opt.rootpath, opt.collection, 'FeatureData', options.feature) 74 | video_feats = BigFile(video_feat_path) 75 | 76 | 77 | # Construct the model 78 | if opt.test_set == 'val': 79 | val_rootpath = os.path.join(opt.rootpath, opt.collection, 'relevance_val.csv') 80 | val_video2gtrank = read_csv_to_dict(val_rootpath) 81 | val_feat_loader = data.get_feat_loader(val_video_list, video_feats, opt.batch_size, False, 1) 82 | cand_feat_loader = data.get_feat_loader(train_video_list + val_video_list, video_feats, opt.batch_size, False, 1) 83 | elif opt.test_set == 'test': 84 | val_feat_loader = data.get_feat_loader(test_video_list, video_feats, opt.batch_size, False, 1) 85 | cand_feat_loader = data.get_feat_loader(train_video_list + val_video_list + test_video_list, video_feats, opt.batch_size, False, 1) 86 | 87 | model = ReLearning(options) 88 | model.load_state_dict(checkpoint['model']) 89 | val_video_embs, val_video_ids_list = encode_data(model, val_feat_loader, options.log_step, logging.info) 90 | cand_video_embs, cand_video_ids_list = encode_data(model, cand_feat_loader, options.log_step, logging.info) 91 | 92 | 93 | video2predrank = do_predict(val_video_embs, val_video_ids_list, cand_video_embs, cand_video_ids_list, rel_index, opt.n, output_dir=output_dir, overwrite=1, no_imgnorm=options.no_imgnorm) 94 | write_csv_video2rank(output_file, video2predrank) 95 | 96 | if opt.test_set == 'val': 97 | hit_top_k = [5, 10, 20, 30] 98 | recall_top_k = [50, 100, 200, 300] 99 | hit_k_scores = hit_k_own(val_video2gtrank, video2predrank, top_k=hit_top_k) 100 | recall_K_scores = recall_k_own(val_video2gtrank, video2predrank, top_k=recall_top_k) 101 | 102 | # output val performance 103 | 104 | print('# Using Strategy %d for relevance prediction:' % (opt.strategy)) 105 | print('best performance on validation:') 106 | print('hit_top_k', [round(x,3) for x in hit_k_scores]) 107 | print('recall_top_k', [round(x,3) for x in recall_K_scores]) 108 | with open(os.path.join(output_dir,'perf.txt'), 'w') as fout: 109 | fout.write('best performance on validation:') 110 | fout.write('\nhit_top_k: ' + ", ".join(map(str, [round(x,3) for x in hit_k_scores]))) 111 | fout.write('\necall_top_k: ' + ", ".join(map(str, [round(x,3) for x in recall_K_scores]))) 112 | 113 | 114 | if __name__ == '__main__': 115 | main() 116 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import json 5 | import shutil 6 | 7 | import torch 8 | 9 | import data 10 | from model import ReLearning 11 | from evaluation import AverageMeter, LogCollector, encode_data, do_predict 12 | 13 | import argparse 14 | import logging 15 | import tensorboard_logger as tb_logger 16 | 17 | from simpleknn.bigfile import BigFile 18 | from utils.generic_utils import Progbar 19 | from utils.util import read_video_set, write_csv, read_dict 20 | from utils.common import ROOT_PATH, checkToSkip, makedirsforfile 21 | from utils.cbvrp_eval import read_csv_to_dict, hit_k_own, recall_k_own 22 | 23 | 24 | 25 | def main(): 26 | # Hyper Parameters 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("--rootpath", default=ROOT_PATH, type=str, help="rootpath (default: %s)" % ROOT_PATH) 29 | parser.add_argument("--overwrite", default=0, type=int, help="overwrite existing file (default: 0)") 30 | parser.add_argument('--collection', default='track_1_shows', type=str, help='collection') 31 | 32 | parser.add_argument('--feature', default='inception-pool3', type=str, help="video feature.") 33 | parser.add_argument('--embed_size', default=1024, type=int, help='Dimensionality of the video embedding.') 34 | 35 | parser.add_argument('--loss', default='trl', type=str, help='loss function. (trl|netrl)') 36 | parser.add_argument('--alpha', default=1.0, type=float, help='loss weight for irrelevant loss') 37 | parser.add_argument("--cost_style", default='sum', type=str, help="cost_style (sum|mean)") 38 | parser.add_argument('--max_violation', action='store_true', help='Use max instead of sum in the rank loss.') 39 | parser.add_argument('--margin', default=0.2, type=float, help='Rank loss margin.') 40 | parser.add_argument('--margin_irel', default=0.05, type=float, help='Irrelevant loss margin.') 41 | parser.add_argument('--grad_clip', default=2., type=float, help='Gradient clipping threshold.') 42 | parser.add_argument('--optimizer', default='adam', type=str, help='optimizer. (adam|rmsprop)') 43 | parser.add_argument('--learning_rate', default=.001, type=float, help='Initial learning rate.') 44 | parser.add_argument('--lr_decay', default=0.99, type=float, help='learning rate decay after each epoch') 45 | 46 | parser.add_argument('--num_epochs', default=50, type=int, help='Number of training epochs.') 47 | parser.add_argument('--batch_size', default=32, type=int, help='Size of a training mini-batch.') 48 | parser.add_argument('--workers', default=2, type=int, help='Number of data loader workers.') 49 | parser.add_argument('--log_step', default=100, type=int, help='Number of steps to print and record the log.') 50 | 51 | parser.add_argument('--measure', default='cosine', help='Similarity measure used (cosine|order)') 52 | parser.add_argument('--no_imgnorm', action='store_true', help='Do not normalize the image embeddings.') 53 | parser.add_argument('--postfix', default='run_0', type=str, help='') 54 | 55 | # augmentation for frame-level features 56 | parser.add_argument('--stride', default='1', type=str, help='stride=1 means no frame-level data augmentation (default: 1)') 57 | # augmentation for video-level features 58 | parser.add_argument('--aug_prob', default=0.0, type=float, 59 | help='aug_prob=0 means no frame-level data augmentation, aug_prob=0.5 means half of video use augmented features(default: 0.0)') 60 | parser.add_argument('--perturb_intensity', default=1.0, type=float, help='perturbation intensity, epsilon in Eq.2 (default: 1.0)') 61 | parser.add_argument('--perturb_prob', default=0.5, type=float, help='perturbation probability, p in Eq.2 (default: 0.5)') 62 | 63 | 64 | opt = parser.parse_args() 65 | print json.dumps(vars(opt), indent = 2) 66 | 67 | visual_info = 'feature_%s_embed_size_%d_no_imgnorm_%s' % (opt.feature, opt.embed_size, opt.no_imgnorm) 68 | loss_info = '%s_%s_margin_%.1f_max_violation_%s_%s' % (opt.loss, opt.measure, opt.margin, opt.max_violation, opt.cost_style) 69 | if opt.loss == 'netrl': 70 | loss_info += '_alpha_%.1f_margin_irel_%.2f' % (opt.alpha, opt.margin_irel) 71 | optimizer_info = '%s_lr_%.5f_%.2f_bs_%d' % ( opt.optimizer, opt.learning_rate, opt.lr_decay, opt.batch_size) 72 | data_argumentation_info = 'frame_stride_%s_video_prob_%.1f_perturb_intensity_%.5f_perturb_prob_%.2f' % (opt.stride, opt.aug_prob, opt.perturb_intensity, opt.perturb_prob) 73 | 74 | 75 | opt.logger_name = os.path.join(opt.rootpath, opt.collection, 'cv', 'ReLearning', visual_info, loss_info, optimizer_info, data_argumentation_info, opt.postfix) 76 | if checkToSkip(os.path.join(opt.logger_name,'model_best.pth.tar'), opt.overwrite): 77 | sys.exit(0) 78 | if checkToSkip(os.path.join(opt.logger_name,'val_perf.txt'), opt.overwrite): 79 | sys.exit(0) 80 | makedirsforfile(os.path.join(opt.logger_name,'model_best.pth.tar')) 81 | 82 | logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO) 83 | tb_logger.configure(opt.logger_name, flush_secs=5) 84 | 85 | 86 | # reading data 87 | train_video_set_file = os.path.join(opt.rootpath, opt.collection, 'split', 'train.csv' ) 88 | val_video_set_file = os.path.join(opt.rootpath, opt.collection, 'split', 'val.csv' ) 89 | train_video_list = read_video_set(train_video_set_file) 90 | val_video_list = read_video_set(val_video_set_file) 91 | 92 | train_rootpath = os.path.join(opt.rootpath, opt.collection, 'relevance_train.csv') 93 | val_rootpath = os.path.join(opt.rootpath, opt.collection, 'relevance_val.csv') 94 | val_video2gtrank = read_csv_to_dict(val_rootpath) 95 | 96 | stride_list = map(int, opt.stride.strip().split('-')) 97 | opt.sum_subs = sum(stride_list) 98 | if opt.aug_prob <= 0: 99 | opt.feature = "avg-" + opt.feature + "-stride%s" % opt.stride 100 | 101 | video_feat_path = os.path.join(opt.rootpath, opt.collection, 'FeatureData', opt.feature) 102 | video_feats = BigFile(video_feat_path) 103 | opt.feature_dim = video_feats.ndims 104 | 105 | 106 | # Load data loaders 107 | if opt.sum_subs > 1: 108 | video2subvideo_path = os.path.join(video_feat_path, 'video2subvideo.txt') 109 | video2subvideo = read_dict(video2subvideo_path) 110 | train_loader = data.get_video_da_loader(train_rootpath, video_feats, opt, opt.batch_size, True, opt.workers, 111 | video2subvideo, opt.sum_subs, feat_path=video_feat_path) 112 | else: 113 | train_loader = data.get_video_da_loader(train_rootpath, video_feats, opt, opt.batch_size, True, opt.workers, feat_path=video_feat_path) 114 | val_feat_loader = data.get_feat_loader(val_video_list, video_feats, opt.batch_size, False, 1) 115 | cand_feat_loader = data.get_feat_loader(train_video_list + val_video_list, video_feats, opt.batch_size, False, 1) 116 | 117 | # Construct the model 118 | model = ReLearning(opt) 119 | 120 | # Train the Model 121 | best_rsum = 0 122 | best_hit_k_scores = 0 123 | best_recall_K_scoress = 0 124 | no_impr_counter = 0 125 | lr_counter = 0 126 | fout_val_perf_hist = open(os.path.join(opt.logger_name,'val_perf_hist.txt'), 'w') 127 | 128 | for epoch in range(opt.num_epochs): 129 | 130 | # train for one epoch 131 | print "\nEpoch: ", epoch + 1 132 | print "learning rate: ", get_learning_rate(model.optimizer) 133 | train(opt, train_loader, model, epoch) 134 | 135 | # evaluate on validation set 136 | rsum, hit_k_scores, recall_K_scores = validate(val_feat_loader, cand_feat_loader, model, val_video2gtrank, log_step=opt.log_step, opt=opt) 137 | 138 | # remember best R@ sum and save checkpoint 139 | is_best = rsum > best_rsum 140 | best_rsum = max(rsum, best_rsum) 141 | if is_best: 142 | best_hit_k_scores = hit_k_scores 143 | best_recall_K_scoress = recall_K_scores 144 | print 'current perf: ', rsum 145 | print 'best perf: ', best_rsum 146 | print 'current hit_top_k: ', [round(x,3) for x in hit_k_scores] 147 | print 'current recall_top_k: ', [round(x,3) for x in recall_K_scores] 148 | fout_val_perf_hist.write("epoch_%d %f\n" % (epoch, rsum)) 149 | fout_val_perf_hist.flush() 150 | 151 | save_checkpoint({ 152 | 'epoch': epoch + 1, 153 | 'model': model.state_dict(), 154 | 'best_rsum': best_rsum, 155 | 'opt': opt, 156 | 'Eiters': model.Eiters, 157 | }, is_best, filename='checkpoint_epoch_%s.pth.tar' % epoch, prefix=opt.logger_name + '/') 158 | 159 | lr_counter += 1 160 | decay_learning_rate(opt, model.optimizer, opt.lr_decay) 161 | if not is_best: 162 | # Early stop occurs if the validation performance 163 | # does not improve in ten consecutive epochs. 164 | no_impr_counter += 1 165 | if no_impr_counter > 10: 166 | print ("Early stopping happened") 167 | break 168 | 169 | # when the validation performance has decreased after an epoch, 170 | # we divide the learning rate by 2 and continue training; 171 | # but we use each learning rate for at least 3 epochs 172 | if lr_counter > 2: 173 | decay_learning_rate(opt, model.optimizer, 0.5) 174 | lr_counter = 0 175 | else: 176 | # lr_counter = 0 177 | no_impr_counter = 0 178 | 179 | fout_val_perf_hist.close() 180 | # output val performance 181 | print json.dumps(vars(opt), indent = 2) 182 | print '\nbest performance on validation:' 183 | print 'hit_top_k', [round(x,3) for x in best_hit_k_scores] 184 | print 'recall_top_k', [round(x,3) for x in best_recall_K_scoress] 185 | with open(os.path.join(opt.logger_name,'val_perf.txt'), 'w') as fout: 186 | fout.write('best performance on validation:') 187 | fout.write('\nhit_top_k: ' + ", ".join(map(str, [round(x,3) for x in best_hit_k_scores]))) 188 | fout.write('\necall_top_k: ' + ", ".join(map(str, [round(x,3) for x in best_recall_K_scoress]))) 189 | 190 | 191 | 192 | # generate and run the shell script for test 193 | templete = ''.join(open( 'TEMPLATE_eval.sh' ).readlines()) 194 | striptStr = templete.replace('@@@rootpath@@@', opt.rootpath) 195 | striptStr = striptStr.replace('@@@collection@@@', opt.collection) 196 | striptStr = striptStr.replace('@@@overwrite@@@', str(opt.overwrite)) 197 | striptStr = striptStr.replace('@@@model_path@@@', opt.logger_name) 198 | 199 | runfile = 'do_eval_%s.sh' % opt.collection 200 | open( runfile, 'w' ).write(striptStr+'\n') 201 | os.system('chmod +x %s' % runfile) 202 | os.system('./%s' % runfile) 203 | 204 | 205 | def train(opt, train_loader, model, epoch): 206 | # average meters to record the training statistics 207 | batch_time = AverageMeter() 208 | data_time = AverageMeter() 209 | train_logger = LogCollector() 210 | 211 | # switch to train mode 212 | model.train_start() 213 | 214 | progbar = Progbar(train_loader.dataset.length) 215 | end = time.time() 216 | for i, train_data in enumerate(train_loader): 217 | 218 | # measure data loading time 219 | data_time.update(time.time() - end) 220 | 221 | # make sure train logger is used 222 | model.logger = train_logger 223 | 224 | # Update the model 225 | b_size, loss = model.train_emb(*train_data) 226 | # print loss 227 | progbar.add(b_size, values=[("loss", loss)]) 228 | 229 | # measure elapsed time 230 | batch_time.update(time.time() - end) 231 | end = time.time() 232 | 233 | # Record logs in tensorboard 234 | tb_logger.log_value('epoch', epoch, step=model.Eiters) 235 | tb_logger.log_value('step', i, step=model.Eiters) 236 | tb_logger.log_value('batch_time', batch_time.val, step=model.Eiters) 237 | tb_logger.log_value('data_time', data_time.val, step=model.Eiters) 238 | model.logger.tb_log(tb_logger, step=model.Eiters) 239 | 240 | 241 | 242 | def validate(val_feat_loader, cand_feat_loader, model, video2gtrank, log_step=100, opt=None): 243 | # compute the encoding for all the validation images and captions 244 | val_video_embs, val_video_ids_list = encode_data(model, val_feat_loader, log_step, logging.info) 245 | cand_video_embs, cand_video_ids_list = encode_data(model, cand_feat_loader, log_step, logging.info) 246 | 247 | video2predrank = do_predict(val_video_embs, val_video_ids_list, cand_video_embs, cand_video_ids_list, output_dir=None, overwrite=0, no_imgnorm=opt.no_imgnorm) 248 | hit_top_k = [5, 10, 20, 30] 249 | recall_top_k = [50, 100, 200, 300] 250 | hit_k_scores = hit_k_own(video2gtrank, video2predrank, top_k=hit_top_k) 251 | recall_K_scores = recall_k_own(video2gtrank, video2predrank, top_k=recall_top_k) 252 | 253 | for i, k in enumerate(hit_top_k): 254 | tb_logger.log_value('hit_%d' % k, hit_k_scores[i], step=model.Eiters) 255 | for i, k in enumerate(recall_top_k): 256 | tb_logger.log_value('recall_%d' % k, recall_K_scores[i], step=model.Eiters) 257 | currscore = recall_K_scores[1] 258 | 259 | return currscore, hit_k_scores, recall_K_scores 260 | 261 | 262 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', prefix=''): 263 | torch.save(state, prefix + filename) 264 | if is_best: 265 | shutil.copyfile(prefix + filename, prefix + 'model_best.pth.tar') 266 | 267 | 268 | def decay_learning_rate(opt, optimizer, decay): 269 | """decay learning rate to the last LR""" 270 | for param_group in optimizer.param_groups: 271 | param_group['lr'] = param_group['lr']*decay 272 | 273 | 274 | def get_learning_rate(optimizer): 275 | """decay learning rate to the last LR""" 276 | lr_list = [] 277 | for param_group in optimizer.param_groups: 278 | lr_list.append(param_group['lr']) 279 | return lr_list 280 | 281 | 282 | if __name__ == '__main__': 283 | main() 284 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danieljf24/cbvr/f604588a4d25523cc9a667f620d364793924e877/utils/__init__.py -------------------------------------------------------------------------------- /utils/cbvrp_eval.py: -------------------------------------------------------------------------------- 1 | # __author__ = 'Hulu_Research' 2 | 3 | import csv 4 | import numpy as np 5 | 6 | def read_csv(filepath): 7 | """ 8 | read csv file. 9 | :param file_path: the path of the csv file 10 | :return: list of list 11 | """ 12 | reader = csv.reader(open(filepath, 'r')) 13 | data = [] 14 | for x in reader: 15 | data.append(x) 16 | return data 17 | 18 | def read_csv_to_dict(filepath): 19 | """ 20 | read csv file. 21 | :param file_path: the path of the csv file 22 | :return: list of list 23 | """ 24 | reader = csv.reader(open(filepath, 'r')) 25 | video2gtrank = {} 26 | for x in reader: 27 | video2gtrank[x[0]] = x[1:] 28 | return video2gtrank 29 | 30 | 31 | def eval_recall(ground, predict, top_k): 32 | """ 33 | Compute the recall metric using in CBVRP-ACMMM-2018 Challenge. 34 | :param ground: A list of indices represent real relevant show (ids) for current show (id). 35 | :param predict: A list of indices represent predicted relevant show (ids) for current show (id). 36 | :param top_k: max top_k = 500 37 | :return: recall, a float. 38 | 39 | """ 40 | predict = predict[:top_k] 41 | intersect = [x for x in predict if x in ground] 42 | recall = float(len(intersect)) / len(ground) 43 | return recall 44 | 45 | def mean_recall_hit_k(gdir, pdir, top_k): 46 | """ 47 | Compute the mean recall@k & mean hit@k metric over a whole val/test set. 48 | :param gdir: the dir (path) of ground truth file. 49 | :param pdir: the dir (path) of prediction file. 50 | :param top_k: max top_k = 500 51 | :return: mean_recall_k, mean_hit_k, both are floats. 52 | """ 53 | recall_k = 0.0 54 | hit_k = 0 55 | predict_set = read_csv(pdir) 56 | ground_set = read_csv(gdir) 57 | 58 | for i in range(len(predict_set)): 59 | predict = [int(x) for x in predict_set[i]] 60 | ground = [int(x) for x in ground_set[i]] 61 | recall = eval_recall(ground[1:], predict[1:], top_k) 62 | recall_k = recall_k + recall 63 | if recall > 0: 64 | hit_k = hit_k + 1 65 | mean_recall_k = float(recall_k) / len(predict_set) 66 | mean_hit_k = float(hit_k) / len(predict_set) 67 | return mean_recall_k, mean_hit_k 68 | 69 | 70 | 71 | def recall_k(gdir, pdir, top_k=[50, 100, 200, 300]): 72 | mean_recall_k_list = [] 73 | for recall_k in top_k: 74 | mean_recall_k, _ = mean_recall_hit_k(gdir, pdir, recall_k) 75 | mean_recall_k_list.append(round(mean_recall_k,3)) 76 | return mean_recall_k_list 77 | 78 | 79 | 80 | def hit_k(gdir, pdir, top_k=[5, 10, 20, 30]): 81 | mean_hit_k_list = [] 82 | for hit_k in top_k: 83 | _, mean_hit_k = mean_recall_hit_k(gdir, pdir, hit_k) 84 | mean_hit_k_list.append(round(mean_hit_k,3)) 85 | return mean_hit_k_list 86 | 87 | 88 | def mean_recall_hit_k_own(gdict, pdict, top_k): 89 | """ 90 | Compute the mean recall@k & mean hit@k metric over a whole val/test set. 91 | :param gdir: the dir (path) of ground truth file. 92 | :param pdir: the dir (path) of prediction file. 93 | :param top_k: max top_k = 500 94 | :return: mean_recall_k, mean_hit_k, both are floats. 95 | """ 96 | recall_k = 0.0 97 | hit_k = 0 98 | assert len(gdict) == len(pdict), '%d != %d' % (len(gdict), len(pdict)) 99 | 100 | for i, video in enumerate(gdict.keys()): 101 | predict = [int(x) for x in pdict[video]] 102 | ground = [int(x) for x in gdict[video]] 103 | recall = eval_recall(ground, predict, top_k) 104 | recall_k = recall_k + recall 105 | if recall > 0: 106 | hit_k = hit_k + 1 107 | mean_recall_k = float(recall_k) / len(gdict.keys()) 108 | mean_hit_k = float(hit_k) / len(gdict.keys()) 109 | return mean_recall_k, mean_hit_k 110 | 111 | def recall_k_own(gdict, pdict, top_k=[50, 100, 200, 300]): 112 | mean_recall_k_list = [] 113 | for recall_k in top_k: 114 | mean_recall_k, _ = mean_recall_hit_k_own(gdict, pdict, recall_k) 115 | mean_recall_k_list.append(mean_recall_k) 116 | return mean_recall_k_list 117 | 118 | 119 | def hit_k_own(gdict, pdict, top_k=[5, 10, 20, 30]): 120 | mean_hit_k_list = [] 121 | for hit_k in top_k: 122 | _, mean_hit_k = mean_recall_hit_k_own(gdict, pdict, hit_k) 123 | mean_hit_k_list.append(mean_hit_k) 124 | return mean_hit_k_list 125 | 126 | 127 | # Evaluation script example. 128 | if __name__ == "__main__": 129 | track = 'track_1_shows' 130 | fname = 'c3d-pool5' 131 | 132 | gdir = './%s/relevance_val.csv'%(track) 133 | pdir = './%s/predict_val_%s.csv'%(track, fname) 134 | 135 | print('hit_k rate for %s'%(fname)) 136 | mean_hit_k_list = [] 137 | for hit_k in [5, 10, 20, 30, 40, 50]: 138 | _, mean_hit_k = mean_recall_hit_k(gdir, pdir, hit_k) 139 | mean_hit_k_list.append(round(mean_hit_k,3)) 140 | #print('%.3f' % mean_hit_k) 141 | print(mean_hit_k_list) 142 | 143 | 144 | print('recall_k rate for %s'%(fname)) 145 | mean_recall_k_list = [] 146 | for recall_k in [50, 100, 200, 300, 400, 500]: 147 | mean_recall_k, _ = mean_recall_hit_k(gdir, pdir, recall_k) 148 | mean_recall_k_list.append(round(mean_recall_k,3)) 149 | #print('%.3f' % mean_recall_k) 150 | print(mean_recall_k_list) 151 | 152 | -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | 5 | 6 | #ROOT_PATH = os.path.join(os.environ['HOME'], 'VisualSearch') 7 | ROOT_PATH = '/media/daniel/D/zlm/VisualSearch/cbvr_tkde' 8 | 9 | 10 | def makedirsforfile(filename): 11 | try: 12 | os.makedirs(os.path.split(filename)[0]) 13 | except: 14 | pass 15 | 16 | 17 | def niceNumber(v, maxdigit=6): 18 | """Nicely format a number, with a maximum of 6 digits.""" 19 | assert(maxdigit >= 0) 20 | 21 | if maxdigit == 0: 22 | return "%.0f" % v 23 | 24 | fmt = '%%.%df' % maxdigit 25 | s = fmt % v 26 | 27 | if len(s) > maxdigit: 28 | return s.rstrip("0").rstrip(".") 29 | elif len(s) == 0: 30 | return "0" 31 | else: 32 | return s 33 | 34 | 35 | 36 | def checkToSkip(filename, overwrite): 37 | if os.path.exists(filename): 38 | print ("%s exists." % filename), 39 | if overwrite: 40 | print ("overwrite") 41 | return 0 42 | else: 43 | print ("skip") 44 | return 1 45 | return 0 46 | 47 | 48 | def printMessage(message_type, trace, message): 49 | print ('%s %s [%s] %s' % (time.strftime('%d/%m/%Y %H:%M:%S'), message_type, trace, message)) 50 | 51 | def printStatus(trace, message): 52 | printMessage('INFO', trace, message) 53 | 54 | def printError(trace, message): 55 | printMessage('ERROR', trace, message) 56 | 57 | 58 | class CmdOptions: 59 | def __init__(self): 60 | self.value = {} 61 | self.addOption("rootpath", ROOT_PATH) 62 | self.addOption("overwrite", 0) 63 | self.addOption("dryrun", 0) 64 | self.addOption("numjobs", 1) 65 | self.addOption("job", 1) 66 | 67 | def printHelp(self): 68 | print (""" 69 | --rootpath [default: %s] 70 | --numjobs [default: 1] 71 | --job [default: 1] 72 | --overwrite [default: %d]""" % (self.getString("rootpath"), self.getInt("overwrite"))) 73 | 74 | def addOption(self, param, val): 75 | self.value[param] = val 76 | 77 | def getString(self, param): 78 | return self.value[param] 79 | 80 | def getInt(self, param): 81 | return int(self.getDouble(param)) 82 | 83 | def getDouble(self, param): 84 | return float(self.getString(param)) 85 | 86 | def getBool(self, param): 87 | return self.getInt(param) == 1 88 | 89 | def parseArgs(self, argv): 90 | i = 0 91 | while i < len(argv) -1: 92 | if argv[i].startswith("--"): 93 | if argv[i+1].startswith("--"): 94 | i += 1 95 | continue 96 | param = argv[i][2:] 97 | if param in self.value: 98 | self.value[param] = argv[i+1] 99 | i += 2 100 | else: 101 | i += 1 102 | else: 103 | i += 1 104 | okay = self.checkArgs() 105 | if not okay: 106 | self.printHelp() 107 | return okay 108 | 109 | 110 | def printArgs(self): 111 | for key in self.value.keys(): 112 | print ("--%s %s" % (key, self.getString(key))) 113 | 114 | def checkArgs(self): 115 | paramsNeeded = [param for (param,value) in self.value.iteritems() if value is ""] 116 | 117 | if paramsNeeded: 118 | printError(self.__class__.__name__,"Need more arguments: %s" % " ".join(paramsNeeded)) 119 | return False 120 | 121 | if self.getInt("numjobs") < self.getInt("job"): 122 | printError(self.__class__.__name__, "numjobs cannot be smaller than job") 123 | return False 124 | 125 | return True 126 | 127 | 128 | def total_seconds(td): 129 | return (td.microseconds + (td.seconds + td.days * 24 * 3600) * 1e6) / 1e6 130 | 131 | if __name__ == "__main__": 132 | cmdOpts = CmdOptions() 133 | cmdOpts.printHelp() 134 | cmdOpts.parseArgs(sys.argv[1:]) 135 | print niceNumber(1.0/3, 4) 136 | for i in range(0, 15): 137 | print niceNumber(8.17717824342e-10, i) 138 | -------------------------------------------------------------------------------- /utils/generic_utils.py: -------------------------------------------------------------------------------- 1 | """Python utilities required by Keras.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import binascii 7 | import numpy as np 8 | 9 | import time 10 | import sys 11 | import six 12 | import marshal 13 | import types as python_types 14 | import inspect 15 | import codecs 16 | import collections 17 | 18 | _GLOBAL_CUSTOM_OBJECTS = {} 19 | 20 | 21 | class CustomObjectScope(object): 22 | """Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape. 23 | 24 | Code within a `with` statement will be able to access custom objects 25 | by name. Changes to global custom objects persist 26 | within the enclosing `with` statement. At end of the `with` statement, 27 | global custom objects are reverted to state 28 | at beginning of the `with` statement. 29 | 30 | # Example 31 | 32 | Consider a custom object `MyObject` (e.g. a class): 33 | 34 | ```python 35 | with CustomObjectScope({'MyObject':MyObject}): 36 | layer = Dense(..., kernel_regularizer='MyObject') 37 | # save, load, etc. will recognize custom object by name 38 | ``` 39 | """ 40 | 41 | def __init__(self, *args): 42 | self.custom_objects = args 43 | self.backup = None 44 | 45 | def __enter__(self): 46 | self.backup = _GLOBAL_CUSTOM_OBJECTS.copy() 47 | for objects in self.custom_objects: 48 | _GLOBAL_CUSTOM_OBJECTS.update(objects) 49 | return self 50 | 51 | def __exit__(self, *args, **kwargs): 52 | _GLOBAL_CUSTOM_OBJECTS.clear() 53 | _GLOBAL_CUSTOM_OBJECTS.update(self.backup) 54 | 55 | 56 | def custom_object_scope(*args): 57 | """Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape. 58 | 59 | Convenience wrapper for `CustomObjectScope`. 60 | Code within a `with` statement will be able to access custom objects 61 | by name. Changes to global custom objects persist 62 | within the enclosing `with` statement. At end of the `with` statement, 63 | global custom objects are reverted to state 64 | at beginning of the `with` statement. 65 | 66 | # Example 67 | 68 | Consider a custom object `MyObject` 69 | 70 | ```python 71 | with custom_object_scope({'MyObject':MyObject}): 72 | layer = Dense(..., kernel_regularizer='MyObject') 73 | # save, load, etc. will recognize custom object by name 74 | ``` 75 | 76 | # Arguments 77 | *args: Variable length list of dictionaries of name, 78 | class pairs to add to custom objects. 79 | 80 | # Returns 81 | Object of type `CustomObjectScope`. 82 | """ 83 | return CustomObjectScope(*args) 84 | 85 | 86 | def get_custom_objects(): 87 | """Retrieves a live reference to the global dictionary of custom objects. 88 | 89 | Updating and clearing custom objects using `custom_object_scope` 90 | is preferred, but `get_custom_objects` can 91 | be used to directly access `_GLOBAL_CUSTOM_OBJECTS`. 92 | 93 | # Example 94 | 95 | ```python 96 | get_custom_objects().clear() 97 | get_custom_objects()['MyObject'] = MyObject 98 | ``` 99 | 100 | # Returns 101 | Global dictionary of names to classes (`_GLOBAL_CUSTOM_OBJECTS`). 102 | """ 103 | return _GLOBAL_CUSTOM_OBJECTS 104 | 105 | 106 | def serialize_keras_object(instance): 107 | if instance is None: 108 | return None 109 | if hasattr(instance, 'get_config'): 110 | return { 111 | 'class_name': instance.__class__.__name__, 112 | 'config': instance.get_config() 113 | } 114 | if hasattr(instance, '__name__'): 115 | return instance.__name__ 116 | else: 117 | raise ValueError('Cannot serialize', instance) 118 | 119 | 120 | def deserialize_keras_object(identifier, module_objects=None, 121 | custom_objects=None, 122 | printable_module_name='object'): 123 | if isinstance(identifier, dict): 124 | # In this case we are dealing with a Keras config dictionary. 125 | config = identifier 126 | if 'class_name' not in config or 'config' not in config: 127 | raise ValueError('Improper config format: ' + str(config)) 128 | class_name = config['class_name'] 129 | if custom_objects and class_name in custom_objects: 130 | cls = custom_objects[class_name] 131 | elif class_name in _GLOBAL_CUSTOM_OBJECTS: 132 | cls = _GLOBAL_CUSTOM_OBJECTS[class_name] 133 | else: 134 | module_objects = module_objects or {} 135 | cls = module_objects.get(class_name) 136 | if cls is None: 137 | raise ValueError('Unknown ' + printable_module_name + 138 | ': ' + class_name) 139 | if hasattr(cls, 'from_config'): 140 | custom_objects = custom_objects or {} 141 | if has_arg(cls.from_config, 'custom_objects'): 142 | return cls.from_config(config['config'], 143 | custom_objects=dict(list(_GLOBAL_CUSTOM_OBJECTS.items()) + 144 | list(custom_objects.items()))) 145 | with CustomObjectScope(custom_objects): 146 | return cls.from_config(config['config']) 147 | else: 148 | # Then `cls` may be a function returning a class. 149 | # in this case by convention `config` holds 150 | # the kwargs of the function. 151 | custom_objects = custom_objects or {} 152 | with CustomObjectScope(custom_objects): 153 | return cls(**config['config']) 154 | elif isinstance(identifier, six.string_types): 155 | function_name = identifier 156 | if custom_objects and function_name in custom_objects: 157 | fn = custom_objects.get(function_name) 158 | elif function_name in _GLOBAL_CUSTOM_OBJECTS: 159 | fn = _GLOBAL_CUSTOM_OBJECTS[function_name] 160 | else: 161 | fn = module_objects.get(function_name) 162 | if fn is None: 163 | raise ValueError('Unknown ' + printable_module_name + 164 | ':' + function_name) 165 | return fn 166 | else: 167 | raise ValueError('Could not interpret serialized ' + 168 | printable_module_name + ': ' + identifier) 169 | 170 | 171 | def func_dump(func): 172 | """Serializes a user defined function. 173 | 174 | # Arguments 175 | func: the function to serialize. 176 | 177 | # Returns 178 | A tuple `(code, defaults, closure)`. 179 | """ 180 | raw_code = marshal.dumps(func.__code__) 181 | code = codecs.encode(raw_code, 'base64').decode('ascii') 182 | defaults = func.__defaults__ 183 | if func.__closure__: 184 | closure = tuple(c.cell_contents for c in func.__closure__) 185 | else: 186 | closure = None 187 | return code, defaults, closure 188 | 189 | 190 | def func_load(code, defaults=None, closure=None, globs=None): 191 | """Deserializes a user defined function. 192 | 193 | # Arguments 194 | code: bytecode of the function. 195 | defaults: defaults of the function. 196 | closure: closure of the function. 197 | globs: dictionary of global objects. 198 | 199 | # Returns 200 | A function object. 201 | """ 202 | if isinstance(code, (tuple, list)): # unpack previous dump 203 | code, defaults, closure = code 204 | if isinstance(defaults, list): 205 | defaults = tuple(defaults) 206 | 207 | def ensure_value_to_cell(value): 208 | """Ensures that a value is converted to a python cell object. 209 | 210 | # Arguments 211 | value: Any value that needs to be casted to the cell type 212 | 213 | # Returns 214 | A value wrapped as a cell object (see function "func_load") 215 | 216 | """ 217 | def dummy_fn(): 218 | value # just access it so it gets captured in .__closure__ 219 | 220 | cell_value = dummy_fn.__closure__[0] 221 | if not isinstance(value, type(cell_value)): 222 | return cell_value 223 | else: 224 | return value 225 | 226 | if closure is not None: 227 | closure = tuple(ensure_value_to_cell(_) for _ in closure) 228 | try: 229 | raw_code = codecs.decode(code.encode('ascii'), 'base64') 230 | code = marshal.loads(raw_code) 231 | except (UnicodeEncodeError, binascii.Error, ValueError): 232 | # backwards compatibility for models serialized prior to 2.1.2 233 | raw_code = code.encode('raw_unicode_escape') 234 | code = marshal.loads(raw_code) 235 | if globs is None: 236 | globs = globals() 237 | return python_types.FunctionType(code, globs, 238 | name=code.co_name, 239 | argdefs=defaults, 240 | closure=closure) 241 | 242 | 243 | def has_arg(fn, name, accept_all=False): 244 | """Checks if a callable accepts a given keyword argument. 245 | 246 | For Python 2, checks if there is an argument with the given name. 247 | 248 | For Python 3, checks if there is an argument with the given name, and 249 | also whether this argument can be called with a keyword (i.e. if it is 250 | not a positional-only argument). 251 | 252 | # Arguments 253 | fn: Callable to inspect. 254 | name: Check if `fn` can be called with `name` as a keyword argument. 255 | accept_all: What to return if there is no parameter called `name` 256 | but the function accepts a `**kwargs` argument. 257 | 258 | # Returns 259 | bool, whether `fn` accepts a `name` keyword argument. 260 | """ 261 | if sys.version_info < (3,): 262 | arg_spec = inspect.getargspec(fn) 263 | if accept_all and arg_spec.keywords is not None: 264 | return True 265 | return (name in arg_spec.args) 266 | elif sys.version_info < (3, 3): 267 | arg_spec = inspect.getfullargspec(fn) 268 | if accept_all and arg_spec.varkw is not None: 269 | return True 270 | return (name in arg_spec.args or 271 | name in arg_spec.kwonlyargs) 272 | else: 273 | signature = inspect.signature(fn) 274 | parameter = signature.parameters.get(name) 275 | if parameter is None: 276 | if accept_all: 277 | for param in signature.parameters.values(): 278 | if param.kind == inspect.Parameter.VAR_KEYWORD: 279 | return True 280 | return False 281 | return (parameter.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, 282 | inspect.Parameter.KEYWORD_ONLY)) 283 | 284 | 285 | class Progbar(object): 286 | """Displays a progress bar. 287 | 288 | # Arguments 289 | target: Total number of steps expected, None if unknown. 290 | width: Progress bar width on screen. 291 | verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) 292 | stateful_metrics: Iterable of string names of metrics that 293 | should *not* be averaged over time. Metrics in this list 294 | will be displayed as-is. All others will be averaged 295 | by the progbar before display. 296 | interval: Minimum visual progress update interval (in seconds). 297 | """ 298 | 299 | def __init__(self, target, width=30, verbose=1, interval=0.05, 300 | stateful_metrics=None): 301 | self.target = target 302 | self.width = width 303 | self.verbose = verbose 304 | self.interval = interval 305 | if stateful_metrics: 306 | self.stateful_metrics = set(stateful_metrics) 307 | else: 308 | self.stateful_metrics = set() 309 | 310 | self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and 311 | sys.stdout.isatty()) or 312 | 'ipykernel' in sys.modules) 313 | self._total_width = 0 314 | self._seen_so_far = 0 315 | self._values = collections.OrderedDict() 316 | self._start = time.time() 317 | self._last_update = 0 318 | 319 | def update(self, current, values=None): 320 | """Updates the progress bar. 321 | 322 | # Arguments 323 | current: Index of current step. 324 | values: List of tuples: 325 | `(name, value_for_last_step)`. 326 | If `name` is in `stateful_metrics`, 327 | `value_for_last_step` will be displayed as-is. 328 | Else, an average of the metric over time will be displayed. 329 | """ 330 | values = values or [] 331 | for k, v in values: 332 | if k not in self.stateful_metrics: 333 | if k not in self._values: 334 | self._values[k] = [v * (current - self._seen_so_far), 335 | current - self._seen_so_far] 336 | else: 337 | self._values[k][0] += v * (current - self._seen_so_far) 338 | self._values[k][1] += (current - self._seen_so_far) 339 | else: 340 | self._values[k] = v 341 | self._seen_so_far = current 342 | 343 | now = time.time() 344 | info = ' - %.0fs' % (now - self._start) 345 | if self.verbose == 1: 346 | if (now - self._last_update < self.interval and 347 | self.target is not None and current < self.target): 348 | return 349 | 350 | prev_total_width = self._total_width 351 | if self._dynamic_display: 352 | sys.stdout.write('\b' * prev_total_width) 353 | sys.stdout.write('\r') 354 | else: 355 | sys.stdout.write('\n') 356 | 357 | if self.target is not None: 358 | numdigits = int(np.floor(np.log10(self.target))) + 1 359 | barstr = '%%%dd/%d [' % (numdigits, self.target) 360 | bar = barstr % current 361 | prog = float(current) / self.target 362 | prog_width = int(self.width * prog) 363 | if prog_width > 0: 364 | bar += ('=' * (prog_width - 1)) 365 | if current < self.target: 366 | bar += '>' 367 | else: 368 | bar += '=' 369 | bar += ('.' * (self.width - prog_width)) 370 | bar += ']' 371 | else: 372 | bar = '%7d/Unknown' % current 373 | 374 | self._total_width = len(bar) 375 | sys.stdout.write(bar) 376 | 377 | if current: 378 | time_per_unit = (now - self._start) / current 379 | else: 380 | time_per_unit = 0 381 | if self.target is not None and current < self.target: 382 | eta = time_per_unit * (self.target - current) 383 | if eta > 3600: 384 | eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) // 60, eta % 60) 385 | elif eta > 60: 386 | eta_format = '%d:%02d' % (eta // 60, eta % 60) 387 | else: 388 | eta_format = '%ds' % eta 389 | 390 | info = ' - ETA: %s' % eta_format 391 | else: 392 | if time_per_unit >= 1: 393 | info += ' %.0fs/step' % time_per_unit 394 | elif time_per_unit >= 1e-3: 395 | info += ' %.0fms/step' % (time_per_unit * 1e3) 396 | else: 397 | info += ' %.0fus/step' % (time_per_unit * 1e6) 398 | 399 | for k in self._values: 400 | info += ' - %s:' % k 401 | if isinstance(self._values[k], list): 402 | avg = np.mean( 403 | self._values[k][0] / max(1, self._values[k][1])) 404 | if abs(avg) > 1e-3: 405 | info += ' %.4f' % avg 406 | else: 407 | info += ' %.4e' % avg 408 | else: 409 | info += ' %s' % self._values[k] 410 | 411 | self._total_width += len(info) 412 | if prev_total_width > self._total_width: 413 | info += (' ' * (prev_total_width - self._total_width)) 414 | 415 | if self.target is not None and current >= self.target: 416 | info += '\n' 417 | 418 | sys.stdout.write(info) 419 | sys.stdout.flush() 420 | 421 | elif self.verbose == 2: 422 | if self.target is None or current >= self.target: 423 | for k in self._values: 424 | info += ' - %s:' % k 425 | avg = np.mean( 426 | self._values[k][0] / max(1, self._values[k][1])) 427 | if avg > 1e-3: 428 | info += ' %.4f' % avg 429 | else: 430 | info += ' %.4e' % avg 431 | info += '\n' 432 | 433 | sys.stdout.write(info) 434 | sys.stdout.flush() 435 | 436 | self._last_update = now 437 | 438 | def add(self, n, values=None): 439 | self.update(self._seen_so_far + n, values) 440 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | 4 | def read_video_set(filepath): 5 | reader = csv.reader(open(filepath, 'r')) 6 | data = [] 7 | for x in reader: 8 | data.append(x[0]) 9 | return data 10 | 11 | 12 | def write_csv(filepath, data): 13 | """ 14 | wirete csv file. 15 | :param file_path: the path of the csv file 16 | :param data: writing data 17 | """ 18 | with open(filepath, 'w') as csvfile: 19 | csv_writer = csv.writer(csvfile) 20 | for line in data: 21 | csv_writer.writerow(line) 22 | 23 | 24 | def read_csv_video2rank(filepath): 25 | video2rank = {} 26 | with open(filepath, 'r') as csvfile: 27 | reader = csv.reader(csvfile) 28 | for data in reader: 29 | video = data[0] 30 | assert video not in video2rank 31 | video2rank[video] = data[1:] 32 | return video2rank 33 | 34 | 35 | def write_csv_video2rank(filepath, video2rank, topk=500): 36 | result_data = [] 37 | for video, ranks in video2rank.items(): 38 | result_data.append([video] + ranks[:topk]) 39 | write_csv(filepath, result_data) 40 | 41 | 42 | def write_csv_video2rank_fusion(videolist, filepath, video2rank, topk=500): 43 | result_data = [] 44 | assert len(videolist) == len(video2rank) 45 | for video in videolist: 46 | ranks = video2rank[video] 47 | result_data.append([video] + ranks[:topk]) 48 | write_csv(filepath, result_data) 49 | 50 | 51 | def read_dict(filepath): 52 | f = open(filepath,'r') 53 | a = f.read() 54 | dict_data = eval(a) 55 | f.close() 56 | return dict_data 57 | 58 | 59 | def write_dict(filepath, dict_data): 60 | f = open(filepath,'w') 61 | f.write(str(dict_data)) 62 | f.close() 63 | 64 | 65 | def get_count(path): 66 | train=os.path.join(path,'split/train.csv') 67 | val= os.path.join(path,'split/val.csv') 68 | id_list=[] 69 | train_reader=csv.reader(open(train)) 70 | val_reader=csv.reader(open(val)) 71 | 72 | for vid in train_reader: 73 | id_list.append(vid[0]) 74 | for vid in val_reader: 75 | id_list.append(vid[0]) 76 | id_set = set(id_list) 77 | 78 | rele_train=os.path.join(path,'relevance_train.csv') 79 | rele_val = os.path.join(path,'relevance_val.csv') 80 | rele_train_reader=csv.reader(open(rele_train)) 81 | rele_val_reader=csv.reader(open(rele_val)) 82 | 83 | output_file = os.path.join(path, 'rel_index.csv') 84 | with open(output_file, 'w') as csvfile: 85 | writew = csv.writer(csvfile) 86 | for rele in rele_train_reader: 87 | index_list=[] 88 | for rele_id in (rele[1:]): 89 | if rele_id in id_set: 90 | index_list.append(id_list.index(rele_id)) 91 | writew.writerow(index_list) 92 | 93 | for rele in rele_val_reader: 94 | index_list=[] 95 | for rele_id in rele[1:]: 96 | if rele_id in id_set: 97 | index_list.append(id_list.index(rele_id)) 98 | writew.writerow(index_list) 99 | 100 | print('write out: %s' % output_file) --------------------------------------------------------------------------------