├── LICENSE
├── README.md
├── annotations
├── ptd_api.py
├── script_test_annotation.py
└── utils.py
├── data
├── __init__.py
├── data_loader.py
└── dictForDb_vid_v2.pd
├── fun
├── __init__.py
├── classSST.py
├── create_word2vec_for_dataset.py
├── dashed_rect.py
├── datasetLoader.py
├── datasetParser.py
├── eval.py
├── evalDet.py
├── image_toolbox.py
├── logInfo.py
├── lossPackage.py
├── modelArc.py
├── netUtil.py
├── optimizers.py
├── train.py
├── vidDataset.py
├── vidDatasetParser.py
└── wsParamParser.py
├── images
├── frm.png
└── task.png
├── readme.md
├── scripts
├── test_video_emb_att.sh
└── train_video_emb_att.sh
└── util
├── __init__.py
├── __init__.pyc
├── base_parser.py
├── base_parser.pyc
├── get_image_size.py
├── mytoolbox.py
└── mytoolbox.pyc
/LICENSE:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zfchenUnique/WSSTG/ba19919b5010321651088afaecba877aebb4ed3a/LICENSE
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Weakly-Supervised Spatio-Temporally Grounding Natural Sentence in Video
2 |
3 | This repo contains the main baselines of VID-sentence dataset introduced in WSSTG.
4 | Please refer to [our paper](https://arxiv.org/abs/1906.02549) and the [repo](https://github.com/JeffCHEN2017/VID-sentence) for the information of VID-sentence dataset.
5 |
6 |
7 | ### Task
8 |
9 |
10 | Description: "A brown and white dog is lying on the grass and then it stands up."
11 |
12 |
13 |
14 |
15 |
16 | The proposed WSSTG task aims to localize a spatio-temporal tube (i.e., the sequence of green bounding boxes) in the video which semantically corresponds to the given sentence, with no reliance on any spatio-temporal annotations during training.
17 |
18 |
19 | ### Architecture
20 |
21 |
22 | The architecture of the proposed method.
23 |
24 |
25 |
26 |
27 |
28 | ### Contents
29 | 1. [Requirements: software](#requirements-software)
30 | 2. [Installation](#installation)
31 | 3. [Training](#Training)
32 | 4. [Testing](#Testing)
33 |
34 | ### Requirements: software
35 |
36 | - Pytorch (version=0.4.0)
37 | - python 2.7
38 | - numpy
39 | - scipy
40 | - magic
41 | - easydict
42 | - dill
43 | - matplotlib
44 | - tensorboardX
45 |
46 |
47 | ### Installation
48 |
49 | 1. Clone the WSSTG repository and VID-sentence reposity
50 |
51 | ```Shell
52 | git clone https://github.com/JeffCHEN2017/WSSTG.git
53 | git clone https://github.com/JeffCHEN2017/VID-Sentence.git
54 | ln -s VID-sentence_ROOT/data/ILSVRC WSSTG_ROOT/data
55 | ```
56 | 2. Download [tube proposals](https://drive.google.com/file/d/1SHwXtlb7V8PH4_60-0-VZYL-7kXEG_Wj/view?usp=sharing), [RGB feature](https://drive.google.com/file/d/1ll_AkiByvQsJTdPNVt1TH6BUPbQxjvG_/view?usp=sharing) and [I3D feature](https://drive.google.com/file/d/1SwPGweipeuREZrAXGzu7nPACy9vXNmmp/view?usp=sharing) from Google Drive.
57 |
58 | 3. Extract *.tar files and make symlinks between the download data and the desired data folder
59 |
60 | ```Shell
61 | tar xvf tubePrp.tar
62 | ln -s tubePrp $WSSTG_ROOT/data/tubePrp
63 |
64 | tar xvf vid_i3d.tar vid_i3d
65 | ln -s vid_i3d $WSSTG_ROOT/data/vid_i3d
66 | ln -s $WSSTG_ROOT/data/vid_i3d/val test
67 |
68 | tar xvf vid_rgb.tar vid_rgb
69 | ln -s vid_rgb $WSSTG_ROOT/data/vid_rgb
70 | ln -s $WSSTG_ROOT/data/vid_rgb/vidTubeCacheFtr/val test
71 | ```
72 |
73 | Note: We extract the tube proposals using the method proposed by [Gkioxari and Malik](https://arxiv.org/abs/1411.6031) .A python implementation [here](ttps://www.mi.t.u-tokyo.ac.jp/projects/person_search/) is provided by Yamaguchi etal..
74 | We extract singel-frame propsoals and RGB feature for each frame using a [faster-RCNN](https://arxiv.org/abs/1506.01497) model pretrained on COCO dataset, which is provided by [Jianwei Yang](https://github.com/jwyang/faster-rcnn.pytorch).
75 | We extract [I3D-RGB and I3D-flow features](https://arxiv.org/abs/1705.07750) using the model provided by [Carreira and Zisserman](https://github.com/deepmind/kinetics-i3d.git).
76 |
77 |
78 | ### Training
79 | ```Shell
80 | cd $WSSTG_ROOT
81 | sh scripts/train_video_emb_att.sh
82 | ```
83 | Notice: Because the changes of batch sizes and the random seed, the performance may be slightly different from our submission. We provide a checkpoint here which achieves similar performance (38.1 VS 38.2 on the accuracy@0.5 ) to the model we reported in the paper.
84 |
85 | ### Testing
86 | Download the checkpoint from [Google Drive](https://drive.google.com/file/d/1oM0J4jIbcd4SYo9T29ydk3gugoOjFCKA/view?usp=sharing), put it in WSSTG_ROOT/data/models and run
87 | ```Shell
88 | cd $WSSTG_ROOT
89 | sh scripts/test_video_emb_att.sh
90 | ```
91 |
92 | ### License
93 |
94 | WSSTG is released under the CC-BY-NC 4.0 LICENSE (refer to the LICENSE file for details).
95 |
96 | ### Citing WSSTG
97 |
98 | If you find this repo useful in your research, please consider citing:
99 |
100 | @inproceedings{chen2019weakly,
101 | Title={Weakly-Supervised Spatio-Temporally Grounding Natural Sentence in Video},
102 | Author={Chen, Zhenfang and Ma, Lin and Luo, Wenhan and Wong, Kwan-Yee~K},
103 | Booktitle={ACL},
104 | year={2019}
105 | }
106 |
107 | ### Contact
108 |
109 | You can contact Zhenfang Chen by sending email to chenzhenfang2013@gmail.com
110 |
--------------------------------------------------------------------------------
/annotations/ptd_api.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | import os
4 | from collections import defaultdict
5 | from utils import *
6 | import pdb
7 |
8 | ROOT_DIR = os.path.dirname(os.path.abspath(os.path.join(os.getcwd(), __file__)))
9 | DESCS_JSON = os.path.join(ROOT_DIR, 'data', '%s_descriptions.json')
10 | PEOPLE_JSON = os.path.join(ROOT_DIR, 'data', '%s_people.json')
11 | SHOTS_JSON = os.path.join(ROOT_DIR, 'data', '%s_shots.json')
12 | if os.path.exists(os.path.join(ROOT_DIR, 'data', 'video2fps.json')):
13 | VID2FPS_JSON = os.path.join(ROOT_DIR, 'data', 'video2fps.json')
14 | else:
15 | VID2FPS_JSON = os.path.join(ROOT_DIR, 'data', 'video2fps.json.example')
16 | DS = ['train', 'val', 'test']
17 | AN_JSON = os.path.join(ROOT_DIR, 'activity_net.v1-3.min.json')
18 |
19 | class ActivityNet(object):
20 | _data = None
21 | _video_ids = None
22 | def __init__(self):
23 | if not os.path.exists(AN_JSON):
24 | raise Exception('Download the annotation file of ActivityNet v1.3, and place it to "%s".' % AN_JSON)
25 | self._data = jsonload(AN_JSON)
26 | assert self._data['version'] == 'VERSION 1.3'
27 | self._video_ids = sorted(self._data['database'].keys())
28 |
29 | @property
30 | def video_ids(self):
31 | return self._video_ids
32 |
33 | @property
34 | def data(self):
35 | return self._data
36 |
37 | class PTD(object):
38 | _descriptions = None
39 | _shots = None
40 | _people = None
41 | _person2desc = None
42 | _desc2person = None
43 | _person2shot = None
44 | _shot2person = None
45 | _id2shot = None
46 | _id2person = None
47 | _id2desc = None
48 | _vid2fps = None
49 | _an_data = None
50 |
51 | def __init__(self, dataset):
52 | assert dataset in DS
53 | self._dataset = dataset
54 |
55 | def description(self, index):
56 | return Description(self.id2desc[index], self)
57 |
58 | def person(self, index):
59 | return Person(self.id2person[index], self)
60 |
61 | def shot(self, index):
62 | return Shot(self.id2shot[index], self)
63 |
64 | @property
65 | def descriptions(self):
66 | if self._descriptions is None:
67 | self._descriptions = jsonload(DESCS_JSON % self._dataset)
68 | return self._descriptions
69 |
70 | @property
71 | def id2desc(self):
72 | if self._id2desc is None:
73 | self._id2desc = {}
74 | for desc in self.descriptions:
75 | self._id2desc[desc['id']] = desc
76 | return self._id2desc
77 |
78 | @property
79 | def people(self):
80 | if self._people is None:
81 | self._people = jsonload(PEOPLE_JSON % self._dataset)
82 | return self._people
83 |
84 | @property
85 | def id2person(self):
86 | if self._id2person is None:
87 | self._id2person = {}
88 | for person in self.people:
89 | self._id2person[person['id']] = person
90 | return self._id2person
91 |
92 | @property
93 | def shots(self):
94 | if self._shots is None:
95 | self._shots = jsonload(SHOTS_JSON % self._dataset)
96 | return self._shots
97 |
98 | @property
99 | def id2shot(self):
100 | if self._id2shot is None:
101 | self._id2shot = {}
102 | for shot in self.shots:
103 | self._id2shot[shot['id']] = shot
104 | return self._id2shot
105 |
106 | ### id2id ###
107 | @property
108 | def person2desc(self):
109 | if self._person2desc is None:
110 | self._person2desc = {}
111 | for person in self.people:
112 | self._person2desc[person['id']] = person['descriptions']
113 | return self._person2desc
114 |
115 | @property
116 | def desc2person(self):
117 | if self._desc2person is None:
118 | self._desc2person = {}
119 | for person in self.people:
120 | for desc_id in person['descriptions']:
121 | self._desc2person[desc_id] = person['id']
122 | return self._desc2person
123 |
124 | @property
125 | def person2shot(self):
126 | if self._person2shot is None:
127 | self._person2shot = {}
128 | for person in self.people:
129 | self._person2shot[person['id']] = person['shot_id']
130 | return self._person2shot
131 |
132 | @property
133 | def shot2person(self):
134 | if self._shot2person is None:
135 | self._shot2person = defaultdict(list)
136 | for person in self.people:
137 | self._shot2person[person['shot_id']].append(person['id'])
138 | return self._shot2person
139 |
140 | ### others
141 | @property
142 | def vid2fps(self):
143 | if self._vid2fps is None:
144 | self._vid2fps = jsonload(VID2FPS_JSON)
145 | return self._vid2fps
146 |
147 | @property
148 | def an_data(self):
149 | if self._an_data is None:
150 | self._an_data = ActivityNet()
151 | return self._an_data
152 |
153 | class BaseInstance(object):
154 | def __init__(self, info, parent):
155 | self._info = info
156 | self._parent = parent
157 |
158 | def __getitem__(self, index):
159 | return self._info[index]
160 |
161 | @property
162 | def id(self):
163 | return self._info['id']
164 |
165 | class Description(BaseInstance):
166 | @property
167 | def description(self):
168 | return self._info['description']
169 |
170 | @property
171 | def person(self):
172 | person_id = self._parent.desc2person[self._info['id']]
173 | return self._parent.person(person_id)
174 |
175 | @property
176 | def shot(self):
177 | return self.person.shot
178 |
179 | class Shot(BaseInstance):
180 | @property
181 | def video_id(self):
182 | return self._parent.an_data.video_ids[self._info['an_video_id']]
183 |
184 | @property
185 | def first_second(self):
186 | return min(self.annotated_seconds)
187 |
188 | @property
189 | def last_second(self):
190 | return max(self.annotated_seconds)
191 |
192 | @property
193 | def annotated_seconds(self):
194 | return self._info['annotated_seconds']
195 |
196 | @property
197 | def first_frame(self):
198 | return self.sec2frame(self.first_second)
199 |
200 | @property
201 | def last_frame(self):
202 | if self.video_id == 'll91M5topgU':
203 | return 161
204 | if self.video_id == 'uOmCwWVJnLQ':
205 | return 2999
206 | return self.sec2frame(self.last_second)
207 |
208 | @property
209 | def annotated_frames(self):
210 | return [self.sec2frame(sec) for sec in self._info['annotated_seconds']]
211 |
212 | @property
213 | def fully_annotated(self):
214 | return self._info['fully_annotated']
215 |
216 | @property
217 | def people(self):
218 | person_ids = self._parent.shot2person[self._info['id']]
219 | return [self._parent.person(person_id) for person_id in person_ids]
220 |
221 | @property
222 | def descriptions(self):
223 | descs = []
224 | for person in self.people:
225 | descs += person.descriptions
226 | return descs
227 |
228 | @property
229 | def fps(self):
230 | return self._parent.vid2fps[self.video_id]
231 |
232 | def sec2frame(self, sec):
233 | return int(round(sec * self.fps))
234 |
235 | class Person(BaseInstance):
236 | @property
237 | def boxes(self):
238 | return self._info['boxes']
239 |
240 | @property
241 | def descriptions(self):
242 | return [self._parent.description(i) for i in self._info['descriptions']]
243 |
244 | @property
245 | def shot(self):
246 | shot_id = self._parent.person2shot[self._info['id']]
247 | return self._parent.shot(shot_id)
248 |
249 | def demo():
250 | ptd = PTD('test')
251 | print 'Showing information of the shot of which ID is 1...'
252 | shot = ptd.shot(1)
253 | print '[SHOT] ID: %d' % shot.id
254 | print '[SHOT] VIDEO URL: %s' % shot.video_id
255 | print '[SHOT] START TIME: %s' % shot.first_second
256 | print '[SHOT] END TIME: %s' % shot.last_second
257 | print
258 |
259 | print 'Showing information of the person of which ID is 1...'
260 | person = ptd.person(1)
261 | print '[PERSON] ID: %d' % person.id
262 | print '[PERSON] PARENT SHOT ID: %s' % person.shot.id
263 | print '[PERSON] DESCRIPTIONS: %s' % ', '.join(['"%s"' % d.description for d in person.descriptions])
264 | #print
265 |
266 | print 'Showing information of the description of which ID is 1...'
267 | description = ptd.description(1)
268 | print '[DESCRIPTION] ID: %d' % description.id
269 | print '[DESCRIPTION] PARENT PERSON ID: %d' % description.person.id
270 | print '[DESCRIPTION] DESCRIPTION: %s' % description.description
271 |
272 | if __name__ == '__main__':
273 | demo()
274 |
--------------------------------------------------------------------------------
/annotations/script_test_annotation.py:
--------------------------------------------------------------------------------
1 | from ptd_api import *
2 | import cv2
3 | import copy
4 | import numpy as np
5 | import shutil
6 | import commands
7 | import sys
8 | sys.path.append('../')
9 | from util.mytoolbox import *
10 |
11 | TMP_DIR = '.tmp'
12 | FFMPEG = 'ffmpeg'
13 | SAVE_VIDEO = FFMPEG + ' -y -r %d -i %s/%s.jpg %s'
14 |
15 | def draw_rectangle(img, bbox, color=(0,0,255), thickness=3):
16 | img = imread_if_str(img)
17 | if isinstance(bbox, dict):
18 | bbox = [
19 | bbox['x1'],
20 | bbox['y1'],
21 | bbox['x2'],
22 | bbox['y2'],
23 | ]
24 | assert bbox[2] >= bbox[0]
25 | assert bbox[3] >= bbox[1]
26 | assert bbox[0] >= 0
27 | assert bbox[1] >= 0
28 | assert bbox[2] <= img.shape[1]
29 | assert bbox[3] <= img.shape[0]
30 | cur_img = copy.deepcopy(img)
31 | cv2.rectangle(
32 | cur_img,
33 | (int(bbox[0]), int(bbox[1])),
34 | (int(bbox[2]), int(bbox[3])),
35 | color,
36 | thickness)
37 | return cur_img
38 |
39 | def images2video(image_list, frame_rate, video_path, max_edge=None):
40 | if os.path.exists(TMP_DIR):
41 | shutil.rmtree(TMP_DIR)
42 | os.mkdir(TMP_DIR)
43 | img_size = None
44 | for cur_num, cur_img in enumerate(image_list):
45 | cur_fname = os.path.join(TMP_DIR, '%08d.jpg' % cur_num)
46 | if max_edge is not None:
47 | cur_img = imread_if_str(cur_img)
48 | if isinstance(cur_img, str) or isinstance(cur_img, unicode):
49 | shutil.copyfile(cur_img, cur_fname)
50 | elif isinstance(cur_img, np.ndarray):
51 | max_len = max(cur_img.shape[:2])
52 | if max_len > max_edge and img_size is None and max_edge is not None:
53 | magnif = float(max_edge) / float(max_len)
54 | img_size = (int(cur_img.shape[1] * magnif), int(cur_img.shape[0] * magnif))
55 | cur_img = cv2.resize(cur_img, img_size)
56 | elif max_edge is not None:
57 | if img_size is None:
58 | magnif = float(max_edge) / float(max_len)
59 | img_size = (int(cur_img.shape[1] * magnif), int(cur_img.shape[0] * magnif))
60 | cur_img = cv2.resize(cur_img, img_size)
61 | cv2.imwrite(cur_fname, cur_img)
62 | else:
63 | NotImplementedError()
64 | print commands.getoutput(SAVE_VIDEO % (frame_rate, TMP_DIR, '%08d', video_path))
65 | shutil.rmtree(TMP_DIR)
66 |
67 | def video2shot(vdName, ptd_list = {}):
68 | setNameList = ['train', 'val', 'test']
69 | shotInfo = list()
70 | for set_name in setNameList:
71 | if set_name in ptd_list.keys():
72 | ptd = ptd_list[set_name]
73 | else:
74 | ptd = PTD(set_name)
75 | shotLgh = len(ptd.id2shot)
76 | for i in range(shotLgh):
77 | vdNameShot = ptd.shot(i+1).video_id
78 | if(vdName==vdNameShot):
79 | shotInfo.append([set_name, i+1])
80 | return shotInfo
81 |
82 | def imread_if_str(img):
83 | if isinstance(img, basestring):
84 | img = cv2.imread(img)
85 | return img
86 |
87 | def get_shot_frames(shot):
88 | annFrmSt = shot.first_frame
89 | annFrmLs = shot.last_frame
90 | tmpFrmList = list(range(annFrmSt, annFrmLs+1))
91 | frmList = list()
92 | for i, frmIdx in enumerate(tmpFrmList):
93 | strIdx = '%05d' %(frmIdx)
94 | frmList.append(strIdx)
95 | return frmList
96 |
97 | def get_shot_frames_full_path(shot, preFd, fn_ext='.jpg'):
98 | vd_name = shot.video_id
99 | vd_subPath = os.path.join(preFd, 'v_' + vd_name)
100 | framList = get_shot_frames(shot)
101 | framListFull = list()
102 | for frm in framList:
103 | tmpFrm = os.path.join(vd_subPath, frm + fn_ext)
104 | framListFull.append(tmpFrm)
105 | return framListFull
106 |
107 | # shot_proposals : [tubes, frameList]
108 | #
109 | def evaluate_tube_recall(shot_proposals, shot, person_in_shot, thre=0.5 ,topKOri=20):
110 | # pdb.set_trace()
111 | topK = min(topKOri, len(shot_proposals[0][0]))
112 | recall_k = [0.0] * (topK + 1)
113 | boxes = {}
114 | for frame_ind, box in zip(shot.annotated_frames, person_in_shot['boxes']):
115 | keyName = '%05d' %(frame_ind)
116 | boxes[keyName] = box
117 |
118 | #pdb.set_trace()
119 | tube_list, frame_list = shot_proposals
120 | assert(len(tube_list[0][0])== len(frame_list))
121 | is_person_annotated = False
122 | for i in range(topK):
123 | recall_k[i+1] = recall_k[i]
124 | if is_person_annotated:
125 | continue
126 | curTubeOri = tube_list[0][i]
127 | tube_key_bbxList = {}
128 | for frame_ind, gt_box in boxes.iteritems():
129 | try:
130 | index_tmp = frame_list.index(frame_ind)
131 | tube_key_bbxList[frame_ind] = curTubeOri[index_tmp]
132 | except:
133 | print('key %s do not exist in shot' %(frame_ind))
134 | #pdb.set_trace()
135 | ol = compute_LS(tube_key_bbxList, boxes)
136 | if ol < thre:
137 | continue
138 | else:
139 | recall_k[i+1] += 1.0
140 | is_person_annotated = True
141 | return recall_k
142 |
143 | def vis_ann_tube():
144 | pngFolder ='/data1/zfchen/data/actNet/actNetJpgs'
145 | annFolder ='/data1/zfchen/data/actNet/actNetAnn'
146 | #vdName = '2Peh_gdQCjg'
147 | #vdName = 'jqKK2KH6l4Q'
148 | #vdName = 'W4tmb8RwzQM'
149 | vdName = 'ydRycaBjMVw'
150 | shotList = video2shot(vdName)
151 |
152 | pdb.set_trace()
153 |
154 | txtFn = os.path.join(annFolder, 'v_'+vdName + '.txt')
155 | fH = open(txtFn)
156 | vdInfoStr = fH.readlines()[0]
157 | vdInfo = [float(ele) for ele in vdInfoStr.split(' ')]
158 | insNum = 1
159 |
160 | for i, shotInfo in enumerate(shotList):
161 | set_name, shotId = shotInfo
162 | ptd = PTD(set_name)
163 | shot = ptd.shot(shotId)
164 | annFtrList = shot.annotated_frames
165 | for ii, person in enumerate(shot.people):
166 | imgList = list()
167 | print(person.descriptions[0].description)
168 | for iii in range(len(annFtrList)):
169 | frmName = '%05d' %(annFtrList[iii]+1)
170 | imFn = pngFolder + '/v_' + shot.video_id + '/' + frmName + '.jpg'
171 | img = cv2.imread(imFn)
172 | h, w, c = img.shape
173 | bbox = person.boxes[iii]
174 | #pdb.set_trace()
175 | bbox[0] = bbox[0]*w
176 | bbox[2] = bbox[2]*w
177 | bbox[1] = bbox[1]*h
178 | bbox[3] = bbox[3]*h
179 | #pdb.set_trace()
180 | imgNew = draw_rectangle(img, bbox, color=(0,0,255), thickness=3)
181 | imgList.append(imgNew)
182 | images2video(imgList, 10, './' + str(insNum)+'.gif')
183 | insNum +=1
184 |
185 | def extract_instance_caption_list():
186 | set_list = ['train', 'test', 'val']
187 | outPre = './data/cap_list_'
188 | for i, set_name in enumerate(set_list):
189 | ptd = PTD(set_name)
190 | outFn = outPre + set_name + '.txt'
191 | desption_list_dict = ptd.descriptions
192 | description_list = [ des_dict['description'] for des_dict in desption_list_dict]
193 | textdump(outFn, description_list)
194 |
195 | if __name__=='__main__':
196 | extract_instance_caption_list()
197 |
--------------------------------------------------------------------------------
/annotations/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 |
4 | import json
5 | import os
6 | import pdb
7 |
8 | EPS = 1e-10
9 |
10 |
11 | def compute_IoU(bbox1, bbox2):
12 | bbox1_area = float((bbox1[2] - bbox1[0] + EPS) * (bbox1[3] - bbox1[1] + EPS))
13 | bbox2_area = float((bbox2[2] - bbox2[0] + EPS) * (bbox2[3] - bbox2[1] + EPS))
14 | w = max(0.0, min(bbox1[2], bbox2[2]) - max(bbox1[0], bbox2[0]) + EPS)
15 | h = max(0.0, min(bbox1[3], bbox2[3]) - max(bbox1[1], bbox2[1]) + EPS)
16 | inter = float(w * h)
17 | ovr = inter / (bbox1_area + bbox2_area - inter)
18 | return ovr
19 |
20 | def is_annotated(traj, frame_ind):
21 | if not frame_ind in traj:
22 | return False
23 | box = traj[frame_ind]
24 | if box[0] < 0:
25 | #for el_val in box[1:]:
26 | #assert el_val < 0
27 | return False
28 | #for el_val in box[1:]:
29 | # assert el_val >= 0
30 | return True
31 |
32 | def compute_LS(traj, gt_traj):
33 | # see http://jvgemert.github.io/pub/jain-tubelets-cvpr2014.pdf
34 | assert isinstance(traj.keys()[0], type(gt_traj.keys()[0]))
35 | IoU_list = []
36 | for frame_ind, gt_box in gt_traj.iteritems():
37 | # make sure the gt_box is within valid
38 |
39 | gt_is_annotated = is_annotated(gt_traj, frame_ind)
40 | pr_is_annotated = is_annotated(traj, frame_ind)
41 | if (not gt_is_annotated) and (not pr_is_annotated):
42 | continue
43 | if (not gt_is_annotated) or (not pr_is_annotated):
44 | IoU_list.append(0.0)
45 | continue
46 | box = traj[frame_ind]
47 | IoU_list.append(compute_IoU(box, gt_box))
48 | return sum(IoU_list) / len(IoU_list)
49 |
50 | def jsonload(path):
51 | f = open(path)
52 | json_data = json.load(f)
53 | f.close()
54 | return json_data
55 |
56 | def get_abs_path():
57 | return os.path.dirname(os.path.abspath(os.path.join(os.getcwd(), __file__)))
58 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zfchenUnique/WSSTG/ba19919b5010321651088afaecba877aebb4ed3a/data/__init__.py
--------------------------------------------------------------------------------
/data/data_loader.py:
--------------------------------------------------------------------------------
1 | def creatDataloader(opts):
2 | from data.custom_data_loader import customDataLoader
3 | data_loader = customDataLoader()
4 | data_loader.initialize(opt)
5 | return data_loader
6 |
--------------------------------------------------------------------------------
/fun/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zfchenUnique/WSSTG/ba19919b5010321651088afaecba877aebb4ed3a/fun/__init__.py
--------------------------------------------------------------------------------
/fun/classSST.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import os
7 | import ipdb
8 | import numpy as np
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from torch.autograd import *
14 |
15 | import opts
16 | import pdb
17 |
18 |
19 | class SST(nn.Module):
20 | def __init__(self, opt):
21 | super(SST, self).__init__()
22 |
23 | self.word_cnt = opt.word_cnt
24 | self.fc_feat_size = opt.fc_feat_size
25 | self.video_embedding_size = opt.video_embedding_size
26 | self.word_embedding_size = opt.word_embedding_size
27 | self.lstm_hidden_size = opt.lstm_hidden_size
28 | self.video_time_step = opt.video_time_step
29 | self.caption_time_step = opt.caption_time_step
30 | self.dropout_prob = opt.dropout_prob
31 | self.att_hidden_size = opt.att_hidden_size
32 |
33 | self.video_embedding = nn.Linear(self.fc_feat_size, self.video_embedding_size)
34 |
35 | self.lstm_video = torch.nn.LSTMCell(self.video_embedding_size, self.lstm_hidden_size, bias=True)
36 | self.lstm_caption = torch.nn.LSTMCell(self.word_embedding_size, self.lstm_hidden_size, bias=True)
37 |
38 |
39 | self.vid_linear = nn.Linear(self.lstm_hidden_size, self.att_hidden_size)
40 | self.sen_linear = nn.Linear(self.lstm_hidden_size, self.att_hidden_size)
41 |
42 | self.att_linear = nn.Linear(self.att_hidden_size, 1)
43 |
44 | self.word_embedding = nn.Linear(300, self.word_embedding_size)
45 | self.init_weights()
46 |
47 | def init_weights(self):
48 | initrange = 0.1
49 | self.video_embedding.weight.data.uniform_(-initrange, initrange)
50 | self.video_embedding.bias.data.fill_(0)
51 |
52 | self.word_embedding.weight.data.uniform_(-initrange, initrange)
53 | self.word_embedding.bias.data.fill_(0)
54 |
55 | self.vid_linear.weight.data.uniform_(-initrange, initrange)
56 | self.vid_linear.bias.data.fill_(0)
57 |
58 | self.sen_linear.weight.data.uniform_(-initrange, initrange)
59 | self.sen_linear.bias.data.fill_(0)
60 |
61 | self.att_linear.weight.data.uniform_(-initrange, initrange)
62 | self.att_linear.bias.data.fill_(0)
63 |
64 |
65 | def init_hidden(self, batch_size):
66 | weight = next(self.parameters()).data
67 | init_h = Variable(weight.new(1, batch_size, self.lstm_hidden_size).zero_())
68 | init_c = Variable(weight.new(1, batch_size, self.lstm_hidden_size).zero_())
69 | init_state = (init_h, init_c)
70 |
71 | return init_state
72 |
73 | def init_hidden_new(self, batch_size):
74 | weight = next(self.parameters()).data
75 | init_h = Variable(weight.new(batch_size, self.lstm_hidden_size).zero_())
76 | init_c = Variable(weight.new(batch_size, self.lstm_hidden_size).zero_())
77 | init_state = (init_h, init_c)
78 |
79 | return init_state
80 |
81 | def forward(self, video_fc_feat, video_caption, cap_length_list=None):
82 | if self.training:
83 | return self.forward_training(video_fc_feat, video_caption, cap_length_list)
84 | else:
85 | return self.forward_val(video_fc_feat, video_caption, cap_length_list)
86 |
87 | # video_fc_feats: batch * encode_time_step * fc_feat_size
88 | def forward_training(self, video_fc_feat, video_caption, cap_length_list=None):
89 | #pdb.set_trace()
90 | batch_size = video_fc_feat.size(0)
91 | batch_size_caption = video_caption.size(0)
92 |
93 | video_state = self.init_hidden_new(batch_size)
94 | caption_state = self.init_hidden_new(batch_size_caption)
95 |
96 | caption_outputs = []
97 | caption_time_step = video_caption.size(1)
98 | for i in range(caption_time_step):
99 | word = video_caption[:, i].clone()
100 | #if video_caption[:, i].data.sum() == 0:
101 | # break
102 | #import ipdb
103 | #pdb.set_trace()
104 | caption_xt = self.word_embedding(word)
105 | #caption_output, caption_state = self.lstm_caption.forward(caption_xt, caption_state)
106 | caption_output, caption_state = self.lstm_caption.forward(caption_xt, caption_state)
107 | caption_outputs.append(caption_output)
108 | caption_state = (caption_output, caption_state)
109 | # caption_outputs: batch * caption_time_step * lstm_hidden_size
110 | caption_outputs = torch.cat([_.unsqueeze(1) for _ in caption_outputs], 1).contiguous()
111 |
112 | # LSTM encoding
113 | video_outputs = []
114 | for i in range(self.video_time_step):
115 | video_xt = self.video_embedding(video_fc_feat[:, i, :])
116 | video_output, video_state = self.lstm_video.forward(video_xt, video_state)
117 | video_outputs.append(video_output)
118 | video_state = (video_output, video_state)
119 | # video_outputs: batch * video_time_step * lstm_hidden_size
120 | video_outputs = torch.cat([_.unsqueeze(1) for _ in video_outputs], 1).contiguous()
121 |
122 | # soft attention for caption based on each video
123 | output_list = list()
124 | for i in range(self.video_time_step):
125 | # part 1
126 | video_outputs_linear = self.vid_linear(video_outputs[:, i, :])
127 | video_outputs_linear_expand = video_outputs_linear.expand(caption_outputs.size(1), video_outputs_linear.size(0),
128 | video_outputs_linear.size(1)).transpose(0, 1)
129 |
130 | # part 2
131 | caption_outputs_flatten = caption_outputs.view(-1, self.lstm_hidden_size)
132 | caption_outputs_linear = self.sen_linear(caption_outputs_flatten)
133 | caption_outputs_linear = caption_outputs_linear.view(batch_size_caption, caption_outputs.size(1), self.att_hidden_size)
134 |
135 | # part 1 and part 2 attention
136 | sig_probs = []
137 | for cap_id in range(batch_size_caption):
138 | #pdb.set_trace()
139 | cap_length = max(cap_length_list[cap_id], 1)
140 | caption_output_linear_cap_id = caption_outputs_linear[cap_id, : cap_length, :]
141 | video_outputs_linear_expand_clip = video_outputs_linear_expand[:, :cap_length, :]
142 | caption_outputs_linear_cap_id_exp = caption_output_linear_cap_id.expand_as(video_outputs_linear_expand_clip)
143 | video_caption = F.tanh(video_outputs_linear_expand_clip \
144 | + caption_outputs_linear_cap_id_exp)
145 |
146 | video_caption_view = video_caption.contiguous().view(-1, self.att_hidden_size)
147 | video_caption_out = self.att_linear(video_caption_view)
148 | video_caption_out_view = video_caption_out.view(-1, cap_length)
149 | atten_weights = nn.Softmax(dim=1)(video_caption_out_view).unsqueeze(2)
150 |
151 | caption_output_cap_id = caption_outputs[cap_id, : cap_length, :]
152 | caption_output_cap_id_exp = caption_output_cap_id.expand(batch_size,\
153 | caption_output_cap_id.size(0), caption_output_cap_id.size(1))
154 | atten_caption = torch.bmm(caption_output_cap_id_exp.transpose(1, 2), atten_weights).squeeze(2)
155 |
156 | video_caption_hidden = torch.cat((atten_caption, video_outputs[:, i, :]), dim=1)
157 | cos = nn.CosineSimilarity(dim=1, eps=1e-6)
158 | cur_probs = cos(atten_caption, video_outputs[: ,i, :]).unsqueeze(1)
159 |
160 |
161 | sig_probs.append(cur_probs)
162 |
163 | sig_probs = torch.cat([_ for _ in sig_probs], 1).contiguous()
164 | output_list.append(sig_probs)
165 | simMM = torch.stack(output_list, dim=2).mean(2)
166 | return simMM
167 |
168 | # video_fc_feats: batch * encode_time_step * fc_feat_size
169 | def forward_val(self, video_fc_feat, video_caption, cap_length_list=None):
170 | #pdb.set_trace()
171 | batch_size = video_fc_feat.size(0)
172 | batch_size_caption = video_caption.size(0)
173 | num_tube_per_video = int(batch_size / batch_size_caption)
174 |
175 | video_state = self.init_hidden_new(batch_size)
176 | caption_state = self.init_hidden_new(batch_size_caption)
177 |
178 | caption_outputs = []
179 | caption_time_step = video_caption.size(1)
180 | for i in range(caption_time_step):
181 | word = video_caption[:, i].clone()
182 | #if video_caption[:, i].data.sum() == 0:
183 | # break
184 | #import ipdb
185 | #pdb.set_trace()
186 | caption_xt = self.word_embedding(word)
187 | #caption_output, caption_state = self.lstm_caption.forward(caption_xt, caption_state)
188 | caption_output, caption_state = self.lstm_caption.forward(caption_xt, caption_state)
189 | caption_outputs.append(caption_output)
190 | caption_state = (caption_output, caption_state)
191 | # caption_outputs: batch * caption_time_step * lstm_hidden_size
192 | caption_outputs = torch.cat([_.unsqueeze(1) for _ in caption_outputs], 1).contiguous()
193 |
194 | # LSTM encoding
195 | video_outputs = []
196 | for i in range(self.video_time_step):
197 | video_xt = self.video_embedding(video_fc_feat[:, i, :])
198 | video_output, video_state = self.lstm_video.forward(video_xt, video_state)
199 | video_outputs.append(video_output)
200 | video_state = (video_output, video_state)
201 | # video_outputs: batch * video_time_step * lstm_hidden_size
202 | video_outputs = torch.cat([_.unsqueeze(1) for _ in video_outputs], 1).contiguous()
203 |
204 | # batch_size_caption * num_word * lstm_hidden_size
205 | caption_outputs_linear = self.sen_linear(caption_outputs)
206 |
207 | # soft attention for caption based on each video
208 | output_list = list()
209 | for i in range(self.video_time_step):
210 | # batch_size * lstm_hidden_size
211 | video_outputs_linear = self.vid_linear(video_outputs[:, i, :])
212 | # batch_size * num_word * lstm_hidden_size
213 | video_outputs_linear_expand = video_outputs_linear.expand(caption_outputs.size(1), video_outputs_linear.size(0),
214 | video_outputs_linear.size(1)).transpose(0, 1)
215 |
216 | # part 1 and part 2 attention
217 | sig_probs = []
218 |
219 | for cap_id in range(batch_size_caption):
220 |
221 | tube_id_st = cap_id*num_tube_per_video
222 | tube_id_ed = (cap_id+1)*num_tube_per_video
223 |
224 | cap_length = max(cap_length_list[cap_id], 1)
225 | # num_tube_per_video * cap_length * lstm_hidden_size
226 | tube_outputs_aligned = video_outputs_linear_expand[tube_id_st:tube_id_ed, : cap_length, :]
227 | # cap_length * lstm_hidden_size
228 | caption_output_linear_cap_id = caption_outputs_linear[cap_id, : cap_length, :]
229 | # num_tube_per_video * cap_length * lstm_hidden_size
230 | caption_outputs_linear_cap_id_exp = caption_output_linear_cap_id.expand_as(tube_outputs_aligned)
231 | # num_tube_per_video * cap_length * lstm_hidden_size
232 | video_caption = F.tanh(tube_outputs_aligned \
233 | + caption_outputs_linear_cap_id_exp)
234 |
235 | video_caption_view = video_caption.contiguous().view(-1, self.att_hidden_size)
236 | video_caption_out = self.att_linear(video_caption_view)
237 | video_caption_out_view = video_caption_out.view(-1, cap_length)
238 |
239 | # num_tube_per_video * cap_length
240 | atten_weights = nn.Softmax(dim=1)(video_caption_out_view).unsqueeze(2)
241 |
242 | # cap_length * lstm_hidden_size
243 | caption_output_cap_id = caption_outputs[cap_id, : cap_length, :]
244 | # num_tube_per_video * cap_length * lstm_hidden_size
245 | caption_output_cap_id_exp = caption_output_cap_id.expand(num_tube_per_video,\
246 | caption_output_cap_id.size(0), caption_output_cap_id.size(1))
247 | # num_tube_per_video * lstm_hidden_size
248 | atten_caption = torch.bmm(caption_output_cap_id_exp.transpose(1, 2), atten_weights).squeeze(2)
249 | cos = nn.CosineSimilarity(dim=1, eps=1e-6)
250 | # num_tube_per_video
251 | cur_probs = cos(atten_caption, video_outputs[tube_id_st:tube_id_ed, i, :]).unsqueeze(1)
252 | # num_tube_per_video
253 | sig_probs.append(cur_probs)
254 |
255 | # num_tube_per_video * batch_size_caption
256 | sig_probs = torch.cat([_ for _ in sig_probs], 1).contiguous()
257 | output_list.append(sig_probs)
258 | # num_tube_per_video * batch_size_caption
259 | simMM = torch.stack(output_list, dim=2).mean(2)
260 | # batch_size_caption * num_tube_per_video
261 | simMM = simMM.transpose(0, 1)
262 | #pdb.set_trace()
263 | return simMM
264 |
265 |
--------------------------------------------------------------------------------
/fun/create_word2vec_for_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append('../')
4 | sys.path.append('../util')
5 |
6 | from gensim.models import KeyedVectors
7 | from datasetParser import *
8 | import numpy as np
9 | from mytoolbox import pickledump, set_debugger
10 | from util.base_parser import BaseParser
11 | from vidDatasetParser import *
12 |
13 | set_debugger()
14 |
15 | class dictParser(BaseParser):
16 | def __init__(self, *arg_list, **arg_dict):
17 | super(dictParser, self).__init__(*arg_list, **arg_dict)
18 | self.add_argument('--dictOutPath', default='../data/dictForDb', type=str)
19 | self.add_argument('--annoteFd', default='/disk2/zfchen/data/OTB_sentences', type=str)
20 | self.add_argument('--dictPath', default='/disk2/zfchen/data/dict/GoogleNews-vectors-negative300.bin', type=str)
21 | self.add_argument('--setName', default='otb', type=str)
22 | self.add_argument('--setOutPath', default='../data/annForDb', type=str)
23 | self.add_argument('--annFn', default='', type=str)
24 | self.add_argument('--annFd', default='', type=str)
25 | self.add_argument('--annIgListFn', default='', type=str)
26 | self.add_argument('--annOriFn', default='', type=str)
27 |
28 |
29 |
30 | def parse_args():
31 | parser = dictParser()
32 | args = parser.parse_args()
33 | return args
34 |
35 | def buildVoc(concaList):
36 | vocList =[]
37 | for capList in concaList:
38 | for subCapList in capList:
39 | for ele in subCapList:
40 | if ele not in vocList:
41 | vocList.append(ele)
42 | word2idx={}
43 | idx2word={}
44 | for i, ele in enumerate(vocList):
45 | word2idx[ele]=i
46 | idx2word[i] =ele
47 | return word2idx, idx2word
48 |
49 | def buildVocA2d(concaList):
50 | vocList =[]
51 | for capList in concaList:
52 | for ele in capList:
53 | if ele not in vocList:
54 | vocList.append(ele)
55 | word2idx={}
56 | idx2word={}
57 | for i, ele in enumerate(vocList):
58 | word2idx[ele]=i
59 | idx2word[i] =ele
60 | return word2idx, idx2word
61 |
62 | def buildVocActNet(vocListOri):
63 | vocList = list()
64 | for ele in vocListOri:
65 | if ele not in vocList:
66 | vocList.append(ele)
67 | word2idx={}
68 | idx2word={}
69 | for i, ele in enumerate(vocList):
70 | word2idx[ele]=i
71 | idx2word[i] =ele
72 | return word2idx, idx2word
73 |
74 | def build_word_vec(word_list, model_word2vec):
75 | matrix_word2vec = []
76 | igNoreList = list()
77 | for i, word in enumerate(word_list):
78 | print(i, word)
79 | try:
80 | matrix_word2vec.append(model_word2vec[word])
81 | except:
82 | igNoreList.append(word)
83 | #matrix_word2vec.append(np.zeros((300), dtype=np.float32))
84 | randArray=np.random.rand((300)).astype('float32')
85 | matrix_word2vec.append(randArray)
86 | try:
87 | print('%s is not the vocaburary'% word)
88 | except:
89 | print('fail to print the word!')
90 | pdb.set_trace()
91 | return matrix_word2vec, igNoreList
92 |
93 | if __name__ == '__main__':
94 | opt = parse_args()
95 |
96 | if opt.setName=='actNet':
97 | print('begin parsing dataset: %s\n' %(opt.setName))
98 | word_list = build_actNet_word_list()
99 | print(len(word_list))
100 | pdb.set_trace()
101 | word2idx, idx2word= buildVocActNet(word_list)
102 | model_word2vec = KeyedVectors.load_word2vec_format(opt.dictPath, binary=True)
103 | matrix_word2vec, igNoreList = build_word_vec(word2idx.keys(), model_word2vec)
104 | matrix_word2vec = np.asarray(matrix_word2vec).astype(np.float32)
105 | pdb.set_trace()
106 |
107 | outDict = {'idx2word': idx2word, 'word2idx': word2idx, 'word2vec': matrix_word2vec, 'out_voca': igNoreList}
108 | pickledump(opt.dictOutPath+'_'+opt.setName+'.pd', outDict)
109 | print('Finish constructing dictionary for dataset: %s\n' %(opt.setName))
110 |
111 | elif opt.setName=='vid':
112 | print('begin parsing dataset: %s\n' %(opt.setName))
113 | word_list = build_vid_word_list()
114 | print(len(word_list))
115 | word2idx, idx2word= buildVocActNet(word_list)
116 | model_word2vec = KeyedVectors.load_word2vec_format(opt.dictPath, binary=True)
117 | matrix_word2vec, igNoreList = build_word_vec(word2idx.keys(), model_word2vec)
118 | matrix_word2vec = np.asarray(matrix_word2vec).astype(np.float32)
119 | pdb.set_trace()
120 |
121 | outDict = {'idx2word': idx2word, 'word2idx': word2idx, 'word2vec': matrix_word2vec, 'out_voca': igNoreList}
122 | pickledump(opt.dictOutPath+'_'+opt.setName+'_v2.pd', outDict)
123 | print('Finish constructing dictionary for dataset: %s\n' %(opt.setName))
124 |
125 |
126 |
127 |
128 |
129 |
130 |
--------------------------------------------------------------------------------
/fun/dashed_rect.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | def drawline(img,pt1,pt2,color,thickness=1,style='dotted',gap=10):
4 | dist =((pt1[0]-pt2[0])**2+(pt1[1]-pt2[1])**2)**.5
5 | pts= []
6 | for i in np.arange(0,dist,gap):
7 | r=i/dist
8 | x=int((pt1[0]*(1-r)+pt2[0]*r)+.5)
9 | y=int((pt1[1]*(1-r)+pt2[1]*r)+.5)
10 | p = (x,y)
11 | pts.append(p)
12 |
13 | if style=='dotted':
14 | for p in pts:
15 | cv2.circle(img,p,thickness,color,-1)
16 | else:
17 | s=pts[0]
18 | e=pts[0]
19 | i=0
20 | for p in pts:
21 | s=e
22 | e=p
23 | if i%2==1:
24 | cv2.line(img,s,e,color,thickness)
25 | i+=1
26 |
27 | def drawpoly(img,pts,color,thickness=1,style='dotted',):
28 | s=pts[0]
29 | e=pts[0]
30 | pts.append(pts.pop(0))
31 | for p in pts:
32 | s=e
33 | e=p
34 | drawline(img,s,e,color,thickness,style)
35 |
36 | def drawrect(img,pt1,pt2,color,thickness=1,style='dotted'):
37 | pts = [pt1,(pt2[0],pt1[1]),pt2,(pt1[0],pt2[1])]
38 | drawpoly(img,pts,color,thickness,style)
39 |
40 | im = np.zeros((800,800,3),dtype='uint8')
41 | s=(234,222)
42 | e=(500,700)
43 | drawrect(im,s,e,(0,255,255),1,'dotted')
44 |
--------------------------------------------------------------------------------
/fun/datasetLoader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append('..')
4 | import torch.utils.data as data
5 | import cv2
6 | import numpy as np
7 | from util.mytoolbox import *
8 | import random
9 | import scipy.io as sio
10 | import copy
11 | import torch
12 | from util.get_image_size import get_image_size
13 | from evalDet import *
14 | from datasetParser import extAllFrmFn
15 | import pdb
16 | from netUtil import *
17 | from wsParamParser import parse_args
18 | from vidDataset import *
19 | import h5py
20 |
21 | def build_dataloader(opt):
22 | if opt.dbSet=='vid':
23 | ftrPath = '' # Path for frame level feature
24 | tubePath = '../data/tubePrp' # Path for information of each tube proposals
25 | dictFile = '../data/dictForDb_vid_v2.pd' # Path for word embedding
26 | out_cached_folder = '../data/vid_rgb/vidTubeCacheFtr' # Path for RGB features of tubes
27 | ann_folder = '../data/ILSVRC' # Path for tube annotations
28 | i3d_ftr_path ='../data/vid_i3d' # Path for I3D features
29 | prp_type = 'coco_30_2' # frame-level proposal extractors
30 | mean_cache_ftr_path = '../data/vid/meanFeature'
31 | ftr_context_path = '../data/vid/coco32/context/Data/VID'
32 |
33 | set_name = opt.set_name
34 | dataset = vidDataloader(ann_folder, prp_type, set_name, dictFile, tubePath \
35 | , ftrPath, out_cached_folder)
36 | capNum = opt.capNum
37 | maxWordNum = opt.maxWordNum
38 | rpNum = opt.rpNum
39 | pos_emb_dim = opt.pos_emb_dim
40 | pos_type = opt.pos_type
41 | vis_ftr_type = opt.vis_ftr_type
42 | use_mean_cache_flag = opt.use_mean_cache_flag
43 | context_flag = opt.context_flag
44 | frm_level_flag = opt.frm_level_flag
45 | frm_num = opt.frm_num
46 | dataset.image_samper_set_up(rpNum= rpNum, capNum = capNum, \
47 | maxWordNum= maxWordNum, usedBadWord=False, \
48 | pos_emb_dim=pos_emb_dim, pos_type=pos_type, vis_ftr_type=vis_ftr_type, \
49 | i3d_ftr_path=i3d_ftr_path, use_mean_cache_flag=use_mean_cache_flag,\
50 | mean_cache_ftr_path=mean_cache_ftr_path, context_flag=context_flag, ftr_context_path=ftr_context_path, frm_level_flag=frm_level_flag, frm_num=frm_num)
51 |
52 | shuffle_flag = not opt.no_shuffle_flag
53 |
54 | data_loader = data.DataLoader(dataset, opt.batchSize, \
55 | num_workers=opt.num_workers, collate_fn=dis_collate_vid, \
56 | shuffle=shuffle_flag, pin_memory=True)
57 | return data_loader, dataset
58 |
59 | data_loader = data.DataLoader(dataset, opt.batchSize, \
60 | num_workers=opt.num_workers, collate_fn=dis_collate_actNet, \
61 | shuffle=shuffle_flag, pin_memory=True)
62 | return data_loader, dataset
63 |
64 | else:
65 | print('Not implemented for dataset %s\n' %(opt.dbSet))
66 | return
67 |
68 | if __name__=='__main__':
69 | opt = parse_args()
70 | opt.dbSet = 'vid'
71 | opt.set_name = 'train'
72 | opt.vis_ftr_type = 'i3d'
73 | out_pre = '/data1/zfchen/data/'
74 | opt.cache_flag = False
75 | opt.pos_type = 'none'
76 | data_loader, dataset_ori = build_dataloader(opt)
77 |
--------------------------------------------------------------------------------
/fun/datasetParser.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append('../')
4 | sys.path.append('../util')
5 | from util.mytoolbox import get_specific_file_list_from_fd, textread, split_carefully, parse_mul_num_lines, pickleload, pickledump, get_list_dir
6 | import pdb
7 | import h5py
8 | import csv
9 |
10 | def build_idx_from_list(vdList):
11 | vdListU = list(set(vdList))
12 | vd2idx = {}
13 | idx2vd = {}
14 | for i, ele in enumerate(vdList):
15 | vd2idx[ele] = i
16 | idx2vd[i] = ele
17 | return vd2idx, idx2vd
18 |
19 | def get_word_list(full_name_list):
20 | capList =[]
21 | for filePath in full_name_list:
22 | lines = textread(filePath)
23 | subCapList=[]
24 | for line in lines:
25 | wordList=line.split(line, ' ')
26 | subCapList.append(wordList.lower())
27 | capList.append(subCapList)
28 | return capList
29 |
30 | def fix_otb_frm_bbxList(annFn, outFn):
31 | annDict = pickleload(annFn)
32 |
33 | isTrainSet = True
34 | if isTrainSet:
35 | bbxListDict=annDict['train_bbx_list']
36 | vidList = annDict['trainName']
37 | frmListDict=annDict['trainImg']
38 | else:
39 | bbxListDict=annDict['test_bbx_list']
40 | vidList = annDict['testName']
41 | frmListDict=annDict['testImg']
42 |
43 | for i, vidName in enumerate(vidList):
44 | frmList = frmListDict[i]
45 | bbxList = bbxListDict[i]
46 | frmList.sort()
47 | if(vidName=='David'):
48 | annDict['trainImg'][i] = frmList[299:]
49 |
50 | isTrainSet = False
51 | if isTrainSet:
52 | bbxListDict=annDict['train_bbx_list']
53 | vidList = annDict['trainName']
54 | frmListDict=annDict['trainImg']
55 | else:
56 | bbxListDict=annDict['test_bbx_list']
57 | vidList = annDict['testName']
58 | frmListDict=annDict['testImg']
59 |
60 | for i, vidName in enumerate(vidList):
61 | frmList = frmListDict[i]
62 | bbxList = bbxListDict[i]
63 | frmList.sort()
64 | if(vidName=='David'):
65 | annDict['testImg'][i] = frmList[299:]
66 |
67 | pickledump(outFn, annDict)
68 | return annDict
69 |
70 | # parse OTB 99
71 | def get_otb_data(inFd):
72 | test_text_fd = inFd + '/OTB_query_test'
73 | train_text_fd = inFd + '/OTB_query_train'
74 | video_fd = inFd + '/OTB_videos'
75 | test_name_list = get_specific_file_list_from_fd(test_text_fd, '.txt')
76 | train_name_list = get_specific_file_list_from_fd(train_text_fd, '.txt')
77 | test_name_listFull = get_specific_file_list_from_fd(test_text_fd, '.txt', False)
78 | train_name_listFull = get_specific_file_list_from_fd(train_text_fd, '.txt', False)
79 |
80 | test_cap_list = get_word_list(test_name_listFull)
81 | train_cap_list = get_word_list(train_name_listFull)
82 |
83 | test_im_list =[]
84 | test_bbx_list =[]
85 | for i, vdName in enumerate(test_name_list):
86 | full_vd_path = video_fd + '/'+ vdName
87 | vd_frame_path = full_vd_path + '/img'
88 | imgList = get_specific_file_list_from_fd(vd_frame_path, '.jpg', True)
89 | test_im_list.append(imgList)
90 |
91 | gtBoxFn = full_vd_path +'/groundtruth_rect.txt'
92 | bbxList= parse_mul_num_lines(gtBoxFn)
93 | test_bbx_list.append(bbxList)
94 |
95 |
96 | train_im_list =[]
97 | train_bbx_list =[]
98 | for i, vdName in enumerate(train_name_list):
99 | full_vd_path = video_fd + '/'+ vdName
100 | vd_frame_path = full_vd_path + '/img'
101 | imgList = get_specific_file_list_from_fd(vd_frame_path, '.jpg')
102 | train_im_list.append(imgList)
103 |
104 | gtBoxFn = full_vd_path +'/groundtruth_rect.txt'
105 | bbxList= parse_mul_num_lines(gtBoxFn)
106 | train_bbx_list.append(bbxList)
107 |
108 | otb_info_raw= {'trainName': train_name_list, 'testName': test_name_list, 'trainCap': train_cap_list, 'testCap': test_cap_list, 'trainImg': train_im_list, 'testImg': test_im_list, 'train_bbx_list': train_bbx_list, 'test_bbx_list': test_bbx_list}
109 | return otb_info_raw
110 |
111 | def otbPCK2List(pckFn):
112 | otbDict = pickleload(pckFn)
113 | testVdList= otbDict['testName']
114 | testImList = otbDict['testImg']
115 | trainVdList= otbDict['trainName']
116 | trainImList = otbDict['trainImg']
117 | imgList = list()
118 | for i, vdName in enumerate(testVdList):
119 | vd_frame_path = vdName + '/img'
120 | for j, imName in enumerate(testImList[i]):
121 | imNameFull = vd_frame_path+'/'+imName+'.jpg'
122 | imgList.append(imNameFull)
123 |
124 | for i, vdName in enumerate(trainVdList):
125 | vd_frame_path = vdName + '/img'
126 | for j, imName in enumerate(trainImList[i]):
127 | imNameFull = vd_frame_path+'/'+imName+'.jpg'
128 | imgList.append(imNameFull)
129 |
130 | return imgList
131 |
132 | def a2dSetParser(annFn, annFd, annIgListFn, annFnOri):
133 | fLineList = textread(annFn)
134 | videoList = list()
135 | capList = list()
136 | insList = list()
137 | bbxList = list()
138 | frmNameList = list()
139 | splitDict = {}
140 |
141 | with open(annFnOri, 'rb') as csvFile:
142 | lines = csv.reader(csvFile)
143 | for line in lines:
144 | eleSegs = split_carefully(line, ',')
145 | splitDict[eleSegs[0][0]]=int(eleSegs[0][-1])
146 |
147 | for i, line in enumerate(fLineList):
148 | if(i<=0):
149 | continue
150 | splitSegs= split_carefully(line, ',')
151 | annFdSub = annFd + '/' + splitSegs[0]
152 | annNameList = get_specific_file_list_from_fd(annFdSub, '.h5')
153 | tmpFrmList = list()
154 | tmpBbxList = list()
155 | #pdb.set_trace()
156 | insIdx = int(splitSegs[1])
157 | print(splitSegs[2])
158 | print('%s %d %d\n' %(splitSegs[0], i, insIdx))
159 | for ii, annName in enumerate(annNameList):
160 | annSubFullPath = annFdSub + '/' + annName +'.h5'
161 | annIns = h5py.File(annSubFullPath)
162 | tmpInsList = list(annIns['instance'][:])
163 | if(insIdx in tmpInsList):
164 | tmpFrmList.append(annName)
165 | bxIdx = tmpInsList.index(insIdx)
166 | tmpBbxList.append(annIns['reBBox'][:, bxIdx])
167 | frmNameList.append(tmpFrmList)
168 | bbxList.append(tmpBbxList)
169 | videoList.append(splitSegs[0])
170 | insList.append(int(splitSegs[1]))
171 | capSegs = splitSegs[2].lower().split(' ')
172 | capList.append(capSegs)
173 |
174 | vd2idx, idx2vd = build_idx_from_list(videoList)
175 | igNameList =textread(annIgListFn)
176 |
177 | a2d_info_raw= {'cap': capList, 'vd': videoList, 'bbxList': bbxList, 'frmList': frmNameList, 'insList' : insList, 'igList': igNameList, 'splitDict': splitDict, 'vd2idx': vd2idx, 'idx2vd': idx2vd}
178 | return a2d_info_raw
179 |
180 | #data = h5py.file()
181 | def a2dPCK2List(pckFn):
182 | a2dDict = pickleload(pckFn)
183 | imgList = list()
184 | testCapList = a2dDict['cap']
185 | for i, cap in enumerate(testCapList):
186 | vdName = a2dDict['vd'][i]
187 | frmList = a2dDict['frmList'][i]
188 | for j, imName in enumerate(frmList):
189 | imNameFull = vdName+'/'+imName+'.png'
190 | imgList.append(imNameFull)
191 | imgList= list(set(imgList))
192 | return imgList
193 |
194 | def extAllFrm(annFn, fdPre):
195 | annDict = pickleload(annFn)
196 | videoListU = list(set(annDict['vd']))
197 | frmFullList = list()
198 | for vdName in videoListU:
199 | subPre = fdPre + '/' + vdName
200 | frmNameList = get_specific_file_list_from_fd(subPre, '.png')
201 | for frm in frmNameList:
202 | frmFullList.append(vdName+'/'+frm+'.png')
203 | return frmFullList
204 |
205 | def extAllFrmFn(videoList, fdPre):
206 | frmvDict = list()
207 | for vdName in videoList:
208 | subPre = fdPre + '/' + vdName
209 | frmNameList = get_specific_file_list_from_fd(subPre, '.png')
210 | frmNameList.sort()
211 | frmvDict.append(frmNameList)
212 | #pdb.set_trace()
213 | return frmvDict
214 |
215 | def getFrmFn(fdPre, extFn='.jpg'):
216 | frmListFull = list()
217 | sub_fd_list = get_list_dir(fdPre)
218 | for i, vdName in enumerate(sub_fd_list):
219 | frmList = get_specific_file_list_from_fd(vdName, extFn, nameOnly=False)
220 | frmListFull +=frmList
221 | return frmListFull
222 |
223 |
224 |
225 |
226 | if __name__=='__main__':
227 |
228 | fdListFn = '/data1/zfchen/data/actNet/actNetJpgs/'
229 | #frmList = getFrmFn(fdListFn)
230 | #pdb.set_trace()
231 |
232 | #pckFn = '../data/annoted_a2d.pd'
233 | #imPre ='/data1/zfchen/data/A2D/Release/pngs320H'
234 | #dataAnn = pickleload(pckFn)
235 | #frmFullList = extAllFrm(pckFn, imPre)
236 | #pdb.set_trace()
237 | #dataAnn['frmListGt'] = dataAnn['frmList']
238 | #dataAnn['frmList'] = frmFullList
239 | #pickledump('../data/annoted_a2dV2.pd', dataAnn)
240 | #pdb.set_trace()
241 | #print('finish')
242 | #imgList = a2dPCK2List(pckFn)
243 | #print('finish')
244 | #annFd = '/disk2/zfchen/data/A2D/Release/sentenceAnno/a2d_annotation_with_instances'
245 | #annFn = '/disk2/zfchen/data/A2D/Release/sentenceAnno/a2d_annotation.txt'
246 | #annIgListFn = '/disk2/zfchen/data/A2D/Release/sentenceAnno/a2d_missed_videos.txt'
247 | #annOriFn = '/disk2/zfchen/data/A2D/Release/videoset.csv'
248 | #a2dSetParser(annFn, annFd, annIgListFn, annOriFn)
249 | #otbPKFile ='../data/annForDb_otb.pd'
250 | #otbPKFileV2 ='../data/annForDb_otbV2.pd'
251 | #otbNew= fix_otb_frm_bbxList(otbPKFile, otbPKFileV2)
252 | #imList = otbPCK2List(otbPKFile)
253 | #print(imList[:10])
254 | #print(imList[10000])
255 | #print(len(imList))
256 |
--------------------------------------------------------------------------------
/fun/eval.py:
--------------------------------------------------------------------------------
1 | from wsParamParser import parse_args
2 | from data.data_loader import*
3 | from datasetLoader import *
4 | from modelArc import *
5 | from optimizers import *
6 | from logInfo import logInF
7 | from lossPackage import *
8 | from netUtil import *
9 | from tensorboardX import SummaryWriter
10 | import time
11 |
12 | import pdb
13 |
14 | if __name__=='__main__':
15 | opt = parse_args()
16 | # build dataloader
17 | dataLoader, datasetOri= build_dataloader(opt)
18 | # build network
19 | model = build_network(opt)
20 | # build_optimizer
21 | logger = logInF(opt.logFd)
22 | ep = 0
23 | #thre_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
24 | thre_list = [0.5]
25 | acc_list = list()
26 | more_detailed_flag = True
27 | for thre in thre_list:
28 | acc_list.append([])
29 |
30 | if opt.eval_val_flag:
31 | model.eval()
32 | resultList = list()
33 | vIdList = list()
34 | set_name_ori= opt.set_name
35 | opt.set_name = 'val'
36 | dataLoaderEval, datasetEvalOri = build_dataloader(opt)
37 | for itr_eval, inputData in enumerate(dataLoaderEval):
38 | tube_embedding, cap_embedding, tubeInfo, indexOri, cap_length_list, vd_name_list, word_lbl_list = inputData
39 | dataIdx = None
40 | b_size = tube_embedding.shape[0]
41 | # B*P*T*D
42 | imDis = tube_embedding.cuda()
43 | imDis = imDis.view(-1, imDis.shape[2], imDis.shape[3])
44 | wordEmb = cap_embedding.cuda()
45 | wordEmb = wordEmb.view(-1, wordEmb.shape[2], wordEmb.shape[3])
46 | imDis.requires_grad=False
47 | wordEmb.requires_grad=False
48 |
49 | if opt.wsMode =='coAtt':
50 | simMM = model(imDis, wordEmb, cap_length_list)
51 | #pdb.set_trace()
52 | simMM = simMM.view(b_size, opt.rpNum)
53 | for i, thre in enumerate(thre_list):
54 | acc_list[i] += evalAcc_att_test(simMM, tubeInfo, indexOri, datasetEvalOri, opt.visRsFd+str(ep), False, topK=1, thre_list=[thre], more_detailed_flag=more_detailed_flag)
55 |
56 | for i, thre in enumerate(thre_list):
57 | resultList = acc_list[i]
58 | accSum = 0
59 | for ele in resultList:
60 | recall_k= ele[1]
61 | accSum +=recall_k
62 | logger('thre @ %f, Average accuracy on validation set is %3f\n' %(thre, accSum/len(resultList)))
63 |
64 | out_result_fn = opt.logFd + 'result_val_' +os.path.basename(opt.initmodel).split('.')[0] + '.pk'
65 | pickledump(out_result_fn, acc_list)
66 |
67 | if opt.eval_test_flag:
68 | model.eval()
69 | resultList = list()
70 | vIdList = list()
71 | set_name_ori= opt.set_name
72 | opt.set_name = 'test'
73 | dataLoaderEval, datasetEvalOri = build_dataloader(opt)
74 | #pdb.set_trace()
75 | for itr_eval, inputData in enumerate(dataLoaderEval):
76 | tube_embedding, cap_embedding, tubeInfo, indexOri, cap_length_list, vd_name_list, word_lbl_list = inputData
77 | dataIdx = None
78 | b_size = tube_embedding.shape[0]
79 | # B*P*T*D
80 | imDis = tube_embedding.cuda()
81 | imDis = imDis.view(-1, imDis.shape[2], imDis.shape[3])
82 | wordEmb = cap_embedding.cuda()
83 | wordEmb = wordEmb.view(-1, wordEmb.shape[2], wordEmb.shape[3])
84 | imDis.requires_grad=False
85 | wordEmb.requires_grad=False
86 |
87 | if opt.wsMode =='coAtt':
88 | simMM = model(imDis, wordEmb, cap_length_list)
89 | simMM = simMM.view(b_size, opt.rpNum)
90 | for i, thre in enumerate(thre_list):
91 | acc_list[i] += evalAcc_att_test(simMM, tubeInfo, indexOri, datasetEvalOri, opt.visRsFd+str(ep), False, thre_list=[thre], more_detailed_flag=more_detailed_flag)
92 |
93 | for i, thre in enumerate(thre_list):
94 | resultList = acc_list[i]
95 | accSum = 0
96 | for ele in resultList:
97 | recall_k= ele[1]
98 | accSum +=recall_k
99 | logger('thre @ %f, Average accuracy on testing set is %3f\n' %(thre, accSum/len(resultList)))
100 | out_result_fn = opt.logFd + 'result_test_' +os.path.basename(opt.initmodel).split('.')[0] + '.pk'
101 | pickledump(out_result_fn, acc_list)
102 |
103 |
--------------------------------------------------------------------------------
/fun/evalDet.py:
--------------------------------------------------------------------------------
1 | #import os
2 | import numpy as np
3 | import sys
4 | import magic
5 | import re
6 | sys.path.append('..')
7 | from util.mytoolbox import *
8 | from util.get_image_size import get_image_size
9 | import cv2
10 | import pdb
11 | import ipdb
12 | import copy
13 | import torch
14 | from fun.datasetLoader import *
15 | from vidDatasetParser import evaluate_tube_recall_vid, resize_tube_bbx
16 | from netUtil import *
17 |
18 | sys.path.append('../annotation')
19 | from script_test_annotation import evaluate_tube_recall
20 |
21 | def computeIoU(box1, box2):
22 | # each box is of [x1, y1, w, h]
23 | inter_x1 = max(box1[0], box2[0])
24 | inter_y1 = max(box1[1], box2[1])
25 | inter_x2 = min(box1[0]+box1[2]-1, box2[0]+box2[2]-1)
26 | inter_y2 = min(box1[1]+box1[3]-1, box2[1]+box2[3]-1)
27 |
28 | if inter_x1 < inter_x2 and inter_y1 < inter_y2:
29 | inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
30 | else:
31 | inter = 0
32 | union = box1[2]*box1[3] + box2[2]*box2[3] - inter
33 | return float(inter)/union
34 |
35 | # get original size of the proposls
36 | # the proposals are extracted at larger size
37 | def transFormBbx(bbx, im_shape,reSize=800, maxSize=1200, im_scale=None):
38 | if im_scale is not None:
39 | bbx = [ ele/im_scale for ele in bbx]
40 | bbx[2] = bbx[2]-bbx[0]
41 | bbx[3] = bbx[3]-bbx[1]
42 | return bbx
43 | im_size_min = np.min(im_shape[0:2])
44 | im_size_max = np.max(im_shape[0:2])
45 | im_scale = float(reSize)/im_size_min
46 | if(np.round(im_size_max*im_scale) > maxSize):
47 | im_scale = float(maxSize)/ float(im_size_max)
48 | bbx = [ ele/im_scale for ele in bbx]
49 | bbx[2] = bbx[2]-bbx[0]
50 | bbx[3] = bbx[3]-bbx[1]
51 | return bbx
52 |
53 |
54 | class evalDetAcc(object):
55 | def __init__(self, gtBbxList=None, IoU=0.5, topK=1):
56 | self.gtList = gtBbxList
57 | self.thre = IoU
58 | self.topK= topK
59 |
60 | def evalList(self, prpList):
61 | imgNum = len(prpList)
62 | posNum = 0
63 | for i in range(imgNum):
64 | tmpRpList= prpList[i]
65 | gtBbx = self.gtList[i]
66 | for j in range(self.topK):
67 | iou = computeIoU(gtBbx, tmpRpList[j])
68 | if iou>self.thre:
69 | posNum +=1
70 | break
71 | return float(posNum)/imgNum
72 |
73 | def rpMatPreprocess(rpMatrix, imWH, isA2D=False):
74 | rpList= list()
75 | rpNum = rpMatrix.shape[0]
76 | for i in range(rpNum):
77 | tmpRp = list(rpMatrix[i, :])
78 | if not isA2D:
79 | bbx = transFormBbx(tmpRp, imWH)
80 | else:
81 | bbx = transFormBbx(tmpRp, imWH, im_scale=imWH)
82 | rpList.append(bbx)
83 | return rpList
84 |
85 | def evalAcc_att_test(simMM_batch, tubeInfo, indexOri, datasetOri, visRsFd, visFlag=False, topK = 1, thre_list=[0.5], more_detailed_flag=False):
86 | resultList= list()
87 | bSize = len(indexOri)
88 | tube_Prp_num = len(tubeInfo[0][0][0])
89 | stIdx = 0
90 | vid_parser = datasetOri.vid_parser
91 | for idx, lbl in enumerate(indexOri):
92 | simMM = simMM_batch[idx, :]
93 | simMMReshape = simMM.view(-1, tube_Prp_num)
94 |
95 | sortSim, simIdx = torch.sort(simMMReshape, dim=1, descending=True)
96 | simIdx = simIdx.data.cpu().numpy().squeeze(axis=0)
97 | sort_sim_np = sortSim.data.cpu().numpy().squeeze(axis=0)
98 |
99 | tube_Info_sub = tubeInfo[idx]
100 | tube_info_sub_prp, frm_info_list = tube_Info_sub
101 | tube_info_sub_prp_bbx, tube_info_sub_prp_score = tube_info_sub_prp
102 | #prpListSort = [ [tube_info_sub_prp_bbx[simIdx[i]], sort_sim_np[i] ]for i in range(topK)]
103 | prpListSort = [ [tube_info_sub_prp_bbx[simIdx[i]] for i in range(topK)], [sort_sim_np[i] for i in range(topK)] ]
104 | shot_proposals = [prpListSort, frm_info_list]
105 | for ii, thre in enumerate(thre_list):
106 | if more_detailed_flag:
107 | recall_tmp, iou_list = evaluate_tube_recall_vid(shot_proposals, vid_parser, lbl, thre, topKOri=topK, more_detailed_flag=more_detailed_flag)
108 | else:
109 | recall_tmp= evaluate_tube_recall_vid(shot_proposals, vid_parser, lbl, thre, topKOri=topK)
110 | if more_detailed_flag:
111 | resultList.append((lbl, recall_tmp[-1], simIdx, sort_sim_np, iou_list))
112 | else:
113 | resultList.append((lbl, recall_tmp[-1]))
114 | print('accuracy for %d: %3f' %(lbl, recall_tmp[-1]))
115 |
116 | if visFlag:
117 |
118 | print(vid_parser.tube_cap_dict[lbl])
119 | print(lbl)
120 | #continue
121 | vd_name, ins_id_str = vid_parser.get_shot_info_from_index(lbl)
122 | frmImNameList = [os.path.join(vid_parser.jpg_folder, vd_name, frame_name + '.JPEG') for frame_name in frm_info_list]
123 | frmImList = list()
124 | for fId, imPath in enumerate(frmImNameList):
125 | img = cv2.imread(imPath)
126 | frmImList.append(img)
127 | vis_frame_num = 30
128 | visIner =max(int(len(frmImList) /vis_frame_num), 1)
129 |
130 | for ii in range(topK):
131 | print('visualizing tube %d\n'%(ii))
132 | #pdb.set_trace()
133 | tube = prpListSort[0][ii]
134 | frmImList_vis = [frmImList[iii] for iii in range(0, len(frmImList), visIner)]
135 | tube_vis = [tube[iii] for iii in range(0, len(frmImList), visIner)]
136 | tube_vis_resize = resize_tube_bbx(tube_vis, frmImList_vis)
137 | vd_name_raw = vd_name.split('/')[-1]
138 | makedirs_if_missing(visRsFd)
139 | visTube_from_image(copy.deepcopy(frmImList_vis), tube_vis_resize, visRsFd+'/'+vd_name_raw+ '_' + str(ii)+'.gif')
140 | pdb.set_trace()
141 | return resultList
142 |
143 |
144 | def evalAcc_att(simMMFull, tubeInfo, indexOri, datasetOri, visRsFd, visFlag=False, topK = 1, thre_list=[0.5], more_detailed_flag=False):
145 | # pdb.set_trace()
146 | resultList= list()
147 | bSize = len(indexOri)
148 | tube_Prp_num = len(tubeInfo[0][0][0])
149 | stIdx = 0
150 | #thre_list = [0.2, 0.3, 0.4, 0.5]
151 | vid_parser = datasetOri.vid_parser
152 | for idx, lbl in enumerate(indexOri):
153 | simMM = simMMFull[idx, :, idx]
154 | simMMReshape = simMM.view(-1, tube_Prp_num)
155 | #pdb.set_trace()
156 | sortSim, simIdx = torch.sort(simMMReshape, dim=1, descending=True)
157 | simIdx = simIdx.data.cpu().numpy().squeeze(axis=0)
158 | sort_sim_np = sortSim.data.cpu().numpy().squeeze(axis=0)
159 |
160 | tube_Info_sub = tubeInfo[idx]
161 | tube_info_sub_prp, frm_info_list = tube_Info_sub
162 | tube_info_sub_prp_bbx, tube_info_sub_prp_score = tube_info_sub_prp
163 | #prpListSort = [ [tube_info_sub_prp_bbx[simIdx[i]], sort_sim_np[i] ]for i in range(topK)]
164 | prpListSort = [ [tube_info_sub_prp_bbx[simIdx[i]] for i in range(topK)], [sort_sim_np[i] for i in range(topK)] ]
165 | shot_proposals = [prpListSort, frm_info_list]
166 | for ii, thre in enumerate(thre_list):
167 | if more_detailed_flag:
168 | recall_tmp, iou_list = evaluate_tube_recall_vid(shot_proposals, vid_parser, lbl, thre, topKOri=topK, more_detailed_flag=more_detailed_flag)
169 | else:
170 | recall_tmp= evaluate_tube_recall_vid(shot_proposals, vid_parser, lbl, thre, topKOri=topK)
171 | if more_detailed_flag:
172 | resultList.append((lbl, recall_tmp[-1], simIdx, sort_sim_np, iou_list))
173 | else:
174 | resultList.append((lbl, recall_tmp[-1]))
175 | print('accuracy for %d: %3f' %(lbl, recall_tmp[-1]))
176 |
177 | #pdb.set_trace()
178 | if visFlag:
179 |
180 |
181 | # visualize sample results
182 | #if(recall_tmp[-1]<=0.5):
183 | # continue
184 | print(vid_parser.tube_cap_dict[lbl])
185 | print(lbl)
186 | #pdb.set_trace()
187 | #continue
188 | vd_name, ins_id_str = vid_parser.get_shot_info_from_index(lbl)
189 | frmImNameList = [os.path.join(vid_parser.jpg_folder, vd_name, frame_name + '.JPEG') for frame_name in frm_info_list]
190 | frmImList = list()
191 | for fId, imPath in enumerate(frmImNameList):
192 | img = cv2.imread(imPath)
193 | frmImList.append(img)
194 | vis_frame_num = 30
195 | visIner =max(int(len(frmImList) /vis_frame_num), 1)
196 |
197 | for ii in range(topK):
198 | print('visualizing tube %d\n'%(ii))
199 | #pdb.set_trace()
200 | tube = prpListSort[0][ii]
201 | frmImList_vis = [frmImList[iii] for iii in range(0, len(frmImList), visIner)]
202 | tube_vis = [tube[iii] for iii in range(0, len(frmImList), visIner)]
203 | tube_vis_resize = resize_tube_bbx(tube_vis, frmImList_vis)
204 | vd_name_raw = vd_name.split('/')[-1]
205 | makedirs_if_missing(visRsFd)
206 | visTube_from_image(copy.deepcopy(frmImList_vis), tube_vis_resize, visRsFd+'/'+vd_name_raw+ '_' + str(ii)+'.gif')
207 | pdb.set_trace()
208 | return resultList
209 |
210 |
211 |
212 |
213 | def evalAcc(imFtr, txtFtr, tubeInfo, indexOri, datasetOri, visRsFd, visFlag=False, topK = 1, thre_list=[0.5], more_detailed_flag=False):
214 | resultList= list()
215 | bSize = len(indexOri)
216 | tube_Prp_num = imFtr.shape[1]
217 | stIdx = 0
218 | #thre_list = [0.2, 0.3, 0.4, 0.5]
219 | vid_parser = datasetOri.vid_parser
220 | assert txtFtr.shape[1]==1
221 | for idx, lbl in enumerate(indexOri):
222 | imFtrSub = imFtr[idx]
223 | txtFtrSub = txtFtr[idx].view(-1,1)
224 | simMM = torch.mm(imFtrSub, txtFtrSub)
225 | #pdb.set_trace()
226 | simMMReshape = simMM.view(-1, tube_Prp_num)
227 | sortSim, simIdx = torch.sort(simMMReshape, dim=1, descending=True)
228 | simIdx = simIdx.data.cpu().numpy().squeeze(axis=0)
229 | sort_sim_np = sortSim.data.cpu().numpy().squeeze(axis=0)
230 |
231 | tube_Info_sub = tubeInfo[idx]
232 | tube_info_sub_prp, frm_info_list = tube_Info_sub
233 | tube_info_sub_prp_bbx, tube_info_sub_prp_score = tube_info_sub_prp
234 | #prpListSort = [ [tube_info_sub_prp_bbx[simIdx[i]], sort_sim_np[i] ]for i in range(topK)]
235 | prpListSort = [ [tube_info_sub_prp_bbx[simIdx[i]] for i in range(topK)], [sort_sim_np[i] for i in range(topK)] ]
236 | shot_proposals = [prpListSort, frm_info_list]
237 | for ii, thre in enumerate(thre_list):
238 | if more_detailed_flag:
239 | recall_tmp, iou_list= evaluate_tube_recall_vid(shot_proposals, vid_parser, lbl, thre, topKOri=topK, more_detailed_flag=more_detailed_flag)
240 | else:
241 | recall_tmp = evaluate_tube_recall_vid(shot_proposals, vid_parser, lbl, thre, topKOri=topK)
242 | if more_detailed_flag:
243 | resultList.append((lbl, recall_tmp[-1], simIdx, sort_sim_np, iou_list))
244 | else:
245 | resultList.append((lbl, recall_tmp[-1]))
246 | #print('accuracy for %d: %3f' %(lbl, recall_tmp[-1]))
247 |
248 | #pdb.set_trace()
249 | if visFlag:
250 | # visualize sample results
251 | #if(recall_tmp[-1]<=0.5):
252 | # continue
253 | vd_name, ins_id_str = vid_parser.get_shot_info_from_index(lbl)
254 | frmImNameList = [os.path.join(vid_parser.jpg_folder, vd_name, frame_name + '.JPEG') for frame_name in frm_info_list]
255 | frmImList = list()
256 | for fId, imPath in enumerate(frmImNameList):
257 | img = cv2.imread(imPath)
258 | frmImList.append(img)
259 | vis_frame_num = 30
260 | visIner =max(int(len(frmImList) /vis_frame_num), 1)
261 |
262 | for ii in range(topK):
263 | print('visualizing tube %d\n'%(ii))
264 | #pdb.set_trace()
265 | tube = prpListSort[0][ii]
266 | frmImList_vis = [frmImList[iii] for iii in range(0, len(frmImList), visIner)]
267 | tube_vis = [tube[iii] for iii in range(0, len(frmImList), visIner)]
268 | tube_vis_resize = resize_tube_bbx(tube_vis, frmImList_vis)
269 | vd_name_raw = vd_name.split('/')[-1]
270 | makedirs_if_missing(visRsFd)
271 | visTube_from_image(copy.deepcopy(frmImList_vis), tube_vis_resize, visRsFd+'/'+vd_name_raw+ '_' + str(ii)+'.gif')
272 | pdb.set_trace()
273 | return resultList
274 |
275 | if __name__=='__main__':
276 | pdb.set_trace()
277 |
--------------------------------------------------------------------------------
/fun/image_toolbox.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 |
4 | import commands
5 | import copy
6 | import cv2
7 | from easydict import EasyDict as edict
8 | import numpy as np
9 | import os
10 | import shutil
11 | from dashed_rect import drawrect
12 | #from settings import FFMPEG
13 | import pdb
14 | TMP_DIR = '.tmp'
15 | FFMPEG = 'ffmpeg'
16 |
17 | SAVE_VIDEO = FFMPEG + ' -y -r %d -i %s/%s.jpg %s'
18 |
19 | def images2video(image_list, frame_rate, video_path, max_edge=None):
20 | if os.path.exists(TMP_DIR):
21 | shutil.rmtree(TMP_DIR)
22 | os.mkdir(TMP_DIR)
23 | img_size = None
24 | for cur_num, cur_img in enumerate(image_list):
25 | cur_fname = os.path.join(TMP_DIR, '%08d.jpg' % cur_num)
26 | if max_edge is not None:
27 | cur_img = imread_if_str(cur_img)
28 | if isinstance(cur_img, str) or isinstance(cur_img, unicode):
29 | shutil.copyfile(cur_img, cur_fname)
30 | elif isinstance(cur_img, np.ndarray):
31 | max_len = max(cur_img.shape[:2])
32 | if max_len > max_edge and img_size is None and max_edge is not None:
33 | magnif = float(max_edge) / float(max_len)
34 | img_size = (int(cur_img.shape[1] * magnif), int(cur_img.shape[0] * magnif))
35 | cur_img = cv2.resize(cur_img, img_size)
36 | elif max_edge is not None:
37 | if img_size is None:
38 | magnif = float(max_edge) / float(max_len)
39 | img_size = (int(cur_img.shape[1] * magnif), int(cur_img.shape[0] * magnif))
40 | cur_img = cv2.resize(cur_img, img_size)
41 | cv2.imwrite(cur_fname, cur_img)
42 | else:
43 | NotImplementedError()
44 | print commands.getoutput(SAVE_VIDEO % (frame_rate, TMP_DIR, '%08d', video_path))
45 | shutil.rmtree(TMP_DIR)
46 |
47 | def imread_if_str(img):
48 | if isinstance(img, basestring):
49 | img = cv2.imread(img)
50 | return img
51 |
52 | def draw_rectangle(img, bbox, color=(0,0,255), thickness=3, use_dashed_line=False):
53 | img = imread_if_str(img)
54 | if isinstance(bbox, dict):
55 | bbox = [
56 | bbox['x1'],
57 | bbox['y1'],
58 | bbox['x2'],
59 | bbox['y2'],
60 | ]
61 | bbox[0] = max(bbox[0], 0)
62 | bbox[1] = max(bbox[1], 0)
63 | bbox[0] = min(bbox[0], img.shape[1])
64 | bbox[1] = min(bbox[1], img.shape[0])
65 | bbox[2] = max(bbox[2], 0)
66 | bbox[3] = max(bbox[3], 0)
67 | bbox[2] = min(bbox[2], img.shape[1])
68 | bbox[3] = min(bbox[3], img.shape[0])
69 | assert bbox[2] >= bbox[0]
70 | assert bbox[3] >= bbox[1]
71 | assert bbox[0] >= 0
72 | assert bbox[1] >= 0
73 | assert bbox[2] <= img.shape[1]
74 | assert bbox[3] <= img.shape[0]
75 | cur_img = copy.deepcopy(img)
76 | #pdb.set_trace()
77 | if use_dashed_line:
78 | drawrect(
79 | cur_img,
80 | (int(bbox[0]), int(bbox[1])),
81 | (int(bbox[2]), int(bbox[3])),
82 | color,
83 | thickness,
84 | 'dotted'
85 | )
86 | else:
87 | #pdb.set_trace()
88 | cv2.rectangle(
89 | cur_img,
90 | (int(bbox[0]), int(bbox[1])),
91 | (int(bbox[2]), int(bbox[3])),
92 | color,
93 | thickness)
94 | return cur_img
95 |
96 | def rgb2gray(rgb):
97 | r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2]
98 | gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
99 | return gray
100 |
101 | def gray_background(img, bbox):
102 | img = imread_if_str(img)
103 | if isinstance(bbox, dict):
104 | bbox = [
105 | bbox['x1'],
106 | bbox['y1'],
107 | bbox['x2'],
108 | bbox['y2'],
109 | ]
110 | bbox[0] = int(max(bbox[0], 0))
111 | bbox[1] = int(max(bbox[1], 0))
112 | bbox[0] = min(bbox[0], img.shape[1])
113 | bbox[1] = min(bbox[1], img.shape[0])
114 | bbox[2] = int(max(bbox[2], 0))
115 | bbox[3] = int(max(bbox[3], 0))
116 | bbox[2] = min(bbox[2], img.shape[1])
117 | bbox[3] = min(bbox[3], img.shape[0])
118 | assert bbox[2] >= bbox[0]
119 | assert bbox[3] >= bbox[1]
120 | assert bbox[0] >= 0
121 | assert bbox[1] >= 0
122 | assert bbox[2] <= img.shape[1]
123 | assert bbox[3] <= img.shape[0]
124 | #gray_img = copy.deepcopy(img)
125 | #gray_img = np.stack((gray_img, gray_img, gray_img), axis=2)
126 | #gray_img = rgb2gray(gray_img)
127 | gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
128 | gray_img = np.stack((gray_img, gray_img, gray_img), axis=2)
129 | #pdb.set_trace()
130 | gray_img[bbox[1]:bbox[3], bbox[0]:bbox[2], ...] = img[bbox[1]:bbox[3], bbox[0]:bbox[2], ...]
131 | #pdb.set_trace()
132 | return gray_img
133 |
--------------------------------------------------------------------------------
/fun/logInfo.py:
--------------------------------------------------------------------------------
1 |
2 | import time
3 | import os
4 | class logInF (object):
5 | def __init__(self, logPre):
6 | dirname = os.path.dirname(logPre)
7 | if not os.path.exists(dirname):
8 | os.makedirs(dirname)
9 | logFile=logPre+time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + 'log.txt'
10 | self.prHandle=open(logFile, 'w')
11 |
12 | def __call__(self, logData):
13 | print(logData)
14 | self.prHandle.write(logData)
15 | self.prHandle.flush()
16 |
--------------------------------------------------------------------------------
/fun/lossPackage.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import pdb
5 | import time
6 |
7 | class HLoss(nn.Module):
8 | def __init__(self):
9 | super(HLoss, self).__init__()
10 |
11 | def forward(self, x):
12 | b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1)
13 | b = -1.0 * b.sum(dim=1)
14 | b = torch.mean(b)
15 | return b
16 |
17 | class lossEvaluator(nn.Module):
18 | def __init__(self, margin=0.1, biLossFlag=True, lossWFlag=False, lamda=0.8, struct_flag=False, struct_only_flag=False, entropy_regu_flag=False, lamda2=0.1, loss_type=''):
19 | super(lossEvaluator, self).__init__()
20 | self.margin =margin
21 | self.biLossFlag = biLossFlag
22 | self.lossWFlag = lossWFlag
23 | self.lamda= lamda
24 | self.struct_flag = struct_flag
25 | self.struct_only_flag = struct_only_flag
26 | self.entropy_regu_flag = entropy_regu_flag
27 | self.lamda2 = lamda2
28 | self.loss_type = loss_type
29 | if self.entropy_regu_flag:
30 | self.entropy_calculator = HLoss()
31 |
32 | def forward(self, imFtr=None, disFtr=None, lblList=None, simMM=None, region_gt_ori=None):
33 | if self.wsMode=='coAtt':
34 | loss = self.forwardCoAtt_efficient(simMM, lblList)
35 | return loss
36 |
37 | def forwardCoAtt_efficient(self, simMMRe, lblList):
38 |
39 | simMax, maxIdx= torch.max(simMMRe.squeeze(3), dim=1)
40 | loss = torch.zeros(1).cuda()
41 | pair_num = 0.000001
42 |
43 | t1 = time.time()
44 | #print(lblList)
45 | bSize = len(lblList)
46 | b_size = simMax.shape[0]
47 | pos_diag = torch.cat([simMax[i, i].unsqueeze(0) for i in range(b_size)])
48 | one_mat = torch.ones(b_size, b_size).cuda()
49 | pos_diag_mat = torch.mul(pos_diag, one_mat)
50 | pos_diag_trs = pos_diag_mat.transpose(0,1)
51 |
52 | mask_val = torch.ones(b_size, b_size).cuda()
53 |
54 | for i in range(b_size):
55 | lblI = lblList[i]
56 | for j in range(i, b_size):
57 | lblJ = lblList[j]
58 | if lblI==lblJ:
59 | mask_val[i, j]=0
60 | mask_val[j, i]=0
61 | pair_num = pair_num + torch.sum(mask_val)
62 |
63 | loss_mat_1 = simMax -pos_diag + self.margin
64 | loss_mask = (loss_mat_1>0).float()
65 | loss_mat_1_mask = loss_mat_1 *loss_mask * mask_val
66 | loss1 = torch.sum(loss_mat_1_mask)
67 |
68 | loss_mat_2 = simMax -pos_diag_trs +self.margin
69 | loss_mask_2 = (loss_mat_2>0).float()
70 | loss_mat_2_mask = loss_mat_2 *loss_mask_2 * mask_val
71 | loss2 = torch.sum(loss_mat_2_mask)
72 | loss = (loss1+loss2)/pair_num
73 |
74 | if self.entropy_regu_flag:
75 | simMMRe = simMMRe.squeeze()
76 | ftr_match_pair_list = list()
77 | ftr_unmatch_pair_list = list()
78 |
79 | for i in range(bSize):
80 | for j in range(bSize):
81 | if i==j:
82 | ftr_match_pair_list.append(simMMRe[i, ..., i])
83 | elif lblList[i]!=lblList[j]:
84 | ftr_unmatch_pair_list.append(simMMRe[i, ..., j])
85 |
86 | ftr_match_pair_mat = torch.stack(ftr_match_pair_list, 0)
87 | ftr_unmatch_pair_mat = torch.stack(ftr_unmatch_pair_list, 0)
88 | match_num = len(ftr_match_pair_list)
89 | unmatch_num = len(ftr_unmatch_pair_list)
90 | if match_num>0:
91 | entro_loss = self.entropy_calculator(ftr_match_pair_mat)
92 | loss +=self.lamda2*entro_loss
93 | print('entropy loss: %3f ' %(float(entro_loss)))
94 | #pdb.set_trace()
95 | print('\n')
96 | return loss
97 |
98 | def build_lossEval(opts):
99 | if opts.wsMode == 'rankTube' or opts.wsMode=='coAtt' or opts.wsMode=='coAttV2' or opts.wsMode=='coAttV3' or opts.wsMode == 'coAttV4' or opts.wsMode=='rankFrm' or opts.wsMode=='coAttBi' or opts.wsMode=='coAttBiV2' or opts.wsMode =='coAttBiV3':
100 | loss_criterion = lossEvaluator(opts.margin, opts.biLoss, opts.lossW, \
101 | opts.lamda, opts.struct_flag, opts.struct_only, \
102 | opts.entropy_regu_flag, opts.lamda2, \
103 | loss_type=opts.loss_type)
104 | loss_criterion.wsMode =opts.wsMode
105 | return loss_criterion
106 | elif opts.wsMode =='rankGroundR' or opts.wsMode =='coAttGroundR' or opts.wsMode=='rankGroundRV2':
107 | loss_criterion = lossGroundR(entropy_regu_flag=opts.entropy_regu_flag, \
108 | lamda2=opts.lamda2 )
109 | loss_criterion.wsMode = opts.wsMode
110 | return loss_criterion
111 |
--------------------------------------------------------------------------------
/fun/modelArc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pdb
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.nn.utils.rnn import pack_padded_sequence
6 | #from multiGraphAttention import *
7 | from netvlad import NetVLAD
8 | from classSST import *
9 |
10 | class wsEmb(nn.Module):
11 | def __init__(self, imEncoder, wordEncoder):
12 | super(wsEmb, self).__init__()
13 | self._train=True
14 | self.imEncoder = imEncoder
15 | self.wordEncoder = wordEncoder
16 | self._initialize_weights()
17 | self.wsMode = 'rank'
18 | self.vis_type = 'fc'
19 |
20 | def forward(self, imDis, wordEmb, capLengthsFull=None, frmListFull=None, rpListFull=None, dataIdx=None):
21 | if dataIdx is not None:
22 | numEle = len(dataIdx)
23 | if capLengthsFull is not None:
24 | capLengths = [capLengthsFull[int(dataIdx[i])] for i in range(numEle)]
25 | else:
26 | capLengths = capLengthsFull
27 | if frmListFull is not None:
28 | frmList = [frmListFull[int(dataIdx[i])] for i in range(numEle)]
29 | else:
30 | frmList = frmListFull
31 | if rpListFull is not None:
32 | rpList = [rpListFull[int(dataIdx[i])] for i in range(numEle)]
33 | else:
34 | rpList = rpListFull
35 | else:
36 | capLengths = capLengthsFull
37 | frmList = frmListFull
38 | rpList = rpListFull
39 |
40 | if self.wsMode == 'rankTube' or self.wsMode =='rankFrm':
41 | imEnDis, wordEnDis = self.forwardRank(imDis, wordEmb, capLengths)
42 | return imEnDis, wordEnDis
43 |
44 | #pdb.set_trace()
45 | if self.wsMode == 'coAtt' or self.wsMode == 'coAttV2' or self.wsMode=='coAttV3' or self.wsMode=='coAttV4' or self.wsMode=='coAttBi' or self.wsMode=='coAttBiV2' or self.wsMode=='coAttBiV3':
46 | simMM = self.forwardCoAtt(imDis, wordEmb, capLengths)
47 | return simMM
48 |
49 | elif self.wsMode == 'rankGroundR' or self.wsMode=='rankGroundRV2':
50 | logDist, rpSS= self.forwardGroundR(imDis, wordEmb, capLengths)
51 | return logDist, rpSS
52 |
53 | elif self.wsMode == 'coAttGroundR':
54 | logDist, rpSS= self.forwardCoAttGroundR(imDis, wordEmb, capLengths)
55 | return logDist, rpSS
56 |
57 |
58 | def forwardCoAttGroundR(self, imDis, wordEmb, capLengths):
59 | b_size = imDis.shape[0]
60 | assert len(imDis.size())==4
61 | assert len(wordEmb.size())==4
62 |
63 | imDis = imDis.view(-1, imDis.shape[2], imDis.shape[3])
64 | wordEmb = wordEmb.view(-1, wordEmb.shape[2], wordEmb.shape[3])
65 | imEnDis, wordEnDis = self.imEncoder(imDis, wordEmb, capLengths)
66 |
67 | imEnDis = imEnDis.view(b_size , -1, imEnDis.shape[2])
68 | wordEnDis = wordEnDis.view(b_size, -1, wordEnDis.shape[2])
69 | visFtrEm, rpSS= self.visDecoder(imEnDis, wordEnDis)
70 | reCnsSen =None
71 | if self.training:
72 | assert visFtrEm.shape[2]==1
73 | visFtrEm = visFtrEm.squeeze(dim=2)
74 | reCnsSen= self.recntr(visFtrEm, wordEmb, capLengths)
75 | return reCnsSen, rpSS
76 |
77 |
78 | def forwardRank(self, imDis, wordEmb, capLengths):
79 | if self.vis_type =='fc':
80 | imDis = imDis.mean(1)
81 | else:
82 | assert len(imDis.size())==3
83 | assert len(wordEmb.size())==3
84 | #pdb.set_trace()
85 | imEnDis = self.imEncoder(imDis)
86 | wordEnDis = self.wordEncoder(wordEmb, capLengths)
87 | assert len(imEnDis.size())==2
88 | assert len(wordEnDis.size())==2
89 | imEnDis = F.normalize(imEnDis, p=2, dim=1)
90 | wordEnDis = F.normalize(wordEnDis, p=2, dim=1)
91 | return imEnDis, wordEnDis
92 |
93 | def forwardGroundR(self, imDis, wordEmb, capLengths, frmList=None):
94 | #pdb.set_trace()
95 | b_size = imDis.shape[0]
96 | assert len(imDis.size())==4
97 | assert len(wordEmb.size())==4
98 |
99 | if self.vis_type =='fc':
100 | imDis = imDis.mean(2)
101 | else:
102 | imDis = imDis.view(-1, imDis.shape[2], imDis.shape[3])
103 | imEnDis = self.imEncoder(imDis)
104 | wordEmb = wordEmb.view(-1, wordEmb.shape[2], wordEmb.shape[3])
105 | wordEnDis = self.wordEncoder(wordEmb, capLengths)
106 | #imEnDis = F.normalize(imEnDis, p=2, dim=2)
107 | #wordEnDis = F.normalize(wordEnDis, p=2, dim=2)
108 | # pdb.set_trace()
109 | if (len(imEnDis.shape)==3):
110 | imEnDis = imEnDis.view(b_size , -1, imEnDis.shape[2])
111 | elif(len(imEnDis.shape)==2):
112 | imEnDis = imEnDis.view(b_size , -1, imEnDis.shape[1])
113 | wordEnDis = wordEnDis.view(b_size, -1, wordEnDis.shape[1])
114 | visFtrEm, rpSS= self.visDecoder(imEnDis, wordEnDis)
115 |
116 | #pdb.set_trace()
117 | if hasattr(self, 'fc2recontr'):
118 | visFtrEm_trs = visFtrEm.transpose(1,2)
119 | visFtrEm_trs = self.fc2recontr(visFtrEm_trs)
120 | visFtrEm = visFtrEm_trs.transpose(1,2)
121 | reCnsSen =None
122 | # pdb.set_trace()
123 | if self.training:
124 | assert visFtrEm.shape[2]==1
125 | visFtrEm = visFtrEm.squeeze(dim=2)
126 | reCnsSen= self.recntr(visFtrEm, wordEmb, capLengths)
127 | return reCnsSen, rpSS
128 |
129 |
130 | def forwardCoAtt(self, imDis, wordEmb, capLengths):
131 | assert len(imDis.size())==3
132 | assert len(wordEmb.size())==3
133 | #pdb.set_trace()
134 | simMM = self.imEncoder(imDis, wordEmb, capLengths)
135 | return simMM
136 |
137 | def _initialize_weights(self):
138 | #pdb.set_trace()
139 | for m in self.modules():
140 | if isinstance(m, nn.Conv2d):
141 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
142 | if m.bias is not None:
143 | nn.init.constant_(m.bias, 0)
144 | elif isinstance(m, nn.BatchNorm2d):
145 | nn.init.constant_(m.weight, 1)
146 | nn.init.constant_(m.bias, 0)
147 | elif isinstance(m, nn.BatchNorm1d):
148 | nn.init.constant_(m.weight, 1)
149 | nn.init.constant_(m.bias, 0)
150 | elif isinstance(m, nn.Linear):
151 | nn.init.normal_(m.weight, 0, 0.01)
152 | if m.bias is not None:
153 | nn.init.constant_(m.bias, 0)
154 |
155 | def build_groundR(self, opts):
156 | self.recntr = build_recontructor(opts)
157 | self.visDecoder = build_visDecoder(opts)
158 | if opts.wsMode=='rankGroundRV2':
159 | self.fc2recontr = torch.nn.Linear(opts.dim_ftr, 300)
160 |
161 | class txtDecoder(nn.Module):
162 | def __init__(self, embedDim, hidden_dim, vocaSize):
163 | super(txtDecoder, self).__init__()
164 | self.hidden_dim = hidden_dim
165 | self.lstm =nn.LSTM(embedDim, hidden_dim, batch_first=True)
166 | self.hidden = self.init_hidden()
167 | self.dec_log = torch.nn.Linear(hidden_dim, vocaSize)
168 |
169 | def init_hidden(self, batchSize=10):
170 | self.hidden=(torch.zeros(1, batchSize, self.hidden_dim).cuda(),
171 | torch.zeros(1, batchSize, self.hidden_dim).cuda())
172 |
173 | def forward(self, visFtr, capFtr, capLght):
174 | #pdb.set_trace()
175 | inputs = torch.cat((visFtr.unsqueeze(1), capFtr), 1)
176 |
177 | # pack data (prepare it for pytorch model)
178 | # inputs_packed = pack_padded_sequence(inputs, capLght, batch_first=True)
179 | inputs_packed = inputs
180 | # run data through recurrent network
181 | hiddens, _ = self.lstm(inputs_packed)
182 | #pdb.set_trace()
183 | outputs = self.dec_log(hiddens)
184 | #pdb.set_trace()
185 | return outputs
186 |
187 | class visDecoder(nn.Module):
188 | def __init__(self, dim_ftr, hdSize, coAtt_flag=False):
189 | super(visDecoder, self).__init__()
190 | self.attNet = torch.nn.Sequential(
191 | torch.nn.Linear(dim_ftr*2, hdSize),
192 | torch.nn.ReLU(),
193 | torch.nn.Dropout(p=0.5),
194 | torch.nn.Linear(hdSize, 1),
195 | )
196 | self.coAtt_flag = coAtt_flag
197 |
198 | def _initialize_weights(self):
199 | for m in self.modules():
200 | if isinstance(m, nn.Conv2d):
201 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
202 | if m.bias is not None:
203 | nn.init.constant_(m.bias, 0)
204 | elif isinstance(m, nn.BatchNorm2d):
205 | nn.init.constant_(m.weight, 1)
206 | nn.init.constant_(m.bias, 0)
207 | elif isinstance(m, nn.BatchNorm1d):
208 | nn.init.constant_(m.weight, 1)
209 | nn.init.constant_(m.bias, 0)
210 | elif isinstance(m, nn.Linear):
211 | nn.init.normal_(m.weight, 0, 0.01)
212 | nn.init.constant_(m.bias, 0)
213 |
214 | def forward(self, vis_rp_ftr, cap_ftr):
215 | b_size, per_cap_num, dim_cap = cap_ftr.shape
216 | b_szie_v, rp_num, dim_vis =vis_rp_ftr.shape
217 |
218 | # for co-Attention
219 | if self.coAtt_flag:
220 | conca_ftr = torch.cat((cap_ftr, vis_rp_ftr), 2)
221 | rpScore = self.attNet(conca_ftr)
222 | rpScore = rpScore.view(b_size, rp_num)
223 | rpSS= F.softmax(rpScore, dim=1)
224 | rpSS = rpSS.view(b_size, rp_num, 1)
225 | visFtrAtt = torch.sum(torch.mul(vis_rp_ftr, rpSS), dim=1)
226 | return visFtrAtt.unsqueeze(2), rpSS
227 |
228 | # for independent embedding
229 | rp_ss_list = list()
230 | vis_att_list = list()
231 | for i in range(per_cap_num):
232 | cap_ftr_exp = cap_ftr[:, i, :].unsqueeze(dim=1).expand(b_size, rp_num, dim_cap)
233 | conca_ftr = torch.cat((cap_ftr_exp, vis_rp_ftr), 2)
234 | rpScore = self.attNet(conca_ftr)
235 | rpScore = rpScore.view(b_size, rp_num)
236 | assert(len(rpScore.shape)==2)
237 | rpSS= F.softmax(rpScore, dim=1)
238 | rpSS = rpSS.view(b_size, rp_num, 1)
239 | visFtrAtt = torch.sum(torch.mul(vis_rp_ftr, rpSS), dim=1)
240 | rp_ss_list.append(rpSS)
241 | vis_att_list.append(visFtrAtt.unsqueeze(dim=2))
242 | vis_ftr_mat = torch.cat(vis_att_list, dim=2)
243 | rp_ss_mat = torch.cat(rp_ss_list, dim=2)
244 | return vis_ftr_mat, rp_ss_mat
245 |
246 |
247 | class txtEncoderV2(nn.Module):
248 | def __init__(self, embedDim, hidden_dim, seq_type='lstmV2'):
249 | super(txtEncoderV2, self).__init__()
250 | self.hidden_dim = hidden_dim
251 | self.seq_type = seq_type
252 | self.fc1 = torch.nn.Linear(embedDim, hidden_dim)
253 | if seq_type =='lstmV2':
254 | self.lstm =nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
255 | elif seq_type =='gruV2':
256 | self.lstm =nn.GRU(hidden_dim, hidden_dim, batch_first=True)
257 | self.hidden = self.init_hidden()
258 |
259 | def init_hidden(self, batchSize=10):
260 | if self.seq_type=='lstmV2':
261 | self.hidden=(torch.zeros(1, batchSize, self.hidden_dim).cuda(),
262 | torch.zeros(1, batchSize, self.hidden_dim).cuda())
263 | elif self.seq_type=='gruV2':
264 | self.hidden=torch.zeros(1, batchSize, self.hidden_dim).cuda()
265 |
266 | def forward(self, wordMatrixOri, wordLeg=None):
267 | #pdb.set_trace()
268 | # shorten steps for faster training
269 | wordMatrix = self.fc1(wordMatrixOri)
270 | self.init_hidden(wordMatrix.shape[0])
271 | #pdb.set_trace()
272 | lstmOut, self.hidden = self.lstm(wordMatrix, self.hidden)
273 | #pdb.set_trace()
274 | # lstmOut: B*T*D
275 | if wordLeg==None:
276 | return self.hidden[0].squeeze()
277 | else:
278 | txtEMb = [lstmOut[i, wordLeg[i]-1,:] for i in range(len(wordLeg))]
279 | #pdb.set_trace()
280 | txtEmbMat = torch.stack(txtEMb)
281 | return txtEmbMat
282 |
283 |
284 | class txtEncoder(nn.Module):
285 | def __init__(self, embedDim, hidden_dim, seq_type='lstm'):
286 | super(txtEncoder, self).__init__()
287 | self.hidden_dim = hidden_dim
288 | self.seq_type = seq_type
289 | if seq_type =='lstm':
290 | self.lstm =nn.LSTM(embedDim, hidden_dim, batch_first=True)
291 | elif seq_type =='gru':
292 | self.lstm =nn.GRU(embedDim, hidden_dim, batch_first=True)
293 | self.hidden = self.init_hidden()
294 |
295 | def init_hidden(self, batchSize=10):
296 | if self.seq_type=='lstm':
297 | self.hidden=(torch.zeros(1, batchSize, self.hidden_dim).cuda(),
298 | torch.zeros(1, batchSize, self.hidden_dim).cuda())
299 | elif self.seq_type=='gru':
300 | self.hidden=torch.zeros(1, batchSize, self.hidden_dim).cuda()
301 |
302 |
303 | def forward(self, wordMatrix, wordLeg=None):
304 | #pdb.set_trace()
305 | # shorten steps for faster training
306 | self.init_hidden(wordMatrix.shape[0])
307 | #pdb.set_trace()
308 | lstmOut, self.hidden = self.lstm(wordMatrix, self.hidden)
309 | #pdb.set_trace()
310 | # lstmOut: B*T*D
311 | if wordLeg==None:
312 | return self.hidden[0].squeeze()
313 | else:
314 | txtEMb = [lstmOut[i, wordLeg[i]-1,:] for i in range(len(wordLeg))]
315 | #pdb.set_trace()
316 | txtEmbMat = torch.stack(txtEMb)
317 | return txtEmbMat
318 |
319 | def build_txt_encoder(opts):
320 | embedDim = 300
321 | if opts.txt_type=='lstmV2':
322 | txt_encoder = txtEncoderV2(embedDim, opts.dim_ftr, opts.txt_type)
323 | return txt_encoder
324 | else:
325 | txt_encoder = txtEncoder(embedDim, opts.dim_ftr, opts.txt_type)
326 | return txt_encoder
327 |
328 | class visSeqEncoder(nn.Module):
329 | def __init__(self, embedDim, hidden_dim, seq_Type='lstm'):
330 | super(visSeqEncoder, self).__init__()
331 | self.hidden_dim = hidden_dim
332 | self.seq_type = seq_Type
333 | if seq_Type =='lstm':
334 | self.lstm =nn.LSTM(embedDim, hidden_dim, batch_first=True)
335 | elif seq_Type =='gru':
336 | self.lstm =nn.GRU(embedDim, hidden_dim, batch_first=True)
337 |
338 | self.hidden = self.init_hidden()
339 |
340 | def init_hidden(self, batchSize=10):
341 | if self.seq_type=='lstm':
342 | self.hidden=(torch.zeros(1, batchSize, self.hidden_dim).cuda(),
343 | torch.zeros(1, batchSize, self.hidden_dim).cuda())
344 | elif self.seq_type=='gru':
345 | self.hidden=torch.zeros(1, batchSize, self.hidden_dim).cuda()
346 |
347 | def forward(self, wordMatrix, wordLeg=None):
348 | # shorten steps for faster training
349 | self.init_hidden(wordMatrix.shape[0])
350 | #pdb.set_trace()
351 | lstmOut, self.hidden = self.lstm(wordMatrix, self.hidden)
352 | #pdb.set_trace()
353 | # lstmOut: B*T*D
354 | if wordLeg==None:
355 | return self.hidden[0].squeeze()
356 | else:
357 | txtEMb = [lstmOut[i, wordLeg[i]-1,:] for i in range(len(wordLeg))]
358 | #pdb.set_trace()
359 | txtEmbMat = torch.stack(txtEMb)
360 | return txtEmbMat
361 |
362 |
363 | class visSeqEncoderV2(nn.Module):
364 | def __init__(self, embedDim, hidden_dim, seq_Type='lstmV2'):
365 | super(visSeqEncoderV2, self).__init__()
366 | self.hidden_dim = hidden_dim
367 | self.seq_type = seq_Type
368 | if seq_Type =='lstm':
369 | self.lstm =nn.LSTM(embedDim, hidden_dim, batch_first=True)
370 | elif seq_Type =='gru':
371 | self.lstm =nn.GRU(embedDim, hidden_dim, batch_first=True)
372 | if seq_Type =='lstmV2':
373 | self.lstm =nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
374 | self.fc1 = torch.nn.Linear(embedDim, hidden_dim)
375 |
376 | self.hidden = self.init_hidden()
377 |
378 | def init_hidden(self, batchSize=10):
379 | if self.seq_type=='lstm' or self.seq_type=='lstmV2':
380 | self.hidden=(torch.zeros(1, batchSize, self.hidden_dim).cuda(),
381 | torch.zeros(1, batchSize, self.hidden_dim).cuda())
382 | elif self.seq_type=='gru':
383 | self.hidden=torch.zeros(1, batchSize, self.hidden_dim).cuda()
384 |
385 | def forward(self, wordMatrixOri, wordLeg=None):
386 | # shorten steps for faster training
387 | if self.seq_type =='lstmV2':
388 | wordMatrix = self.fc1(wordMatrixOri)
389 | else:
390 | wordMatrix = wordMatrixOri
391 | self.init_hidden(wordMatrix.shape[0])
392 | #pdb.set_trace()
393 | lstmOut, self.hidden = self.lstm(wordMatrix, self.hidden)
394 | #pdb.set_trace()
395 | # lstmOut: B*T*D
396 | if wordLeg==None:
397 | return self.hidden[0].squeeze()
398 | else:
399 | txtEMb = [lstmOut[i, wordLeg[i]-1,:] for i in range(len(wordLeg))]
400 | #pdb.set_trace()
401 | txtEmbMat = torch.stack(txtEMb)
402 | return txtEmbMat
403 |
404 |
405 | def build_vis_fc_encoder(opts):
406 | inputDim = opts.vis_dim
407 | visNet = torch.nn.Sequential(
408 | torch.nn.Linear(inputDim, inputDim),
409 | torch.nn.ReLU(),
410 | torch.nn.Dropout(p=0.5),
411 | torch.nn.Linear(inputDim, opts.dim_ftr),
412 | )
413 | return visNet
414 |
415 | class vlad_encoder(nn.Module):
416 | def __init__(self, input_dim, out_dim, hidden_dim, centre_num, alpha=1.0):
417 | super(vlad_encoder, self).__init__()
418 | self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
419 | self.net_vlad = NetVLAD(num_clusters= centre_num, dim= hidden_dim, alpha=1.0)
420 | self.fc2 = torch.nn.Linear(centre_num*hidden_dim, out_dim)
421 |
422 | def forward(self, input_data):
423 | input_hidden = self.fc1(input_data)
424 | input_hidden = F.relu(input_hidden)
425 | input_hidden = input_hidden.unsqueeze(dim=3)
426 | input_hidden = input_hidden.transpose(dim0=1, dim1=2)
427 | pdb.set_trace()
428 | input_vlad = self.net_vlad(input_hidden)
429 | #pdb.set_trace()
430 | out_vlad = self.fc2(input_vlad)
431 | out_vlad = F.relu(out_vlad)
432 | return out_vlad
433 |
434 | def build_vis_vlad_encoder_v1(opts):
435 | input_dim = opts.vis_dim
436 | hidden_dim = opts.hidden_dim
437 | centre_num = opts.centre_num
438 | out_dim = opts.dim_ftr
439 | alpha = opts.vlad_alpha
440 | vis_encoder = vlad_encoder(input_dim, out_dim, hidden_dim, centre_num, alpha=1.0 )
441 | return vis_encoder
442 |
443 | def build_vis_seq_encoder(opts):
444 | embedDim = opts.vis_dim
445 | if opts.vis_type == 'lstm' or opts.vis_type=='gru':
446 | vis_seq_encoder = visSeqEncoder(embedDim, opts.dim_ftr, opts.vis_type)
447 | return vis_seq_encoder
448 | if opts.vis_type == 'lstmV2':
449 | vis_seq_encoder = visSeqEncoderV2(embedDim, opts.dim_ftr, opts.vis_type)
450 | return vis_seq_encoder
451 | elif opts.vis_type == 'fc':
452 | vis_avg_encoder = build_vis_fc_encoder(opts)
453 | return vis_avg_encoder
454 | elif opts.vis_type == 'vlad_v1':
455 | vis_vlad_encoder = build_vis_vlad_encoder_v1(opts)
456 | return vis_vlad_encoder
457 | elif opts.vis_type == 'avgMIL':
458 | vis_avg_encoder = build_vis_fc_encoder(opts)
459 | return vis_avg_encoder
460 |
461 | def build_recontructor(opts):
462 | if opts.wsMode=='coAttGroundR':
463 | return txtDecoder(opts.lstm_hidden_size, opts.hdSize, opts.vocaSize)
464 | else:
465 | return txtDecoder(opts.hdSize, opts.hdSize, opts.vocaSize)
466 |
467 | def build_visDecoder(opts):
468 | if opts.wsMode=='coAttGroundR':
469 | return visDecoder(opts.lstm_hidden_size, opts.hdSize, True)
470 | else:
471 | return visDecoder(opts.dim_ftr, opts.hdSize, False)
472 |
473 |
474 | def build_network(opts):
475 | # pdb.set_trace()
476 | if opts.wsMode == 'rankTube' or opts.wsMode=='rankFrm':
477 | imEncoder= build_vis_seq_encoder(opts)
478 | wordEncoder = build_txt_encoder(opts)
479 | wsEncoder = wsEmb(imEncoder, wordEncoder)
480 | elif opts.wsMode == 'coAtt':
481 | sst_Obj = SST(opts)
482 | wsEncoder = wsEmb(sst_Obj, None)
483 | elif opts.wsMode == 'coAttBi':
484 | sst_Obj = SSTBi(opts)
485 | wsEncoder = wsEmb(sst_Obj, None)
486 | elif opts.wsMode == 'coAttBiV2':
487 | sst_Obj = SSTBiV2(opts)
488 | wsEncoder = wsEmb(sst_Obj, None)
489 | elif opts.wsMode == 'coAttBiV3':
490 | sst_Obj = SSTBiV3(opts)
491 | wsEncoder = wsEmb(sst_Obj, None)
492 | elif opts.wsMode == 'rankGroundR':
493 | imEncoder= build_vis_seq_encoder(opts)
494 | wordEncoder = build_txt_encoder(opts)
495 | wsEncoder = wsEmb(imEncoder, wordEncoder)
496 | wsEncoder.build_groundR(opts)
497 | elif opts.wsMode == 'rankGroundRV2':
498 | imEncoder= build_vis_seq_encoder(opts)
499 | wordEncoder = build_txt_encoder(opts)
500 | wsEncoder = wsEmb(imEncoder, wordEncoder)
501 | wsEncoder.build_groundR(opts)
502 | elif opts.wsMode == 'coAttGroundR':
503 | sst_Gr = SSTGroundR(opts)
504 | wsEncoder = wsEmb(sst_Gr, None)
505 | wsEncoder.build_groundR(opts)
506 | elif opts.wsMode == 'coAttV2':
507 | sst_mul = SSTMul(opts)
508 | wsEncoder = wsEmb(sst_mul, None)
509 | elif opts.wsMode == 'coAttV3':
510 | sst_v3 = SSTV3(opts)
511 | wsEncoder = wsEmb(sst_v3, None)
512 | elif opts.wsMode == 'coAttV4':
513 | sst_v4 = SSTV4(opts)
514 | wsEncoder = wsEmb(sst_v4, None)
515 |
516 | wsEncoder.wsMode = opts.wsMode
517 | wsEncoder.vis_type = opts.vis_type
518 | if opts.gpu:
519 | wsEncoder= wsEncoder.cuda()
520 | if opts.initmodel is not None:
521 | md_stat = torch.load(opts.initmodel)
522 | wsEncoder.load_state_dict(md_stat, strict=False)
523 | if opts.isParal:
524 | wsEncoder = nn.DataParallel(wsEncoder).cuda()
525 |
526 | return wsEncoder
527 |
528 |
--------------------------------------------------------------------------------
/fun/netUtil.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import shutil
4 | import sys
5 | import pdb
6 | #from evalDet import *
7 | sys.path.append('..')
8 | from util.mytoolbox import *
9 | from image_toolbox import *
10 |
11 | def visTube_from_image(frmList, tube, outName):
12 | image_list = list()
13 | for i, bbx in enumerate(tube):
14 | imName = frmList[i]
15 | img = draw_rectangle(imName, bbx)
16 | image_list.append(img)
17 | images2video(image_list, 10, outName)
18 |
19 | def vis_image_bbx(frmList, tube, color=(0,0,255), thickness=3, dotted=False):
20 | image_list = list()
21 | for i, bbx in enumerate(tube):
22 | imName = frmList[i]
23 | img = draw_rectangle(imName, bbx, color, thickness, dotted)
24 | image_list.append(img)
25 | return image_list
26 |
27 | def vis_gray_but_bbx(frmList, tube):
28 | image_list = list()
29 | for i, bbx in enumerate(tube):
30 | imName = frmList[i]
31 | img = gray_background(imName, bbx)
32 | image_list.append(img)
33 | return image_list
34 |
35 |
36 |
37 | KEYS = ['x1', 'y1', 'x2', 'y2']
38 | def compute_IoU(box1, box2):
39 | if isinstance(box1, list):
40 | box1 = {key: val for key, val in zip(KEYS, box1)}
41 | if isinstance(box2, list):
42 | box2 = {key: val for key, val in zip(KEYS, box2)}
43 | width = max(min(box1['x2'], box2['x2']) - max(box1['x1'], box2['x1']), 0)
44 | height = max(min(box1['y2'], box2['y2']) - max(box1['y1'], box2['y1']), 0)
45 | intersection = width * height
46 | box1_area = (box1['x2'] - box1['x1']) * (box1['y2'] - box1['y1'])
47 | box2_area = (box2['x2'] - box2['x1']) * (box2['y2'] - box2['y1'])
48 | union = box1_area + box2_area - intersection
49 | return float(intersection) / (float(union) +0.000001) # avoid overthow
50 |
51 |
52 | EPS = 1e-10
53 | def compute_IoU_v2(bbox1, bbox2):
54 | bbox1_area = float((bbox1[2] - bbox1[0] + EPS) * (bbox1[3] - bbox1[1] + EPS))
55 | bbox2_area = float((bbox2[2] - bbox2[0] + EPS) * (bbox2[3] - bbox2[1] + EPS))
56 | w = max(0.0, min(bbox1[2], bbox2[2]) - max(bbox1[0], bbox2[0]) + EPS)
57 | h = max(0.0, min(bbox1[3], bbox2[3]) - max(bbox1[1], bbox2[1]) + EPS)
58 | inter = float(w * h)
59 | ovr = inter / (bbox1_area + bbox2_area - inter)
60 | return ovr
61 |
62 | def is_annotated(traj, frame_ind):
63 | if not frame_ind in traj:
64 | return False
65 | box = traj[frame_ind]
66 | if box[0] < 0:
67 | for el_val in box[1:]:
68 | assert el_val < 0
69 | return False
70 | for el_val in box[1:]:
71 | assert el_val >= 0
72 | return True
73 |
74 | def compute_LS(traj, gt_traj):
75 | # see http://jvgemert.github.io/pub/jain-tubelets-cvpr2014.pdf
76 | assert isinstance(traj.keys()[0], type(gt_traj.keys()[0]))
77 | IoU_list = []
78 | for frame_ind, gt_box in gt_traj.iteritems():
79 | gt_is_annotated = is_annotated(gt_traj, frame_ind)
80 | pr_is_annotated = is_annotated(traj, frame_ind)
81 | if (not gt_is_annotated) and (not pr_is_annotated):
82 | continue
83 | if (not gt_is_annotated) or (not pr_is_annotated):
84 | IoU_list.append(0.0)
85 | continue
86 | box = traj[frame_ind]
87 | IoU_list.append(compute_IoU_v2(box, gt_box))
88 | return sum(IoU_list) / len(IoU_list)
89 |
90 | def get_tubes(det_list_org, alpha):
91 | det_list = copy.deepcopy(det_list_org)
92 | tubes = []
93 | continue_flg = True
94 | tube_scores = []
95 | while continue_flg:
96 | timestep = 0
97 | score_list = []
98 | score_list.append(np.zeros(det_list[timestep][0].shape[0]))
99 | prevind_list = []
100 | prevind_list.append([-1] * det_list[timestep][0].shape[0])
101 | timestep += 1
102 | while timestep < len(det_list):
103 | n_curbox = det_list[timestep][0].shape[0]
104 | n_prevbox = score_list[-1].shape[0]
105 | cur_scores = np.zeros(n_curbox) - np.inf
106 | prev_inds = [-1] * n_curbox
107 | for i_prevbox in range(n_prevbox):
108 | prevbox_coods = det_list[timestep-1][1][i_prevbox, :]
109 | prevbox_score = det_list[timestep-1][0][i_prevbox, 0]
110 | for i_curbox in range(n_curbox):
111 | curbox_coods = det_list[timestep][1][i_curbox, :]
112 | curbox_score = det_list[timestep][0][i_curbox, 0]
113 | try:
114 | e_score = compute_IoU(prevbox_coods.tolist(), curbox_coods.tolist())
115 | except:
116 | pdb.set_trace()
117 | link_score = prevbox_score + curbox_score + alpha * e_score
118 | cur_score = score_list[-1][i_prevbox] + link_score
119 | if cur_score > cur_scores[i_curbox]:
120 | cur_scores[i_curbox] = cur_score
121 | prev_inds[i_curbox] = i_prevbox
122 | score_list.append(cur_scores)
123 | prevind_list.append(prev_inds)
124 | timestep += 1
125 |
126 | # get path and remove used boxes
127 | cur_tube = [None] * len(det_list)
128 | tube_score = np.max(score_list[-1]) / len(det_list)
129 | prev_ind = np.argmax(score_list[-1])
130 | timestep = len(det_list) - 1
131 | while timestep >= 0:
132 | cur_tube[timestep] = det_list[timestep][1][prev_ind, :].tolist()
133 | det_list[timestep][0] = np.delete(det_list[timestep][0], prev_ind, axis=0)
134 | det_list[timestep][1] = np.delete(det_list[timestep][1], prev_ind, axis=0)
135 | prev_ind = prevind_list[timestep][prev_ind]
136 | if det_list[timestep][1].shape[0] == 0:
137 | continue_flg = False
138 | timestep -= 1
139 | assert prev_ind < 0
140 | tubes.append(cur_tube)
141 | tube_scores.append(tube_score)
142 | return tubes, tube_scores
143 |
144 | def save_check_point(state, is_best=False, file_name='../data/models/checkpoint.pth'):
145 | fdName = os.path.dirname(file_name)
146 | makedirs_if_missing(fdName)
147 | torch.save(state, file_name)
148 | if is_best:
149 | shutil.copyfile(file_name, '../data/model/best_model.pth')
150 |
151 | def load_model_state(model, file_name):
152 | states = torch.load(file_name)
153 | model.load_state_dict(states)
154 |
155 |
156 |
--------------------------------------------------------------------------------
/fun/optimizers.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def build_opt(args, model):
4 | if args.optimizer.lower() == 'adam':
5 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay)
6 | elif args.optimizer.lower() == 'rmsprop':
7 | optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.decay)
8 | elif args.optimizer.lower() == 'sgd':
9 | optimizer = torch.optim.SGD(model.parameters(),
10 | lr=args.lr,
11 | momentum=args.momentum, weight_decay=args.decay)
12 | else:
13 | raise Exception()
14 | return optimizer
15 |
--------------------------------------------------------------------------------
/fun/train.py:
--------------------------------------------------------------------------------
1 | import numpy
2 | import torch
3 | import random
4 | def random_seeding(seed_value, use_cuda):
5 | numpy.random.seed(seed_value) # cpu vars
6 | torch.manual_seed(seed_value) # cpu vars
7 | random.seed(seed_value)
8 | if use_cuda:
9 | torch.cuda.manual_seed_all(seed_value) # gpu vars
10 |
11 |
12 | from wsParamParser import parse_args
13 | from data.data_loader import*
14 | from datasetLoader import *
15 | from modelArc import *
16 | from optimizers import *
17 | from logInfo import logInF
18 | from lossPackage import *
19 | from netUtil import *
20 | from tensorboardX import SummaryWriter
21 | import time
22 |
23 | import pdb
24 |
25 | if __name__=='__main__':
26 | #pdb.set_trace()
27 | opt = parse_args()
28 | random_seeding(opt.seed, True)
29 | # build dataloader
30 | dataLoader, datasetOri= build_dataloader(opt)
31 | # build network
32 | model = build_network(opt)
33 | # build_optimizer
34 | optimizer = build_opt(opt, model)
35 | # build loss layer
36 | lossEster = build_lossEval(opt)
37 | # build logger
38 | logger = logInF(opt.logFd)
39 | writer = SummaryWriter(opt.logFdTx+time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
40 | #pdb.set_trace()
41 | for ep in range(opt.stEp, opt.epSize):
42 | resultList_full = list()
43 | tBf = time.time()
44 | for itr, inputData in enumerate(dataLoader):
45 | #tube_embedding, cap_embedding, tubeInfo, indexOri, cap_length_list, vd_name_list, word_lbl_list, region_gt_ori = inputData
46 | tube_embedding, cap_embedding, tubeInfo, indexOri, cap_length_list, vd_name_list, word_lbl_list = inputData
47 | #pdb.set_trace()
48 | dataIdx = None
49 | tmp_bsize = tube_embedding.shape[0]
50 | imDis = tube_embedding.cuda()
51 | imDis = imDis.view(-1, imDis.shape[2], imDis.shape[3])
52 | wordEmb = cap_embedding.cuda()
53 | wordEmb = wordEmb.view(-1, wordEmb.shape[2], wordEmb.shape[3])
54 | tAf = time.time()
55 |
56 | if opt.wsMode =='coAtt':
57 | #pdb.set_trace()
58 | simMM = model(imDis, wordEmb, cap_length_list)
59 | # pdb.set_trace()
60 | simMM = simMM.view(tmp_bsize, opt.rpNum, tmp_bsize, opt.capNum)
61 | if len(set(vd_name_list))<2:
62 | continue
63 |
64 | loss = lossEster(simMM=simMM, lblList =vd_name_list)
65 | resultList = evalAcc_att(simMM, tubeInfo, indexOri, datasetOri, opt.visRsFd+str(ep), False)
66 | resultList_full +=resultList
67 | #pdb.set_trace()
68 | if loss<=0:
69 | continue
70 | loss = loss/opt.update_iter
71 | loss.backward(retain_graph=True )
72 | if (itr+1)%opt.update_iter==0:
73 | #optimizer.zero_grad()
74 | optimizer.step()
75 | optimizer.zero_grad()
76 | tNf = time.time()
77 | if(itr%opt.visIter==0):
78 | #resultList = evalAcc_att(simMM, tubeInfo, indexOri, datasetOri, opt.visRsFd+str(ep), False)
79 | #resultList_full +=resultList
80 | #tAf = time.time()
81 | logger('Ep: %d, Iter: %d, T1: %3f, T2:%3f, loss: %3f\n' %(ep, itr, (tAf-tBf), (tNf-tAf)/opt.visIter, float(loss.data.cpu().numpy())))
82 | tBf = time.time()
83 | writer.add_scalar('loss', loss.data.cpu()[0], ep*len(datasetOri)+itr*opt.batchSize)
84 | accSum = 0
85 | resultList = list()
86 | for ele in resultList:
87 | index, recall_k= ele
88 | accSum +=recall_k
89 | #logger('Average accuracy on training batch is %3f\n' %(accSum/len(resultList)))
90 | tBf = time.time()
91 |
92 | ## evaluation within an epoch
93 | if(ep % opt.saveEp==0 and itr==0 and ep >0):
94 | checkName = opt.outPre+'_ep_'+str(ep) +'_itr_'+str(itr)+'.pth'
95 | save_check_point(model.state_dict(), file_name=checkName)
96 | model.eval()
97 | resultList = list()
98 | vIdList = list()
99 | set_name_ori= opt.set_name
100 | opt.set_name = 'val'
101 | batchSizeOri = opt.batchSize
102 | opt.batchSize = 1
103 | dataLoaderEval, datasetEvalOri = build_dataloader(opt)
104 | opt.batchSize = batchSizeOri
105 | #pdb.set_trace()
106 | for itr_eval, inputData in enumerate(dataLoaderEval):
107 | #tube_embedding, cap_embedding, tubeInfo, indexOri, cap_length_list, vd_name_list, word_lbl_list, region_gt_ori = inputData
108 | tube_embedding, cap_embedding, tubeInfo, indexOri, cap_length_list, vd_name_list, word_lbl_list = inputData
109 | #pdb.set_trace()
110 | dataIdx = None
111 | #pdb.set_trace()
112 | b_size = tube_embedding.shape[0]
113 | # B*P*T*D
114 | imDis = tube_embedding.cuda()
115 | imDis = imDis.view(-1, imDis.shape[2], imDis.shape[3])
116 | wordEmb = cap_embedding.cuda()
117 | wordEmb = wordEmb.view(-1, wordEmb.shape[2], wordEmb.shape[3])
118 | imDis.requires_grad=False
119 | wordEmb.requires_grad=False
120 | if opt.wsMode =='coAtt':
121 | simMM = model(imDis, wordEmb, cap_length_list)
122 | simMM = simMM.view(b_size, opt.rpNum, b_size)
123 | resultList += evalAcc_att(simMM, tubeInfo, indexOri, datasetEvalOri, opt.visRsFd+str(ep), False)
124 |
125 | accSum = 0
126 | for ele in resultList:
127 | index, recall_k= ele
128 | accSum +=recall_k
129 | logger('Average accuracy on validation set is %3f\n' %(accSum/len(resultList)))
130 | writer.add_scalar('Average validation accuracy', accSum/len(resultList), ep*len(datasetOri)+ itr*opt.batchSize)
131 | #pdb.set_trace()
132 | model.train()
133 | opt.set_name = set_name_ori
134 |
135 | if(ep % opt.saveEp==0 and itr==0 and ep > 0):
136 | model.eval()
137 | resultList = list()
138 | vIdList = list()
139 | set_name_ori= opt.set_name
140 | opt.set_name = 'test'
141 | batchSizeOri = opt.batchSize
142 | opt.batchSize = 1
143 | dataLoaderEval, datasetEvalOri = build_dataloader(opt)
144 | opt.batchSize = batchSizeOri
145 | #pdb.set_trace()
146 | for itr_eval, inputData in enumerate(dataLoaderEval):
147 | tube_embedding, cap_embedding, tubeInfo, indexOri, cap_length_list, vd_name_list, word_lbl_lis =inputData
148 | #pdb.set_trace()
149 | dataIdx = None
150 | #pdb.set_trace()
151 | b_size = tube_embedding.shape[0]
152 | # B*P*T*D
153 | imDis = tube_embedding.cuda()
154 | imDis = imDis.view(-1, imDis.shape[2], imDis.shape[3])
155 | wordEmb = cap_embedding.cuda()
156 | wordEmb = wordEmb.view(-1, wordEmb.shape[2], wordEmb.shape[3])
157 | imDis.requires_grad=False
158 | wordEmb.requires_grad=False
159 | if opt.wsMode =='coAtt':
160 | simMM = model(imDis, wordEmb, cap_length_list)
161 | simMM = simMM.view(b_size, opt.rpNum, b_size)
162 | resultList += evalAcc_att(simMM, tubeInfo, indexOri, datasetEvalOri, opt.visRsFd+str(ep), False)
163 |
164 |
165 | accSum = 0
166 | for ele in resultList:
167 | index, recall_k= ele
168 | accSum +=recall_k
169 | logger('Average accuracy on testing set is %3f\n' %(accSum/len(resultList)))
170 | writer.add_scalar('Average testing accuracy', accSum/len(resultList), ep*len(datasetOri)+ itr*opt.batchSize)
171 | #pdb.set_trace()
172 | model.train()
173 | opt.set_name = set_name_ori
174 |
175 | accSum = 0
176 | for ele in resultList_full:
177 | index, recall_k= ele
178 | accSum +=recall_k
179 | #writer.add_scalar('Average training accuracy', accSum/len(resultList_full), ep*len(datasetOri)+ itr*opt.batchSize)
180 |
--------------------------------------------------------------------------------
/fun/vidDataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append('..')
4 | sys.path.append('../../../WSSTL/fun')
5 | sys.path.append('../annotations')
6 | sys.path.append('../../annotations')
7 | import torch.utils.data as data
8 | import cv2
9 | import numpy as np
10 | from util.mytoolbox import *
11 | import random
12 | import scipy.io as sio
13 | import copy
14 | import torch
15 | from util.get_image_size import get_image_size
16 | from evalDet import *
17 | from datasetParser import extAllFrmFn
18 | import pdb
19 | from netUtil import *
20 | from wsParamParser import parse_args
21 | from ptd_api import *
22 | from vidDatasetParser import *
23 | from multiprocessing import Process, Pipe, cpu_count, Queue
24 | #from vidDatasetParser import vidInfoParser
25 | #from multiGraphAttention import extract_position_embedding
26 | import h5py
27 |
28 |
29 | class vidDataloader(data.Dataset):
30 | def __init__(self, ann_folder, prp_type, set_name, dictFile, tubePath, ftrPath, out_cached_folder):
31 | self.set_name = set_name
32 | self.dict = pickleload(dictFile)
33 | self.rpNum = 30
34 | self.maxWordNum =20
35 | self.maxTubelegth = 20
36 | self.tube_ftr_dim = 2048
37 | self.tubePath = tubePath
38 | self.ftrPath = ftrPath
39 | self.out_cache_folder = out_cached_folder
40 | self.prp_type = prp_type
41 | self.vid_parser = vidInfoParser(set_name, ann_folder)
42 | self.use_key_index = self.vid_parser.tube_cap_dict.keys()
43 | self.use_key_index.sort()
44 | self.online_cache ={}
45 | self.i3d_cache_flag = False
46 | self.cache_flag = False
47 | self.cache_ftr_dict = {}
48 | self.use_mean_cache_flag = False
49 | self.mean_cache_ftr_path = ''
50 | self.context_flag =False
51 | self.extracting_context =False
52 |
53 |
54 | def get_gt_embedding_i3d(self, index, maxTubelegth, out_cached_folder = ''):
55 | '''
56 | get the grounding truth region embedding
57 | '''
58 | rgb_tube_embedding = np.zeros((maxTubelegth, self.tube_ftr_dim_i3d), dtype=np.float32)
59 | flow_tube_embedding = np.zeros((maxTubelegth, self.tube_ftr_dim_i3d), dtype=np.float32)
60 | set_name = self.set_name
61 | i3d_ftr_path = os.path.join(self.i3d_ftr_path[:-4], 'gt/vid',set_name, str(index) +'.h5')
62 | if i3d_ftr_path in self.online_cache.keys() and self.i3d_cache_flag:
63 | tube_embedding = self.online_cache[i3d_ftr_path]
64 | return tube_embedding
65 | h5_handle = h5py.File(i3d_ftr_path, 'r')
66 | for tube_id in range(1):
67 | rgb_tube_ftr = h5_handle[str(tube_id)]['rgb_feature'][()].squeeze()
68 | flow_tube_ftr = h5_handle[str(tube_id)]['flow_feature'][()].squeeze()
69 | num_tube_ftr = h5_handle[str(tube_id)]['num_feature'][()].squeeze()
70 | seg_length = max(int(round(num_tube_ftr/maxTubelegth)), 1)
71 | tmp_rgb_tube_embedding = np.zeros((maxTubelegth, self.tube_ftr_dim_i3d), dtype=np.float32)
72 | tmp_flow_tube_embedding = np.zeros((maxTubelegth, self.tube_ftr_dim_i3d), dtype=np.float32)
73 | for segId in range(maxTubelegth):
74 | #print('%d %d\n' %(tube_id, segId))
75 | start_id = segId*seg_length
76 | end_id = (segId+1)*seg_length
77 | if end_id > num_tube_ftr and num_tube_ftr < maxTubelegth:
78 | break
79 | end_id = min((segId+1)*(seg_length), num_tube_ftr)
80 | tmp_rgb_tube_embedding[segId, :] = np.mean(rgb_tube_ftr[start_id:end_id], axis=0)
81 | tmp_flow_tube_embedding[segId, :] = np.mean(flow_tube_ftr[start_id:end_id], axis=0)
82 |
83 | rgb_tube_embedding = tmp_rgb_tube_embedding
84 | flow_tube_embedding = tmp_flow_tube_embedding
85 |
86 | tube_embedding = np.concatenate((rgb_tube_embedding, flow_tube_embedding), axis=1)
87 | return tube_embedding
88 |
89 |
90 | def get_gt_embedding(self, index, maxTubelegth, out_cached_folder = ''):
91 | '''
92 | get the grounding truth region embedding
93 | '''
94 | gt_embedding = np.zeros((maxTubelegth, self.tube_ftr_dim), dtype=np.float32)
95 | set_name = self.set_name
96 | ins_ann, vd_name = self.vid_parser.get_shot_anno_from_index(index)
97 | tube_info_path = os.path.join(self.tubePath, set_name, self.prp_type, str(index)+'.pd')
98 | tubeInfo = pickleload(tube_info_path)
99 | tube_list, frame_list = tubeInfo
100 | frmNum = len(frame_list)
101 | seg_length = max(int(frmNum/maxTubelegth), 1)
102 |
103 | tube_to_prp_idx = list()
104 | ftr_tube_list = list()
105 | prp_range_num = len(tube_list[0])
106 | tmp_cache_gt_feature_path = os.path.join(out_cached_folder, \
107 | 'gt' , set_name, self.prp_type, str(index) + '.pk')
108 | if os.path.isfile(tmp_cache_gt_feature_path):
109 | tmp_gt_ftr_info = None
110 | try:
111 | tmp_gt_ftr_info = pickleload(tmp_cache_gt_feature_path)
112 | except:
113 | print('--------------------------------------------------')
114 | print(tmp_cache_gt_feature_path)
115 | print('--------------------------------------------------')
116 | if tmp_gt_ftr_info is not None:
117 | return tmp_gt_ftr_info
118 |
119 | # cache data for saving IO time
120 | cache_data_dict ={}
121 |
122 | for frmId, frmName in enumerate(frame_list):
123 | frmName = frame_list[frmId]
124 | img_prp_ftr_info_path = os.path.join(self.ftr_gt_path, self.set_name, str(index), frmName+ '.pd')
125 | img_prp_ftr_info = pickleload(img_prp_ftr_info_path)
126 | cache_data_dict[frmName] = img_prp_ftr_info
127 |
128 | for segId in range(maxTubelegth):
129 | start_id = segId*seg_length
130 | end_id = (segId+1)*seg_length
131 | if end_id>frmNum and frmNum=self.maxWordNum):
197 | break
198 | idx = self.dict['word2idx'][word]
199 | wordEmbMatrix[valCount, :]= self.dict['word2vec'][idx]
200 | valCount +=1
201 | wordLbl.append(idx)
202 | return wordEmbMatrix, valCount, wordLbl
203 |
204 | def get_cap_emb(self, index, capNum):
205 | cap_list_index = self.vid_parser.tube_cap_dict[index]
206 | assert len(cap_list_index)>=capNum
207 | cap_sample_index = random.sample(range(len(cap_list_index)), capNum)
208 |
209 | # get word embedding
210 | wordEmbMatrix= np.zeros((capNum, self.maxWordNum, 300), dtype=np.float32)
211 | cap_length_list = list()
212 | word_lbl_list = list()
213 | for i, capIdx in enumerate(cap_sample_index):
214 | capString = cap_list_index[capIdx]
215 | wordEmbMatrix[i, ...], valid_length, wordLbl = self.get_word_emb_from_str(capString, self.maxWordNum)
216 | cap_length_list.append(valid_length)
217 | word_lbl_list.append(wordLbl)
218 | return wordEmbMatrix, cap_length_list, word_lbl_list
219 |
220 | def get_tube_embedding(self, index, maxTubelegth, out_cached_folder = ''):
221 | tube_embedding = np.zeros((self.rpNum, maxTubelegth, self.tube_ftr_dim), dtype=np.float32)
222 | set_name = self.set_name
223 | ins_ann, vd_name = self.vid_parser.get_shot_anno_from_index(index)
224 | tube_info_path = os.path.join(self.tubePath, set_name, self.prp_type, str(index)+'.pd')
225 | tubeInfo = pickleload(tube_info_path)
226 | tube_list, frame_list = tubeInfo
227 | frmNum = len(frame_list)
228 | seg_length = max(int(frmNum/maxTubelegth), 1)
229 |
230 | tube_to_prp_idx = list()
231 | ftr_tube_list = list()
232 | prp_range_num = len(tube_list[0])
233 | tmp_cache_tube_feature_path = os.path.join(out_cached_folder, \
234 | set_name, self.prp_type, str(index) + '.pk')
235 | if os.path.isfile(tmp_cache_tube_feature_path):
236 | tmp_tube_ftr_info = None
237 | try:
238 | tmp_tube_ftr_info = pickleload(tmp_cache_tube_feature_path)
239 | except:
240 | print('--------------------------------------------------')
241 | print(tmp_cache_tube_feature_path)
242 | print('--------------------------------------------------')
243 | if tmp_tube_ftr_info is not None:
244 | tube_embedding, tubeInfo, tube_to_prp_idx = tmp_tube_ftr_info
245 |
246 | #if((tube_to_prp_idx[0])>maxTubelegth):
247 | if tube_embedding.shape[0]>=self.rpNum:
248 | return tube_embedding[:self.rpNum], tubeInfo, tube_to_prp_idx
249 | else:
250 | tube_embedding = np.zeros((self.rpNum, maxTubelegth, self.tube_ftr_dim), dtype=np.float32)
251 |
252 | # cache data for saving IO time
253 | cache_data_dict ={}
254 | for tubeId, tube in enumerate(tube_list[0]):
255 | if tubeId>= self.rpNum:
256 | continue
257 | tube_prp_map = list()
258 | # find proposals
259 | for frmId, bbox in enumerate(tube):
260 | frmName = frame_list[frmId]
261 | if frmName in cache_data_dict.keys():
262 | img_prp_ftr_info = cache_data_dict[frmName]
263 | else:
264 | img_prp_ftr_info_path = os.path.join(self.ftrPath, self.set_name, vd_name, frmName+ '.pd')
265 | img_prp_ftr_info = pickleload(img_prp_ftr_info_path)
266 | cache_data_dict[frmName] = img_prp_ftr_info
267 |
268 | tmp_bbx = copy.deepcopy(img_prp_ftr_info['rois'][:prp_range_num]) # to be modified
269 | tmp_info = img_prp_ftr_info['imFo'].squeeze()
270 | tmp_bbx[:, 0] = tmp_bbx[:, 0]/tmp_info[1]
271 | tmp_bbx[:, 2] = tmp_bbx[:, 2]/tmp_info[1]
272 | tmp_bbx[:, 1] = tmp_bbx[:, 1]/tmp_info[0]
273 | tmp_bbx[:, 3] = tmp_bbx[:, 3]/tmp_info[0]
274 | img_prp_res = tmp_bbx - bbox
275 | img_prp_res_sum = np.sum(img_prp_res, axis=1)
276 | for prpId in range(prp_range_num):
277 | if(abs(img_prp_res_sum[prpId])<0.00001):
278 | tube_prp_map.append(prpId)
279 | break
280 | #assert("fail to find proposals")
281 | if (len(tube_prp_map)!=len(tube)):
282 | pdb.set_trace()
283 | assert(len(tube_prp_map)==len(tube))
284 |
285 | tube_to_prp_idx.append(tube_prp_map)
286 |
287 | # extract features
288 | tmp_tube_embedding = np.zeros((maxTubelegth, self.tube_ftr_dim), dtype=np.float32)
289 | for segId in range(maxTubelegth):
290 | start_id = segId*seg_length
291 | end_id = (segId+1)*seg_length
292 | if end_id>frmNum and frmNumfrmNum and frmNum num_tube_ftr and num_tube_ftr < maxTubelegth:
403 | break
404 | end_id = min((segId+1)*(seg_length), num_tube_ftr)
405 | tmp_rgb_tube_embedding[segId, :] = np.mean(rgb_tube_ftr[start_id:end_id], axis=0)
406 | tmp_flow_tube_embedding[segId, :] = np.mean(flow_tube_ftr[start_id:end_id], axis=0)
407 |
408 | rgb_tube_embedding[tube_id, ...] = tmp_rgb_tube_embedding
409 | flow_tube_embedding[tube_id, ...] = tmp_flow_tube_embedding
410 |
411 | tube_embedding = np.concatenate((rgb_tube_embedding, flow_tube_embedding), axis=2)
412 | return tube_embedding
413 |
414 |
415 | def get_context_embedding_i3d(self, index, maxTubelegth, out_cached_folder = ''):
416 | rgb_tube_embedding = np.zeros((maxTubelegth, self.tube_ftr_dim_i3d), dtype=np.float32)
417 | flow_tube_embedding = np.zeros((maxTubelegth, self.tube_ftr_dim_i3d), dtype=np.float32)
418 | set_name = self.set_name
419 | i3d_ftr_path = os.path.join(self.i3d_ftr_path, 'context/vid',set_name, str(index) +'.h5')
420 | if i3d_ftr_path in self.online_cache.keys() and self.i3d_cache_flag:
421 | tube_embedding = self.online_cache[i3d_ftr_path]
422 | return tube_embedding
423 | h5_handle = h5py.File(i3d_ftr_path, 'r')
424 | for tube_id in range(1):
425 | rgb_tube_ftr = h5_handle[str(tube_id)]['rgb_feature'][()].squeeze()
426 | flow_tube_ftr = h5_handle[str(tube_id)]['flow_feature'][()].squeeze()
427 | num_tube_ftr = h5_handle[str(tube_id)]['num_feature'][()].squeeze()
428 | seg_length = max(int(round(num_tube_ftr/maxTubelegth)), 1)
429 | tmp_rgb_tube_embedding = np.zeros((maxTubelegth, self.tube_ftr_dim_i3d), dtype=np.float32)
430 | tmp_flow_tube_embedding = np.zeros((maxTubelegth, self.tube_ftr_dim_i3d), dtype=np.float32)
431 | for segId in range(maxTubelegth):
432 | start_id = segId*seg_length
433 | end_id = (segId+1)*seg_length
434 | if end_id > num_tube_ftr and num_tube_ftr < maxTubelegth:
435 | break
436 | end_id = min((segId+1)*(seg_length), num_tube_ftr)
437 | tmp_rgb_tube_embedding[segId, :] = np.mean(rgb_tube_ftr[start_id:end_id], axis=0)
438 | tmp_flow_tube_embedding[segId, :] = np.mean(flow_tube_ftr[start_id:end_id], axis=0)
439 |
440 | rgb_tube_embedding = tmp_rgb_tube_embedding
441 | flow_tube_embedding = tmp_flow_tube_embedding
442 |
443 | tube_embedding = np.concatenate((rgb_tube_embedding, flow_tube_embedding), axis=1)
444 | self.online_cache[i3d_ftr_path] = tube_embedding
445 | return tube_embedding
446 |
447 | def get_tube_pos_embedding(self, tubeInfo, tube_length, feat_dim=64, feat_type='aiayn'):
448 | tube_list, frame_list = tubeInfo
449 | position_mat_raw = torch.zeros((1, self.rpNum, tube_length, 4))
450 | if feat_type=='aiayn':
451 | bSize = 1
452 | prpSize = self.rpNum
453 | kNN = tube_length
454 | for tubeId, tube in enumerate(tube_list[0]):
455 | if tubeId>=self.rpNum:
456 | break
457 | tube_length_ori = len(tube)
458 | tube_seg_length = max(int(tube_length_ori/tube_length), 1)
459 |
460 | for tube_seg_id in range(0, tube_length):
461 | tube_seg_id_st = tube_seg_id*tube_seg_length
462 | tube_seg_id_end = min((tube_seg_id+1)*tube_seg_length, tube_length_ori)
463 | if(tube_seg_id_st)>=tube_length_ori:
464 | position_mat_raw[0, tubeId, tube_seg_id, :] = position_mat_raw[0, tubeId, tube_seg_id-1, :]
465 | continue
466 | bbox_list = tube[tube_seg_id_st:tube_seg_id_end]
467 | box_np = np.concatenate(bbox_list, axis=0)
468 | box_tf = torch.FloatTensor(box_np).view(-1, 4)
469 | position_mat_raw[0, tubeId, tube_seg_id, :]= box_tf.mean(0)
470 | position_mat_raw_v2 = copy.deepcopy(position_mat_raw)
471 | position_mat_raw_v2[:, 0] = (position_mat_raw[:, 0] + position_mat_raw[:, 2])/2
472 | position_mat_raw_v2[:, 1] = (position_mat_raw[:, 1] + position_mat_raw[:, 3])/2
473 | position_mat_raw_v2[:, 2] = position_mat_raw[:, 2] - position_mat_raw[:, 0]
474 | position_mat_raw_v2[:, 3] = position_mat_raw[:, 3] - position_mat_raw[:, 1]
475 |
476 | pos_emb = extract_position_embedding(position_mat_raw_v2, feat_dim, wave_length=1000)
477 |
478 | return pos_emb.squeeze(0)
479 | else:
480 | raise ValueError('%s is not implemented!' %(feat_type))
481 |
482 |
483 | def get_visual_item(self, indexOri):
484 | index = self.use_key_index[indexOri]
485 | sumInd = 0
486 | tube_embedding = None
487 | cap_embedding = None
488 | cap_length_list = -1
489 | tAf = time.time()
490 | cap_embedding, cap_length_list, word_lbl_list = self.get_cap_emb(index, self.capNum)
491 | tBf = time.time()
492 |
493 | cache_str_shot_str = str(self.maxTubelegth) +'_' + str(index)
494 | tube_embedding_list = list()
495 | if cache_str_shot_str in self.cache_ftr_dict.keys():
496 | if self.vis_ftr_type=='rgb'or self.vis_ftr_type=='rgb_i3d':
497 | tube_embedding, tubeInfo, tube_to_prp_idx = self.cache_ftr_dict[cache_str_shot_str]
498 | elif self.vis_ftr_type=='i3d':
499 | tube_embedding, tubeInfo = self.cache_ftr_dict[cache_str_shot_str]
500 | # get visual tube embedding
501 | else:
502 | if self.vis_ftr_type =='rgb'or self.vis_ftr_type=='rgb_i3d':
503 | tube_embedding, tubeInfo, tube_to_prp_idx = self.get_tube_embedding(index, self.maxTubelegth, self.out_cache_folder)
504 | if self.context_flag:
505 | tube_embedding_context = self.get_context_embedding(index, self.maxTubelegth, self.out_cache_folder)
506 | tube_embedding_context_exp = np.expand_dims(tube_embedding_context, axis=0).repeat(self.rpNum, axis=0)
507 | tube_embedding = np.concatenate([tube_embedding, tube_embedding_context_exp], axis=2)
508 |
509 |
510 | if self.vis_ftr_type=='rgb_i3d':
511 | tube_embedding_list.append(tube_embedding)
512 | tube_embedding = torch.FloatTensor(tube_embedding)
513 |
514 | if self.vis_ftr_type =='i3d' or self.vis_ftr_type=='rgb_i3d':
515 | tube_embedding = self.get_tube_embedding_i3d(index, self.maxTubelegth, self.out_cache_folder)
516 | if self.context_flag:
517 | tube_embedding_context = self.get_context_embedding_i3d(index, self.maxTubelegth, self.out_cache_folder)
518 | tube_embedding_context_exp = np.expand_dims(tube_embedding_context, axis=0).repeat(self.rpNum, axis=0)
519 | tube_embedding = np.concatenate([tube_embedding, tube_embedding_context_exp], axis=2)
520 |
521 | ins_ann, vd_name = self.vid_parser.get_shot_anno_from_index(index)
522 | tube_info_path = os.path.join(self.tubePath, self.set_name, self.prp_type, str(index)+'.pd')
523 | tubeInfo = pickleload(tube_info_path)
524 | if self.vis_ftr_type=='rgb_i3d':
525 | tube_embedding_list.append(tube_embedding)
526 | tube_embedding = torch.FloatTensor(tube_embedding)
527 |
528 | if self.vis_ftr_type=='rgb_i3d' and cache_str_shot_str not in self.cache_ftr_dict.keys():
529 | tube_embedding = np.concatenate(tube_embedding_list, axis=2)
530 | tube_embedding = torch.FloatTensor(tube_embedding)
531 |
532 | # get position embedding
533 | if self.pos_type !='none':
534 | tp1 = time.time()
535 | tube_embedding_pos = self.get_tube_pos_embedding(tubeInfo, tube_length=self.maxTubelegth, \
536 | feat_dim=self.pos_emb_dim, feat_type=self.pos_type)
537 | tp2 = time.time()
538 | tube_embedding = torch.cat((tube_embedding, tube_embedding_pos), dim=2)
539 |
540 | if self.cache_flag and self.vis_ftr_type=='rgb':
541 | self.cache_ftr_dict[cache_str_shot_str] = [tube_embedding, tubeInfo, tube_to_prp_idx]
542 | elif self.cache_flag and self.vis_ftr_type=='rgb_i3d':
543 | self.cache_ftr_dict[cache_str_shot_str] = [tube_embedding, tubeInfo, tube_to_prp_idx]
544 | elif self.cache_flag and self.vis_ftr_type=='i3d':
545 | self.cache_ftr_dict[cache_str_shot_str] = [tube_embedding, tubeInfo]
546 | tAf2 = time.time()
547 | vd_name, ins_in_vd = self.vid_parser.get_shot_info_from_index(index)
548 |
549 | return tube_embedding, cap_embedding, tubeInfo, index, cap_length_list, vd_name, word_lbl_list
550 |
551 | def get_tube_info(self, indexOri):
552 | index = self.use_key_index[indexOri]
553 | ins_ann, vd_name = self.vid_parser.get_shot_anno_from_index(index)
554 | tube_info_path = os.path.join(self.tubePath, self.set_name, self.prp_type, str(index)+'.pd')
555 | tubeInfo = pickleload(tube_info_path)
556 | return tubeInfo, index
557 |
558 | def get_tube_info_gt(self, indexOri):
559 | '''
560 | get ground truth info
561 | '''
562 | index = self.use_key_index[indexOri]
563 | ins_ann, vd_name = self.vid_parser.get_shot_anno_from_index(index)
564 | tube_info_path = os.path.join(self.tubePath, self.set_name, self.prp_type, str(index)+'.pd')
565 | tubeInfo = pickleload(tube_info_path)
566 | return ins_ann, index, vd_name
567 |
568 | def get_frm_embedding(self, index):
569 | set_name = self.set_name
570 | ins_ann, vd_name = self.vid_parser.get_shot_anno_from_index(index)
571 | tube_info_path = os.path.join(self.tubePath, set_name, self.prp_type, str(index)+'.pd')
572 | tubeInfo = pickleload(tube_info_path)
573 | tube_list, frame_list = tubeInfo
574 | frmNum = len(frame_list)
575 | rpNum = self.rpNum
576 | tube_to_prp_idx = list()
577 | ftr_tube_list = list()
578 | if self.frm_num>0:
579 | sample_index_list = random.sample(range(frmNum), self.frm_num)
580 | else:
581 | sample_index_list = list(range(frmNum))
582 |
583 | frm_ftr_list = list()
584 | bbx_list = list()
585 | for i, frm_id in enumerate(sample_index_list):
586 | frmName = frame_list[frm_id]
587 | img_prp_ftr_info_path = os.path.join(self.ftrPath, self.set_name, vd_name, frmName+ '.pd')
588 | img_prp_ftr_info = pickleload(img_prp_ftr_info_path)
589 | tmp_frm_ftr = img_prp_ftr_info['roiFtr'][:rpNum]
590 | frm_ftr_list.append(np.expand_dims(tmp_frm_ftr, axis=0))
591 | tmp_bbx = copy.deepcopy(img_prp_ftr_info['rois'][:rpNum]) # to be modified
592 | tmp_info = img_prp_ftr_info['imFo'].squeeze()
593 | tmp_bbx[:, 0] = tmp_bbx[:, 0]/tmp_info[1]
594 | tmp_bbx[:, 2] = tmp_bbx[:, 2]/tmp_info[1]
595 | tmp_bbx[:, 1] = tmp_bbx[:, 1]/tmp_info[0]
596 | tmp_bbx[:, 3] = tmp_bbx[:, 3]/tmp_info[0]
597 | bbx_list.append(tmp_bbx)
598 | frm_embedding = np.concatenate(frm_ftr_list, axis=0)
599 | return frm_embedding, tubeInfo, sample_index_list, bbx_list
600 |
601 | def get_visual_frm_item(self, indexOri):
602 | #pdb.set_trace()
603 | index = self.use_key_index[indexOri]
604 |
605 | # testing for certain sample:
606 | sumInd = 0
607 | tube_embedding = None
608 | cap_embedding = None
609 | cap_length_list = -1
610 | tAf = time.time()
611 | cap_embedding, cap_length_list, word_lbl_list = self.get_cap_emb(index, self.capNum)
612 | tBf = time.time()
613 |
614 | cache_str_shot_str = str(self.maxTubelegth) +'_' + str(index)
615 | frm_embedding_list = list()
616 | if self.vis_ftr_type =='rgb'or self.vis_ftr_type=='rgb_i3d':
617 | frm_embedding, tubeInfo, frm_idx, bbx_list = self.get_frm_embedding(index)
618 | #pdb.set_trace()
619 | if self.vis_ftr_type=='rgb_i3d':
620 | frm_embedding_list.append(frm_embedding)
621 | frm_embedding = np.concatenate(tube_embedding_list, axis=2)
622 |
623 | frm_embedding = torch.FloatTensor(frm_embedding)
624 |
625 | tAf2 = time.time()
626 | vd_name, ins_in_vd = self.vid_parser.get_shot_info_from_index(index)
627 | return frm_embedding, cap_embedding, tubeInfo, index, cap_length_list, vd_name, word_lbl_list, frm_idx, bbx_list
628 |
629 | def __getitem__(self, index):
630 | if not self.frm_level_flag:
631 | return self.get_visual_item(index)
632 | else:
633 | return self.get_visual_frm_item(index)
634 |
635 | def dis_collate_vid(batch):
636 | ftr_tube_list = list()
637 | ftr_cap_list = list()
638 | tube_info_list = list()
639 | cap_length_list = list()
640 | index_list = list()
641 | vd_name_list = list()
642 | word_lbl_list = list()
643 | max_length = 0
644 | frm_idx_list = list()
645 | bbx_list = list()
646 | region_gt_ori = list()
647 | for sample in batch:
648 | ftr_tube_list.append(sample[0])
649 | ftr_cap_list.append(torch.FloatTensor(sample[1]))
650 | tube_info_list.append(sample[2])
651 | index_list.append(sample[3])
652 | vd_name_list.append(sample[5])
653 |
654 | for tmp_length in sample[4]:
655 | if(tmp_length>max_length):
656 | max_length = tmp_length
657 | cap_length_list.append(tmp_length)
658 | word_lbl_list.append(sample[6])
659 | if len(sample)>8:
660 | frm_idx_list.append(sample[7])
661 | bbx_list.append(sample[8])
662 |
663 | capMatrix = torch.stack(ftr_cap_list, 0)
664 | capMatrix = capMatrix[:, :, :max_length, :]
665 | if len(frm_idx_list)>0:
666 | return torch.stack(ftr_tube_list, 0), capMatrix, tube_info_list, index_list, cap_length_list, vd_name_list, word_lbl_list, frm_idx_list, bbx_list
667 | else:
668 | return torch.stack(ftr_tube_list, 0), capMatrix, tube_info_list, index_list, cap_length_list, vd_name_list, word_lbl_list
669 |
670 |
671 | if __name__=='__main__':
672 | from datasetLoader import build_dataloader
673 | opt = parse_args()
674 | opt.dbSet = 'vid'
675 | opt.set_name ='train'
676 | opt.batchSize = 4
677 | opt.num_workers = 0
678 | opt.rpNum =30
679 | opt.vis_ftr_type = 'i3d'
680 | data_loader, dataset = build_dataloader(opt)
681 |
--------------------------------------------------------------------------------
/fun/vidDatasetParser.py:
--------------------------------------------------------------------------------
1 |
2 | import sys
3 | sys.path.append('../')
4 | sys.path.append('../util')
5 | from util.mytoolbox import *
6 | import pdb
7 | import h5py
8 | import csv
9 | import numpy as np
10 | from netUtil import *
11 | sys.path.append('../annotations')
12 |
13 | from itertools import izip
14 | import multiprocessing
15 | from multiprocessing import Pool
16 | import dill
17 | import matplotlib
18 | matplotlib.use('Agg')
19 | import matplotlib.pyplot as plt
20 | from wsParamParser import parse_args
21 | import random
22 | import operator
23 | from fun.datasetLoader import *
24 | import math
25 |
26 | class vidInfoParser(object):
27 | def __init__(self, set_name, annFd):
28 | self.tube_gt_path = os.path.join(annFd, 'Annotations/VID/tubeGt', set_name)
29 | self.tube_name_list_fn = os.path.join(annFd, 'Data/VID/annSamples/', set_name+'_valid_list.txt')
30 | self.jpg_folder = os.path.join(annFd, 'Data/VID/', set_name)
31 | self.info_lines = textread(self.tube_name_list_fn)
32 | self.set_name = set_name
33 | self.tube_ann_list_fn = os.path.join(annFd, 'Data/VID/annSamples/', set_name + '_ann_list_v2.txt')
34 | #pdb.set_trace()
35 | ins_lines = textread(self.tube_ann_list_fn)
36 | ann_dict_set_dict = {}
37 | for line in ins_lines:
38 | ins_id_str, caption = line.split(',', 1)
39 | ins_id = int(ins_id_str)
40 | if ins_id not in ann_dict_set_dict.keys():
41 | ann_dict_set_dict[ins_id] = list()
42 | ann_dict_set_dict[ins_id].append(caption)
43 | self.tube_cap_dict = ann_dict_set_dict
44 |
45 | def get_length(self):
46 | return len(self.info_lines)
47 |
48 | def get_shot_info_from_index(self, index):
49 | info_Str = self.info_lines[index]
50 | vd_name, ins_id_str = info_Str.split(',')
51 | return vd_name, ins_id_str
52 |
53 | def get_shot_anno_from_index(self, index):
54 | vd_name, ins_id_str = self.get_shot_info_from_index(index)
55 | jsFn = os.path.join(self.tube_gt_path, vd_name + '.js')
56 | annDict = jsonload(jsFn)
57 | ann = None
58 | for ii, ann in enumerate(annDict['annotations']):
59 | track = ann['track']
60 | trackId = ann['id']
61 | if(trackId!=ins_id_str):
62 | continue
63 | break;
64 | return ann, vd_name
65 |
66 | def get_shot_frame_list_from_index(self, index):
67 | ann, vd_name = self.get_shot_anno_from_index(index)
68 | frm_list = list()
69 | track = ann['track']
70 | trackId = ann['id']
71 | frmNum = len(track)
72 | for iii in range(frmNum):
73 | vdFrmInfo = track[iii]
74 | imPath = '%06d' %(vdFrmInfo['frame']-1)
75 | frm_list.append(imPath)
76 | return frm_list, vd_name
77 |
78 | def proposal_path_set_up(self, prpPath):
79 | self.propsal_path = os.path.join(prpPath, self.set_name)
80 |
81 |
82 | def get_all_instance_frames(set_name, annFd):
83 | tubeGtPath = os.path.join(annFd, 'Annotations/VID/tubeGt', set_name)
84 | tube_name_list_fn = os.path.join(annFd, 'Data/VID/annSamples/', set_name+'_valid_list.txt')
85 | jpg_folder = os.path.join(annFd, 'Data/VID/', set_name)
86 | all_frm_path_list = list()
87 |
88 | info_lines = textread(tube_name_list_fn)
89 | for i, vd_info in enumerate(info_lines):
90 | #if(i!=300):
91 | # continue
92 | vd_name, ins_id_str = vd_info.split(',')
93 | jsFn = os.path.join(tubeGtPath, vd_name + '.js')
94 | annDict = jsonload(jsFn)
95 | for ii, ann in enumerate(annDict['annotations']):
96 | track = ann['track']
97 | trackId = ann['id']
98 | if(trackId!=ins_id_str):
99 | continue
100 | frmNum = len(track)
101 | for iii in range(frmNum):
102 | vdFrmInfo = track[iii]
103 | imPath = jpg_folder + '/' + vd_name + '/' + '%06d.JPEG' %(vdFrmInfo['frame']-1)
104 | all_frm_path_list.append(imPath)
105 | break;
106 | all_frm_path_list_unique = list(set(all_frm_path_list))
107 | all_frm_path_list_unique.sort()
108 | return all_frm_path_list_unique
109 |
110 | def extract_shot_prp_list_from_pickle(vidParser, shot_index, prp_num=20,do_norm=1):
111 | frm_list, vd_name = vidParser.get_shot_frame_list_from_index(shot_index)
112 | prp_list = list()
113 |
114 | for i, frm_name in enumerate(frm_list):
115 | frm_name_raw = frm_name.split('.')[0]
116 | prp_path = os.path.join(vidParser.propsal_path, vd_name, frm_name_raw+'.pd')
117 | frm_prp_info = cPickleload(prp_path)
118 | tmp_bbx = frm_prp_info['rois'][:prp_num]
119 | tmp_score = frm_prp_info['roisS'][:prp_num]
120 | tmp_info = frm_prp_info['imFo'].squeeze()
121 |
122 | #pdb.set_trace()
123 | if do_norm==1:
124 | tmp_bbx[:, 0] = tmp_bbx[:, 0]/tmp_info[1]
125 | tmp_bbx[:, 2] = tmp_bbx[:, 2]/tmp_info[1]
126 | tmp_bbx[:, 1] = tmp_bbx[:, 1]/tmp_info[0]
127 | tmp_bbx[:, 3] = tmp_bbx[:, 3]/tmp_info[0]
128 | elif do_norm==2:
129 | tmp_bbx[:, 0] = tmp_bbx[:, 0]*tmp_info[2]/tmp_info[1]
130 | tmp_bbx[:, 2] = tmp_bbx[:, 2]*tmp_info[2]/tmp_info[1]
131 | tmp_bbx[:, 1] = tmp_bbx[:, 1]*tmp_info[2]/tmp_info[0]
132 | tmp_bbx[:, 3] = tmp_bbx[:, 3]*tmp_info[2]/tmp_info[0]
133 | else:
134 | tmp_bbx = tmp_bbx/tmp_info[2]
135 | tmp_score = np.expand_dims(tmp_score, axis=1 )
136 | prp_list.append([tmp_score, tmp_bbx])
137 | return prp_list, frm_list, vd_name
138 |
139 | def resize_tube_bbx(tube_vis, frmImList_vis):
140 | for prpId, frm in enumerate(tube_vis):
141 | h, w, c = frmImList_vis[prpId].shape
142 | tube_vis[prpId][0] = tube_vis[prpId][0]*w
143 | tube_vis[prpId][2] = tube_vis[prpId][2]*w
144 | tube_vis[prpId][1] = tube_vis[prpId][1]*h
145 | tube_vis[prpId][3] = tube_vis[prpId][3]*h
146 | return tube_vis
147 |
148 | def evaluate_tube_recall_vid(shot_proposals, vid_parser, tube_index, thre=0.5 ,topKOri=20, more_detailed_flag=False):
149 | #pdb.set_trace()
150 | topK = min(topKOri, len(shot_proposals[0][0]))
151 | recall_k = [0.0] * (topK + 1)
152 | iou_list = list()
153 | ann, vid_name = vid_parser.get_shot_anno_from_index(tube_index)
154 | boxes = {}
155 | for i, ann_frame in enumerate(ann['track']):
156 | frame_ind = ann_frame['frame']
157 | box = ann_frame['bbox']
158 | h, w = ann_frame['frame_size']
159 | box[0] = box[0]*1.0/w
160 | box[2] = box[2]*1.0/w
161 | box[1] = box[1]*1.0/h
162 | box[3] = box[3]*1.0/h
163 | keyName = '%06d' %(frame_ind-1)
164 | boxes[keyName] = box
165 |
166 | # pdb.set_trace()
167 | tube_list, frame_list = shot_proposals
168 | assert(len(tube_list[0][0])== len(frame_list))
169 | is_instance_annotated = False
170 | for i in range(topK):
171 | recall_k[i+1] = recall_k[i]
172 | if is_instance_annotated:
173 | continue
174 | curTubeOri = tube_list[0][i]
175 | tube_key_bbxList = {}
176 | for frame_ind, gt_box in boxes.iteritems():
177 | try:
178 | index_tmp = frame_list.index(frame_ind)
179 | tube_key_bbxList[frame_ind] = curTubeOri[index_tmp]
180 | except:
181 | print('key %s do not exist in shot' %(frame_ind))
182 | #pdb.set_trace()
183 | ol = compute_LS(tube_key_bbxList, boxes)
184 | if ol < thre:
185 | if more_detailed_flag:
186 | iou_list.append(ol)
187 | continue
188 | else:
189 | recall_k[i+1] += 1.0
190 | is_instance_annotated = True
191 | if more_detailed_flag:
192 | iou_list.append(ol)
193 | if more_detailed_flag:
194 | return recall_k, iou_list
195 | else:
196 | return recall_k
197 |
198 | def multi_process_connect_tubes(param_list):
199 | tube_index, tube_save_path, prp_num, tube_model_name, connect_w, set_name, annFd, vid_parser = param_list
200 | #vid_parser = vidInfoParser(set_name, annFd)
201 | print(tube_save_path)
202 | if os.path.isfile(tube_save_path):
203 | print('\n file exist\n')
204 | return
205 | if tube_model_name =='coco' :
206 | prp_list, frmList, vd_name = extract_shot_prp_list_from_pickle(vid_parser, tube_index, prp_num, do_norm=1)
207 | else:
208 | prp_list, frmList, vd_name = extract_shot_prp_list_from_pickle(vid_parser, tube_index, prp_num, do_norm=2)
209 | results = get_tubes(prp_list, connect_w)
210 | shot_proposals = [results, frmList]
211 | makedirs_if_missing(os.path.dirname(tube_save_path))
212 | pickledump(tube_save_path, shot_proposals)
213 |
214 | def visual_tube_proposals(tube_save_path, vid_parser, tube_index, prp_num):
215 |
216 | topK = prp_num
217 | recall_k2_sum = np.array([0.0] * (topK + 1))
218 | recall_k3_sum = np.array([0.0] * (topK + 1))
219 | recall_k4_sum = np.array([0.0] * (topK + 1))
220 | recall_k5_sum = np.array([0.0] * (topK + 1))
221 |
222 | shot_proposals = pickleload(tube_save_path)
223 | results, frmList = shot_proposals
224 | vd_name, ins_id_str = vid_parser.get_shot_info_from_index(tube_index)
225 |
226 | recallK2 = evaluate_tube_recall_vid(shot_proposals, vid_parser, tube_index, 0.2 ,topKOri=prp_num)
227 | recallK3 = evaluate_tube_recall_vid(shot_proposals, vid_parser, tube_index, 0.3 ,topKOri=prp_num)
228 | recallK4 = evaluate_tube_recall_vid(shot_proposals, vid_parser, tube_index, 0.4 ,topKOri=prp_num)
229 | recallK5 = evaluate_tube_recall_vid(shot_proposals, vid_parser, tube_index, 0.5 ,topKOri=prp_num)
230 | recall_k2_sum += np.array(recallK2)
231 | recall_k3_sum += np.array(recallK3)
232 | recall_k4_sum += np.array(recallK4)
233 | recall_k5_sum += np.array(recallK5)
234 |
235 | print('%d/%d' %(tube_index, vid_parser.get_length()))
236 | print('thre: %f %f %f %f\n' %( 0.2,0.3,0.4,0.5))
237 | print((recall_k2_sum)*1.0/(tube_index+1))
238 | print((recall_k3_sum)*1.0/(tube_index+1))
239 | print((recall_k4_sum)*1.0/(tube_index+1))
240 | print((recall_k5_sum)*1.0/(tube_index+1))
241 |
242 | #continue
243 | # visualization
244 | frmImNameList = [os.path.join(vid_parser.jpg_folder, vd_name, frame_name + '.JPEG') for frame_name in frmList]
245 | frmImList = list()
246 | for fId, imPath in enumerate(frmImNameList):
247 | img = cv2.imread(imPath)
248 | frmImList.append(img)
249 | vis_frame_num = 30
250 | visIner = int(len(frmImList) /vis_frame_num)
251 | #pdb.set_trace()
252 | for ii in range(len(results[0])):
253 | print('visualizing tube %d\n'%(ii))
254 | tube = results[0][ii]
255 | frmImList_vis = [frmImList[iii] for iii in range(0, len(frmImList), visIner)]
256 | tube_vis = [tube[iii] for iii in range(0, len(frmImList), visIner)]
257 | tube_vis_resize = resize_tube_bbx(tube_vis, frmImList_vis)
258 | vd_name_raw = vd_name.split('/')[-1]
259 | visTube_from_image(copy.deepcopy(frmImList_vis), tube_vis_resize, 'sample/'+vd_name_raw+ '_'+str(prp_num) + str(ii)+'.gif')
260 |
261 |
262 | ###############################################################################################
263 | def get_recall_for_tube_proposals(tube_save_path, vid_parser, tube_index, prp_num, thre_list=[0.2, 0.3, 0.4, 0.5]):
264 |
265 | topK = prp_num
266 |
267 | shot_proposals = pickleload(tube_save_path)
268 | results, frmList = shot_proposals
269 | vd_name, ins_id_str = vid_parser.get_shot_info_from_index(tube_index)
270 |
271 | recall_list = list()
272 | for thre in thre_list:
273 | recallK = evaluate_tube_recall_vid(shot_proposals, vid_parser, tube_index, thre ,topKOri=prp_num)
274 | recall_list.append(np.array(recallK))
275 | return recall_list
276 |
277 | #####################################################################################
278 | def vis_im_prp():
279 | #set_name = 'val'
280 | set_name = 'train'
281 | annFd = '/data1/zfchen/data/ILSVRC'
282 | #prpFd = '/data1/zfchen/data/ILSVRC/vid_prp_zf/Data/VID'
283 | prpFd = '/mnt/ceph_cv/aicv_image_data/forestlma/zfchen/vidPrp/Data/VID'
284 | #prpFd = '/data1/zfchen/data/ILSVRC/vid_prp_vg/Data/VID'
285 | tube_model_name = 'coco'
286 | #tube_model_name = 'vg'
287 | #tube_model_name = 'zf'
288 | vis_frame_num = 30
289 | prp_num = 30
290 | connect_w =0.2
291 | tube_index = 0
292 | thre = 0.5
293 | vis_flag = True
294 | vid_parser = vidInfoParser(set_name, annFd)
295 | vid_parser.proposal_path_set_up(prpFd)
296 |
297 | for tube_index in range(100, vid_parser.get_length()):
298 | #for tube_index in range(0, 2):
299 | tube_save_path = os.path.join(annFd, 'tubePrp' ,set_name, tube_model_name + '_' + str(prp_num) +'_' + str(int(10*connect_w)) , str(tube_index) + '.pd')
300 |
301 | if not vis_flag:
302 | continue
303 | visual_tube_proposals(tube_save_path, vid_parser, tube_index, prp_num)
304 | pdb.set_trace()
305 |
306 | #####################################################################################
307 | def show_distribute_over_categories(recall_list, ann_list, thre_list):
308 | #pdb.set_trace()
309 | print('average instance level performance')
310 | recall_ins_sum = list()
311 | for ii, thre in enumerate(thre_list):
312 | tmp_ins_sum = np.array([0.0] * (recall_list[0][ii].shape[0]))
313 | for i, recall_ins in enumerate(recall_list) :
314 | tmp_ins_sum +=recall_ins[ii]
315 | recall_ins_sum.append( tmp_ins_sum /len(recall_list))
316 | print('thre@ins@%f, %f\n' %(thre, recall_ins_sum[ii][-1]))
317 | print('top K, recall')
318 | print(recall_ins_sum[ii])
319 | pdb.set_trace()
320 |
321 | print('showing recall distribution over categories')
322 | recall_k_categories_dict = {}
323 | for i, ann in enumerate(ann_list):
324 | class_id = str(ann['track'][0]['class'])
325 | if class_id in recall_k_categories_dict.keys():
326 | recall_cat_list, ins_cat_num = recall_k_categories_dict[class_id]
327 | for ii, recall_thre in enumerate(recall_list[i]):
328 | recall_cat_list[ii] += recall_thre
329 | ins_cat_num +=1
330 | recall_k_categories_dict[class_id] =[recall_cat_list, ins_cat_num]
331 | else:
332 | ins_cat_num =1
333 | recall_k_categories_dict[class_id] = [recall_list[i], ins_cat_num]
334 |
335 | mean_cat_map = list()
336 | for i, thre in enumerate(thre_list):
337 | print('--------------------------------------------------------\n')
338 | print('recall@%f\n' %(thre))
339 | recall_plot = list()
340 | for ii, cat_name in enumerate(recall_k_categories_dict.keys()):
341 | recall_cat_list, ins_num = recall_k_categories_dict[cat_name]
342 | recall_thre = recall_cat_list[i][-1]*1.0/ins_num
343 | print('%s: %f\n' %(cat_name, recall_thre))
344 | recall_plot.append(recall_thre)
345 | cat_list = recall_k_categories_dict.keys()
346 | plt.close()
347 | fig, ax = plt.subplots(figsize=(12, 12))
348 | plt.barh(range(len(cat_list)), recall_plot, tick_label=cat_list)
349 | plt.show()
350 | bar_name = './sample/vid_recall_train@%d.jpg' %(int(thre*10))
351 | plt.savefig(bar_name)
352 | mean_cat_map.append(sum(recall_plot)/float(len(recall_plot)))
353 | for thre, mAp in zip(thre_list, mean_cat_map):
354 | print('thre@%f , map: %f\n' %(thre, mAp))
355 |
356 | def caption_to_word_list(des_str):
357 | import string
358 | des_str = des_str.lower().replace('_', ' ').replace(',' , ' ').replace('-', ' ')
359 | for c in string.punctuation:
360 | des_str = des_str.replace(c, '')
361 | return split_carefully(des_str.lower().replace('_', ' ').replace('.', '').replace(',', '').replace("\'", '').replace('-', '').replace('\n', '').replace('\r', '').replace('\"', '').rstrip().replace("\\",'').replace('?', '').replace('/','').replace('#','').replace('(', '').replace(')','').replace(';','').replace('!', '').replace('/',''), ' ')
362 |
363 | def build_vid_word_list():
364 | set_name_list = ['train', 'val', 'test']
365 | ann_cap_path = '/data1/zfchen/data/ILSVRC/Data/VID/annSamples'
366 | word_list = list()
367 | for i, set_name in enumerate(set_name_list):
368 | #ann_cap_set_fn = os.path.join(ann_cap_path, set_name+'_ann_list.txt')
369 | ann_cap_set_fn = os.path.join(ann_cap_path, set_name+'_ann_list_v2.txt')
370 | cap_lines = textread(ann_cap_set_fn)
371 | for ii, line in enumerate(cap_lines):
372 | ins_id_str, caption = line.split(',', 1)
373 | word_list_tmp = caption_to_word_list(caption)
374 | word_list += word_list_tmp
375 | word_list= list(set(word_list))
376 | return word_list
377 |
378 |
379 |
380 | def statistic_vid_word_list():
381 | set_name_list = ['train', 'val', 'test']
382 | ann_cap_path = '/data1/zfchen/data/ILSVRC/Data/VID/annSamples'
383 | word_list = list()
384 | cap_num = 0
385 | for i, set_name in enumerate(set_name_list):
386 | #ann_cap_set_fn = os.path.join(ann_cap_path, set_name+'_ann_list.txt')
387 | ann_cap_set_fn = os.path.join(ann_cap_path, set_name+'_ann_list_v2.txt')
388 | cap_lines = textread(ann_cap_set_fn)
389 | for ii, line in enumerate(cap_lines):
390 | ins_id_str, caption = line.split(',', 1)
391 | word_list_tmp = caption_to_word_list(caption)
392 | while '' in word_list_tmp:
393 | word_list_tmp.remove('')
394 | #pdb.set_trace()
395 | word_list += word_list_tmp
396 | cap_num +=1
397 | print('Average word length: %f\n'%(len(word_list)*1.0/cap_num))
398 | print('total word number: %f\n'%(len(word_list)))
399 | word_dict = list(set(word_list))
400 | print('word in dictionary: %f\n'%(len(word_dict)*1.0))
401 |
402 | # get frequence
403 | word_to_dict ={}
404 | for i, word in enumerate(word_list):
405 | if word in word_to_dict.keys():
406 | word_to_dict[word] +=1
407 | else:
408 | word_to_dict[word] =1
409 | sorted_word = sorted(word_to_dict.items(), key=operator.itemgetter(1))
410 |
411 | sorted_word.reverse()
412 |
413 |
414 | topK = 30
415 | plot_data =[]
416 | cat_name = []
417 | data_fn = 'word_noun.pdf'
418 | count_num = 0
419 | for i in range(len(sorted_word)):
420 | if sorted_word[i][0] not in stopwords.words("english"):
421 | print(sorted_word[i])
422 | plot_data.append(sorted_word[i][1])
423 | cat_name.append(sorted_word[i][0])
424 | count_num +=1
425 | if count_num>=topK:
426 | break
427 | #pdb.set_trace()
428 | plot_data.reverse()
429 | cat_name.reverse()
430 | plot_distribution_word_ori(plot_data, cat_name, data_fn,rot=30, fsize=110)
431 | #pdb.set_trace()
432 | return word_list
433 |
434 | def get_h5_feature_dict(h5file_path):
435 | img_prp_reader = h5py.File(h5file_path, 'r')
436 | return img_prp_reader
437 |
438 | def draw_specific_tube_proposals(vid_parser, index, tube_id_list, tube_proposal_list, out_fd, color_list=None):
439 |
440 | load_image_flag = True
441 | lbl = index
442 | frmImList = list()
443 | tube_info_sub_prp, frm_info_list = tube_proposal_list
444 | if color_list is None:
445 | color_list =[(255, 0, 0), (0, 255, 0)]
446 | #color_list =[(0, 255, 0), (255, 0, 0)]
447 | dotted = False
448 | line_width = 6
449 | for ii, tube_id in enumerate(tube_id_list):
450 | if ii==1:
451 | dotted = True
452 | line_width =3
453 | tube = copy.deepcopy(tube_info_sub_prp[0][tube_id])
454 |
455 | if load_image_flag:
456 | # visualize sample results
457 | vd_name, ins_id_str = vid_parser.get_shot_info_from_index(lbl)
458 | frmImNameList = [os.path.join(vid_parser.jpg_folder, vd_name, frame_name + '.JPEG') for frame_name in frm_info_list]
459 | for fId, imPath in enumerate(frmImNameList):
460 | img = cv2.imread(imPath)
461 | frmImList.append(img)
462 | vis_frame_num = 3000
463 | visIner =max(int(len(frmImList) /vis_frame_num), 1)
464 | load_image_flag = False
465 |
466 | frmImList_vis = [frmImList[iii] for iii in range(0, len(frmImList), visIner)]
467 |
468 | tube_vis = [tube[iii] for iii in range(0, len(frmImList), visIner)]
469 | print('visualizing tube %d\n'%(tube_id))
470 | tube_vis_resize = resize_tube_bbx(tube_vis, frmImList_vis)
471 | frmImList_vis = vis_image_bbx(frmImList_vis, tube_vis_resize, color_list[ii], line_width, dotted)
472 | #frmImList_vis = vis_gray_but_bbx(frmImList_vis, tube_vis_resize)
473 | break
474 |
475 | out_fd_full = os.path.join(out_fd, vid_parser.set_name + str(lbl))
476 | makedirs_if_missing(out_fd_full)
477 | frm_name_list = list()
478 | for i, idx in enumerate(range(0, len(frmImList), visIner)):
479 | out_fn_full = os.path.join(out_fd_full, frm_info_list[idx]+'.jpg')
480 | cv2.imwrite(out_fn_full, frmImList_vis[i])
481 | frm_name_list.append(frm_info_list[idx])
482 | return frmImList_vis, frm_name_list
483 |
484 | def get_gt_bbx(ins_ann):
485 | tube_length = len(ins_ann['track'])
486 | gt_bbx = list()
487 | for i in range(tube_length):
488 | tmp_bbx = copy.deepcopy(ins_ann['track'][i]['bbox'])
489 | h, w = ins_ann['track'][i]['frame_size']
490 | tmp_bbx[0] = tmp_bbx[0]*1.0/w
491 | tmp_bbx[2] = tmp_bbx[2]*1.0/w
492 | tmp_bbx[1] = tmp_bbx[1]*1.0/h
493 | tmp_bbx[3] = tmp_bbx[3]*1.0/h
494 | gt_bbx.append(tmp_bbx)
495 | return gt_bbx
496 |
497 |
498 | if __name__ == '__main__':
499 | pdb.set_trace()
500 |
501 |
502 |
--------------------------------------------------------------------------------
/fun/wsParamParser.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('..')
3 | from util.base_parser import BaseParser
4 |
5 | class wsParamParser(BaseParser):
6 | def __init__(self, *arg_list, **arg_dict):
7 | super(wsParamParser, self).__init__(*arg_list, **arg_dict)
8 | self.add_argument('--batchSize', default=16, type=int)
9 | self.add_argument('--dim_ftr', default=128, type=int)
10 | self.add_argument('--n_pairs', default=50, type=int)
11 | self.add_argument('--initmodel', default=None)
12 | self.add_argument('--resume', default='')
13 | self.add_argument('--gpu', default=1, type=int)
14 | self.add_argument('--decay', default=0.001, type=float)
15 | self.add_argument('--optimizer', default='sgd', type=str)
16 | self.add_argument('--margin', default=1, type=float)
17 | self.add_argument('--test', action='store_true', default=False)
18 | self.add_argument('--dbSet', default='otb', type=str)
19 | self.add_argument('--num_workers', default=8, type=int)
20 | self.add_argument('--visIter', default=5, type=int)
21 | self.add_argument('--evalIter', default=10, type=int)
22 | self.add_argument('--epSize', default=10000, type=int)
23 | self.add_argument('--logFd', default='../data/log/wsEmb', type=str)
24 | self.add_argument('--saveEp', default=10, type=int)
25 | self.add_argument('--outPre', default='../data/models/a2d_checkpoint_', type=str)
26 | self.add_argument('--biLoss', action='store_true', default=False)
27 | self.add_argument('--lossW', action='store_true', default=False)
28 | self.add_argument('--lamda', default=0.8, type=float)
29 | self.add_argument('--vwFlag', action='store_true', default=False)
30 | self.add_argument('--wsMode', default='rank', type=str)
31 | self.add_argument('--hdSize', default=128, type=int)
32 | self.add_argument('--vocaSize', default=1900, type=int)
33 | self.add_argument('--conSecFlag', action='store_true', default=False)
34 | self.add_argument('--conFrmNum', default=9, type=int)
35 | self.add_argument('--moduleNum', default=2, type=int)
36 | self.add_argument('--moduleHdSize', default=1024, type=int)
37 | self.add_argument('--stEp', default=0, type=int)
38 | self.add_argument('--keepKeyFrameOnly', action='store_true', default=False)
39 | self.add_argument('--visRsFd', default='../data/visResult/rank_', type=str)
40 | self.add_argument('--logFdTx', default='../logs/wsEmb', type=str)
41 | self.add_argument('--set_name', default='train', type=str)
42 | self.add_argument('--isParal', action='store_true', default=False)
43 | self.add_argument('--capNum', default=1, type=int)
44 | self.add_argument('--maxWordNum', default=20, type=int)
45 | self.add_argument('--rpNum', default=30, type=int)
46 | self.add_argument('--vis_dim', default=2048, type=int)
47 | self.add_argument('--vis_type', default='lstm', type=str)
48 | self.add_argument('--txt_type', default='lstm', type=str)
49 | self.add_argument('--pos_type', default='aiayn', type=str)
50 | self.add_argument('--pos_emb_dim', default=64, type=int)
51 | self.add_argument('--half_size', action='store_true', default=False)
52 | self.add_argument('--server_id', default=36, type=int)
53 | self.add_argument('--vis_ftr_type', default='rgb', type=str)
54 | self.add_argument('--struct_flag', action='store_true', default=False)
55 | self.add_argument('--struct_only', action='store_true', default=False)
56 | self.add_argument('--eval_val_flag', action='store_true', default=False)
57 | self.add_argument('--eval_test_flag', action='store_true', default=False)
58 | self.add_argument('--entropy_regu_flag', action='store_true', default=False)
59 | self.add_argument('--hidden_dim', default=128, type=int)
60 | self.add_argument('--centre_num', default=32, type=int)
61 | self.add_argument('--vlad_alpha', default=1.0, type=float)
62 | self.add_argument('--cache_flag', action='store_true', default=False)
63 | self.add_argument('--use_mean_cache_flag', action='store_true', default=False)
64 | self.add_argument('--batch_size', type=int, default=64)
65 | self.add_argument('--video_embedding_size', type=int, default=512)
66 | self.add_argument('--fc_feat_size', type=int, default=2048)
67 | self.add_argument('--word_embedding_size', type=int, default=512)
68 | self.add_argument('--lstm_hidden_size', type=int, default=512)
69 | self.add_argument('--att_hidden_size', type=int, default=512)
70 | self.add_argument('--word_cnt', type=int, default=20)
71 | self.add_argument('--context_flag', action='store_true', default=False)
72 | self.add_argument('--no_shuffle_flag', action='store_true', default=False)
73 | self.add_argument('--frm_level_flag', action='store_true', default=False)
74 | self.add_argument('--frm_num', type=int, default=1)
75 | self.add_argument('--att_exp', type=int, default=1)
76 | self.add_argument('--loss_type', default='triplet_mil', type=str)
77 | self.add_argument('--seed', default=0, type=int)
78 | self.add_argument('--update_iter', default=4, type=int)
79 | self.add_argument('--lamda2', default=1, type=float)
80 | self.add_argument('--video_time_step', type=int, default=20)
81 | self.add_argument('--caption_time_step', type=int, default=20) # tacos65 DDM 15
82 | self.add_argument('--dropout_prob', type=float, default=0.1)
83 |
84 |
85 | def parse_args():
86 | parser = wsParamParser()
87 | args = parser.parse_args()
88 | half_size ='full'
89 | if args.half_size:
90 | half_size = 'half'
91 | struct_ann = ''
92 | if args.struct_flag:
93 | struct_ann = '_struct_ann_lamda_%d' %(int(args.lamda*10))
94 | if args.struct_only:
95 | struct_ann = struct_ann + '_only'
96 | if args.entropy_regu_flag:
97 | struct_ann = struct_ann + '_lamda2_' + str(args.lamda2*10)
98 |
99 | struct_ann + args.loss_type
100 |
101 | struct_ann = struct_ann + '_margin_'+ str(args.margin*10)+ '_att_exp' + str(args.att_exp)
102 |
103 | if args.context_flag:
104 | struct_ann = struct_ann + '_context'
105 |
106 | if args.wsMode == 'coAtt':
107 | struct_ann = struct_ann + 'lstm_hd_' + str(args.lstm_hidden_size)
108 |
109 | if args.frm_level_flag:
110 | struct_ann = struct_ann + '_frm_level_'
111 |
112 | if args.lossW:
113 | struct_ann = struct_ann + 'weak_weight_'+str(args.lamda*10)
114 |
115 | struct_ann += '_seed_' + str(args.seed)
116 |
117 | args.logFd = args.logFd +'_bs_'+str(args.batchSize) + '_tn_' + str(args.rpNum) \
118 | +'_wl_' +str(args.maxWordNum) + '_cn_' + str(args.capNum) +'_fd_'+ str(args.dim_ftr) \
119 | + '_' + str(args.wsMode) +'_' +str(args.vis_type)+ '_' + str(args.pos_type) + \
120 | '_' + half_size + '_txt_' + str(args.txt_type) + '_' + str(args.vis_ftr_type) \
121 | + '_lr_' + str(args.lr*100000) + '_' + str(args.dbSet) + struct_ann
122 |
123 | args.outPre = args.outPre +'_bs_'+str(args.batchSize) + '_tn_' + str(args.rpNum) \
124 | +'_wl_'+str(args.maxWordNum) + '_cn_' + str(args.capNum) +'_fd_'+ str(args.dim_ftr) \
125 | + '_' + str(args.wsMode) + str(args.vis_type) + '_'+ str(args.pos_type) + \
126 | '_'+ half_size+'_txt_'+str(args.txt_type)+ '_' +str(args.vis_ftr_type) + \
127 | '_lr_' + str(args.lr*100000) + '_' + str(args.dbSet) + struct_ann +'/'
128 |
129 | args.logFdTx = args.logFdTx +'_bs_'+str(args.batchSize) + '_tn_' + str(args.rpNum) \
130 | +'_wl_' +str(args.maxWordNum) + '_cn_' + str(args.capNum) +'_fd_'+ str(args.dim_ftr) \
131 | + '_' + str(args.wsMode) +'_' +str(args.vis_type)+ '_' + str(args.pos_type) +'_' + \
132 | half_size +'_txt_'+ str(args.txt_type) + '_' + str(args.vis_ftr_type) + '_lr_' + \
133 | str(args.lr*100000) + '_' + str(args.dbSet) + struct_ann
134 |
135 | args.visRsFd = args.visRsFd + args.dbSet + '_'
136 | return args
137 |
--------------------------------------------------------------------------------
/images/frm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zfchenUnique/WSSTG/ba19919b5010321651088afaecba877aebb4ed3a/images/frm.png
--------------------------------------------------------------------------------
/images/task.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zfchenUnique/WSSTG/ba19919b5010321651088afaecba877aebb4ed3a/images/task.png
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # data preparation
2 | ## linking data and features to ./data
3 | ## ln -s ../../WSSTL/data/ILSVRC ./ILSVRC
4 | ## ln -s ../../WSSTL/data/vid ./vid
5 |
--------------------------------------------------------------------------------
/scripts/test_video_emb_att.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python ../fun/eval.py --epSize 1000 \
2 | --dbSet vid \
3 | --maxWordNum 20 --num_workers 0 --batchSize 1 --logFd ../data/final_models/abs_ \
4 | --outPre ../data/vd_model/ --biLoss \
5 | --hdSize 300 --vwFlag --stEp 0 --logFdTx ../data/tensorBoardX/ \
6 | --vis_dim 4096 --set_name val\
7 | --rpNum 30 --saveEp 1 \
8 | --txt_type gru --pos_type none \
9 | --lr 0.001 \
10 | --vis_ftr_type rgb_i3d \
11 | --margin 1 \
12 | --vis_type lstm \
13 | --visIter 1\
14 | --server_id 36\
15 | --wsMode coAtt \
16 | --fc_feat_size 4096 \
17 | --dim_ftr 512 \
18 | --no_shuffle_flag \
19 | --eval_test_flag \
20 | --initmodel ../data/models/_ep_29_itr_0.pth \
21 |
--------------------------------------------------------------------------------
/scripts/train_video_emb_att.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=2 python ../fun/train.py --epSize 30 \
2 | --seed 0\
3 | --dbSet vid \
4 | --maxWordNum 20 --num_workers 0 --batchSize 3 --logFd ../data/log/pm_resume_ \
5 | --outPre ../data/acc_grad --biLoss \
6 | --hdSize 300 --vwFlag --stEp 0 --logFdTx ../data/tensorBoardX/pm_resume_ \
7 | --vis_dim 4096 --set_name train\
8 | --rpNum 30 --saveEp 1 \
9 | --txt_type lstm --pos_type none \
10 | --lr 0 \
11 | --vis_ftr_type rgb_i3d \
12 | --margin 1 \
13 | --vis_type lstm \
14 | --visIter 5\
15 | --fc_feat_size 4096 \
16 | --dim_ftr 512 \
17 | --server_id 36\
18 | --stEp 0 \
19 | --optimizer sgd \
20 | --entropy_regu_flag --lamda2 1 \
21 | --wsMode coAtt \
22 | --update_iter 4 \
23 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zfchenUnique/WSSTG/ba19919b5010321651088afaecba877aebb4ed3a/util/__init__.py
--------------------------------------------------------------------------------
/util/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zfchenUnique/WSSTG/ba19919b5010321651088afaecba877aebb4ed3a/util/__init__.pyc
--------------------------------------------------------------------------------
/util/base_parser.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 |
4 | import argparse
5 |
6 | class BaseParser(argparse.ArgumentParser):
7 | def __init__(self, *arg_list, **arg_dict):
8 | super(BaseParser, self).__init__(*arg_list, **arg_dict)
9 | self.add_argument('--max_iter', default=1500, type=int)
10 | self.add_argument('--start_iter', default=0, type=int)
11 | self.add_argument('--val_interval', default=20, type=int)
12 | self.add_argument('--saving_interval', default=100, type=int)
13 | self.add_argument('--suffix', default='')
14 | self.add_argument('--lrdecay', default=500, type=float)
15 | self.add_argument('--clip_c', default=10., type=float)
16 | self.add_argument('--wo_early_stopping',
17 | dest='early_stopping',
18 | action='store_false',
19 | default=True)
20 | self.add_argument('--alpha', default=0.001, type=float)
21 | self.add_argument('--lr', default=0.001, type=float)
22 | self.add_argument('--momentum', default=0.9, type=float)
23 |
24 | def parse_args(self, *arg_list, **arg_dict):
25 | args = super(BaseParser, self).parse_args(*arg_list, **arg_dict)
26 | if len(args.suffix) > 0:
27 | args.suffix = '_' + args.suffix
28 | return args
29 |
--------------------------------------------------------------------------------
/util/base_parser.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zfchenUnique/WSSTG/ba19919b5010321651088afaecba877aebb4ed3a/util/base_parser.pyc
--------------------------------------------------------------------------------
/util/get_image_size.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | from __future__ import print_function
4 | """
5 |
6 | get_image_size.py
7 | ====================
8 |
9 | :Name: get_image_size
10 | :Purpose: extract image dimensions given a file path
11 |
12 | :Author: Paulo Scardine (based on code from Emmanuel VAÏSSE)
13 |
14 | :Created: 26/09/2013
15 | :Copyright: (c) Paulo Scardine 2013
16 | :Licence: MIT
17 |
18 | """
19 | import collections
20 | import json
21 | import os
22 | import struct
23 |
24 | FILE_UNKNOWN = "Sorry, don't know how to get size for this file."
25 |
26 |
27 | class UnknownImageFormat(Exception):
28 | pass
29 |
30 |
31 | types = collections.OrderedDict()
32 | BMP = types['BMP'] = 'BMP'
33 | GIF = types['GIF'] = 'GIF'
34 | ICO = types['ICO'] = 'ICO'
35 | JPEG = types['JPEG'] = 'JPEG'
36 | PNG = types['PNG'] = 'PNG'
37 | TIFF = types['TIFF'] = 'TIFF'
38 |
39 | image_fields = ['path', 'type', 'file_size', 'width', 'height']
40 |
41 |
42 | class Image(collections.namedtuple('Image', image_fields)):
43 |
44 | def to_str_row(self):
45 | return ("%d\t%d\t%d\t%s\t%s" % (
46 | self.width,
47 | self.height,
48 | self.file_size,
49 | self.type,
50 | self.path.replace('\t', '\\t'),
51 | ))
52 |
53 | def to_str_row_verbose(self):
54 | return ("%d\t%d\t%d\t%s\t%s\t##%s" % (
55 | self.width,
56 | self.height,
57 | self.file_size,
58 | self.type,
59 | self.path.replace('\t', '\\t'),
60 | self))
61 |
62 | def to_str_json(self, indent=None):
63 | return json.dumps(self._asdict(), indent=indent)
64 |
65 |
66 | def get_image_size(file_path):
67 | """
68 | Return (width, height) for a given img file content - no external
69 | dependencies except the os and struct builtin modules
70 | """
71 | img = get_image_metadata(file_path)
72 | return (img.width, img.height)
73 |
74 |
75 | def get_image_metadata(file_path):
76 | """
77 | Return an `Image` object for a given img file content - no external
78 | dependencies except the os and struct builtin modules
79 |
80 | Args:
81 | file_path (str): path to an image file
82 |
83 | Returns:
84 | Image: (path, type, file_size, width, height)
85 | """
86 | size = os.path.getsize(file_path)
87 |
88 | # be explicit with open arguments - we need binary mode
89 | with open(file_path, "rb") as input:
90 | height = -1
91 | width = -1
92 | data = input.read(26)
93 | msg = " raised while trying to decode as JPEG."
94 |
95 | if (size >= 10) and data[:6] in (b'GIF87a', b'GIF89a'):
96 | # GIFs
97 | imgtype = GIF
98 | w, h = struct.unpack("= 24) and data.startswith(b'\211PNG\r\n\032\n')
102 | and (data[12:16] == b'IHDR')):
103 | # PNGs
104 | imgtype = PNG
105 | w, h = struct.unpack(">LL", data[16:24])
106 | width = int(w)
107 | height = int(h)
108 | elif (size >= 16) and data.startswith(b'\211PNG\r\n\032\n'):
109 | # older PNGs
110 | imgtype = PNG
111 | w, h = struct.unpack(">LL", data[8:16])
112 | width = int(w)
113 | height = int(h)
114 | elif (size >= 2) and data.startswith(b'\377\330'):
115 | # JPEG
116 | imgtype = JPEG
117 | input.seek(0)
118 | input.read(2)
119 | b = input.read(1)
120 | try:
121 | while (b and ord(b) != 0xDA):
122 | while (ord(b) != 0xFF):
123 | b = input.read(1)
124 | while (ord(b) == 0xFF):
125 | b = input.read(1)
126 | if (ord(b) >= 0xC0 and ord(b) <= 0xC3):
127 | input.read(3)
128 | h, w = struct.unpack(">HH", input.read(4))
129 | break
130 | else:
131 | input.read(
132 | int(struct.unpack(">H", input.read(2))[0]) - 2)
133 | b = input.read(1)
134 | width = int(w)
135 | height = int(h)
136 | except struct.error:
137 | raise UnknownImageFormat("StructError" + msg)
138 | except ValueError:
139 | raise UnknownImageFormat("ValueError" + msg)
140 | except Exception as e:
141 | raise UnknownImageFormat(e.__class__.__name__ + msg)
142 | elif (size >= 26) and data.startswith(b'BM'):
143 | # BMP
144 | imgtype = 'BMP'
145 | headersize = struct.unpack("= 40:
151 | w, h = struct.unpack("= 8) and data[:4] in (b"II\052\000", b"MM\000\052"):
160 | # Standard TIFF, big- or little-endian
161 | # BigTIFF and other different but TIFF-like formats are not
162 | # supported currently
163 | imgtype = TIFF
164 | byteOrder = data[:2]
165 | boChar = ">" if byteOrder == "MM" else "<"
166 | # maps TIFF type id to size (in bytes)
167 | # and python format char for struct
168 | tiffTypes = {
169 | 1: (1, boChar + "B"), # BYTE
170 | 2: (1, boChar + "c"), # ASCII
171 | 3: (2, boChar + "H"), # SHORT
172 | 4: (4, boChar + "L"), # LONG
173 | 5: (8, boChar + "LL"), # RATIONAL
174 | 6: (1, boChar + "b"), # SBYTE
175 | 7: (1, boChar + "c"), # UNDEFINED
176 | 8: (2, boChar + "h"), # SSHORT
177 | 9: (4, boChar + "l"), # SLONG
178 | 10: (8, boChar + "ll"), # SRATIONAL
179 | 11: (4, boChar + "f"), # FLOAT
180 | 12: (8, boChar + "d") # DOUBLE
181 | }
182 | ifdOffset = struct.unpack(boChar + "L", data[4:8])[0]
183 | try:
184 | countSize = 2
185 | input.seek(ifdOffset)
186 | ec = input.read(countSize)
187 | ifdEntryCount = struct.unpack(boChar + "H", ec)[0]
188 | # 2 bytes: TagId + 2 bytes: type + 4 bytes: count of values + 4
189 | # bytes: value offset
190 | ifdEntrySize = 12
191 | for i in range(ifdEntryCount):
192 | entryOffset = ifdOffset + countSize + i * ifdEntrySize
193 | input.seek(entryOffset)
194 | tag = input.read(2)
195 | tag = struct.unpack(boChar + "H", tag)[0]
196 | if(tag == 256 or tag == 257):
197 | # if type indicates that value fits into 4 bytes, value
198 | # offset is not an offset but value itself
199 | type = input.read(2)
200 | type = struct.unpack(boChar + "H", type)[0]
201 | if type not in tiffTypes:
202 | raise UnknownImageFormat(
203 | "Unkown TIFF field type:" +
204 | str(type))
205 | typeSize = tiffTypes[type][0]
206 | typeChar = tiffTypes[type][1]
207 | input.seek(entryOffset + 8)
208 | value = input.read(typeSize)
209 | value = int(struct.unpack(typeChar, value)[0])
210 | if tag == 256:
211 | width = value
212 | else:
213 | height = value
214 | if width > -1 and height > -1:
215 | break
216 | except Exception as e:
217 | raise UnknownImageFormat(str(e))
218 | elif size >= 2:
219 | # see http://en.wikipedia.org/wiki/ICO_(file_format)
220 | imgtype = 'ICO'
221 | input.seek(0)
222 | reserved = input.read(2)
223 | if 0 != struct.unpack(" 1:
230 | import warnings
231 | warnings.warn("ICO File contains more than one image")
232 | # http://msdn.microsoft.com/en-us/library/ms997538.aspx
233 | w = input.read(1)
234 | h = input.read(1)
235 | width = ord(w)
236 | height = ord(h)
237 | else:
238 | raise UnknownImageFormat(FILE_UNKNOWN)
239 |
240 | return Image(path=file_path,
241 | type=imgtype,
242 | file_size=size,
243 | width=width,
244 | height=height)
245 |
246 |
247 | import unittest
248 |
249 |
250 | class Test_get_image_size(unittest.TestCase):
251 | data = [{
252 | 'path': 'lookmanodeps.png',
253 | 'width': 251,
254 | 'height': 208,
255 | 'file_size': 22228,
256 | 'type': 'PNG'}]
257 |
258 | def setUp(self):
259 | pass
260 |
261 | def test_get_image_metadata(self):
262 | img = self.data[0]
263 | output = get_image_metadata(img['path'])
264 | self.assertTrue(output)
265 | self.assertEqual(output.path, img['path'])
266 | self.assertEqual(output.width, img['width'])
267 | self.assertEqual(output.height, img['height'])
268 | self.assertEqual(output.type, img['type'])
269 | self.assertEqual(output.file_size, img['file_size'])
270 | for field in image_fields:
271 | self.assertEqual(getattr(output, field), img[field])
272 |
273 | def test_get_image_metadata__ENOENT_OSError(self):
274 | with self.assertRaises(OSError):
275 | get_image_metadata('THIS_DOES_NOT_EXIST')
276 |
277 | def test_get_image_metadata__not_an_image_UnknownImageFormat(self):
278 | with self.assertRaises(UnknownImageFormat):
279 | get_image_metadata('README.rst')
280 |
281 | def test_get_image_size(self):
282 | img = self.data[0]
283 | output = get_image_size(img['path'])
284 | self.assertTrue(output)
285 | self.assertEqual(output,
286 | (img['width'],
287 | img['height']))
288 |
289 | def tearDown(self):
290 | pass
291 |
292 |
293 | def main(argv=None):
294 | """
295 | Print image metadata fields for the given file path.
296 |
297 | Keyword Arguments:
298 | argv (list): commandline arguments (e.g. sys.argv[1:])
299 | Returns:
300 | int: zero for OK
301 | """
302 | import logging
303 | import optparse
304 | import sys
305 |
306 | prs = optparse.OptionParser(
307 | usage="%prog [-v|--verbose] [--json|--json-indent] []",
308 | description="Print metadata for the given image paths "
309 | "(without image library bindings).")
310 |
311 | prs.add_option('--json',
312 | dest='json',
313 | action='store_true')
314 | prs.add_option('--json-indent',
315 | dest='json_indent',
316 | action='store_true')
317 |
318 | prs.add_option('-v', '--verbose',
319 | dest='verbose',
320 | action='store_true',)
321 | prs.add_option('-q', '--quiet',
322 | dest='quiet',
323 | action='store_true',)
324 | prs.add_option('-t', '--test',
325 | dest='run_tests',
326 | action='store_true',)
327 |
328 | argv = list(argv) if argv is not None else sys.argv[1:]
329 | (opts, args) = prs.parse_args(args=argv)
330 | loglevel = logging.INFO
331 | if opts.verbose:
332 | loglevel = logging.DEBUG
333 | elif opts.quiet:
334 | loglevel = logging.ERROR
335 | logging.basicConfig(level=loglevel)
336 | log = logging.getLogger()
337 | log.debug('argv: %r', argv)
338 | log.debug('opts: %r', opts)
339 | log.debug('args: %r', args)
340 |
341 | if opts.run_tests:
342 | import sys
343 | sys.argv = [sys.argv[0]] + args
344 | import unittest
345 | return unittest.main()
346 |
347 | output_func = Image.to_str_row
348 | if opts.json_indent:
349 | import functools
350 | output_func = functools.partial(Image.to_str_json, indent=2)
351 | elif opts.json:
352 | output_func = Image.to_str_json
353 | elif opts.verbose:
354 | output_func = Image.to_str_row_verbose
355 |
356 | EX_OK = 0
357 | EX_NOT_OK = 2
358 |
359 | if len(args) < 1:
360 | prs.print_help()
361 | print('')
362 | prs.error("You must specify one or more paths to image files")
363 |
364 | errors = []
365 | for path_arg in args:
366 | try:
367 | img = get_image_metadata(path_arg)
368 | print(output_func(img))
369 | except KeyboardInterrupt:
370 | raise
371 | except OSError as e:
372 | log.error((path_arg, e))
373 | errors.append((path_arg, e))
374 | except Exception as e:
375 | log.exception(e)
376 | errors.append((path_arg, e))
377 | pass
378 | if len(errors):
379 | import pprint
380 | print("ERRORS", file=sys.stderr)
381 | print("======", file=sys.stderr)
382 | print(pprint.pformat(errors, indent=2), file=sys.stderr)
383 | return EX_NOT_OK
384 | return EX_OK
385 |
386 |
387 | if __name__ == "__main__":
388 | import sys
389 | sys.exit(main(argv=sys.argv[1:]))
390 |
--------------------------------------------------------------------------------
/util/mytoolbox.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 |
4 | import commands
5 | import copy
6 | import cPickle
7 | import datetime
8 | import inspect
9 | import json
10 | import os
11 | import pickle
12 | import sys
13 | import time
14 | import scipy.io
15 | import cv2
16 | import pdb
17 |
18 | def set_debugger():
19 | from IPython.core import ultratb
20 | sys.excepthook = ultratb.FormattedTB(call_pdb=True)
21 |
22 | class TimeReporter:
23 | def __init__(self, max_count, interval=1, moving_average=False):
24 | self.time = time.time
25 | self.start_time = time.time()
26 | self.max_count = max_count
27 | self.cur_count = 0
28 | self.prev_time = time.time()
29 | self.interval = interval
30 | self.moving_average = moving_average
31 | def report(self, cur_count=None, max_count=None, overwrite=True, prefix=None, postfix=None, interval=None):
32 | if cur_count is not None:
33 | self.cur_count = cur_count
34 | else:
35 | self.cur_count += 1
36 | if max_count is None:
37 | max_count = self.max_count
38 | cur_time = self.time()
39 | elapsed = cur_time - self.start_time
40 | if self.cur_count <= 0:
41 | ave_time = float('inf')
42 | elif self.moving_average and self.cur_count == 1:
43 | ave_time = float('inf')
44 | self.ma_prev_time = cur_time
45 | elif self.moving_average and self.cur_count == 2:
46 | self.ma_time = cur_time - self.ma_prev_time
47 | ave_time = self.ma_time
48 | self.ma_prev_time = cur_time
49 | elif self.moving_average:
50 | self.ma_time = self.ma_time * 0.95 + (cur_time - self.ma_prev_time) * 0.05
51 | ave_time = self.ma_time
52 | self.ma_prev_time = cur_time
53 | else:
54 | ave_time = elapsed / self.cur_count
55 | ETA = (max_count - self.cur_count) * ave_time
56 | print_str = 'count : %d / %d, elapsed time : %f, ETA : %f' % (self.cur_count, self.max_count, elapsed, ETA)
57 | if prefix is not None:
58 | print_str = str(prefix) + ' ' + print_str
59 | if postfix is not None:
60 | print_str += ' ' + str(postfix)
61 | this_interval = self.interval
62 | if interval is not None:
63 | this_interval = interval
64 | if cur_time - self.prev_time < this_interval:
65 | return
66 | if overwrite and self.cur_count != self.max_count:
67 | printr(print_str)
68 | self.prev_time = cur_time
69 | else:
70 | print print_str
71 | self.prev_time = cur_time
72 |
73 | def textread(path):
74 | f = open(path)
75 | lines = f.readlines()
76 | f.close()
77 | for i in range(len(lines)):
78 | lines[i] = lines[i].replace('\n', '').replace('\r', '')
79 | return lines
80 |
81 | def textdump(path, lines, need_asking=False):
82 | if os.path.exists(path) and need_asking:
83 | if 'n' == choosebyinput(['Y', 'n'], path + ' exists. Would you replace? [Y/n]'):
84 | return False
85 | f = open(path, 'w')
86 | for index, i in enumerate(lines):
87 | try:
88 | f.write(i.encode("utf-8") + '\n')
89 | except:
90 | print(index)
91 | pdb.set_trace()
92 |
93 | f.close()
94 |
95 | def pickleload(path):
96 | f = open(path)
97 | this_ans = pickle.load(f)
98 | f.close()
99 | return this_ans
100 |
101 | def pickledump(path, this_dic):
102 | f = open(path, 'w')
103 | this_ans = pickle.dump(this_dic, f)
104 | f.close()
105 |
106 | def cPickleload(path):
107 | f = open(path, 'rb')
108 | this_ans = cPickle.load(f)
109 | f.close()
110 | return this_ans
111 |
112 | def cPickledump(path, this_dic):
113 | f = open(path, 'wb')
114 | this_ans = cPickle.dump(this_dic, f, -1)
115 | f.close()
116 |
117 | def jsonload(path):
118 | f = open(path)
119 | this_ans = json.load(f)
120 | f.close()
121 | return this_ans
122 |
123 | def jsondump(path, this_dic):
124 | f = open(path, 'w')
125 | this_ans = json.dump(this_dic, f)
126 | f.close()
127 |
128 | def choosebyinput(cand, message=False):
129 | if not type(cand) == list and not type(cand) == int:
130 | print 'The type of cand_list has to be \'list\' or \'int\' .'
131 | return
132 | if type(cand) == int:
133 | cand_list = range(cand)
134 | if type(cand) == list:
135 | cand_list = cand
136 | int_cand_list = []
137 | for i in cand_list:
138 | if type(i) == int:
139 | int_cand_list.append(str(i))
140 | if message == False:
141 | message = 'choose by input ['
142 | for i in int_cand_list:
143 | message += i + ' / '
144 | for i in cand_list:
145 | if not str(i) in int_cand_list:
146 | message += i + ' / '
147 | message = message[:-3] + '] : '
148 | while True:
149 | your_ans = raw_input(message)
150 | if your_ans in int_cand_list:
151 | return int(your_ans)
152 | break
153 | if your_ans in cand_list:
154 | return your_ans
155 | break
156 |
157 | def printr(*targ_str):
158 | str_to_print = ''
159 | for temp_str in targ_str:
160 | str_to_print += str(temp_str) + ' '
161 | str_to_print = str_to_print[:-1]
162 | sys.stdout.write(str_to_print + '\r')
163 | sys.stdout.flush()
164 |
165 | def make_red(prt):
166 | return '\033[91m%s\033[00m' % prt
167 |
168 | def emphasize(*targ_str):
169 | str_to_print = ''
170 | for temp_str in targ_str:
171 | str_to_print += str(temp_str) + ' '
172 | str_to_print = str_to_print[:-1]
173 | num_repeat = len(str_to_print) / 2 + 1
174 | print '_' + '人' * (num_repeat + 1) + '_'
175 | print '> %s <' % make_red(str_to_print)
176 | print ' ̄' + 'Y^' * num_repeat + 'Y ̄'
177 |
178 | def mkdir_if_missing(dir_path):
179 | if not os.path.exists(dir_path):
180 | os.mkdir(dir_path)
181 |
182 | def makedirs_if_missing(dir_path):
183 | if not os.path.exists(dir_path):
184 | os.makedirs(dir_path)
185 |
186 | def makebsdirs_if_missing(f_path):
187 | makedirs_if_missing(os.path.dirname(f_path) if '/' in f_path else f_path)
188 |
189 | def split_inds(all_num, split_num, split_targ):
190 | assert split_num >= 1
191 | assert split_targ >= 0
192 | assert split_targ < split_num
193 | part = all_num // split_num
194 | if not split_num == split_targ+1:
195 | return split_targ * part, (split_targ+1) * part
196 | else:
197 | return split_targ * part, all_num
198 |
199 | try:
200 | import numpy as np
201 | def are_same_vecs(vec_a, vec_b, this_eps1=1e-5, verbose=False):
202 | if not vec_a.ravel().shape == vec_b.ravel().shape:
203 | return False
204 | if np.linalg.norm(vec_a.ravel()) == 0:
205 | if not np.linalg.norm(vec_b.ravel()) == 0:
206 | if verbose:
207 | print 'assertion failed.'
208 | print 'diff norm : %f' % (np.linalg.norm(vec_a.ravel() - vec_b.ravel()))
209 | return False
210 | else:
211 | if not np.linalg.norm(vec_a.ravel() - vec_b.ravel()) / np.linalg.norm(vec_a.ravel()) < this_eps1:
212 | if verbose:
213 | print 'assertion failed.'
214 | print 'diff norm : %f' % (np.linalg.norm(vec_a.ravel() - vec_b.ravel()) / np.linalg.norm(vec_a.ravel()))
215 | return False
216 | return True
217 | def comp_vecs(vec_a, vec_b, this_eps1=1e-5):
218 | assert are_same_vecs(vec_a, vec_b, this_eps1, True)
219 | def arrayinfo(np_array):
220 | print 'max: %04f, min: %04f, abs_min: %04f, norm: %04f,' % (np_array.max(), np_array.min(), np.abs(np_array).min(), np.linalg.norm(np_array)),
221 | print 'dtype: %s,' % np_array.dtype,
222 | print 'shape: %s,' % str(np_array.shape),
223 | print
224 | except:
225 | def comp_vecs(*input1, **input2):
226 | print 'comp_vecs() cannot be loaded.'
227 | return
228 | def arrayinfo(*input1, **input2):
229 | print 'arrayinfo() cannot be loaded.'
230 | return
231 |
232 | try:
233 | import Levenshtein
234 | def search_nn_str(targ_str, str_lists):
235 | dist = float('inf')
236 | dist_str = None
237 | for i in sorted(str_lists):
238 | cur_dist = Levenshtein.distance(i, targ_str)
239 | if dist > cur_dist:
240 | dist = cur_dist
241 | dist_str = i
242 | return dist_str
243 | except:
244 | def search_nn_str(targ_str, str_lists):
245 | print 'search_nn_str() cannot be imported.'
246 | return
247 |
248 | def flatten(targ_list):
249 | new_list = copy.deepcopy(targ_list)
250 | for i in reversed(range(len(new_list))):
251 | if isinstance(new_list[i], list) or isinstance(new_list[i], tuple):
252 | new_list[i:i+1] = flatten(new_list[i])
253 | return new_list
254 |
255 | def predict_charset(targ_str):
256 | targ_charsets = ['utf-8', 'cp932', 'euc-jp', 'iso-2022-jp']
257 | for targ_charset in targ_charsets:
258 | try:
259 | targ_str.decode(targ_charset)
260 | return targ_charset
261 | except UnicodeDecodeError:
262 | pass
263 | return None
264 |
265 | def remove_non_ascii(targ_str, charset=None):
266 | if charset is not None:
267 | assert isinstance(targ_str, str)
268 | targ_str = targ_str.decode(charset)
269 | else:
270 | assert isinstance(targ_str, unicode)
271 | return ''.join([x for x in targ_str if ord(x) < 128]).encode('ascii')
272 |
273 | class StopWatch(object):
274 | def __init__(self):
275 | self._time = {}
276 | self._bef_time = {}
277 | def tic(self, name):
278 | self._bef_time[name] = time.time()
279 | def toc(self, name):
280 | self._time[name] = time.time() - self._bef_time[name]
281 | self._time[name] = time.time() - self._bef_time[name]
282 | return self._time[name]
283 | def show(self):
284 | show_str = ''
285 | for name, elp in self._time.iteritems():
286 | show_str += '%s: %03.3f, ' % (name, elp)
287 | printr(show_str[:-2])
288 |
289 | Timer = StopWatch # deprecated
290 |
291 | def get_free_gpu(default_gpu):
292 | FORMAT = '--format=csv,noheader'
293 | COM_GPU_UTIL = 'nvidia-smi --query-gpu=index,uuid ' + FORMAT
294 | COM_GPU_PROCESS = 'nvidia-smi --query-compute-apps=gpu_uuid ' + FORMAT
295 | uuid2id = {cur_line.split(',')[1].strip(): int(cur_line.split(',')[0])
296 | for cur_line in commands.getoutput(COM_GPU_UTIL).split('\n')}
297 | used_gpus = set()
298 | for cur_line in commands.getoutput(COM_GPU_PROCESS).split('\n'):
299 | used_gpus.add(cur_line)
300 | if len(uuid2id) == len(used_gpus):
301 | return default_gpu
302 | elif os.uname()[1] == 'konoshiro':
303 | return str(1 - int(uuid2id[list(set(uuid2id.keys()) - used_gpus)[0]]))
304 | else:
305 | return uuid2id[list(set(uuid2id.keys()) - used_gpus)[0]]
306 |
307 | def get_abs_path():
308 | return os.path.dirname(os.path.abspath(os.path.join(os.getcwd(), __file__)))
309 |
310 | def add_path(path):
311 | if path not in sys.path:
312 | sys.path.insert(0, path)
313 |
314 | def get_time_str():
315 | return datetime.datetime.now().strftime('Y%yM%mD%dH%hM%M')
316 |
317 | def get_cur_time():
318 | return str(datetime.datetime.now().strftime('%Y/%m/%d %H:%M:%S'))
319 |
320 | def split_carefully(text, splitter=',', delimiters=['"', "'"]):
321 | # assertion
322 | assert isinstance(splitter, str)
323 | assert not splitter in delimiters
324 | if not (isinstance(delimiters, list) or isinstance(delimiters, tuple)):
325 | delimiters = [delimiters]
326 | for cur_del in delimiters:
327 | assert len(cur_del) == 1
328 |
329 | cur_ind = 0
330 | prev_ind = 0
331 | splitted = []
332 | is_in_delimiters = False
333 | cur_del = None
334 | while cur_ind < len(text):
335 | if text[cur_ind] in delimiters:
336 | if text[cur_ind] == cur_del:
337 | is_in_delimiters = False
338 | cur_del = None
339 | cur_ind += 1
340 | continue
341 | elif not is_in_delimiters:
342 | is_in_delimiters = True
343 | cur_del = text[cur_ind]
344 | cur_ind += 1
345 | continue
346 | if not is_in_delimiters and text[cur_ind] == splitter:
347 | splitted.append(text[prev_ind:cur_ind])
348 | cur_ind += 1
349 | prev_ind = cur_ind
350 | continue
351 | cur_ind += 1
352 | splitted.append(text[prev_ind:cur_ind])
353 | return splitted
354 |
355 | def full_listdir(dir_name):
356 | return [os.path.join(dir_name, i) for i in os.listdir(dir_name)]
357 |
358 | def get_list_dir(dir_name):
359 | fileFdList = full_listdir(dir_name)
360 | folder_list = list()
361 | for item in fileFdList:
362 | if os.path.isdir(item):
363 | folder_list.append(item)
364 | return folder_list
365 |
366 |
367 | class tictoc(object):
368 | def __init__(self, targ_list):
369 | self._targ_list = targ_list
370 | self._list_ind = -1
371 | self._TR = TimeReporter(len(targ_list))
372 | def __iter__(self):
373 | return self
374 | def next(self):
375 | self._list_ind += 1
376 | if self._list_ind > 0:
377 | self._TR.report()
378 | if self._list_ind == len(self._targ_list):
379 | raise StopIteration()
380 | return self._targ_list[self._list_ind]
381 |
382 | ONCE_PRINTED = set()
383 |
384 |
385 | def print_once(*targ_str):
386 | frame = inspect.currentframe(1)
387 | fname = inspect.getfile(frame)
388 | cur_loc = frame.f_lineno
389 | cur_key = fname + str(cur_loc)
390 | if cur_key in ONCE_PRINTED:
391 | return
392 | else:
393 | ONCE_PRINTED.add(cur_key)
394 | str_to_print = ''
395 | for temp_str in targ_str:
396 | str_to_print += str(temp_str) + ' '
397 | str_to_print = str_to_print[:-1]
398 | print str_to_print
399 |
400 | def get_specific_file_list_from_fd(dir_name, fileType, nameOnly=True):
401 | list_name = []
402 | for fileTmp in os.listdir(dir_name):
403 | file_path = os.path.join(dir_name, fileTmp)
404 | if os.path.isdir(file_path):
405 | continue
406 | elif os.path.splitext(fileTmp)[1] == fileType:
407 | if nameOnly:
408 | list_name.append(os.path.splitext(fileTmp)[0])
409 | else:
410 | list_name.append(file_path)
411 | return list_name
412 |
413 | # bugs on annotations
414 | def parse_mul_num_lines(fileName, toFloat=True, spliter=','):
415 | lineOut = []
416 | lineList= textread(fileName)
417 | for lineTmp in lineList:
418 | splitedStr= split_carefully(lineTmp, spliter)
419 | if(len(splitedStr)<4):
420 | print('stange encoding for %s!' %(fileName))
421 | splitedStr= split_carefully(lineTmp, '\t')
422 | if(len(splitedStr)<4):
423 | print('stange encoding for %s!' %(fileName))
424 | splitedStr= split_carefully(lineTmp, ' ')
425 |
426 | if(toFloat):
427 | splitedTmp= [ float(ele) for ele in splitedStr]
428 | lineOut.append(splitedTmp)
429 | else:
430 | lineOut.append(splitedStr)
431 | return lineOut
432 |
433 | def pck2mat(pckFn, outFn):
434 | data = pickleload(pckFn)
435 | scipy.io.savemat(outFn, data)
436 | print('finish transformation')
437 |
438 | def putCapOnImage(imgVis, capList):
439 | if isinstance(capList, list):
440 | cap = ''
441 | for ele in capList:
442 | cap +=ele
443 | cap +=' '
444 | else:
445 | cap = capList
446 | cv2.putText(imgVis, cap,
447 | (10, 50),
448 | cv2.FONT_HERSHEY_PLAIN,
449 | 2, (0, 0, 255),
450 | 2)
451 | return imgVis
452 |
453 | def get_all_file_list(dir_name):
454 | file_list=list()
455 | for fileTmp in os.listdir(dir_name):
456 | file_path = os.path.join(dir_name, fileTmp)
457 | if os.path.isdir(file_path):
458 | continue
459 | file_list.append(file_path)
460 | return file_list
461 |
462 | def resize_image_with_fixed_height(img, hSize =320):
463 | h, w, c = img.shape
464 | scl = hSize*1.0/h
465 | imgResize = cv2.resize(img, None, None, fx=scl, fy=scl)
466 | return imgResize, scl, h, w
467 |
468 |
--------------------------------------------------------------------------------
/util/mytoolbox.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zfchenUnique/WSSTG/ba19919b5010321651088afaecba877aebb4ed3a/util/mytoolbox.pyc
--------------------------------------------------------------------------------