├── LICENSE ├── README.md ├── annotations ├── ptd_api.py ├── script_test_annotation.py └── utils.py ├── data ├── __init__.py ├── data_loader.py └── dictForDb_vid_v2.pd ├── fun ├── __init__.py ├── classSST.py ├── create_word2vec_for_dataset.py ├── dashed_rect.py ├── datasetLoader.py ├── datasetParser.py ├── eval.py ├── evalDet.py ├── image_toolbox.py ├── logInfo.py ├── lossPackage.py ├── modelArc.py ├── netUtil.py ├── optimizers.py ├── train.py ├── vidDataset.py ├── vidDatasetParser.py └── wsParamParser.py ├── images ├── frm.png └── task.png ├── readme.md ├── scripts ├── test_video_emb_att.sh └── train_video_emb_att.sh └── util ├── __init__.py ├── __init__.pyc ├── base_parser.py ├── base_parser.pyc ├── get_image_size.py ├── mytoolbox.py └── mytoolbox.pyc /LICENSE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zfchenUnique/WSSTG/ba19919b5010321651088afaecba877aebb4ed3a/LICENSE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Weakly-Supervised Spatio-Temporally Grounding Natural Sentence in Video 2 | 3 | This repo contains the main baselines of VID-sentence dataset introduced in WSSTG. 4 | Please refer to [our paper](https://arxiv.org/abs/1906.02549) and the [repo](https://github.com/JeffCHEN2017/VID-sentence) for the information of VID-sentence dataset. 5 | 6 | 7 | ### Task 8 | 9 |

10 |

Description: "A brown and white dog is lying on the grass and then it stands up."
11 |

12 |

13 | task 14 |

15 |

16 |

The proposed WSSTG task aims to localize a spatio-temporal tube (i.e., the sequence of green bounding boxes) in the video which semantically corresponds to the given sentence, with no reliance on any spatio-temporal annotations during training.
17 |

18 | 19 | ### Architecture 20 | 21 |

22 |

The architecture of the proposed method.
23 |

24 |

25 | architecture 26 |

27 | 28 | ### Contents 29 | 1. [Requirements: software](#requirements-software) 30 | 2. [Installation](#installation) 31 | 3. [Training](#Training) 32 | 4. [Testing](#Testing) 33 | 34 | ### Requirements: software 35 | 36 | - Pytorch (version=0.4.0) 37 | - python 2.7 38 | - numpy 39 | - scipy 40 | - magic 41 | - easydict 42 | - dill 43 | - matplotlib 44 | - tensorboardX 45 | 46 | 47 | ### Installation 48 | 49 | 1. Clone the WSSTG repository and VID-sentence reposity 50 | 51 | ```Shell 52 | git clone https://github.com/JeffCHEN2017/WSSTG.git 53 | git clone https://github.com/JeffCHEN2017/VID-Sentence.git 54 | ln -s VID-sentence_ROOT/data/ILSVRC WSSTG_ROOT/data 55 | ``` 56 | 2. Download [tube proposals](https://drive.google.com/file/d/1SHwXtlb7V8PH4_60-0-VZYL-7kXEG_Wj/view?usp=sharing), [RGB feature](https://drive.google.com/file/d/1ll_AkiByvQsJTdPNVt1TH6BUPbQxjvG_/view?usp=sharing) and [I3D feature](https://drive.google.com/file/d/1SwPGweipeuREZrAXGzu7nPACy9vXNmmp/view?usp=sharing) from Google Drive. 57 | 58 | 3. Extract *.tar files and make symlinks between the download data and the desired data folder 59 | 60 | ```Shell 61 | tar xvf tubePrp.tar 62 | ln -s tubePrp $WSSTG_ROOT/data/tubePrp 63 | 64 | tar xvf vid_i3d.tar vid_i3d 65 | ln -s vid_i3d $WSSTG_ROOT/data/vid_i3d 66 | ln -s $WSSTG_ROOT/data/vid_i3d/val test 67 | 68 | tar xvf vid_rgb.tar vid_rgb 69 | ln -s vid_rgb $WSSTG_ROOT/data/vid_rgb 70 | ln -s $WSSTG_ROOT/data/vid_rgb/vidTubeCacheFtr/val test 71 | ``` 72 | 73 | Note: We extract the tube proposals using the method proposed by [Gkioxari and Malik](https://arxiv.org/abs/1411.6031) .A python implementation [here](ttps://www.mi.t.u-tokyo.ac.jp/projects/person_search/) is provided by Yamaguchi etal.. 74 | We extract singel-frame propsoals and RGB feature for each frame using a [faster-RCNN](https://arxiv.org/abs/1506.01497) model pretrained on COCO dataset, which is provided by [Jianwei Yang](https://github.com/jwyang/faster-rcnn.pytorch). 75 | We extract [I3D-RGB and I3D-flow features](https://arxiv.org/abs/1705.07750) using the model provided by [Carreira and Zisserman](https://github.com/deepmind/kinetics-i3d.git). 76 | 77 | 78 | ### Training 79 | ```Shell 80 | cd $WSSTG_ROOT 81 | sh scripts/train_video_emb_att.sh 82 | ``` 83 | Notice: Because the changes of batch sizes and the random seed, the performance may be slightly different from our submission. We provide a checkpoint here which achieves similar performance (38.1 VS 38.2 on the accuracy@0.5 ) to the model we reported in the paper. 84 | 85 | ### Testing 86 | Download the checkpoint from [Google Drive](https://drive.google.com/file/d/1oM0J4jIbcd4SYo9T29ydk3gugoOjFCKA/view?usp=sharing), put it in WSSTG_ROOT/data/models and run 87 | ```Shell 88 | cd $WSSTG_ROOT 89 | sh scripts/test_video_emb_att.sh 90 | ``` 91 | 92 | ### License 93 | 94 | WSSTG is released under the CC-BY-NC 4.0 LICENSE (refer to the LICENSE file for details). 95 | 96 | ### Citing WSSTG 97 | 98 | If you find this repo useful in your research, please consider citing: 99 | 100 | @inproceedings{chen2019weakly, 101 | Title={Weakly-Supervised Spatio-Temporally Grounding Natural Sentence in Video}, 102 | Author={Chen, Zhenfang and Ma, Lin and Luo, Wenhan and Wong, Kwan-Yee~K}, 103 | Booktitle={ACL}, 104 | year={2019} 105 | } 106 | 107 | ### Contact 108 | 109 | You can contact Zhenfang Chen by sending email to chenzhenfang2013@gmail.com 110 | -------------------------------------------------------------------------------- /annotations/ptd_api.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import os 4 | from collections import defaultdict 5 | from utils import * 6 | import pdb 7 | 8 | ROOT_DIR = os.path.dirname(os.path.abspath(os.path.join(os.getcwd(), __file__))) 9 | DESCS_JSON = os.path.join(ROOT_DIR, 'data', '%s_descriptions.json') 10 | PEOPLE_JSON = os.path.join(ROOT_DIR, 'data', '%s_people.json') 11 | SHOTS_JSON = os.path.join(ROOT_DIR, 'data', '%s_shots.json') 12 | if os.path.exists(os.path.join(ROOT_DIR, 'data', 'video2fps.json')): 13 | VID2FPS_JSON = os.path.join(ROOT_DIR, 'data', 'video2fps.json') 14 | else: 15 | VID2FPS_JSON = os.path.join(ROOT_DIR, 'data', 'video2fps.json.example') 16 | DS = ['train', 'val', 'test'] 17 | AN_JSON = os.path.join(ROOT_DIR, 'activity_net.v1-3.min.json') 18 | 19 | class ActivityNet(object): 20 | _data = None 21 | _video_ids = None 22 | def __init__(self): 23 | if not os.path.exists(AN_JSON): 24 | raise Exception('Download the annotation file of ActivityNet v1.3, and place it to "%s".' % AN_JSON) 25 | self._data = jsonload(AN_JSON) 26 | assert self._data['version'] == 'VERSION 1.3' 27 | self._video_ids = sorted(self._data['database'].keys()) 28 | 29 | @property 30 | def video_ids(self): 31 | return self._video_ids 32 | 33 | @property 34 | def data(self): 35 | return self._data 36 | 37 | class PTD(object): 38 | _descriptions = None 39 | _shots = None 40 | _people = None 41 | _person2desc = None 42 | _desc2person = None 43 | _person2shot = None 44 | _shot2person = None 45 | _id2shot = None 46 | _id2person = None 47 | _id2desc = None 48 | _vid2fps = None 49 | _an_data = None 50 | 51 | def __init__(self, dataset): 52 | assert dataset in DS 53 | self._dataset = dataset 54 | 55 | def description(self, index): 56 | return Description(self.id2desc[index], self) 57 | 58 | def person(self, index): 59 | return Person(self.id2person[index], self) 60 | 61 | def shot(self, index): 62 | return Shot(self.id2shot[index], self) 63 | 64 | @property 65 | def descriptions(self): 66 | if self._descriptions is None: 67 | self._descriptions = jsonload(DESCS_JSON % self._dataset) 68 | return self._descriptions 69 | 70 | @property 71 | def id2desc(self): 72 | if self._id2desc is None: 73 | self._id2desc = {} 74 | for desc in self.descriptions: 75 | self._id2desc[desc['id']] = desc 76 | return self._id2desc 77 | 78 | @property 79 | def people(self): 80 | if self._people is None: 81 | self._people = jsonload(PEOPLE_JSON % self._dataset) 82 | return self._people 83 | 84 | @property 85 | def id2person(self): 86 | if self._id2person is None: 87 | self._id2person = {} 88 | for person in self.people: 89 | self._id2person[person['id']] = person 90 | return self._id2person 91 | 92 | @property 93 | def shots(self): 94 | if self._shots is None: 95 | self._shots = jsonload(SHOTS_JSON % self._dataset) 96 | return self._shots 97 | 98 | @property 99 | def id2shot(self): 100 | if self._id2shot is None: 101 | self._id2shot = {} 102 | for shot in self.shots: 103 | self._id2shot[shot['id']] = shot 104 | return self._id2shot 105 | 106 | ### id2id ### 107 | @property 108 | def person2desc(self): 109 | if self._person2desc is None: 110 | self._person2desc = {} 111 | for person in self.people: 112 | self._person2desc[person['id']] = person['descriptions'] 113 | return self._person2desc 114 | 115 | @property 116 | def desc2person(self): 117 | if self._desc2person is None: 118 | self._desc2person = {} 119 | for person in self.people: 120 | for desc_id in person['descriptions']: 121 | self._desc2person[desc_id] = person['id'] 122 | return self._desc2person 123 | 124 | @property 125 | def person2shot(self): 126 | if self._person2shot is None: 127 | self._person2shot = {} 128 | for person in self.people: 129 | self._person2shot[person['id']] = person['shot_id'] 130 | return self._person2shot 131 | 132 | @property 133 | def shot2person(self): 134 | if self._shot2person is None: 135 | self._shot2person = defaultdict(list) 136 | for person in self.people: 137 | self._shot2person[person['shot_id']].append(person['id']) 138 | return self._shot2person 139 | 140 | ### others 141 | @property 142 | def vid2fps(self): 143 | if self._vid2fps is None: 144 | self._vid2fps = jsonload(VID2FPS_JSON) 145 | return self._vid2fps 146 | 147 | @property 148 | def an_data(self): 149 | if self._an_data is None: 150 | self._an_data = ActivityNet() 151 | return self._an_data 152 | 153 | class BaseInstance(object): 154 | def __init__(self, info, parent): 155 | self._info = info 156 | self._parent = parent 157 | 158 | def __getitem__(self, index): 159 | return self._info[index] 160 | 161 | @property 162 | def id(self): 163 | return self._info['id'] 164 | 165 | class Description(BaseInstance): 166 | @property 167 | def description(self): 168 | return self._info['description'] 169 | 170 | @property 171 | def person(self): 172 | person_id = self._parent.desc2person[self._info['id']] 173 | return self._parent.person(person_id) 174 | 175 | @property 176 | def shot(self): 177 | return self.person.shot 178 | 179 | class Shot(BaseInstance): 180 | @property 181 | def video_id(self): 182 | return self._parent.an_data.video_ids[self._info['an_video_id']] 183 | 184 | @property 185 | def first_second(self): 186 | return min(self.annotated_seconds) 187 | 188 | @property 189 | def last_second(self): 190 | return max(self.annotated_seconds) 191 | 192 | @property 193 | def annotated_seconds(self): 194 | return self._info['annotated_seconds'] 195 | 196 | @property 197 | def first_frame(self): 198 | return self.sec2frame(self.first_second) 199 | 200 | @property 201 | def last_frame(self): 202 | if self.video_id == 'll91M5topgU': 203 | return 161 204 | if self.video_id == 'uOmCwWVJnLQ': 205 | return 2999 206 | return self.sec2frame(self.last_second) 207 | 208 | @property 209 | def annotated_frames(self): 210 | return [self.sec2frame(sec) for sec in self._info['annotated_seconds']] 211 | 212 | @property 213 | def fully_annotated(self): 214 | return self._info['fully_annotated'] 215 | 216 | @property 217 | def people(self): 218 | person_ids = self._parent.shot2person[self._info['id']] 219 | return [self._parent.person(person_id) for person_id in person_ids] 220 | 221 | @property 222 | def descriptions(self): 223 | descs = [] 224 | for person in self.people: 225 | descs += person.descriptions 226 | return descs 227 | 228 | @property 229 | def fps(self): 230 | return self._parent.vid2fps[self.video_id] 231 | 232 | def sec2frame(self, sec): 233 | return int(round(sec * self.fps)) 234 | 235 | class Person(BaseInstance): 236 | @property 237 | def boxes(self): 238 | return self._info['boxes'] 239 | 240 | @property 241 | def descriptions(self): 242 | return [self._parent.description(i) for i in self._info['descriptions']] 243 | 244 | @property 245 | def shot(self): 246 | shot_id = self._parent.person2shot[self._info['id']] 247 | return self._parent.shot(shot_id) 248 | 249 | def demo(): 250 | ptd = PTD('test') 251 | print 'Showing information of the shot of which ID is 1...' 252 | shot = ptd.shot(1) 253 | print '[SHOT] ID: %d' % shot.id 254 | print '[SHOT] VIDEO URL: %s' % shot.video_id 255 | print '[SHOT] START TIME: %s' % shot.first_second 256 | print '[SHOT] END TIME: %s' % shot.last_second 257 | print 258 | 259 | print 'Showing information of the person of which ID is 1...' 260 | person = ptd.person(1) 261 | print '[PERSON] ID: %d' % person.id 262 | print '[PERSON] PARENT SHOT ID: %s' % person.shot.id 263 | print '[PERSON] DESCRIPTIONS: %s' % ', '.join(['"%s"' % d.description for d in person.descriptions]) 264 | #print 265 | 266 | print 'Showing information of the description of which ID is 1...' 267 | description = ptd.description(1) 268 | print '[DESCRIPTION] ID: %d' % description.id 269 | print '[DESCRIPTION] PARENT PERSON ID: %d' % description.person.id 270 | print '[DESCRIPTION] DESCRIPTION: %s' % description.description 271 | 272 | if __name__ == '__main__': 273 | demo() 274 | -------------------------------------------------------------------------------- /annotations/script_test_annotation.py: -------------------------------------------------------------------------------- 1 | from ptd_api import * 2 | import cv2 3 | import copy 4 | import numpy as np 5 | import shutil 6 | import commands 7 | import sys 8 | sys.path.append('../') 9 | from util.mytoolbox import * 10 | 11 | TMP_DIR = '.tmp' 12 | FFMPEG = 'ffmpeg' 13 | SAVE_VIDEO = FFMPEG + ' -y -r %d -i %s/%s.jpg %s' 14 | 15 | def draw_rectangle(img, bbox, color=(0,0,255), thickness=3): 16 | img = imread_if_str(img) 17 | if isinstance(bbox, dict): 18 | bbox = [ 19 | bbox['x1'], 20 | bbox['y1'], 21 | bbox['x2'], 22 | bbox['y2'], 23 | ] 24 | assert bbox[2] >= bbox[0] 25 | assert bbox[3] >= bbox[1] 26 | assert bbox[0] >= 0 27 | assert bbox[1] >= 0 28 | assert bbox[2] <= img.shape[1] 29 | assert bbox[3] <= img.shape[0] 30 | cur_img = copy.deepcopy(img) 31 | cv2.rectangle( 32 | cur_img, 33 | (int(bbox[0]), int(bbox[1])), 34 | (int(bbox[2]), int(bbox[3])), 35 | color, 36 | thickness) 37 | return cur_img 38 | 39 | def images2video(image_list, frame_rate, video_path, max_edge=None): 40 | if os.path.exists(TMP_DIR): 41 | shutil.rmtree(TMP_DIR) 42 | os.mkdir(TMP_DIR) 43 | img_size = None 44 | for cur_num, cur_img in enumerate(image_list): 45 | cur_fname = os.path.join(TMP_DIR, '%08d.jpg' % cur_num) 46 | if max_edge is not None: 47 | cur_img = imread_if_str(cur_img) 48 | if isinstance(cur_img, str) or isinstance(cur_img, unicode): 49 | shutil.copyfile(cur_img, cur_fname) 50 | elif isinstance(cur_img, np.ndarray): 51 | max_len = max(cur_img.shape[:2]) 52 | if max_len > max_edge and img_size is None and max_edge is not None: 53 | magnif = float(max_edge) / float(max_len) 54 | img_size = (int(cur_img.shape[1] * magnif), int(cur_img.shape[0] * magnif)) 55 | cur_img = cv2.resize(cur_img, img_size) 56 | elif max_edge is not None: 57 | if img_size is None: 58 | magnif = float(max_edge) / float(max_len) 59 | img_size = (int(cur_img.shape[1] * magnif), int(cur_img.shape[0] * magnif)) 60 | cur_img = cv2.resize(cur_img, img_size) 61 | cv2.imwrite(cur_fname, cur_img) 62 | else: 63 | NotImplementedError() 64 | print commands.getoutput(SAVE_VIDEO % (frame_rate, TMP_DIR, '%08d', video_path)) 65 | shutil.rmtree(TMP_DIR) 66 | 67 | def video2shot(vdName, ptd_list = {}): 68 | setNameList = ['train', 'val', 'test'] 69 | shotInfo = list() 70 | for set_name in setNameList: 71 | if set_name in ptd_list.keys(): 72 | ptd = ptd_list[set_name] 73 | else: 74 | ptd = PTD(set_name) 75 | shotLgh = len(ptd.id2shot) 76 | for i in range(shotLgh): 77 | vdNameShot = ptd.shot(i+1).video_id 78 | if(vdName==vdNameShot): 79 | shotInfo.append([set_name, i+1]) 80 | return shotInfo 81 | 82 | def imread_if_str(img): 83 | if isinstance(img, basestring): 84 | img = cv2.imread(img) 85 | return img 86 | 87 | def get_shot_frames(shot): 88 | annFrmSt = shot.first_frame 89 | annFrmLs = shot.last_frame 90 | tmpFrmList = list(range(annFrmSt, annFrmLs+1)) 91 | frmList = list() 92 | for i, frmIdx in enumerate(tmpFrmList): 93 | strIdx = '%05d' %(frmIdx) 94 | frmList.append(strIdx) 95 | return frmList 96 | 97 | def get_shot_frames_full_path(shot, preFd, fn_ext='.jpg'): 98 | vd_name = shot.video_id 99 | vd_subPath = os.path.join(preFd, 'v_' + vd_name) 100 | framList = get_shot_frames(shot) 101 | framListFull = list() 102 | for frm in framList: 103 | tmpFrm = os.path.join(vd_subPath, frm + fn_ext) 104 | framListFull.append(tmpFrm) 105 | return framListFull 106 | 107 | # shot_proposals : [tubes, frameList] 108 | # 109 | def evaluate_tube_recall(shot_proposals, shot, person_in_shot, thre=0.5 ,topKOri=20): 110 | # pdb.set_trace() 111 | topK = min(topKOri, len(shot_proposals[0][0])) 112 | recall_k = [0.0] * (topK + 1) 113 | boxes = {} 114 | for frame_ind, box in zip(shot.annotated_frames, person_in_shot['boxes']): 115 | keyName = '%05d' %(frame_ind) 116 | boxes[keyName] = box 117 | 118 | #pdb.set_trace() 119 | tube_list, frame_list = shot_proposals 120 | assert(len(tube_list[0][0])== len(frame_list)) 121 | is_person_annotated = False 122 | for i in range(topK): 123 | recall_k[i+1] = recall_k[i] 124 | if is_person_annotated: 125 | continue 126 | curTubeOri = tube_list[0][i] 127 | tube_key_bbxList = {} 128 | for frame_ind, gt_box in boxes.iteritems(): 129 | try: 130 | index_tmp = frame_list.index(frame_ind) 131 | tube_key_bbxList[frame_ind] = curTubeOri[index_tmp] 132 | except: 133 | print('key %s do not exist in shot' %(frame_ind)) 134 | #pdb.set_trace() 135 | ol = compute_LS(tube_key_bbxList, boxes) 136 | if ol < thre: 137 | continue 138 | else: 139 | recall_k[i+1] += 1.0 140 | is_person_annotated = True 141 | return recall_k 142 | 143 | def vis_ann_tube(): 144 | pngFolder ='/data1/zfchen/data/actNet/actNetJpgs' 145 | annFolder ='/data1/zfchen/data/actNet/actNetAnn' 146 | #vdName = '2Peh_gdQCjg' 147 | #vdName = 'jqKK2KH6l4Q' 148 | #vdName = 'W4tmb8RwzQM' 149 | vdName = 'ydRycaBjMVw' 150 | shotList = video2shot(vdName) 151 | 152 | pdb.set_trace() 153 | 154 | txtFn = os.path.join(annFolder, 'v_'+vdName + '.txt') 155 | fH = open(txtFn) 156 | vdInfoStr = fH.readlines()[0] 157 | vdInfo = [float(ele) for ele in vdInfoStr.split(' ')] 158 | insNum = 1 159 | 160 | for i, shotInfo in enumerate(shotList): 161 | set_name, shotId = shotInfo 162 | ptd = PTD(set_name) 163 | shot = ptd.shot(shotId) 164 | annFtrList = shot.annotated_frames 165 | for ii, person in enumerate(shot.people): 166 | imgList = list() 167 | print(person.descriptions[0].description) 168 | for iii in range(len(annFtrList)): 169 | frmName = '%05d' %(annFtrList[iii]+1) 170 | imFn = pngFolder + '/v_' + shot.video_id + '/' + frmName + '.jpg' 171 | img = cv2.imread(imFn) 172 | h, w, c = img.shape 173 | bbox = person.boxes[iii] 174 | #pdb.set_trace() 175 | bbox[0] = bbox[0]*w 176 | bbox[2] = bbox[2]*w 177 | bbox[1] = bbox[1]*h 178 | bbox[3] = bbox[3]*h 179 | #pdb.set_trace() 180 | imgNew = draw_rectangle(img, bbox, color=(0,0,255), thickness=3) 181 | imgList.append(imgNew) 182 | images2video(imgList, 10, './' + str(insNum)+'.gif') 183 | insNum +=1 184 | 185 | def extract_instance_caption_list(): 186 | set_list = ['train', 'test', 'val'] 187 | outPre = './data/cap_list_' 188 | for i, set_name in enumerate(set_list): 189 | ptd = PTD(set_name) 190 | outFn = outPre + set_name + '.txt' 191 | desption_list_dict = ptd.descriptions 192 | description_list = [ des_dict['description'] for des_dict in desption_list_dict] 193 | textdump(outFn, description_list) 194 | 195 | if __name__=='__main__': 196 | extract_instance_caption_list() 197 | -------------------------------------------------------------------------------- /annotations/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | 4 | import json 5 | import os 6 | import pdb 7 | 8 | EPS = 1e-10 9 | 10 | 11 | def compute_IoU(bbox1, bbox2): 12 | bbox1_area = float((bbox1[2] - bbox1[0] + EPS) * (bbox1[3] - bbox1[1] + EPS)) 13 | bbox2_area = float((bbox2[2] - bbox2[0] + EPS) * (bbox2[3] - bbox2[1] + EPS)) 14 | w = max(0.0, min(bbox1[2], bbox2[2]) - max(bbox1[0], bbox2[0]) + EPS) 15 | h = max(0.0, min(bbox1[3], bbox2[3]) - max(bbox1[1], bbox2[1]) + EPS) 16 | inter = float(w * h) 17 | ovr = inter / (bbox1_area + bbox2_area - inter) 18 | return ovr 19 | 20 | def is_annotated(traj, frame_ind): 21 | if not frame_ind in traj: 22 | return False 23 | box = traj[frame_ind] 24 | if box[0] < 0: 25 | #for el_val in box[1:]: 26 | #assert el_val < 0 27 | return False 28 | #for el_val in box[1:]: 29 | # assert el_val >= 0 30 | return True 31 | 32 | def compute_LS(traj, gt_traj): 33 | # see http://jvgemert.github.io/pub/jain-tubelets-cvpr2014.pdf 34 | assert isinstance(traj.keys()[0], type(gt_traj.keys()[0])) 35 | IoU_list = [] 36 | for frame_ind, gt_box in gt_traj.iteritems(): 37 | # make sure the gt_box is within valid 38 | 39 | gt_is_annotated = is_annotated(gt_traj, frame_ind) 40 | pr_is_annotated = is_annotated(traj, frame_ind) 41 | if (not gt_is_annotated) and (not pr_is_annotated): 42 | continue 43 | if (not gt_is_annotated) or (not pr_is_annotated): 44 | IoU_list.append(0.0) 45 | continue 46 | box = traj[frame_ind] 47 | IoU_list.append(compute_IoU(box, gt_box)) 48 | return sum(IoU_list) / len(IoU_list) 49 | 50 | def jsonload(path): 51 | f = open(path) 52 | json_data = json.load(f) 53 | f.close() 54 | return json_data 55 | 56 | def get_abs_path(): 57 | return os.path.dirname(os.path.abspath(os.path.join(os.getcwd(), __file__))) 58 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zfchenUnique/WSSTG/ba19919b5010321651088afaecba877aebb4ed3a/data/__init__.py -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | def creatDataloader(opts): 2 | from data.custom_data_loader import customDataLoader 3 | data_loader = customDataLoader() 4 | data_loader.initialize(opt) 5 | return data_loader 6 | -------------------------------------------------------------------------------- /fun/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zfchenUnique/WSSTG/ba19919b5010321651088afaecba877aebb4ed3a/fun/__init__.py -------------------------------------------------------------------------------- /fun/classSST.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | import ipdb 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.autograd import * 14 | 15 | import opts 16 | import pdb 17 | 18 | 19 | class SST(nn.Module): 20 | def __init__(self, opt): 21 | super(SST, self).__init__() 22 | 23 | self.word_cnt = opt.word_cnt 24 | self.fc_feat_size = opt.fc_feat_size 25 | self.video_embedding_size = opt.video_embedding_size 26 | self.word_embedding_size = opt.word_embedding_size 27 | self.lstm_hidden_size = opt.lstm_hidden_size 28 | self.video_time_step = opt.video_time_step 29 | self.caption_time_step = opt.caption_time_step 30 | self.dropout_prob = opt.dropout_prob 31 | self.att_hidden_size = opt.att_hidden_size 32 | 33 | self.video_embedding = nn.Linear(self.fc_feat_size, self.video_embedding_size) 34 | 35 | self.lstm_video = torch.nn.LSTMCell(self.video_embedding_size, self.lstm_hidden_size, bias=True) 36 | self.lstm_caption = torch.nn.LSTMCell(self.word_embedding_size, self.lstm_hidden_size, bias=True) 37 | 38 | 39 | self.vid_linear = nn.Linear(self.lstm_hidden_size, self.att_hidden_size) 40 | self.sen_linear = nn.Linear(self.lstm_hidden_size, self.att_hidden_size) 41 | 42 | self.att_linear = nn.Linear(self.att_hidden_size, 1) 43 | 44 | self.word_embedding = nn.Linear(300, self.word_embedding_size) 45 | self.init_weights() 46 | 47 | def init_weights(self): 48 | initrange = 0.1 49 | self.video_embedding.weight.data.uniform_(-initrange, initrange) 50 | self.video_embedding.bias.data.fill_(0) 51 | 52 | self.word_embedding.weight.data.uniform_(-initrange, initrange) 53 | self.word_embedding.bias.data.fill_(0) 54 | 55 | self.vid_linear.weight.data.uniform_(-initrange, initrange) 56 | self.vid_linear.bias.data.fill_(0) 57 | 58 | self.sen_linear.weight.data.uniform_(-initrange, initrange) 59 | self.sen_linear.bias.data.fill_(0) 60 | 61 | self.att_linear.weight.data.uniform_(-initrange, initrange) 62 | self.att_linear.bias.data.fill_(0) 63 | 64 | 65 | def init_hidden(self, batch_size): 66 | weight = next(self.parameters()).data 67 | init_h = Variable(weight.new(1, batch_size, self.lstm_hidden_size).zero_()) 68 | init_c = Variable(weight.new(1, batch_size, self.lstm_hidden_size).zero_()) 69 | init_state = (init_h, init_c) 70 | 71 | return init_state 72 | 73 | def init_hidden_new(self, batch_size): 74 | weight = next(self.parameters()).data 75 | init_h = Variable(weight.new(batch_size, self.lstm_hidden_size).zero_()) 76 | init_c = Variable(weight.new(batch_size, self.lstm_hidden_size).zero_()) 77 | init_state = (init_h, init_c) 78 | 79 | return init_state 80 | 81 | def forward(self, video_fc_feat, video_caption, cap_length_list=None): 82 | if self.training: 83 | return self.forward_training(video_fc_feat, video_caption, cap_length_list) 84 | else: 85 | return self.forward_val(video_fc_feat, video_caption, cap_length_list) 86 | 87 | # video_fc_feats: batch * encode_time_step * fc_feat_size 88 | def forward_training(self, video_fc_feat, video_caption, cap_length_list=None): 89 | #pdb.set_trace() 90 | batch_size = video_fc_feat.size(0) 91 | batch_size_caption = video_caption.size(0) 92 | 93 | video_state = self.init_hidden_new(batch_size) 94 | caption_state = self.init_hidden_new(batch_size_caption) 95 | 96 | caption_outputs = [] 97 | caption_time_step = video_caption.size(1) 98 | for i in range(caption_time_step): 99 | word = video_caption[:, i].clone() 100 | #if video_caption[:, i].data.sum() == 0: 101 | # break 102 | #import ipdb 103 | #pdb.set_trace() 104 | caption_xt = self.word_embedding(word) 105 | #caption_output, caption_state = self.lstm_caption.forward(caption_xt, caption_state) 106 | caption_output, caption_state = self.lstm_caption.forward(caption_xt, caption_state) 107 | caption_outputs.append(caption_output) 108 | caption_state = (caption_output, caption_state) 109 | # caption_outputs: batch * caption_time_step * lstm_hidden_size 110 | caption_outputs = torch.cat([_.unsqueeze(1) for _ in caption_outputs], 1).contiguous() 111 | 112 | # LSTM encoding 113 | video_outputs = [] 114 | for i in range(self.video_time_step): 115 | video_xt = self.video_embedding(video_fc_feat[:, i, :]) 116 | video_output, video_state = self.lstm_video.forward(video_xt, video_state) 117 | video_outputs.append(video_output) 118 | video_state = (video_output, video_state) 119 | # video_outputs: batch * video_time_step * lstm_hidden_size 120 | video_outputs = torch.cat([_.unsqueeze(1) for _ in video_outputs], 1).contiguous() 121 | 122 | # soft attention for caption based on each video 123 | output_list = list() 124 | for i in range(self.video_time_step): 125 | # part 1 126 | video_outputs_linear = self.vid_linear(video_outputs[:, i, :]) 127 | video_outputs_linear_expand = video_outputs_linear.expand(caption_outputs.size(1), video_outputs_linear.size(0), 128 | video_outputs_linear.size(1)).transpose(0, 1) 129 | 130 | # part 2 131 | caption_outputs_flatten = caption_outputs.view(-1, self.lstm_hidden_size) 132 | caption_outputs_linear = self.sen_linear(caption_outputs_flatten) 133 | caption_outputs_linear = caption_outputs_linear.view(batch_size_caption, caption_outputs.size(1), self.att_hidden_size) 134 | 135 | # part 1 and part 2 attention 136 | sig_probs = [] 137 | for cap_id in range(batch_size_caption): 138 | #pdb.set_trace() 139 | cap_length = max(cap_length_list[cap_id], 1) 140 | caption_output_linear_cap_id = caption_outputs_linear[cap_id, : cap_length, :] 141 | video_outputs_linear_expand_clip = video_outputs_linear_expand[:, :cap_length, :] 142 | caption_outputs_linear_cap_id_exp = caption_output_linear_cap_id.expand_as(video_outputs_linear_expand_clip) 143 | video_caption = F.tanh(video_outputs_linear_expand_clip \ 144 | + caption_outputs_linear_cap_id_exp) 145 | 146 | video_caption_view = video_caption.contiguous().view(-1, self.att_hidden_size) 147 | video_caption_out = self.att_linear(video_caption_view) 148 | video_caption_out_view = video_caption_out.view(-1, cap_length) 149 | atten_weights = nn.Softmax(dim=1)(video_caption_out_view).unsqueeze(2) 150 | 151 | caption_output_cap_id = caption_outputs[cap_id, : cap_length, :] 152 | caption_output_cap_id_exp = caption_output_cap_id.expand(batch_size,\ 153 | caption_output_cap_id.size(0), caption_output_cap_id.size(1)) 154 | atten_caption = torch.bmm(caption_output_cap_id_exp.transpose(1, 2), atten_weights).squeeze(2) 155 | 156 | video_caption_hidden = torch.cat((atten_caption, video_outputs[:, i, :]), dim=1) 157 | cos = nn.CosineSimilarity(dim=1, eps=1e-6) 158 | cur_probs = cos(atten_caption, video_outputs[: ,i, :]).unsqueeze(1) 159 | 160 | 161 | sig_probs.append(cur_probs) 162 | 163 | sig_probs = torch.cat([_ for _ in sig_probs], 1).contiguous() 164 | output_list.append(sig_probs) 165 | simMM = torch.stack(output_list, dim=2).mean(2) 166 | return simMM 167 | 168 | # video_fc_feats: batch * encode_time_step * fc_feat_size 169 | def forward_val(self, video_fc_feat, video_caption, cap_length_list=None): 170 | #pdb.set_trace() 171 | batch_size = video_fc_feat.size(0) 172 | batch_size_caption = video_caption.size(0) 173 | num_tube_per_video = int(batch_size / batch_size_caption) 174 | 175 | video_state = self.init_hidden_new(batch_size) 176 | caption_state = self.init_hidden_new(batch_size_caption) 177 | 178 | caption_outputs = [] 179 | caption_time_step = video_caption.size(1) 180 | for i in range(caption_time_step): 181 | word = video_caption[:, i].clone() 182 | #if video_caption[:, i].data.sum() == 0: 183 | # break 184 | #import ipdb 185 | #pdb.set_trace() 186 | caption_xt = self.word_embedding(word) 187 | #caption_output, caption_state = self.lstm_caption.forward(caption_xt, caption_state) 188 | caption_output, caption_state = self.lstm_caption.forward(caption_xt, caption_state) 189 | caption_outputs.append(caption_output) 190 | caption_state = (caption_output, caption_state) 191 | # caption_outputs: batch * caption_time_step * lstm_hidden_size 192 | caption_outputs = torch.cat([_.unsqueeze(1) for _ in caption_outputs], 1).contiguous() 193 | 194 | # LSTM encoding 195 | video_outputs = [] 196 | for i in range(self.video_time_step): 197 | video_xt = self.video_embedding(video_fc_feat[:, i, :]) 198 | video_output, video_state = self.lstm_video.forward(video_xt, video_state) 199 | video_outputs.append(video_output) 200 | video_state = (video_output, video_state) 201 | # video_outputs: batch * video_time_step * lstm_hidden_size 202 | video_outputs = torch.cat([_.unsqueeze(1) for _ in video_outputs], 1).contiguous() 203 | 204 | # batch_size_caption * num_word * lstm_hidden_size 205 | caption_outputs_linear = self.sen_linear(caption_outputs) 206 | 207 | # soft attention for caption based on each video 208 | output_list = list() 209 | for i in range(self.video_time_step): 210 | # batch_size * lstm_hidden_size 211 | video_outputs_linear = self.vid_linear(video_outputs[:, i, :]) 212 | # batch_size * num_word * lstm_hidden_size 213 | video_outputs_linear_expand = video_outputs_linear.expand(caption_outputs.size(1), video_outputs_linear.size(0), 214 | video_outputs_linear.size(1)).transpose(0, 1) 215 | 216 | # part 1 and part 2 attention 217 | sig_probs = [] 218 | 219 | for cap_id in range(batch_size_caption): 220 | 221 | tube_id_st = cap_id*num_tube_per_video 222 | tube_id_ed = (cap_id+1)*num_tube_per_video 223 | 224 | cap_length = max(cap_length_list[cap_id], 1) 225 | # num_tube_per_video * cap_length * lstm_hidden_size 226 | tube_outputs_aligned = video_outputs_linear_expand[tube_id_st:tube_id_ed, : cap_length, :] 227 | # cap_length * lstm_hidden_size 228 | caption_output_linear_cap_id = caption_outputs_linear[cap_id, : cap_length, :] 229 | # num_tube_per_video * cap_length * lstm_hidden_size 230 | caption_outputs_linear_cap_id_exp = caption_output_linear_cap_id.expand_as(tube_outputs_aligned) 231 | # num_tube_per_video * cap_length * lstm_hidden_size 232 | video_caption = F.tanh(tube_outputs_aligned \ 233 | + caption_outputs_linear_cap_id_exp) 234 | 235 | video_caption_view = video_caption.contiguous().view(-1, self.att_hidden_size) 236 | video_caption_out = self.att_linear(video_caption_view) 237 | video_caption_out_view = video_caption_out.view(-1, cap_length) 238 | 239 | # num_tube_per_video * cap_length 240 | atten_weights = nn.Softmax(dim=1)(video_caption_out_view).unsqueeze(2) 241 | 242 | # cap_length * lstm_hidden_size 243 | caption_output_cap_id = caption_outputs[cap_id, : cap_length, :] 244 | # num_tube_per_video * cap_length * lstm_hidden_size 245 | caption_output_cap_id_exp = caption_output_cap_id.expand(num_tube_per_video,\ 246 | caption_output_cap_id.size(0), caption_output_cap_id.size(1)) 247 | # num_tube_per_video * lstm_hidden_size 248 | atten_caption = torch.bmm(caption_output_cap_id_exp.transpose(1, 2), atten_weights).squeeze(2) 249 | cos = nn.CosineSimilarity(dim=1, eps=1e-6) 250 | # num_tube_per_video 251 | cur_probs = cos(atten_caption, video_outputs[tube_id_st:tube_id_ed, i, :]).unsqueeze(1) 252 | # num_tube_per_video 253 | sig_probs.append(cur_probs) 254 | 255 | # num_tube_per_video * batch_size_caption 256 | sig_probs = torch.cat([_ for _ in sig_probs], 1).contiguous() 257 | output_list.append(sig_probs) 258 | # num_tube_per_video * batch_size_caption 259 | simMM = torch.stack(output_list, dim=2).mean(2) 260 | # batch_size_caption * num_tube_per_video 261 | simMM = simMM.transpose(0, 1) 262 | #pdb.set_trace() 263 | return simMM 264 | 265 | -------------------------------------------------------------------------------- /fun/create_word2vec_for_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('../') 4 | sys.path.append('../util') 5 | 6 | from gensim.models import KeyedVectors 7 | from datasetParser import * 8 | import numpy as np 9 | from mytoolbox import pickledump, set_debugger 10 | from util.base_parser import BaseParser 11 | from vidDatasetParser import * 12 | 13 | set_debugger() 14 | 15 | class dictParser(BaseParser): 16 | def __init__(self, *arg_list, **arg_dict): 17 | super(dictParser, self).__init__(*arg_list, **arg_dict) 18 | self.add_argument('--dictOutPath', default='../data/dictForDb', type=str) 19 | self.add_argument('--annoteFd', default='/disk2/zfchen/data/OTB_sentences', type=str) 20 | self.add_argument('--dictPath', default='/disk2/zfchen/data/dict/GoogleNews-vectors-negative300.bin', type=str) 21 | self.add_argument('--setName', default='otb', type=str) 22 | self.add_argument('--setOutPath', default='../data/annForDb', type=str) 23 | self.add_argument('--annFn', default='', type=str) 24 | self.add_argument('--annFd', default='', type=str) 25 | self.add_argument('--annIgListFn', default='', type=str) 26 | self.add_argument('--annOriFn', default='', type=str) 27 | 28 | 29 | 30 | def parse_args(): 31 | parser = dictParser() 32 | args = parser.parse_args() 33 | return args 34 | 35 | def buildVoc(concaList): 36 | vocList =[] 37 | for capList in concaList: 38 | for subCapList in capList: 39 | for ele in subCapList: 40 | if ele not in vocList: 41 | vocList.append(ele) 42 | word2idx={} 43 | idx2word={} 44 | for i, ele in enumerate(vocList): 45 | word2idx[ele]=i 46 | idx2word[i] =ele 47 | return word2idx, idx2word 48 | 49 | def buildVocA2d(concaList): 50 | vocList =[] 51 | for capList in concaList: 52 | for ele in capList: 53 | if ele not in vocList: 54 | vocList.append(ele) 55 | word2idx={} 56 | idx2word={} 57 | for i, ele in enumerate(vocList): 58 | word2idx[ele]=i 59 | idx2word[i] =ele 60 | return word2idx, idx2word 61 | 62 | def buildVocActNet(vocListOri): 63 | vocList = list() 64 | for ele in vocListOri: 65 | if ele not in vocList: 66 | vocList.append(ele) 67 | word2idx={} 68 | idx2word={} 69 | for i, ele in enumerate(vocList): 70 | word2idx[ele]=i 71 | idx2word[i] =ele 72 | return word2idx, idx2word 73 | 74 | def build_word_vec(word_list, model_word2vec): 75 | matrix_word2vec = [] 76 | igNoreList = list() 77 | for i, word in enumerate(word_list): 78 | print(i, word) 79 | try: 80 | matrix_word2vec.append(model_word2vec[word]) 81 | except: 82 | igNoreList.append(word) 83 | #matrix_word2vec.append(np.zeros((300), dtype=np.float32)) 84 | randArray=np.random.rand((300)).astype('float32') 85 | matrix_word2vec.append(randArray) 86 | try: 87 | print('%s is not the vocaburary'% word) 88 | except: 89 | print('fail to print the word!') 90 | pdb.set_trace() 91 | return matrix_word2vec, igNoreList 92 | 93 | if __name__ == '__main__': 94 | opt = parse_args() 95 | 96 | if opt.setName=='actNet': 97 | print('begin parsing dataset: %s\n' %(opt.setName)) 98 | word_list = build_actNet_word_list() 99 | print(len(word_list)) 100 | pdb.set_trace() 101 | word2idx, idx2word= buildVocActNet(word_list) 102 | model_word2vec = KeyedVectors.load_word2vec_format(opt.dictPath, binary=True) 103 | matrix_word2vec, igNoreList = build_word_vec(word2idx.keys(), model_word2vec) 104 | matrix_word2vec = np.asarray(matrix_word2vec).astype(np.float32) 105 | pdb.set_trace() 106 | 107 | outDict = {'idx2word': idx2word, 'word2idx': word2idx, 'word2vec': matrix_word2vec, 'out_voca': igNoreList} 108 | pickledump(opt.dictOutPath+'_'+opt.setName+'.pd', outDict) 109 | print('Finish constructing dictionary for dataset: %s\n' %(opt.setName)) 110 | 111 | elif opt.setName=='vid': 112 | print('begin parsing dataset: %s\n' %(opt.setName)) 113 | word_list = build_vid_word_list() 114 | print(len(word_list)) 115 | word2idx, idx2word= buildVocActNet(word_list) 116 | model_word2vec = KeyedVectors.load_word2vec_format(opt.dictPath, binary=True) 117 | matrix_word2vec, igNoreList = build_word_vec(word2idx.keys(), model_word2vec) 118 | matrix_word2vec = np.asarray(matrix_word2vec).astype(np.float32) 119 | pdb.set_trace() 120 | 121 | outDict = {'idx2word': idx2word, 'word2idx': word2idx, 'word2vec': matrix_word2vec, 'out_voca': igNoreList} 122 | pickledump(opt.dictOutPath+'_'+opt.setName+'_v2.pd', outDict) 123 | print('Finish constructing dictionary for dataset: %s\n' %(opt.setName)) 124 | 125 | 126 | 127 | 128 | 129 | 130 | -------------------------------------------------------------------------------- /fun/dashed_rect.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | def drawline(img,pt1,pt2,color,thickness=1,style='dotted',gap=10): 4 | dist =((pt1[0]-pt2[0])**2+(pt1[1]-pt2[1])**2)**.5 5 | pts= [] 6 | for i in np.arange(0,dist,gap): 7 | r=i/dist 8 | x=int((pt1[0]*(1-r)+pt2[0]*r)+.5) 9 | y=int((pt1[1]*(1-r)+pt2[1]*r)+.5) 10 | p = (x,y) 11 | pts.append(p) 12 | 13 | if style=='dotted': 14 | for p in pts: 15 | cv2.circle(img,p,thickness,color,-1) 16 | else: 17 | s=pts[0] 18 | e=pts[0] 19 | i=0 20 | for p in pts: 21 | s=e 22 | e=p 23 | if i%2==1: 24 | cv2.line(img,s,e,color,thickness) 25 | i+=1 26 | 27 | def drawpoly(img,pts,color,thickness=1,style='dotted',): 28 | s=pts[0] 29 | e=pts[0] 30 | pts.append(pts.pop(0)) 31 | for p in pts: 32 | s=e 33 | e=p 34 | drawline(img,s,e,color,thickness,style) 35 | 36 | def drawrect(img,pt1,pt2,color,thickness=1,style='dotted'): 37 | pts = [pt1,(pt2[0],pt1[1]),pt2,(pt1[0],pt2[1])] 38 | drawpoly(img,pts,color,thickness,style) 39 | 40 | im = np.zeros((800,800,3),dtype='uint8') 41 | s=(234,222) 42 | e=(500,700) 43 | drawrect(im,s,e,(0,255,255),1,'dotted') 44 | -------------------------------------------------------------------------------- /fun/datasetLoader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('..') 4 | import torch.utils.data as data 5 | import cv2 6 | import numpy as np 7 | from util.mytoolbox import * 8 | import random 9 | import scipy.io as sio 10 | import copy 11 | import torch 12 | from util.get_image_size import get_image_size 13 | from evalDet import * 14 | from datasetParser import extAllFrmFn 15 | import pdb 16 | from netUtil import * 17 | from wsParamParser import parse_args 18 | from vidDataset import * 19 | import h5py 20 | 21 | def build_dataloader(opt): 22 | if opt.dbSet=='vid': 23 | ftrPath = '' # Path for frame level feature 24 | tubePath = '../data/tubePrp' # Path for information of each tube proposals 25 | dictFile = '../data/dictForDb_vid_v2.pd' # Path for word embedding 26 | out_cached_folder = '../data/vid_rgb/vidTubeCacheFtr' # Path for RGB features of tubes 27 | ann_folder = '../data/ILSVRC' # Path for tube annotations 28 | i3d_ftr_path ='../data/vid_i3d' # Path for I3D features 29 | prp_type = 'coco_30_2' # frame-level proposal extractors 30 | mean_cache_ftr_path = '../data/vid/meanFeature' 31 | ftr_context_path = '../data/vid/coco32/context/Data/VID' 32 | 33 | set_name = opt.set_name 34 | dataset = vidDataloader(ann_folder, prp_type, set_name, dictFile, tubePath \ 35 | , ftrPath, out_cached_folder) 36 | capNum = opt.capNum 37 | maxWordNum = opt.maxWordNum 38 | rpNum = opt.rpNum 39 | pos_emb_dim = opt.pos_emb_dim 40 | pos_type = opt.pos_type 41 | vis_ftr_type = opt.vis_ftr_type 42 | use_mean_cache_flag = opt.use_mean_cache_flag 43 | context_flag = opt.context_flag 44 | frm_level_flag = opt.frm_level_flag 45 | frm_num = opt.frm_num 46 | dataset.image_samper_set_up(rpNum= rpNum, capNum = capNum, \ 47 | maxWordNum= maxWordNum, usedBadWord=False, \ 48 | pos_emb_dim=pos_emb_dim, pos_type=pos_type, vis_ftr_type=vis_ftr_type, \ 49 | i3d_ftr_path=i3d_ftr_path, use_mean_cache_flag=use_mean_cache_flag,\ 50 | mean_cache_ftr_path=mean_cache_ftr_path, context_flag=context_flag, ftr_context_path=ftr_context_path, frm_level_flag=frm_level_flag, frm_num=frm_num) 51 | 52 | shuffle_flag = not opt.no_shuffle_flag 53 | 54 | data_loader = data.DataLoader(dataset, opt.batchSize, \ 55 | num_workers=opt.num_workers, collate_fn=dis_collate_vid, \ 56 | shuffle=shuffle_flag, pin_memory=True) 57 | return data_loader, dataset 58 | 59 | data_loader = data.DataLoader(dataset, opt.batchSize, \ 60 | num_workers=opt.num_workers, collate_fn=dis_collate_actNet, \ 61 | shuffle=shuffle_flag, pin_memory=True) 62 | return data_loader, dataset 63 | 64 | else: 65 | print('Not implemented for dataset %s\n' %(opt.dbSet)) 66 | return 67 | 68 | if __name__=='__main__': 69 | opt = parse_args() 70 | opt.dbSet = 'vid' 71 | opt.set_name = 'train' 72 | opt.vis_ftr_type = 'i3d' 73 | out_pre = '/data1/zfchen/data/' 74 | opt.cache_flag = False 75 | opt.pos_type = 'none' 76 | data_loader, dataset_ori = build_dataloader(opt) 77 | -------------------------------------------------------------------------------- /fun/datasetParser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('../') 4 | sys.path.append('../util') 5 | from util.mytoolbox import get_specific_file_list_from_fd, textread, split_carefully, parse_mul_num_lines, pickleload, pickledump, get_list_dir 6 | import pdb 7 | import h5py 8 | import csv 9 | 10 | def build_idx_from_list(vdList): 11 | vdListU = list(set(vdList)) 12 | vd2idx = {} 13 | idx2vd = {} 14 | for i, ele in enumerate(vdList): 15 | vd2idx[ele] = i 16 | idx2vd[i] = ele 17 | return vd2idx, idx2vd 18 | 19 | def get_word_list(full_name_list): 20 | capList =[] 21 | for filePath in full_name_list: 22 | lines = textread(filePath) 23 | subCapList=[] 24 | for line in lines: 25 | wordList=line.split(line, ' ') 26 | subCapList.append(wordList.lower()) 27 | capList.append(subCapList) 28 | return capList 29 | 30 | def fix_otb_frm_bbxList(annFn, outFn): 31 | annDict = pickleload(annFn) 32 | 33 | isTrainSet = True 34 | if isTrainSet: 35 | bbxListDict=annDict['train_bbx_list'] 36 | vidList = annDict['trainName'] 37 | frmListDict=annDict['trainImg'] 38 | else: 39 | bbxListDict=annDict['test_bbx_list'] 40 | vidList = annDict['testName'] 41 | frmListDict=annDict['testImg'] 42 | 43 | for i, vidName in enumerate(vidList): 44 | frmList = frmListDict[i] 45 | bbxList = bbxListDict[i] 46 | frmList.sort() 47 | if(vidName=='David'): 48 | annDict['trainImg'][i] = frmList[299:] 49 | 50 | isTrainSet = False 51 | if isTrainSet: 52 | bbxListDict=annDict['train_bbx_list'] 53 | vidList = annDict['trainName'] 54 | frmListDict=annDict['trainImg'] 55 | else: 56 | bbxListDict=annDict['test_bbx_list'] 57 | vidList = annDict['testName'] 58 | frmListDict=annDict['testImg'] 59 | 60 | for i, vidName in enumerate(vidList): 61 | frmList = frmListDict[i] 62 | bbxList = bbxListDict[i] 63 | frmList.sort() 64 | if(vidName=='David'): 65 | annDict['testImg'][i] = frmList[299:] 66 | 67 | pickledump(outFn, annDict) 68 | return annDict 69 | 70 | # parse OTB 99 71 | def get_otb_data(inFd): 72 | test_text_fd = inFd + '/OTB_query_test' 73 | train_text_fd = inFd + '/OTB_query_train' 74 | video_fd = inFd + '/OTB_videos' 75 | test_name_list = get_specific_file_list_from_fd(test_text_fd, '.txt') 76 | train_name_list = get_specific_file_list_from_fd(train_text_fd, '.txt') 77 | test_name_listFull = get_specific_file_list_from_fd(test_text_fd, '.txt', False) 78 | train_name_listFull = get_specific_file_list_from_fd(train_text_fd, '.txt', False) 79 | 80 | test_cap_list = get_word_list(test_name_listFull) 81 | train_cap_list = get_word_list(train_name_listFull) 82 | 83 | test_im_list =[] 84 | test_bbx_list =[] 85 | for i, vdName in enumerate(test_name_list): 86 | full_vd_path = video_fd + '/'+ vdName 87 | vd_frame_path = full_vd_path + '/img' 88 | imgList = get_specific_file_list_from_fd(vd_frame_path, '.jpg', True) 89 | test_im_list.append(imgList) 90 | 91 | gtBoxFn = full_vd_path +'/groundtruth_rect.txt' 92 | bbxList= parse_mul_num_lines(gtBoxFn) 93 | test_bbx_list.append(bbxList) 94 | 95 | 96 | train_im_list =[] 97 | train_bbx_list =[] 98 | for i, vdName in enumerate(train_name_list): 99 | full_vd_path = video_fd + '/'+ vdName 100 | vd_frame_path = full_vd_path + '/img' 101 | imgList = get_specific_file_list_from_fd(vd_frame_path, '.jpg') 102 | train_im_list.append(imgList) 103 | 104 | gtBoxFn = full_vd_path +'/groundtruth_rect.txt' 105 | bbxList= parse_mul_num_lines(gtBoxFn) 106 | train_bbx_list.append(bbxList) 107 | 108 | otb_info_raw= {'trainName': train_name_list, 'testName': test_name_list, 'trainCap': train_cap_list, 'testCap': test_cap_list, 'trainImg': train_im_list, 'testImg': test_im_list, 'train_bbx_list': train_bbx_list, 'test_bbx_list': test_bbx_list} 109 | return otb_info_raw 110 | 111 | def otbPCK2List(pckFn): 112 | otbDict = pickleload(pckFn) 113 | testVdList= otbDict['testName'] 114 | testImList = otbDict['testImg'] 115 | trainVdList= otbDict['trainName'] 116 | trainImList = otbDict['trainImg'] 117 | imgList = list() 118 | for i, vdName in enumerate(testVdList): 119 | vd_frame_path = vdName + '/img' 120 | for j, imName in enumerate(testImList[i]): 121 | imNameFull = vd_frame_path+'/'+imName+'.jpg' 122 | imgList.append(imNameFull) 123 | 124 | for i, vdName in enumerate(trainVdList): 125 | vd_frame_path = vdName + '/img' 126 | for j, imName in enumerate(trainImList[i]): 127 | imNameFull = vd_frame_path+'/'+imName+'.jpg' 128 | imgList.append(imNameFull) 129 | 130 | return imgList 131 | 132 | def a2dSetParser(annFn, annFd, annIgListFn, annFnOri): 133 | fLineList = textread(annFn) 134 | videoList = list() 135 | capList = list() 136 | insList = list() 137 | bbxList = list() 138 | frmNameList = list() 139 | splitDict = {} 140 | 141 | with open(annFnOri, 'rb') as csvFile: 142 | lines = csv.reader(csvFile) 143 | for line in lines: 144 | eleSegs = split_carefully(line, ',') 145 | splitDict[eleSegs[0][0]]=int(eleSegs[0][-1]) 146 | 147 | for i, line in enumerate(fLineList): 148 | if(i<=0): 149 | continue 150 | splitSegs= split_carefully(line, ',') 151 | annFdSub = annFd + '/' + splitSegs[0] 152 | annNameList = get_specific_file_list_from_fd(annFdSub, '.h5') 153 | tmpFrmList = list() 154 | tmpBbxList = list() 155 | #pdb.set_trace() 156 | insIdx = int(splitSegs[1]) 157 | print(splitSegs[2]) 158 | print('%s %d %d\n' %(splitSegs[0], i, insIdx)) 159 | for ii, annName in enumerate(annNameList): 160 | annSubFullPath = annFdSub + '/' + annName +'.h5' 161 | annIns = h5py.File(annSubFullPath) 162 | tmpInsList = list(annIns['instance'][:]) 163 | if(insIdx in tmpInsList): 164 | tmpFrmList.append(annName) 165 | bxIdx = tmpInsList.index(insIdx) 166 | tmpBbxList.append(annIns['reBBox'][:, bxIdx]) 167 | frmNameList.append(tmpFrmList) 168 | bbxList.append(tmpBbxList) 169 | videoList.append(splitSegs[0]) 170 | insList.append(int(splitSegs[1])) 171 | capSegs = splitSegs[2].lower().split(' ') 172 | capList.append(capSegs) 173 | 174 | vd2idx, idx2vd = build_idx_from_list(videoList) 175 | igNameList =textread(annIgListFn) 176 | 177 | a2d_info_raw= {'cap': capList, 'vd': videoList, 'bbxList': bbxList, 'frmList': frmNameList, 'insList' : insList, 'igList': igNameList, 'splitDict': splitDict, 'vd2idx': vd2idx, 'idx2vd': idx2vd} 178 | return a2d_info_raw 179 | 180 | #data = h5py.file() 181 | def a2dPCK2List(pckFn): 182 | a2dDict = pickleload(pckFn) 183 | imgList = list() 184 | testCapList = a2dDict['cap'] 185 | for i, cap in enumerate(testCapList): 186 | vdName = a2dDict['vd'][i] 187 | frmList = a2dDict['frmList'][i] 188 | for j, imName in enumerate(frmList): 189 | imNameFull = vdName+'/'+imName+'.png' 190 | imgList.append(imNameFull) 191 | imgList= list(set(imgList)) 192 | return imgList 193 | 194 | def extAllFrm(annFn, fdPre): 195 | annDict = pickleload(annFn) 196 | videoListU = list(set(annDict['vd'])) 197 | frmFullList = list() 198 | for vdName in videoListU: 199 | subPre = fdPre + '/' + vdName 200 | frmNameList = get_specific_file_list_from_fd(subPre, '.png') 201 | for frm in frmNameList: 202 | frmFullList.append(vdName+'/'+frm+'.png') 203 | return frmFullList 204 | 205 | def extAllFrmFn(videoList, fdPre): 206 | frmvDict = list() 207 | for vdName in videoList: 208 | subPre = fdPre + '/' + vdName 209 | frmNameList = get_specific_file_list_from_fd(subPre, '.png') 210 | frmNameList.sort() 211 | frmvDict.append(frmNameList) 212 | #pdb.set_trace() 213 | return frmvDict 214 | 215 | def getFrmFn(fdPre, extFn='.jpg'): 216 | frmListFull = list() 217 | sub_fd_list = get_list_dir(fdPre) 218 | for i, vdName in enumerate(sub_fd_list): 219 | frmList = get_specific_file_list_from_fd(vdName, extFn, nameOnly=False) 220 | frmListFull +=frmList 221 | return frmListFull 222 | 223 | 224 | 225 | 226 | if __name__=='__main__': 227 | 228 | fdListFn = '/data1/zfchen/data/actNet/actNetJpgs/' 229 | #frmList = getFrmFn(fdListFn) 230 | #pdb.set_trace() 231 | 232 | #pckFn = '../data/annoted_a2d.pd' 233 | #imPre ='/data1/zfchen/data/A2D/Release/pngs320H' 234 | #dataAnn = pickleload(pckFn) 235 | #frmFullList = extAllFrm(pckFn, imPre) 236 | #pdb.set_trace() 237 | #dataAnn['frmListGt'] = dataAnn['frmList'] 238 | #dataAnn['frmList'] = frmFullList 239 | #pickledump('../data/annoted_a2dV2.pd', dataAnn) 240 | #pdb.set_trace() 241 | #print('finish') 242 | #imgList = a2dPCK2List(pckFn) 243 | #print('finish') 244 | #annFd = '/disk2/zfchen/data/A2D/Release/sentenceAnno/a2d_annotation_with_instances' 245 | #annFn = '/disk2/zfchen/data/A2D/Release/sentenceAnno/a2d_annotation.txt' 246 | #annIgListFn = '/disk2/zfchen/data/A2D/Release/sentenceAnno/a2d_missed_videos.txt' 247 | #annOriFn = '/disk2/zfchen/data/A2D/Release/videoset.csv' 248 | #a2dSetParser(annFn, annFd, annIgListFn, annOriFn) 249 | #otbPKFile ='../data/annForDb_otb.pd' 250 | #otbPKFileV2 ='../data/annForDb_otbV2.pd' 251 | #otbNew= fix_otb_frm_bbxList(otbPKFile, otbPKFileV2) 252 | #imList = otbPCK2List(otbPKFile) 253 | #print(imList[:10]) 254 | #print(imList[10000]) 255 | #print(len(imList)) 256 | -------------------------------------------------------------------------------- /fun/eval.py: -------------------------------------------------------------------------------- 1 | from wsParamParser import parse_args 2 | from data.data_loader import* 3 | from datasetLoader import * 4 | from modelArc import * 5 | from optimizers import * 6 | from logInfo import logInF 7 | from lossPackage import * 8 | from netUtil import * 9 | from tensorboardX import SummaryWriter 10 | import time 11 | 12 | import pdb 13 | 14 | if __name__=='__main__': 15 | opt = parse_args() 16 | # build dataloader 17 | dataLoader, datasetOri= build_dataloader(opt) 18 | # build network 19 | model = build_network(opt) 20 | # build_optimizer 21 | logger = logInF(opt.logFd) 22 | ep = 0 23 | #thre_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] 24 | thre_list = [0.5] 25 | acc_list = list() 26 | more_detailed_flag = True 27 | for thre in thre_list: 28 | acc_list.append([]) 29 | 30 | if opt.eval_val_flag: 31 | model.eval() 32 | resultList = list() 33 | vIdList = list() 34 | set_name_ori= opt.set_name 35 | opt.set_name = 'val' 36 | dataLoaderEval, datasetEvalOri = build_dataloader(opt) 37 | for itr_eval, inputData in enumerate(dataLoaderEval): 38 | tube_embedding, cap_embedding, tubeInfo, indexOri, cap_length_list, vd_name_list, word_lbl_list = inputData 39 | dataIdx = None 40 | b_size = tube_embedding.shape[0] 41 | # B*P*T*D 42 | imDis = tube_embedding.cuda() 43 | imDis = imDis.view(-1, imDis.shape[2], imDis.shape[3]) 44 | wordEmb = cap_embedding.cuda() 45 | wordEmb = wordEmb.view(-1, wordEmb.shape[2], wordEmb.shape[3]) 46 | imDis.requires_grad=False 47 | wordEmb.requires_grad=False 48 | 49 | if opt.wsMode =='coAtt': 50 | simMM = model(imDis, wordEmb, cap_length_list) 51 | #pdb.set_trace() 52 | simMM = simMM.view(b_size, opt.rpNum) 53 | for i, thre in enumerate(thre_list): 54 | acc_list[i] += evalAcc_att_test(simMM, tubeInfo, indexOri, datasetEvalOri, opt.visRsFd+str(ep), False, topK=1, thre_list=[thre], more_detailed_flag=more_detailed_flag) 55 | 56 | for i, thre in enumerate(thre_list): 57 | resultList = acc_list[i] 58 | accSum = 0 59 | for ele in resultList: 60 | recall_k= ele[1] 61 | accSum +=recall_k 62 | logger('thre @ %f, Average accuracy on validation set is %3f\n' %(thre, accSum/len(resultList))) 63 | 64 | out_result_fn = opt.logFd + 'result_val_' +os.path.basename(opt.initmodel).split('.')[0] + '.pk' 65 | pickledump(out_result_fn, acc_list) 66 | 67 | if opt.eval_test_flag: 68 | model.eval() 69 | resultList = list() 70 | vIdList = list() 71 | set_name_ori= opt.set_name 72 | opt.set_name = 'test' 73 | dataLoaderEval, datasetEvalOri = build_dataloader(opt) 74 | #pdb.set_trace() 75 | for itr_eval, inputData in enumerate(dataLoaderEval): 76 | tube_embedding, cap_embedding, tubeInfo, indexOri, cap_length_list, vd_name_list, word_lbl_list = inputData 77 | dataIdx = None 78 | b_size = tube_embedding.shape[0] 79 | # B*P*T*D 80 | imDis = tube_embedding.cuda() 81 | imDis = imDis.view(-1, imDis.shape[2], imDis.shape[3]) 82 | wordEmb = cap_embedding.cuda() 83 | wordEmb = wordEmb.view(-1, wordEmb.shape[2], wordEmb.shape[3]) 84 | imDis.requires_grad=False 85 | wordEmb.requires_grad=False 86 | 87 | if opt.wsMode =='coAtt': 88 | simMM = model(imDis, wordEmb, cap_length_list) 89 | simMM = simMM.view(b_size, opt.rpNum) 90 | for i, thre in enumerate(thre_list): 91 | acc_list[i] += evalAcc_att_test(simMM, tubeInfo, indexOri, datasetEvalOri, opt.visRsFd+str(ep), False, thre_list=[thre], more_detailed_flag=more_detailed_flag) 92 | 93 | for i, thre in enumerate(thre_list): 94 | resultList = acc_list[i] 95 | accSum = 0 96 | for ele in resultList: 97 | recall_k= ele[1] 98 | accSum +=recall_k 99 | logger('thre @ %f, Average accuracy on testing set is %3f\n' %(thre, accSum/len(resultList))) 100 | out_result_fn = opt.logFd + 'result_test_' +os.path.basename(opt.initmodel).split('.')[0] + '.pk' 101 | pickledump(out_result_fn, acc_list) 102 | 103 | -------------------------------------------------------------------------------- /fun/evalDet.py: -------------------------------------------------------------------------------- 1 | #import os 2 | import numpy as np 3 | import sys 4 | import magic 5 | import re 6 | sys.path.append('..') 7 | from util.mytoolbox import * 8 | from util.get_image_size import get_image_size 9 | import cv2 10 | import pdb 11 | import ipdb 12 | import copy 13 | import torch 14 | from fun.datasetLoader import * 15 | from vidDatasetParser import evaluate_tube_recall_vid, resize_tube_bbx 16 | from netUtil import * 17 | 18 | sys.path.append('../annotation') 19 | from script_test_annotation import evaluate_tube_recall 20 | 21 | def computeIoU(box1, box2): 22 | # each box is of [x1, y1, w, h] 23 | inter_x1 = max(box1[0], box2[0]) 24 | inter_y1 = max(box1[1], box2[1]) 25 | inter_x2 = min(box1[0]+box1[2]-1, box2[0]+box2[2]-1) 26 | inter_y2 = min(box1[1]+box1[3]-1, box2[1]+box2[3]-1) 27 | 28 | if inter_x1 < inter_x2 and inter_y1 < inter_y2: 29 | inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1) 30 | else: 31 | inter = 0 32 | union = box1[2]*box1[3] + box2[2]*box2[3] - inter 33 | return float(inter)/union 34 | 35 | # get original size of the proposls 36 | # the proposals are extracted at larger size 37 | def transFormBbx(bbx, im_shape,reSize=800, maxSize=1200, im_scale=None): 38 | if im_scale is not None: 39 | bbx = [ ele/im_scale for ele in bbx] 40 | bbx[2] = bbx[2]-bbx[0] 41 | bbx[3] = bbx[3]-bbx[1] 42 | return bbx 43 | im_size_min = np.min(im_shape[0:2]) 44 | im_size_max = np.max(im_shape[0:2]) 45 | im_scale = float(reSize)/im_size_min 46 | if(np.round(im_size_max*im_scale) > maxSize): 47 | im_scale = float(maxSize)/ float(im_size_max) 48 | bbx = [ ele/im_scale for ele in bbx] 49 | bbx[2] = bbx[2]-bbx[0] 50 | bbx[3] = bbx[3]-bbx[1] 51 | return bbx 52 | 53 | 54 | class evalDetAcc(object): 55 | def __init__(self, gtBbxList=None, IoU=0.5, topK=1): 56 | self.gtList = gtBbxList 57 | self.thre = IoU 58 | self.topK= topK 59 | 60 | def evalList(self, prpList): 61 | imgNum = len(prpList) 62 | posNum = 0 63 | for i in range(imgNum): 64 | tmpRpList= prpList[i] 65 | gtBbx = self.gtList[i] 66 | for j in range(self.topK): 67 | iou = computeIoU(gtBbx, tmpRpList[j]) 68 | if iou>self.thre: 69 | posNum +=1 70 | break 71 | return float(posNum)/imgNum 72 | 73 | def rpMatPreprocess(rpMatrix, imWH, isA2D=False): 74 | rpList= list() 75 | rpNum = rpMatrix.shape[0] 76 | for i in range(rpNum): 77 | tmpRp = list(rpMatrix[i, :]) 78 | if not isA2D: 79 | bbx = transFormBbx(tmpRp, imWH) 80 | else: 81 | bbx = transFormBbx(tmpRp, imWH, im_scale=imWH) 82 | rpList.append(bbx) 83 | return rpList 84 | 85 | def evalAcc_att_test(simMM_batch, tubeInfo, indexOri, datasetOri, visRsFd, visFlag=False, topK = 1, thre_list=[0.5], more_detailed_flag=False): 86 | resultList= list() 87 | bSize = len(indexOri) 88 | tube_Prp_num = len(tubeInfo[0][0][0]) 89 | stIdx = 0 90 | vid_parser = datasetOri.vid_parser 91 | for idx, lbl in enumerate(indexOri): 92 | simMM = simMM_batch[idx, :] 93 | simMMReshape = simMM.view(-1, tube_Prp_num) 94 | 95 | sortSim, simIdx = torch.sort(simMMReshape, dim=1, descending=True) 96 | simIdx = simIdx.data.cpu().numpy().squeeze(axis=0) 97 | sort_sim_np = sortSim.data.cpu().numpy().squeeze(axis=0) 98 | 99 | tube_Info_sub = tubeInfo[idx] 100 | tube_info_sub_prp, frm_info_list = tube_Info_sub 101 | tube_info_sub_prp_bbx, tube_info_sub_prp_score = tube_info_sub_prp 102 | #prpListSort = [ [tube_info_sub_prp_bbx[simIdx[i]], sort_sim_np[i] ]for i in range(topK)] 103 | prpListSort = [ [tube_info_sub_prp_bbx[simIdx[i]] for i in range(topK)], [sort_sim_np[i] for i in range(topK)] ] 104 | shot_proposals = [prpListSort, frm_info_list] 105 | for ii, thre in enumerate(thre_list): 106 | if more_detailed_flag: 107 | recall_tmp, iou_list = evaluate_tube_recall_vid(shot_proposals, vid_parser, lbl, thre, topKOri=topK, more_detailed_flag=more_detailed_flag) 108 | else: 109 | recall_tmp= evaluate_tube_recall_vid(shot_proposals, vid_parser, lbl, thre, topKOri=topK) 110 | if more_detailed_flag: 111 | resultList.append((lbl, recall_tmp[-1], simIdx, sort_sim_np, iou_list)) 112 | else: 113 | resultList.append((lbl, recall_tmp[-1])) 114 | print('accuracy for %d: %3f' %(lbl, recall_tmp[-1])) 115 | 116 | if visFlag: 117 | 118 | print(vid_parser.tube_cap_dict[lbl]) 119 | print(lbl) 120 | #continue 121 | vd_name, ins_id_str = vid_parser.get_shot_info_from_index(lbl) 122 | frmImNameList = [os.path.join(vid_parser.jpg_folder, vd_name, frame_name + '.JPEG') for frame_name in frm_info_list] 123 | frmImList = list() 124 | for fId, imPath in enumerate(frmImNameList): 125 | img = cv2.imread(imPath) 126 | frmImList.append(img) 127 | vis_frame_num = 30 128 | visIner =max(int(len(frmImList) /vis_frame_num), 1) 129 | 130 | for ii in range(topK): 131 | print('visualizing tube %d\n'%(ii)) 132 | #pdb.set_trace() 133 | tube = prpListSort[0][ii] 134 | frmImList_vis = [frmImList[iii] for iii in range(0, len(frmImList), visIner)] 135 | tube_vis = [tube[iii] for iii in range(0, len(frmImList), visIner)] 136 | tube_vis_resize = resize_tube_bbx(tube_vis, frmImList_vis) 137 | vd_name_raw = vd_name.split('/')[-1] 138 | makedirs_if_missing(visRsFd) 139 | visTube_from_image(copy.deepcopy(frmImList_vis), tube_vis_resize, visRsFd+'/'+vd_name_raw+ '_' + str(ii)+'.gif') 140 | pdb.set_trace() 141 | return resultList 142 | 143 | 144 | def evalAcc_att(simMMFull, tubeInfo, indexOri, datasetOri, visRsFd, visFlag=False, topK = 1, thre_list=[0.5], more_detailed_flag=False): 145 | # pdb.set_trace() 146 | resultList= list() 147 | bSize = len(indexOri) 148 | tube_Prp_num = len(tubeInfo[0][0][0]) 149 | stIdx = 0 150 | #thre_list = [0.2, 0.3, 0.4, 0.5] 151 | vid_parser = datasetOri.vid_parser 152 | for idx, lbl in enumerate(indexOri): 153 | simMM = simMMFull[idx, :, idx] 154 | simMMReshape = simMM.view(-1, tube_Prp_num) 155 | #pdb.set_trace() 156 | sortSim, simIdx = torch.sort(simMMReshape, dim=1, descending=True) 157 | simIdx = simIdx.data.cpu().numpy().squeeze(axis=0) 158 | sort_sim_np = sortSim.data.cpu().numpy().squeeze(axis=0) 159 | 160 | tube_Info_sub = tubeInfo[idx] 161 | tube_info_sub_prp, frm_info_list = tube_Info_sub 162 | tube_info_sub_prp_bbx, tube_info_sub_prp_score = tube_info_sub_prp 163 | #prpListSort = [ [tube_info_sub_prp_bbx[simIdx[i]], sort_sim_np[i] ]for i in range(topK)] 164 | prpListSort = [ [tube_info_sub_prp_bbx[simIdx[i]] for i in range(topK)], [sort_sim_np[i] for i in range(topK)] ] 165 | shot_proposals = [prpListSort, frm_info_list] 166 | for ii, thre in enumerate(thre_list): 167 | if more_detailed_flag: 168 | recall_tmp, iou_list = evaluate_tube_recall_vid(shot_proposals, vid_parser, lbl, thre, topKOri=topK, more_detailed_flag=more_detailed_flag) 169 | else: 170 | recall_tmp= evaluate_tube_recall_vid(shot_proposals, vid_parser, lbl, thre, topKOri=topK) 171 | if more_detailed_flag: 172 | resultList.append((lbl, recall_tmp[-1], simIdx, sort_sim_np, iou_list)) 173 | else: 174 | resultList.append((lbl, recall_tmp[-1])) 175 | print('accuracy for %d: %3f' %(lbl, recall_tmp[-1])) 176 | 177 | #pdb.set_trace() 178 | if visFlag: 179 | 180 | 181 | # visualize sample results 182 | #if(recall_tmp[-1]<=0.5): 183 | # continue 184 | print(vid_parser.tube_cap_dict[lbl]) 185 | print(lbl) 186 | #pdb.set_trace() 187 | #continue 188 | vd_name, ins_id_str = vid_parser.get_shot_info_from_index(lbl) 189 | frmImNameList = [os.path.join(vid_parser.jpg_folder, vd_name, frame_name + '.JPEG') for frame_name in frm_info_list] 190 | frmImList = list() 191 | for fId, imPath in enumerate(frmImNameList): 192 | img = cv2.imread(imPath) 193 | frmImList.append(img) 194 | vis_frame_num = 30 195 | visIner =max(int(len(frmImList) /vis_frame_num), 1) 196 | 197 | for ii in range(topK): 198 | print('visualizing tube %d\n'%(ii)) 199 | #pdb.set_trace() 200 | tube = prpListSort[0][ii] 201 | frmImList_vis = [frmImList[iii] for iii in range(0, len(frmImList), visIner)] 202 | tube_vis = [tube[iii] for iii in range(0, len(frmImList), visIner)] 203 | tube_vis_resize = resize_tube_bbx(tube_vis, frmImList_vis) 204 | vd_name_raw = vd_name.split('/')[-1] 205 | makedirs_if_missing(visRsFd) 206 | visTube_from_image(copy.deepcopy(frmImList_vis), tube_vis_resize, visRsFd+'/'+vd_name_raw+ '_' + str(ii)+'.gif') 207 | pdb.set_trace() 208 | return resultList 209 | 210 | 211 | 212 | 213 | def evalAcc(imFtr, txtFtr, tubeInfo, indexOri, datasetOri, visRsFd, visFlag=False, topK = 1, thre_list=[0.5], more_detailed_flag=False): 214 | resultList= list() 215 | bSize = len(indexOri) 216 | tube_Prp_num = imFtr.shape[1] 217 | stIdx = 0 218 | #thre_list = [0.2, 0.3, 0.4, 0.5] 219 | vid_parser = datasetOri.vid_parser 220 | assert txtFtr.shape[1]==1 221 | for idx, lbl in enumerate(indexOri): 222 | imFtrSub = imFtr[idx] 223 | txtFtrSub = txtFtr[idx].view(-1,1) 224 | simMM = torch.mm(imFtrSub, txtFtrSub) 225 | #pdb.set_trace() 226 | simMMReshape = simMM.view(-1, tube_Prp_num) 227 | sortSim, simIdx = torch.sort(simMMReshape, dim=1, descending=True) 228 | simIdx = simIdx.data.cpu().numpy().squeeze(axis=0) 229 | sort_sim_np = sortSim.data.cpu().numpy().squeeze(axis=0) 230 | 231 | tube_Info_sub = tubeInfo[idx] 232 | tube_info_sub_prp, frm_info_list = tube_Info_sub 233 | tube_info_sub_prp_bbx, tube_info_sub_prp_score = tube_info_sub_prp 234 | #prpListSort = [ [tube_info_sub_prp_bbx[simIdx[i]], sort_sim_np[i] ]for i in range(topK)] 235 | prpListSort = [ [tube_info_sub_prp_bbx[simIdx[i]] for i in range(topK)], [sort_sim_np[i] for i in range(topK)] ] 236 | shot_proposals = [prpListSort, frm_info_list] 237 | for ii, thre in enumerate(thre_list): 238 | if more_detailed_flag: 239 | recall_tmp, iou_list= evaluate_tube_recall_vid(shot_proposals, vid_parser, lbl, thre, topKOri=topK, more_detailed_flag=more_detailed_flag) 240 | else: 241 | recall_tmp = evaluate_tube_recall_vid(shot_proposals, vid_parser, lbl, thre, topKOri=topK) 242 | if more_detailed_flag: 243 | resultList.append((lbl, recall_tmp[-1], simIdx, sort_sim_np, iou_list)) 244 | else: 245 | resultList.append((lbl, recall_tmp[-1])) 246 | #print('accuracy for %d: %3f' %(lbl, recall_tmp[-1])) 247 | 248 | #pdb.set_trace() 249 | if visFlag: 250 | # visualize sample results 251 | #if(recall_tmp[-1]<=0.5): 252 | # continue 253 | vd_name, ins_id_str = vid_parser.get_shot_info_from_index(lbl) 254 | frmImNameList = [os.path.join(vid_parser.jpg_folder, vd_name, frame_name + '.JPEG') for frame_name in frm_info_list] 255 | frmImList = list() 256 | for fId, imPath in enumerate(frmImNameList): 257 | img = cv2.imread(imPath) 258 | frmImList.append(img) 259 | vis_frame_num = 30 260 | visIner =max(int(len(frmImList) /vis_frame_num), 1) 261 | 262 | for ii in range(topK): 263 | print('visualizing tube %d\n'%(ii)) 264 | #pdb.set_trace() 265 | tube = prpListSort[0][ii] 266 | frmImList_vis = [frmImList[iii] for iii in range(0, len(frmImList), visIner)] 267 | tube_vis = [tube[iii] for iii in range(0, len(frmImList), visIner)] 268 | tube_vis_resize = resize_tube_bbx(tube_vis, frmImList_vis) 269 | vd_name_raw = vd_name.split('/')[-1] 270 | makedirs_if_missing(visRsFd) 271 | visTube_from_image(copy.deepcopy(frmImList_vis), tube_vis_resize, visRsFd+'/'+vd_name_raw+ '_' + str(ii)+'.gif') 272 | pdb.set_trace() 273 | return resultList 274 | 275 | if __name__=='__main__': 276 | pdb.set_trace() 277 | -------------------------------------------------------------------------------- /fun/image_toolbox.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | 4 | import commands 5 | import copy 6 | import cv2 7 | from easydict import EasyDict as edict 8 | import numpy as np 9 | import os 10 | import shutil 11 | from dashed_rect import drawrect 12 | #from settings import FFMPEG 13 | import pdb 14 | TMP_DIR = '.tmp' 15 | FFMPEG = 'ffmpeg' 16 | 17 | SAVE_VIDEO = FFMPEG + ' -y -r %d -i %s/%s.jpg %s' 18 | 19 | def images2video(image_list, frame_rate, video_path, max_edge=None): 20 | if os.path.exists(TMP_DIR): 21 | shutil.rmtree(TMP_DIR) 22 | os.mkdir(TMP_DIR) 23 | img_size = None 24 | for cur_num, cur_img in enumerate(image_list): 25 | cur_fname = os.path.join(TMP_DIR, '%08d.jpg' % cur_num) 26 | if max_edge is not None: 27 | cur_img = imread_if_str(cur_img) 28 | if isinstance(cur_img, str) or isinstance(cur_img, unicode): 29 | shutil.copyfile(cur_img, cur_fname) 30 | elif isinstance(cur_img, np.ndarray): 31 | max_len = max(cur_img.shape[:2]) 32 | if max_len > max_edge and img_size is None and max_edge is not None: 33 | magnif = float(max_edge) / float(max_len) 34 | img_size = (int(cur_img.shape[1] * magnif), int(cur_img.shape[0] * magnif)) 35 | cur_img = cv2.resize(cur_img, img_size) 36 | elif max_edge is not None: 37 | if img_size is None: 38 | magnif = float(max_edge) / float(max_len) 39 | img_size = (int(cur_img.shape[1] * magnif), int(cur_img.shape[0] * magnif)) 40 | cur_img = cv2.resize(cur_img, img_size) 41 | cv2.imwrite(cur_fname, cur_img) 42 | else: 43 | NotImplementedError() 44 | print commands.getoutput(SAVE_VIDEO % (frame_rate, TMP_DIR, '%08d', video_path)) 45 | shutil.rmtree(TMP_DIR) 46 | 47 | def imread_if_str(img): 48 | if isinstance(img, basestring): 49 | img = cv2.imread(img) 50 | return img 51 | 52 | def draw_rectangle(img, bbox, color=(0,0,255), thickness=3, use_dashed_line=False): 53 | img = imread_if_str(img) 54 | if isinstance(bbox, dict): 55 | bbox = [ 56 | bbox['x1'], 57 | bbox['y1'], 58 | bbox['x2'], 59 | bbox['y2'], 60 | ] 61 | bbox[0] = max(bbox[0], 0) 62 | bbox[1] = max(bbox[1], 0) 63 | bbox[0] = min(bbox[0], img.shape[1]) 64 | bbox[1] = min(bbox[1], img.shape[0]) 65 | bbox[2] = max(bbox[2], 0) 66 | bbox[3] = max(bbox[3], 0) 67 | bbox[2] = min(bbox[2], img.shape[1]) 68 | bbox[3] = min(bbox[3], img.shape[0]) 69 | assert bbox[2] >= bbox[0] 70 | assert bbox[3] >= bbox[1] 71 | assert bbox[0] >= 0 72 | assert bbox[1] >= 0 73 | assert bbox[2] <= img.shape[1] 74 | assert bbox[3] <= img.shape[0] 75 | cur_img = copy.deepcopy(img) 76 | #pdb.set_trace() 77 | if use_dashed_line: 78 | drawrect( 79 | cur_img, 80 | (int(bbox[0]), int(bbox[1])), 81 | (int(bbox[2]), int(bbox[3])), 82 | color, 83 | thickness, 84 | 'dotted' 85 | ) 86 | else: 87 | #pdb.set_trace() 88 | cv2.rectangle( 89 | cur_img, 90 | (int(bbox[0]), int(bbox[1])), 91 | (int(bbox[2]), int(bbox[3])), 92 | color, 93 | thickness) 94 | return cur_img 95 | 96 | def rgb2gray(rgb): 97 | r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2] 98 | gray = 0.2989 * r + 0.5870 * g + 0.1140 * b 99 | return gray 100 | 101 | def gray_background(img, bbox): 102 | img = imread_if_str(img) 103 | if isinstance(bbox, dict): 104 | bbox = [ 105 | bbox['x1'], 106 | bbox['y1'], 107 | bbox['x2'], 108 | bbox['y2'], 109 | ] 110 | bbox[0] = int(max(bbox[0], 0)) 111 | bbox[1] = int(max(bbox[1], 0)) 112 | bbox[0] = min(bbox[0], img.shape[1]) 113 | bbox[1] = min(bbox[1], img.shape[0]) 114 | bbox[2] = int(max(bbox[2], 0)) 115 | bbox[3] = int(max(bbox[3], 0)) 116 | bbox[2] = min(bbox[2], img.shape[1]) 117 | bbox[3] = min(bbox[3], img.shape[0]) 118 | assert bbox[2] >= bbox[0] 119 | assert bbox[3] >= bbox[1] 120 | assert bbox[0] >= 0 121 | assert bbox[1] >= 0 122 | assert bbox[2] <= img.shape[1] 123 | assert bbox[3] <= img.shape[0] 124 | #gray_img = copy.deepcopy(img) 125 | #gray_img = np.stack((gray_img, gray_img, gray_img), axis=2) 126 | #gray_img = rgb2gray(gray_img) 127 | gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 128 | gray_img = np.stack((gray_img, gray_img, gray_img), axis=2) 129 | #pdb.set_trace() 130 | gray_img[bbox[1]:bbox[3], bbox[0]:bbox[2], ...] = img[bbox[1]:bbox[3], bbox[0]:bbox[2], ...] 131 | #pdb.set_trace() 132 | return gray_img 133 | -------------------------------------------------------------------------------- /fun/logInfo.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | import os 4 | class logInF (object): 5 | def __init__(self, logPre): 6 | dirname = os.path.dirname(logPre) 7 | if not os.path.exists(dirname): 8 | os.makedirs(dirname) 9 | logFile=logPre+time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + 'log.txt' 10 | self.prHandle=open(logFile, 'w') 11 | 12 | def __call__(self, logData): 13 | print(logData) 14 | self.prHandle.write(logData) 15 | self.prHandle.flush() 16 | -------------------------------------------------------------------------------- /fun/lossPackage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import pdb 5 | import time 6 | 7 | class HLoss(nn.Module): 8 | def __init__(self): 9 | super(HLoss, self).__init__() 10 | 11 | def forward(self, x): 12 | b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1) 13 | b = -1.0 * b.sum(dim=1) 14 | b = torch.mean(b) 15 | return b 16 | 17 | class lossEvaluator(nn.Module): 18 | def __init__(self, margin=0.1, biLossFlag=True, lossWFlag=False, lamda=0.8, struct_flag=False, struct_only_flag=False, entropy_regu_flag=False, lamda2=0.1, loss_type=''): 19 | super(lossEvaluator, self).__init__() 20 | self.margin =margin 21 | self.biLossFlag = biLossFlag 22 | self.lossWFlag = lossWFlag 23 | self.lamda= lamda 24 | self.struct_flag = struct_flag 25 | self.struct_only_flag = struct_only_flag 26 | self.entropy_regu_flag = entropy_regu_flag 27 | self.lamda2 = lamda2 28 | self.loss_type = loss_type 29 | if self.entropy_regu_flag: 30 | self.entropy_calculator = HLoss() 31 | 32 | def forward(self, imFtr=None, disFtr=None, lblList=None, simMM=None, region_gt_ori=None): 33 | if self.wsMode=='coAtt': 34 | loss = self.forwardCoAtt_efficient(simMM, lblList) 35 | return loss 36 | 37 | def forwardCoAtt_efficient(self, simMMRe, lblList): 38 | 39 | simMax, maxIdx= torch.max(simMMRe.squeeze(3), dim=1) 40 | loss = torch.zeros(1).cuda() 41 | pair_num = 0.000001 42 | 43 | t1 = time.time() 44 | #print(lblList) 45 | bSize = len(lblList) 46 | b_size = simMax.shape[0] 47 | pos_diag = torch.cat([simMax[i, i].unsqueeze(0) for i in range(b_size)]) 48 | one_mat = torch.ones(b_size, b_size).cuda() 49 | pos_diag_mat = torch.mul(pos_diag, one_mat) 50 | pos_diag_trs = pos_diag_mat.transpose(0,1) 51 | 52 | mask_val = torch.ones(b_size, b_size).cuda() 53 | 54 | for i in range(b_size): 55 | lblI = lblList[i] 56 | for j in range(i, b_size): 57 | lblJ = lblList[j] 58 | if lblI==lblJ: 59 | mask_val[i, j]=0 60 | mask_val[j, i]=0 61 | pair_num = pair_num + torch.sum(mask_val) 62 | 63 | loss_mat_1 = simMax -pos_diag + self.margin 64 | loss_mask = (loss_mat_1>0).float() 65 | loss_mat_1_mask = loss_mat_1 *loss_mask * mask_val 66 | loss1 = torch.sum(loss_mat_1_mask) 67 | 68 | loss_mat_2 = simMax -pos_diag_trs +self.margin 69 | loss_mask_2 = (loss_mat_2>0).float() 70 | loss_mat_2_mask = loss_mat_2 *loss_mask_2 * mask_val 71 | loss2 = torch.sum(loss_mat_2_mask) 72 | loss = (loss1+loss2)/pair_num 73 | 74 | if self.entropy_regu_flag: 75 | simMMRe = simMMRe.squeeze() 76 | ftr_match_pair_list = list() 77 | ftr_unmatch_pair_list = list() 78 | 79 | for i in range(bSize): 80 | for j in range(bSize): 81 | if i==j: 82 | ftr_match_pair_list.append(simMMRe[i, ..., i]) 83 | elif lblList[i]!=lblList[j]: 84 | ftr_unmatch_pair_list.append(simMMRe[i, ..., j]) 85 | 86 | ftr_match_pair_mat = torch.stack(ftr_match_pair_list, 0) 87 | ftr_unmatch_pair_mat = torch.stack(ftr_unmatch_pair_list, 0) 88 | match_num = len(ftr_match_pair_list) 89 | unmatch_num = len(ftr_unmatch_pair_list) 90 | if match_num>0: 91 | entro_loss = self.entropy_calculator(ftr_match_pair_mat) 92 | loss +=self.lamda2*entro_loss 93 | print('entropy loss: %3f ' %(float(entro_loss))) 94 | #pdb.set_trace() 95 | print('\n') 96 | return loss 97 | 98 | def build_lossEval(opts): 99 | if opts.wsMode == 'rankTube' or opts.wsMode=='coAtt' or opts.wsMode=='coAttV2' or opts.wsMode=='coAttV3' or opts.wsMode == 'coAttV4' or opts.wsMode=='rankFrm' or opts.wsMode=='coAttBi' or opts.wsMode=='coAttBiV2' or opts.wsMode =='coAttBiV3': 100 | loss_criterion = lossEvaluator(opts.margin, opts.biLoss, opts.lossW, \ 101 | opts.lamda, opts.struct_flag, opts.struct_only, \ 102 | opts.entropy_regu_flag, opts.lamda2, \ 103 | loss_type=opts.loss_type) 104 | loss_criterion.wsMode =opts.wsMode 105 | return loss_criterion 106 | elif opts.wsMode =='rankGroundR' or opts.wsMode =='coAttGroundR' or opts.wsMode=='rankGroundRV2': 107 | loss_criterion = lossGroundR(entropy_regu_flag=opts.entropy_regu_flag, \ 108 | lamda2=opts.lamda2 ) 109 | loss_criterion.wsMode = opts.wsMode 110 | return loss_criterion 111 | -------------------------------------------------------------------------------- /fun/modelArc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pdb 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.utils.rnn import pack_padded_sequence 6 | #from multiGraphAttention import * 7 | from netvlad import NetVLAD 8 | from classSST import * 9 | 10 | class wsEmb(nn.Module): 11 | def __init__(self, imEncoder, wordEncoder): 12 | super(wsEmb, self).__init__() 13 | self._train=True 14 | self.imEncoder = imEncoder 15 | self.wordEncoder = wordEncoder 16 | self._initialize_weights() 17 | self.wsMode = 'rank' 18 | self.vis_type = 'fc' 19 | 20 | def forward(self, imDis, wordEmb, capLengthsFull=None, frmListFull=None, rpListFull=None, dataIdx=None): 21 | if dataIdx is not None: 22 | numEle = len(dataIdx) 23 | if capLengthsFull is not None: 24 | capLengths = [capLengthsFull[int(dataIdx[i])] for i in range(numEle)] 25 | else: 26 | capLengths = capLengthsFull 27 | if frmListFull is not None: 28 | frmList = [frmListFull[int(dataIdx[i])] for i in range(numEle)] 29 | else: 30 | frmList = frmListFull 31 | if rpListFull is not None: 32 | rpList = [rpListFull[int(dataIdx[i])] for i in range(numEle)] 33 | else: 34 | rpList = rpListFull 35 | else: 36 | capLengths = capLengthsFull 37 | frmList = frmListFull 38 | rpList = rpListFull 39 | 40 | if self.wsMode == 'rankTube' or self.wsMode =='rankFrm': 41 | imEnDis, wordEnDis = self.forwardRank(imDis, wordEmb, capLengths) 42 | return imEnDis, wordEnDis 43 | 44 | #pdb.set_trace() 45 | if self.wsMode == 'coAtt' or self.wsMode == 'coAttV2' or self.wsMode=='coAttV3' or self.wsMode=='coAttV4' or self.wsMode=='coAttBi' or self.wsMode=='coAttBiV2' or self.wsMode=='coAttBiV3': 46 | simMM = self.forwardCoAtt(imDis, wordEmb, capLengths) 47 | return simMM 48 | 49 | elif self.wsMode == 'rankGroundR' or self.wsMode=='rankGroundRV2': 50 | logDist, rpSS= self.forwardGroundR(imDis, wordEmb, capLengths) 51 | return logDist, rpSS 52 | 53 | elif self.wsMode == 'coAttGroundR': 54 | logDist, rpSS= self.forwardCoAttGroundR(imDis, wordEmb, capLengths) 55 | return logDist, rpSS 56 | 57 | 58 | def forwardCoAttGroundR(self, imDis, wordEmb, capLengths): 59 | b_size = imDis.shape[0] 60 | assert len(imDis.size())==4 61 | assert len(wordEmb.size())==4 62 | 63 | imDis = imDis.view(-1, imDis.shape[2], imDis.shape[3]) 64 | wordEmb = wordEmb.view(-1, wordEmb.shape[2], wordEmb.shape[3]) 65 | imEnDis, wordEnDis = self.imEncoder(imDis, wordEmb, capLengths) 66 | 67 | imEnDis = imEnDis.view(b_size , -1, imEnDis.shape[2]) 68 | wordEnDis = wordEnDis.view(b_size, -1, wordEnDis.shape[2]) 69 | visFtrEm, rpSS= self.visDecoder(imEnDis, wordEnDis) 70 | reCnsSen =None 71 | if self.training: 72 | assert visFtrEm.shape[2]==1 73 | visFtrEm = visFtrEm.squeeze(dim=2) 74 | reCnsSen= self.recntr(visFtrEm, wordEmb, capLengths) 75 | return reCnsSen, rpSS 76 | 77 | 78 | def forwardRank(self, imDis, wordEmb, capLengths): 79 | if self.vis_type =='fc': 80 | imDis = imDis.mean(1) 81 | else: 82 | assert len(imDis.size())==3 83 | assert len(wordEmb.size())==3 84 | #pdb.set_trace() 85 | imEnDis = self.imEncoder(imDis) 86 | wordEnDis = self.wordEncoder(wordEmb, capLengths) 87 | assert len(imEnDis.size())==2 88 | assert len(wordEnDis.size())==2 89 | imEnDis = F.normalize(imEnDis, p=2, dim=1) 90 | wordEnDis = F.normalize(wordEnDis, p=2, dim=1) 91 | return imEnDis, wordEnDis 92 | 93 | def forwardGroundR(self, imDis, wordEmb, capLengths, frmList=None): 94 | #pdb.set_trace() 95 | b_size = imDis.shape[0] 96 | assert len(imDis.size())==4 97 | assert len(wordEmb.size())==4 98 | 99 | if self.vis_type =='fc': 100 | imDis = imDis.mean(2) 101 | else: 102 | imDis = imDis.view(-1, imDis.shape[2], imDis.shape[3]) 103 | imEnDis = self.imEncoder(imDis) 104 | wordEmb = wordEmb.view(-1, wordEmb.shape[2], wordEmb.shape[3]) 105 | wordEnDis = self.wordEncoder(wordEmb, capLengths) 106 | #imEnDis = F.normalize(imEnDis, p=2, dim=2) 107 | #wordEnDis = F.normalize(wordEnDis, p=2, dim=2) 108 | # pdb.set_trace() 109 | if (len(imEnDis.shape)==3): 110 | imEnDis = imEnDis.view(b_size , -1, imEnDis.shape[2]) 111 | elif(len(imEnDis.shape)==2): 112 | imEnDis = imEnDis.view(b_size , -1, imEnDis.shape[1]) 113 | wordEnDis = wordEnDis.view(b_size, -1, wordEnDis.shape[1]) 114 | visFtrEm, rpSS= self.visDecoder(imEnDis, wordEnDis) 115 | 116 | #pdb.set_trace() 117 | if hasattr(self, 'fc2recontr'): 118 | visFtrEm_trs = visFtrEm.transpose(1,2) 119 | visFtrEm_trs = self.fc2recontr(visFtrEm_trs) 120 | visFtrEm = visFtrEm_trs.transpose(1,2) 121 | reCnsSen =None 122 | # pdb.set_trace() 123 | if self.training: 124 | assert visFtrEm.shape[2]==1 125 | visFtrEm = visFtrEm.squeeze(dim=2) 126 | reCnsSen= self.recntr(visFtrEm, wordEmb, capLengths) 127 | return reCnsSen, rpSS 128 | 129 | 130 | def forwardCoAtt(self, imDis, wordEmb, capLengths): 131 | assert len(imDis.size())==3 132 | assert len(wordEmb.size())==3 133 | #pdb.set_trace() 134 | simMM = self.imEncoder(imDis, wordEmb, capLengths) 135 | return simMM 136 | 137 | def _initialize_weights(self): 138 | #pdb.set_trace() 139 | for m in self.modules(): 140 | if isinstance(m, nn.Conv2d): 141 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 142 | if m.bias is not None: 143 | nn.init.constant_(m.bias, 0) 144 | elif isinstance(m, nn.BatchNorm2d): 145 | nn.init.constant_(m.weight, 1) 146 | nn.init.constant_(m.bias, 0) 147 | elif isinstance(m, nn.BatchNorm1d): 148 | nn.init.constant_(m.weight, 1) 149 | nn.init.constant_(m.bias, 0) 150 | elif isinstance(m, nn.Linear): 151 | nn.init.normal_(m.weight, 0, 0.01) 152 | if m.bias is not None: 153 | nn.init.constant_(m.bias, 0) 154 | 155 | def build_groundR(self, opts): 156 | self.recntr = build_recontructor(opts) 157 | self.visDecoder = build_visDecoder(opts) 158 | if opts.wsMode=='rankGroundRV2': 159 | self.fc2recontr = torch.nn.Linear(opts.dim_ftr, 300) 160 | 161 | class txtDecoder(nn.Module): 162 | def __init__(self, embedDim, hidden_dim, vocaSize): 163 | super(txtDecoder, self).__init__() 164 | self.hidden_dim = hidden_dim 165 | self.lstm =nn.LSTM(embedDim, hidden_dim, batch_first=True) 166 | self.hidden = self.init_hidden() 167 | self.dec_log = torch.nn.Linear(hidden_dim, vocaSize) 168 | 169 | def init_hidden(self, batchSize=10): 170 | self.hidden=(torch.zeros(1, batchSize, self.hidden_dim).cuda(), 171 | torch.zeros(1, batchSize, self.hidden_dim).cuda()) 172 | 173 | def forward(self, visFtr, capFtr, capLght): 174 | #pdb.set_trace() 175 | inputs = torch.cat((visFtr.unsqueeze(1), capFtr), 1) 176 | 177 | # pack data (prepare it for pytorch model) 178 | # inputs_packed = pack_padded_sequence(inputs, capLght, batch_first=True) 179 | inputs_packed = inputs 180 | # run data through recurrent network 181 | hiddens, _ = self.lstm(inputs_packed) 182 | #pdb.set_trace() 183 | outputs = self.dec_log(hiddens) 184 | #pdb.set_trace() 185 | return outputs 186 | 187 | class visDecoder(nn.Module): 188 | def __init__(self, dim_ftr, hdSize, coAtt_flag=False): 189 | super(visDecoder, self).__init__() 190 | self.attNet = torch.nn.Sequential( 191 | torch.nn.Linear(dim_ftr*2, hdSize), 192 | torch.nn.ReLU(), 193 | torch.nn.Dropout(p=0.5), 194 | torch.nn.Linear(hdSize, 1), 195 | ) 196 | self.coAtt_flag = coAtt_flag 197 | 198 | def _initialize_weights(self): 199 | for m in self.modules(): 200 | if isinstance(m, nn.Conv2d): 201 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 202 | if m.bias is not None: 203 | nn.init.constant_(m.bias, 0) 204 | elif isinstance(m, nn.BatchNorm2d): 205 | nn.init.constant_(m.weight, 1) 206 | nn.init.constant_(m.bias, 0) 207 | elif isinstance(m, nn.BatchNorm1d): 208 | nn.init.constant_(m.weight, 1) 209 | nn.init.constant_(m.bias, 0) 210 | elif isinstance(m, nn.Linear): 211 | nn.init.normal_(m.weight, 0, 0.01) 212 | nn.init.constant_(m.bias, 0) 213 | 214 | def forward(self, vis_rp_ftr, cap_ftr): 215 | b_size, per_cap_num, dim_cap = cap_ftr.shape 216 | b_szie_v, rp_num, dim_vis =vis_rp_ftr.shape 217 | 218 | # for co-Attention 219 | if self.coAtt_flag: 220 | conca_ftr = torch.cat((cap_ftr, vis_rp_ftr), 2) 221 | rpScore = self.attNet(conca_ftr) 222 | rpScore = rpScore.view(b_size, rp_num) 223 | rpSS= F.softmax(rpScore, dim=1) 224 | rpSS = rpSS.view(b_size, rp_num, 1) 225 | visFtrAtt = torch.sum(torch.mul(vis_rp_ftr, rpSS), dim=1) 226 | return visFtrAtt.unsqueeze(2), rpSS 227 | 228 | # for independent embedding 229 | rp_ss_list = list() 230 | vis_att_list = list() 231 | for i in range(per_cap_num): 232 | cap_ftr_exp = cap_ftr[:, i, :].unsqueeze(dim=1).expand(b_size, rp_num, dim_cap) 233 | conca_ftr = torch.cat((cap_ftr_exp, vis_rp_ftr), 2) 234 | rpScore = self.attNet(conca_ftr) 235 | rpScore = rpScore.view(b_size, rp_num) 236 | assert(len(rpScore.shape)==2) 237 | rpSS= F.softmax(rpScore, dim=1) 238 | rpSS = rpSS.view(b_size, rp_num, 1) 239 | visFtrAtt = torch.sum(torch.mul(vis_rp_ftr, rpSS), dim=1) 240 | rp_ss_list.append(rpSS) 241 | vis_att_list.append(visFtrAtt.unsqueeze(dim=2)) 242 | vis_ftr_mat = torch.cat(vis_att_list, dim=2) 243 | rp_ss_mat = torch.cat(rp_ss_list, dim=2) 244 | return vis_ftr_mat, rp_ss_mat 245 | 246 | 247 | class txtEncoderV2(nn.Module): 248 | def __init__(self, embedDim, hidden_dim, seq_type='lstmV2'): 249 | super(txtEncoderV2, self).__init__() 250 | self.hidden_dim = hidden_dim 251 | self.seq_type = seq_type 252 | self.fc1 = torch.nn.Linear(embedDim, hidden_dim) 253 | if seq_type =='lstmV2': 254 | self.lstm =nn.LSTM(hidden_dim, hidden_dim, batch_first=True) 255 | elif seq_type =='gruV2': 256 | self.lstm =nn.GRU(hidden_dim, hidden_dim, batch_first=True) 257 | self.hidden = self.init_hidden() 258 | 259 | def init_hidden(self, batchSize=10): 260 | if self.seq_type=='lstmV2': 261 | self.hidden=(torch.zeros(1, batchSize, self.hidden_dim).cuda(), 262 | torch.zeros(1, batchSize, self.hidden_dim).cuda()) 263 | elif self.seq_type=='gruV2': 264 | self.hidden=torch.zeros(1, batchSize, self.hidden_dim).cuda() 265 | 266 | def forward(self, wordMatrixOri, wordLeg=None): 267 | #pdb.set_trace() 268 | # shorten steps for faster training 269 | wordMatrix = self.fc1(wordMatrixOri) 270 | self.init_hidden(wordMatrix.shape[0]) 271 | #pdb.set_trace() 272 | lstmOut, self.hidden = self.lstm(wordMatrix, self.hidden) 273 | #pdb.set_trace() 274 | # lstmOut: B*T*D 275 | if wordLeg==None: 276 | return self.hidden[0].squeeze() 277 | else: 278 | txtEMb = [lstmOut[i, wordLeg[i]-1,:] for i in range(len(wordLeg))] 279 | #pdb.set_trace() 280 | txtEmbMat = torch.stack(txtEMb) 281 | return txtEmbMat 282 | 283 | 284 | class txtEncoder(nn.Module): 285 | def __init__(self, embedDim, hidden_dim, seq_type='lstm'): 286 | super(txtEncoder, self).__init__() 287 | self.hidden_dim = hidden_dim 288 | self.seq_type = seq_type 289 | if seq_type =='lstm': 290 | self.lstm =nn.LSTM(embedDim, hidden_dim, batch_first=True) 291 | elif seq_type =='gru': 292 | self.lstm =nn.GRU(embedDim, hidden_dim, batch_first=True) 293 | self.hidden = self.init_hidden() 294 | 295 | def init_hidden(self, batchSize=10): 296 | if self.seq_type=='lstm': 297 | self.hidden=(torch.zeros(1, batchSize, self.hidden_dim).cuda(), 298 | torch.zeros(1, batchSize, self.hidden_dim).cuda()) 299 | elif self.seq_type=='gru': 300 | self.hidden=torch.zeros(1, batchSize, self.hidden_dim).cuda() 301 | 302 | 303 | def forward(self, wordMatrix, wordLeg=None): 304 | #pdb.set_trace() 305 | # shorten steps for faster training 306 | self.init_hidden(wordMatrix.shape[0]) 307 | #pdb.set_trace() 308 | lstmOut, self.hidden = self.lstm(wordMatrix, self.hidden) 309 | #pdb.set_trace() 310 | # lstmOut: B*T*D 311 | if wordLeg==None: 312 | return self.hidden[0].squeeze() 313 | else: 314 | txtEMb = [lstmOut[i, wordLeg[i]-1,:] for i in range(len(wordLeg))] 315 | #pdb.set_trace() 316 | txtEmbMat = torch.stack(txtEMb) 317 | return txtEmbMat 318 | 319 | def build_txt_encoder(opts): 320 | embedDim = 300 321 | if opts.txt_type=='lstmV2': 322 | txt_encoder = txtEncoderV2(embedDim, opts.dim_ftr, opts.txt_type) 323 | return txt_encoder 324 | else: 325 | txt_encoder = txtEncoder(embedDim, opts.dim_ftr, opts.txt_type) 326 | return txt_encoder 327 | 328 | class visSeqEncoder(nn.Module): 329 | def __init__(self, embedDim, hidden_dim, seq_Type='lstm'): 330 | super(visSeqEncoder, self).__init__() 331 | self.hidden_dim = hidden_dim 332 | self.seq_type = seq_Type 333 | if seq_Type =='lstm': 334 | self.lstm =nn.LSTM(embedDim, hidden_dim, batch_first=True) 335 | elif seq_Type =='gru': 336 | self.lstm =nn.GRU(embedDim, hidden_dim, batch_first=True) 337 | 338 | self.hidden = self.init_hidden() 339 | 340 | def init_hidden(self, batchSize=10): 341 | if self.seq_type=='lstm': 342 | self.hidden=(torch.zeros(1, batchSize, self.hidden_dim).cuda(), 343 | torch.zeros(1, batchSize, self.hidden_dim).cuda()) 344 | elif self.seq_type=='gru': 345 | self.hidden=torch.zeros(1, batchSize, self.hidden_dim).cuda() 346 | 347 | def forward(self, wordMatrix, wordLeg=None): 348 | # shorten steps for faster training 349 | self.init_hidden(wordMatrix.shape[0]) 350 | #pdb.set_trace() 351 | lstmOut, self.hidden = self.lstm(wordMatrix, self.hidden) 352 | #pdb.set_trace() 353 | # lstmOut: B*T*D 354 | if wordLeg==None: 355 | return self.hidden[0].squeeze() 356 | else: 357 | txtEMb = [lstmOut[i, wordLeg[i]-1,:] for i in range(len(wordLeg))] 358 | #pdb.set_trace() 359 | txtEmbMat = torch.stack(txtEMb) 360 | return txtEmbMat 361 | 362 | 363 | class visSeqEncoderV2(nn.Module): 364 | def __init__(self, embedDim, hidden_dim, seq_Type='lstmV2'): 365 | super(visSeqEncoderV2, self).__init__() 366 | self.hidden_dim = hidden_dim 367 | self.seq_type = seq_Type 368 | if seq_Type =='lstm': 369 | self.lstm =nn.LSTM(embedDim, hidden_dim, batch_first=True) 370 | elif seq_Type =='gru': 371 | self.lstm =nn.GRU(embedDim, hidden_dim, batch_first=True) 372 | if seq_Type =='lstmV2': 373 | self.lstm =nn.LSTM(hidden_dim, hidden_dim, batch_first=True) 374 | self.fc1 = torch.nn.Linear(embedDim, hidden_dim) 375 | 376 | self.hidden = self.init_hidden() 377 | 378 | def init_hidden(self, batchSize=10): 379 | if self.seq_type=='lstm' or self.seq_type=='lstmV2': 380 | self.hidden=(torch.zeros(1, batchSize, self.hidden_dim).cuda(), 381 | torch.zeros(1, batchSize, self.hidden_dim).cuda()) 382 | elif self.seq_type=='gru': 383 | self.hidden=torch.zeros(1, batchSize, self.hidden_dim).cuda() 384 | 385 | def forward(self, wordMatrixOri, wordLeg=None): 386 | # shorten steps for faster training 387 | if self.seq_type =='lstmV2': 388 | wordMatrix = self.fc1(wordMatrixOri) 389 | else: 390 | wordMatrix = wordMatrixOri 391 | self.init_hidden(wordMatrix.shape[0]) 392 | #pdb.set_trace() 393 | lstmOut, self.hidden = self.lstm(wordMatrix, self.hidden) 394 | #pdb.set_trace() 395 | # lstmOut: B*T*D 396 | if wordLeg==None: 397 | return self.hidden[0].squeeze() 398 | else: 399 | txtEMb = [lstmOut[i, wordLeg[i]-1,:] for i in range(len(wordLeg))] 400 | #pdb.set_trace() 401 | txtEmbMat = torch.stack(txtEMb) 402 | return txtEmbMat 403 | 404 | 405 | def build_vis_fc_encoder(opts): 406 | inputDim = opts.vis_dim 407 | visNet = torch.nn.Sequential( 408 | torch.nn.Linear(inputDim, inputDim), 409 | torch.nn.ReLU(), 410 | torch.nn.Dropout(p=0.5), 411 | torch.nn.Linear(inputDim, opts.dim_ftr), 412 | ) 413 | return visNet 414 | 415 | class vlad_encoder(nn.Module): 416 | def __init__(self, input_dim, out_dim, hidden_dim, centre_num, alpha=1.0): 417 | super(vlad_encoder, self).__init__() 418 | self.fc1 = torch.nn.Linear(input_dim, hidden_dim) 419 | self.net_vlad = NetVLAD(num_clusters= centre_num, dim= hidden_dim, alpha=1.0) 420 | self.fc2 = torch.nn.Linear(centre_num*hidden_dim, out_dim) 421 | 422 | def forward(self, input_data): 423 | input_hidden = self.fc1(input_data) 424 | input_hidden = F.relu(input_hidden) 425 | input_hidden = input_hidden.unsqueeze(dim=3) 426 | input_hidden = input_hidden.transpose(dim0=1, dim1=2) 427 | pdb.set_trace() 428 | input_vlad = self.net_vlad(input_hidden) 429 | #pdb.set_trace() 430 | out_vlad = self.fc2(input_vlad) 431 | out_vlad = F.relu(out_vlad) 432 | return out_vlad 433 | 434 | def build_vis_vlad_encoder_v1(opts): 435 | input_dim = opts.vis_dim 436 | hidden_dim = opts.hidden_dim 437 | centre_num = opts.centre_num 438 | out_dim = opts.dim_ftr 439 | alpha = opts.vlad_alpha 440 | vis_encoder = vlad_encoder(input_dim, out_dim, hidden_dim, centre_num, alpha=1.0 ) 441 | return vis_encoder 442 | 443 | def build_vis_seq_encoder(opts): 444 | embedDim = opts.vis_dim 445 | if opts.vis_type == 'lstm' or opts.vis_type=='gru': 446 | vis_seq_encoder = visSeqEncoder(embedDim, opts.dim_ftr, opts.vis_type) 447 | return vis_seq_encoder 448 | if opts.vis_type == 'lstmV2': 449 | vis_seq_encoder = visSeqEncoderV2(embedDim, opts.dim_ftr, opts.vis_type) 450 | return vis_seq_encoder 451 | elif opts.vis_type == 'fc': 452 | vis_avg_encoder = build_vis_fc_encoder(opts) 453 | return vis_avg_encoder 454 | elif opts.vis_type == 'vlad_v1': 455 | vis_vlad_encoder = build_vis_vlad_encoder_v1(opts) 456 | return vis_vlad_encoder 457 | elif opts.vis_type == 'avgMIL': 458 | vis_avg_encoder = build_vis_fc_encoder(opts) 459 | return vis_avg_encoder 460 | 461 | def build_recontructor(opts): 462 | if opts.wsMode=='coAttGroundR': 463 | return txtDecoder(opts.lstm_hidden_size, opts.hdSize, opts.vocaSize) 464 | else: 465 | return txtDecoder(opts.hdSize, opts.hdSize, opts.vocaSize) 466 | 467 | def build_visDecoder(opts): 468 | if opts.wsMode=='coAttGroundR': 469 | return visDecoder(opts.lstm_hidden_size, opts.hdSize, True) 470 | else: 471 | return visDecoder(opts.dim_ftr, opts.hdSize, False) 472 | 473 | 474 | def build_network(opts): 475 | # pdb.set_trace() 476 | if opts.wsMode == 'rankTube' or opts.wsMode=='rankFrm': 477 | imEncoder= build_vis_seq_encoder(opts) 478 | wordEncoder = build_txt_encoder(opts) 479 | wsEncoder = wsEmb(imEncoder, wordEncoder) 480 | elif opts.wsMode == 'coAtt': 481 | sst_Obj = SST(opts) 482 | wsEncoder = wsEmb(sst_Obj, None) 483 | elif opts.wsMode == 'coAttBi': 484 | sst_Obj = SSTBi(opts) 485 | wsEncoder = wsEmb(sst_Obj, None) 486 | elif opts.wsMode == 'coAttBiV2': 487 | sst_Obj = SSTBiV2(opts) 488 | wsEncoder = wsEmb(sst_Obj, None) 489 | elif opts.wsMode == 'coAttBiV3': 490 | sst_Obj = SSTBiV3(opts) 491 | wsEncoder = wsEmb(sst_Obj, None) 492 | elif opts.wsMode == 'rankGroundR': 493 | imEncoder= build_vis_seq_encoder(opts) 494 | wordEncoder = build_txt_encoder(opts) 495 | wsEncoder = wsEmb(imEncoder, wordEncoder) 496 | wsEncoder.build_groundR(opts) 497 | elif opts.wsMode == 'rankGroundRV2': 498 | imEncoder= build_vis_seq_encoder(opts) 499 | wordEncoder = build_txt_encoder(opts) 500 | wsEncoder = wsEmb(imEncoder, wordEncoder) 501 | wsEncoder.build_groundR(opts) 502 | elif opts.wsMode == 'coAttGroundR': 503 | sst_Gr = SSTGroundR(opts) 504 | wsEncoder = wsEmb(sst_Gr, None) 505 | wsEncoder.build_groundR(opts) 506 | elif opts.wsMode == 'coAttV2': 507 | sst_mul = SSTMul(opts) 508 | wsEncoder = wsEmb(sst_mul, None) 509 | elif opts.wsMode == 'coAttV3': 510 | sst_v3 = SSTV3(opts) 511 | wsEncoder = wsEmb(sst_v3, None) 512 | elif opts.wsMode == 'coAttV4': 513 | sst_v4 = SSTV4(opts) 514 | wsEncoder = wsEmb(sst_v4, None) 515 | 516 | wsEncoder.wsMode = opts.wsMode 517 | wsEncoder.vis_type = opts.vis_type 518 | if opts.gpu: 519 | wsEncoder= wsEncoder.cuda() 520 | if opts.initmodel is not None: 521 | md_stat = torch.load(opts.initmodel) 522 | wsEncoder.load_state_dict(md_stat, strict=False) 523 | if opts.isParal: 524 | wsEncoder = nn.DataParallel(wsEncoder).cuda() 525 | 526 | return wsEncoder 527 | 528 | -------------------------------------------------------------------------------- /fun/netUtil.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import shutil 4 | import sys 5 | import pdb 6 | #from evalDet import * 7 | sys.path.append('..') 8 | from util.mytoolbox import * 9 | from image_toolbox import * 10 | 11 | def visTube_from_image(frmList, tube, outName): 12 | image_list = list() 13 | for i, bbx in enumerate(tube): 14 | imName = frmList[i] 15 | img = draw_rectangle(imName, bbx) 16 | image_list.append(img) 17 | images2video(image_list, 10, outName) 18 | 19 | def vis_image_bbx(frmList, tube, color=(0,0,255), thickness=3, dotted=False): 20 | image_list = list() 21 | for i, bbx in enumerate(tube): 22 | imName = frmList[i] 23 | img = draw_rectangle(imName, bbx, color, thickness, dotted) 24 | image_list.append(img) 25 | return image_list 26 | 27 | def vis_gray_but_bbx(frmList, tube): 28 | image_list = list() 29 | for i, bbx in enumerate(tube): 30 | imName = frmList[i] 31 | img = gray_background(imName, bbx) 32 | image_list.append(img) 33 | return image_list 34 | 35 | 36 | 37 | KEYS = ['x1', 'y1', 'x2', 'y2'] 38 | def compute_IoU(box1, box2): 39 | if isinstance(box1, list): 40 | box1 = {key: val for key, val in zip(KEYS, box1)} 41 | if isinstance(box2, list): 42 | box2 = {key: val for key, val in zip(KEYS, box2)} 43 | width = max(min(box1['x2'], box2['x2']) - max(box1['x1'], box2['x1']), 0) 44 | height = max(min(box1['y2'], box2['y2']) - max(box1['y1'], box2['y1']), 0) 45 | intersection = width * height 46 | box1_area = (box1['x2'] - box1['x1']) * (box1['y2'] - box1['y1']) 47 | box2_area = (box2['x2'] - box2['x1']) * (box2['y2'] - box2['y1']) 48 | union = box1_area + box2_area - intersection 49 | return float(intersection) / (float(union) +0.000001) # avoid overthow 50 | 51 | 52 | EPS = 1e-10 53 | def compute_IoU_v2(bbox1, bbox2): 54 | bbox1_area = float((bbox1[2] - bbox1[0] + EPS) * (bbox1[3] - bbox1[1] + EPS)) 55 | bbox2_area = float((bbox2[2] - bbox2[0] + EPS) * (bbox2[3] - bbox2[1] + EPS)) 56 | w = max(0.0, min(bbox1[2], bbox2[2]) - max(bbox1[0], bbox2[0]) + EPS) 57 | h = max(0.0, min(bbox1[3], bbox2[3]) - max(bbox1[1], bbox2[1]) + EPS) 58 | inter = float(w * h) 59 | ovr = inter / (bbox1_area + bbox2_area - inter) 60 | return ovr 61 | 62 | def is_annotated(traj, frame_ind): 63 | if not frame_ind in traj: 64 | return False 65 | box = traj[frame_ind] 66 | if box[0] < 0: 67 | for el_val in box[1:]: 68 | assert el_val < 0 69 | return False 70 | for el_val in box[1:]: 71 | assert el_val >= 0 72 | return True 73 | 74 | def compute_LS(traj, gt_traj): 75 | # see http://jvgemert.github.io/pub/jain-tubelets-cvpr2014.pdf 76 | assert isinstance(traj.keys()[0], type(gt_traj.keys()[0])) 77 | IoU_list = [] 78 | for frame_ind, gt_box in gt_traj.iteritems(): 79 | gt_is_annotated = is_annotated(gt_traj, frame_ind) 80 | pr_is_annotated = is_annotated(traj, frame_ind) 81 | if (not gt_is_annotated) and (not pr_is_annotated): 82 | continue 83 | if (not gt_is_annotated) or (not pr_is_annotated): 84 | IoU_list.append(0.0) 85 | continue 86 | box = traj[frame_ind] 87 | IoU_list.append(compute_IoU_v2(box, gt_box)) 88 | return sum(IoU_list) / len(IoU_list) 89 | 90 | def get_tubes(det_list_org, alpha): 91 | det_list = copy.deepcopy(det_list_org) 92 | tubes = [] 93 | continue_flg = True 94 | tube_scores = [] 95 | while continue_flg: 96 | timestep = 0 97 | score_list = [] 98 | score_list.append(np.zeros(det_list[timestep][0].shape[0])) 99 | prevind_list = [] 100 | prevind_list.append([-1] * det_list[timestep][0].shape[0]) 101 | timestep += 1 102 | while timestep < len(det_list): 103 | n_curbox = det_list[timestep][0].shape[0] 104 | n_prevbox = score_list[-1].shape[0] 105 | cur_scores = np.zeros(n_curbox) - np.inf 106 | prev_inds = [-1] * n_curbox 107 | for i_prevbox in range(n_prevbox): 108 | prevbox_coods = det_list[timestep-1][1][i_prevbox, :] 109 | prevbox_score = det_list[timestep-1][0][i_prevbox, 0] 110 | for i_curbox in range(n_curbox): 111 | curbox_coods = det_list[timestep][1][i_curbox, :] 112 | curbox_score = det_list[timestep][0][i_curbox, 0] 113 | try: 114 | e_score = compute_IoU(prevbox_coods.tolist(), curbox_coods.tolist()) 115 | except: 116 | pdb.set_trace() 117 | link_score = prevbox_score + curbox_score + alpha * e_score 118 | cur_score = score_list[-1][i_prevbox] + link_score 119 | if cur_score > cur_scores[i_curbox]: 120 | cur_scores[i_curbox] = cur_score 121 | prev_inds[i_curbox] = i_prevbox 122 | score_list.append(cur_scores) 123 | prevind_list.append(prev_inds) 124 | timestep += 1 125 | 126 | # get path and remove used boxes 127 | cur_tube = [None] * len(det_list) 128 | tube_score = np.max(score_list[-1]) / len(det_list) 129 | prev_ind = np.argmax(score_list[-1]) 130 | timestep = len(det_list) - 1 131 | while timestep >= 0: 132 | cur_tube[timestep] = det_list[timestep][1][prev_ind, :].tolist() 133 | det_list[timestep][0] = np.delete(det_list[timestep][0], prev_ind, axis=0) 134 | det_list[timestep][1] = np.delete(det_list[timestep][1], prev_ind, axis=0) 135 | prev_ind = prevind_list[timestep][prev_ind] 136 | if det_list[timestep][1].shape[0] == 0: 137 | continue_flg = False 138 | timestep -= 1 139 | assert prev_ind < 0 140 | tubes.append(cur_tube) 141 | tube_scores.append(tube_score) 142 | return tubes, tube_scores 143 | 144 | def save_check_point(state, is_best=False, file_name='../data/models/checkpoint.pth'): 145 | fdName = os.path.dirname(file_name) 146 | makedirs_if_missing(fdName) 147 | torch.save(state, file_name) 148 | if is_best: 149 | shutil.copyfile(file_name, '../data/model/best_model.pth') 150 | 151 | def load_model_state(model, file_name): 152 | states = torch.load(file_name) 153 | model.load_state_dict(states) 154 | 155 | 156 | -------------------------------------------------------------------------------- /fun/optimizers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def build_opt(args, model): 4 | if args.optimizer.lower() == 'adam': 5 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay) 6 | elif args.optimizer.lower() == 'rmsprop': 7 | optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.decay) 8 | elif args.optimizer.lower() == 'sgd': 9 | optimizer = torch.optim.SGD(model.parameters(), 10 | lr=args.lr, 11 | momentum=args.momentum, weight_decay=args.decay) 12 | else: 13 | raise Exception() 14 | return optimizer 15 | -------------------------------------------------------------------------------- /fun/train.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import torch 3 | import random 4 | def random_seeding(seed_value, use_cuda): 5 | numpy.random.seed(seed_value) # cpu vars 6 | torch.manual_seed(seed_value) # cpu vars 7 | random.seed(seed_value) 8 | if use_cuda: 9 | torch.cuda.manual_seed_all(seed_value) # gpu vars 10 | 11 | 12 | from wsParamParser import parse_args 13 | from data.data_loader import* 14 | from datasetLoader import * 15 | from modelArc import * 16 | from optimizers import * 17 | from logInfo import logInF 18 | from lossPackage import * 19 | from netUtil import * 20 | from tensorboardX import SummaryWriter 21 | import time 22 | 23 | import pdb 24 | 25 | if __name__=='__main__': 26 | #pdb.set_trace() 27 | opt = parse_args() 28 | random_seeding(opt.seed, True) 29 | # build dataloader 30 | dataLoader, datasetOri= build_dataloader(opt) 31 | # build network 32 | model = build_network(opt) 33 | # build_optimizer 34 | optimizer = build_opt(opt, model) 35 | # build loss layer 36 | lossEster = build_lossEval(opt) 37 | # build logger 38 | logger = logInF(opt.logFd) 39 | writer = SummaryWriter(opt.logFdTx+time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) 40 | #pdb.set_trace() 41 | for ep in range(opt.stEp, opt.epSize): 42 | resultList_full = list() 43 | tBf = time.time() 44 | for itr, inputData in enumerate(dataLoader): 45 | #tube_embedding, cap_embedding, tubeInfo, indexOri, cap_length_list, vd_name_list, word_lbl_list, region_gt_ori = inputData 46 | tube_embedding, cap_embedding, tubeInfo, indexOri, cap_length_list, vd_name_list, word_lbl_list = inputData 47 | #pdb.set_trace() 48 | dataIdx = None 49 | tmp_bsize = tube_embedding.shape[0] 50 | imDis = tube_embedding.cuda() 51 | imDis = imDis.view(-1, imDis.shape[2], imDis.shape[3]) 52 | wordEmb = cap_embedding.cuda() 53 | wordEmb = wordEmb.view(-1, wordEmb.shape[2], wordEmb.shape[3]) 54 | tAf = time.time() 55 | 56 | if opt.wsMode =='coAtt': 57 | #pdb.set_trace() 58 | simMM = model(imDis, wordEmb, cap_length_list) 59 | # pdb.set_trace() 60 | simMM = simMM.view(tmp_bsize, opt.rpNum, tmp_bsize, opt.capNum) 61 | if len(set(vd_name_list))<2: 62 | continue 63 | 64 | loss = lossEster(simMM=simMM, lblList =vd_name_list) 65 | resultList = evalAcc_att(simMM, tubeInfo, indexOri, datasetOri, opt.visRsFd+str(ep), False) 66 | resultList_full +=resultList 67 | #pdb.set_trace() 68 | if loss<=0: 69 | continue 70 | loss = loss/opt.update_iter 71 | loss.backward(retain_graph=True ) 72 | if (itr+1)%opt.update_iter==0: 73 | #optimizer.zero_grad() 74 | optimizer.step() 75 | optimizer.zero_grad() 76 | tNf = time.time() 77 | if(itr%opt.visIter==0): 78 | #resultList = evalAcc_att(simMM, tubeInfo, indexOri, datasetOri, opt.visRsFd+str(ep), False) 79 | #resultList_full +=resultList 80 | #tAf = time.time() 81 | logger('Ep: %d, Iter: %d, T1: %3f, T2:%3f, loss: %3f\n' %(ep, itr, (tAf-tBf), (tNf-tAf)/opt.visIter, float(loss.data.cpu().numpy()))) 82 | tBf = time.time() 83 | writer.add_scalar('loss', loss.data.cpu()[0], ep*len(datasetOri)+itr*opt.batchSize) 84 | accSum = 0 85 | resultList = list() 86 | for ele in resultList: 87 | index, recall_k= ele 88 | accSum +=recall_k 89 | #logger('Average accuracy on training batch is %3f\n' %(accSum/len(resultList))) 90 | tBf = time.time() 91 | 92 | ## evaluation within an epoch 93 | if(ep % opt.saveEp==0 and itr==0 and ep >0): 94 | checkName = opt.outPre+'_ep_'+str(ep) +'_itr_'+str(itr)+'.pth' 95 | save_check_point(model.state_dict(), file_name=checkName) 96 | model.eval() 97 | resultList = list() 98 | vIdList = list() 99 | set_name_ori= opt.set_name 100 | opt.set_name = 'val' 101 | batchSizeOri = opt.batchSize 102 | opt.batchSize = 1 103 | dataLoaderEval, datasetEvalOri = build_dataloader(opt) 104 | opt.batchSize = batchSizeOri 105 | #pdb.set_trace() 106 | for itr_eval, inputData in enumerate(dataLoaderEval): 107 | #tube_embedding, cap_embedding, tubeInfo, indexOri, cap_length_list, vd_name_list, word_lbl_list, region_gt_ori = inputData 108 | tube_embedding, cap_embedding, tubeInfo, indexOri, cap_length_list, vd_name_list, word_lbl_list = inputData 109 | #pdb.set_trace() 110 | dataIdx = None 111 | #pdb.set_trace() 112 | b_size = tube_embedding.shape[0] 113 | # B*P*T*D 114 | imDis = tube_embedding.cuda() 115 | imDis = imDis.view(-1, imDis.shape[2], imDis.shape[3]) 116 | wordEmb = cap_embedding.cuda() 117 | wordEmb = wordEmb.view(-1, wordEmb.shape[2], wordEmb.shape[3]) 118 | imDis.requires_grad=False 119 | wordEmb.requires_grad=False 120 | if opt.wsMode =='coAtt': 121 | simMM = model(imDis, wordEmb, cap_length_list) 122 | simMM = simMM.view(b_size, opt.rpNum, b_size) 123 | resultList += evalAcc_att(simMM, tubeInfo, indexOri, datasetEvalOri, opt.visRsFd+str(ep), False) 124 | 125 | accSum = 0 126 | for ele in resultList: 127 | index, recall_k= ele 128 | accSum +=recall_k 129 | logger('Average accuracy on validation set is %3f\n' %(accSum/len(resultList))) 130 | writer.add_scalar('Average validation accuracy', accSum/len(resultList), ep*len(datasetOri)+ itr*opt.batchSize) 131 | #pdb.set_trace() 132 | model.train() 133 | opt.set_name = set_name_ori 134 | 135 | if(ep % opt.saveEp==0 and itr==0 and ep > 0): 136 | model.eval() 137 | resultList = list() 138 | vIdList = list() 139 | set_name_ori= opt.set_name 140 | opt.set_name = 'test' 141 | batchSizeOri = opt.batchSize 142 | opt.batchSize = 1 143 | dataLoaderEval, datasetEvalOri = build_dataloader(opt) 144 | opt.batchSize = batchSizeOri 145 | #pdb.set_trace() 146 | for itr_eval, inputData in enumerate(dataLoaderEval): 147 | tube_embedding, cap_embedding, tubeInfo, indexOri, cap_length_list, vd_name_list, word_lbl_lis =inputData 148 | #pdb.set_trace() 149 | dataIdx = None 150 | #pdb.set_trace() 151 | b_size = tube_embedding.shape[0] 152 | # B*P*T*D 153 | imDis = tube_embedding.cuda() 154 | imDis = imDis.view(-1, imDis.shape[2], imDis.shape[3]) 155 | wordEmb = cap_embedding.cuda() 156 | wordEmb = wordEmb.view(-1, wordEmb.shape[2], wordEmb.shape[3]) 157 | imDis.requires_grad=False 158 | wordEmb.requires_grad=False 159 | if opt.wsMode =='coAtt': 160 | simMM = model(imDis, wordEmb, cap_length_list) 161 | simMM = simMM.view(b_size, opt.rpNum, b_size) 162 | resultList += evalAcc_att(simMM, tubeInfo, indexOri, datasetEvalOri, opt.visRsFd+str(ep), False) 163 | 164 | 165 | accSum = 0 166 | for ele in resultList: 167 | index, recall_k= ele 168 | accSum +=recall_k 169 | logger('Average accuracy on testing set is %3f\n' %(accSum/len(resultList))) 170 | writer.add_scalar('Average testing accuracy', accSum/len(resultList), ep*len(datasetOri)+ itr*opt.batchSize) 171 | #pdb.set_trace() 172 | model.train() 173 | opt.set_name = set_name_ori 174 | 175 | accSum = 0 176 | for ele in resultList_full: 177 | index, recall_k= ele 178 | accSum +=recall_k 179 | #writer.add_scalar('Average training accuracy', accSum/len(resultList_full), ep*len(datasetOri)+ itr*opt.batchSize) 180 | -------------------------------------------------------------------------------- /fun/vidDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('..') 4 | sys.path.append('../../../WSSTL/fun') 5 | sys.path.append('../annotations') 6 | sys.path.append('../../annotations') 7 | import torch.utils.data as data 8 | import cv2 9 | import numpy as np 10 | from util.mytoolbox import * 11 | import random 12 | import scipy.io as sio 13 | import copy 14 | import torch 15 | from util.get_image_size import get_image_size 16 | from evalDet import * 17 | from datasetParser import extAllFrmFn 18 | import pdb 19 | from netUtil import * 20 | from wsParamParser import parse_args 21 | from ptd_api import * 22 | from vidDatasetParser import * 23 | from multiprocessing import Process, Pipe, cpu_count, Queue 24 | #from vidDatasetParser import vidInfoParser 25 | #from multiGraphAttention import extract_position_embedding 26 | import h5py 27 | 28 | 29 | class vidDataloader(data.Dataset): 30 | def __init__(self, ann_folder, prp_type, set_name, dictFile, tubePath, ftrPath, out_cached_folder): 31 | self.set_name = set_name 32 | self.dict = pickleload(dictFile) 33 | self.rpNum = 30 34 | self.maxWordNum =20 35 | self.maxTubelegth = 20 36 | self.tube_ftr_dim = 2048 37 | self.tubePath = tubePath 38 | self.ftrPath = ftrPath 39 | self.out_cache_folder = out_cached_folder 40 | self.prp_type = prp_type 41 | self.vid_parser = vidInfoParser(set_name, ann_folder) 42 | self.use_key_index = self.vid_parser.tube_cap_dict.keys() 43 | self.use_key_index.sort() 44 | self.online_cache ={} 45 | self.i3d_cache_flag = False 46 | self.cache_flag = False 47 | self.cache_ftr_dict = {} 48 | self.use_mean_cache_flag = False 49 | self.mean_cache_ftr_path = '' 50 | self.context_flag =False 51 | self.extracting_context =False 52 | 53 | 54 | def get_gt_embedding_i3d(self, index, maxTubelegth, out_cached_folder = ''): 55 | ''' 56 | get the grounding truth region embedding 57 | ''' 58 | rgb_tube_embedding = np.zeros((maxTubelegth, self.tube_ftr_dim_i3d), dtype=np.float32) 59 | flow_tube_embedding = np.zeros((maxTubelegth, self.tube_ftr_dim_i3d), dtype=np.float32) 60 | set_name = self.set_name 61 | i3d_ftr_path = os.path.join(self.i3d_ftr_path[:-4], 'gt/vid',set_name, str(index) +'.h5') 62 | if i3d_ftr_path in self.online_cache.keys() and self.i3d_cache_flag: 63 | tube_embedding = self.online_cache[i3d_ftr_path] 64 | return tube_embedding 65 | h5_handle = h5py.File(i3d_ftr_path, 'r') 66 | for tube_id in range(1): 67 | rgb_tube_ftr = h5_handle[str(tube_id)]['rgb_feature'][()].squeeze() 68 | flow_tube_ftr = h5_handle[str(tube_id)]['flow_feature'][()].squeeze() 69 | num_tube_ftr = h5_handle[str(tube_id)]['num_feature'][()].squeeze() 70 | seg_length = max(int(round(num_tube_ftr/maxTubelegth)), 1) 71 | tmp_rgb_tube_embedding = np.zeros((maxTubelegth, self.tube_ftr_dim_i3d), dtype=np.float32) 72 | tmp_flow_tube_embedding = np.zeros((maxTubelegth, self.tube_ftr_dim_i3d), dtype=np.float32) 73 | for segId in range(maxTubelegth): 74 | #print('%d %d\n' %(tube_id, segId)) 75 | start_id = segId*seg_length 76 | end_id = (segId+1)*seg_length 77 | if end_id > num_tube_ftr and num_tube_ftr < maxTubelegth: 78 | break 79 | end_id = min((segId+1)*(seg_length), num_tube_ftr) 80 | tmp_rgb_tube_embedding[segId, :] = np.mean(rgb_tube_ftr[start_id:end_id], axis=0) 81 | tmp_flow_tube_embedding[segId, :] = np.mean(flow_tube_ftr[start_id:end_id], axis=0) 82 | 83 | rgb_tube_embedding = tmp_rgb_tube_embedding 84 | flow_tube_embedding = tmp_flow_tube_embedding 85 | 86 | tube_embedding = np.concatenate((rgb_tube_embedding, flow_tube_embedding), axis=1) 87 | return tube_embedding 88 | 89 | 90 | def get_gt_embedding(self, index, maxTubelegth, out_cached_folder = ''): 91 | ''' 92 | get the grounding truth region embedding 93 | ''' 94 | gt_embedding = np.zeros((maxTubelegth, self.tube_ftr_dim), dtype=np.float32) 95 | set_name = self.set_name 96 | ins_ann, vd_name = self.vid_parser.get_shot_anno_from_index(index) 97 | tube_info_path = os.path.join(self.tubePath, set_name, self.prp_type, str(index)+'.pd') 98 | tubeInfo = pickleload(tube_info_path) 99 | tube_list, frame_list = tubeInfo 100 | frmNum = len(frame_list) 101 | seg_length = max(int(frmNum/maxTubelegth), 1) 102 | 103 | tube_to_prp_idx = list() 104 | ftr_tube_list = list() 105 | prp_range_num = len(tube_list[0]) 106 | tmp_cache_gt_feature_path = os.path.join(out_cached_folder, \ 107 | 'gt' , set_name, self.prp_type, str(index) + '.pk') 108 | if os.path.isfile(tmp_cache_gt_feature_path): 109 | tmp_gt_ftr_info = None 110 | try: 111 | tmp_gt_ftr_info = pickleload(tmp_cache_gt_feature_path) 112 | except: 113 | print('--------------------------------------------------') 114 | print(tmp_cache_gt_feature_path) 115 | print('--------------------------------------------------') 116 | if tmp_gt_ftr_info is not None: 117 | return tmp_gt_ftr_info 118 | 119 | # cache data for saving IO time 120 | cache_data_dict ={} 121 | 122 | for frmId, frmName in enumerate(frame_list): 123 | frmName = frame_list[frmId] 124 | img_prp_ftr_info_path = os.path.join(self.ftr_gt_path, self.set_name, str(index), frmName+ '.pd') 125 | img_prp_ftr_info = pickleload(img_prp_ftr_info_path) 126 | cache_data_dict[frmName] = img_prp_ftr_info 127 | 128 | for segId in range(maxTubelegth): 129 | start_id = segId*seg_length 130 | end_id = (segId+1)*seg_length 131 | if end_id>frmNum and frmNum=self.maxWordNum): 197 | break 198 | idx = self.dict['word2idx'][word] 199 | wordEmbMatrix[valCount, :]= self.dict['word2vec'][idx] 200 | valCount +=1 201 | wordLbl.append(idx) 202 | return wordEmbMatrix, valCount, wordLbl 203 | 204 | def get_cap_emb(self, index, capNum): 205 | cap_list_index = self.vid_parser.tube_cap_dict[index] 206 | assert len(cap_list_index)>=capNum 207 | cap_sample_index = random.sample(range(len(cap_list_index)), capNum) 208 | 209 | # get word embedding 210 | wordEmbMatrix= np.zeros((capNum, self.maxWordNum, 300), dtype=np.float32) 211 | cap_length_list = list() 212 | word_lbl_list = list() 213 | for i, capIdx in enumerate(cap_sample_index): 214 | capString = cap_list_index[capIdx] 215 | wordEmbMatrix[i, ...], valid_length, wordLbl = self.get_word_emb_from_str(capString, self.maxWordNum) 216 | cap_length_list.append(valid_length) 217 | word_lbl_list.append(wordLbl) 218 | return wordEmbMatrix, cap_length_list, word_lbl_list 219 | 220 | def get_tube_embedding(self, index, maxTubelegth, out_cached_folder = ''): 221 | tube_embedding = np.zeros((self.rpNum, maxTubelegth, self.tube_ftr_dim), dtype=np.float32) 222 | set_name = self.set_name 223 | ins_ann, vd_name = self.vid_parser.get_shot_anno_from_index(index) 224 | tube_info_path = os.path.join(self.tubePath, set_name, self.prp_type, str(index)+'.pd') 225 | tubeInfo = pickleload(tube_info_path) 226 | tube_list, frame_list = tubeInfo 227 | frmNum = len(frame_list) 228 | seg_length = max(int(frmNum/maxTubelegth), 1) 229 | 230 | tube_to_prp_idx = list() 231 | ftr_tube_list = list() 232 | prp_range_num = len(tube_list[0]) 233 | tmp_cache_tube_feature_path = os.path.join(out_cached_folder, \ 234 | set_name, self.prp_type, str(index) + '.pk') 235 | if os.path.isfile(tmp_cache_tube_feature_path): 236 | tmp_tube_ftr_info = None 237 | try: 238 | tmp_tube_ftr_info = pickleload(tmp_cache_tube_feature_path) 239 | except: 240 | print('--------------------------------------------------') 241 | print(tmp_cache_tube_feature_path) 242 | print('--------------------------------------------------') 243 | if tmp_tube_ftr_info is not None: 244 | tube_embedding, tubeInfo, tube_to_prp_idx = tmp_tube_ftr_info 245 | 246 | #if((tube_to_prp_idx[0])>maxTubelegth): 247 | if tube_embedding.shape[0]>=self.rpNum: 248 | return tube_embedding[:self.rpNum], tubeInfo, tube_to_prp_idx 249 | else: 250 | tube_embedding = np.zeros((self.rpNum, maxTubelegth, self.tube_ftr_dim), dtype=np.float32) 251 | 252 | # cache data for saving IO time 253 | cache_data_dict ={} 254 | for tubeId, tube in enumerate(tube_list[0]): 255 | if tubeId>= self.rpNum: 256 | continue 257 | tube_prp_map = list() 258 | # find proposals 259 | for frmId, bbox in enumerate(tube): 260 | frmName = frame_list[frmId] 261 | if frmName in cache_data_dict.keys(): 262 | img_prp_ftr_info = cache_data_dict[frmName] 263 | else: 264 | img_prp_ftr_info_path = os.path.join(self.ftrPath, self.set_name, vd_name, frmName+ '.pd') 265 | img_prp_ftr_info = pickleload(img_prp_ftr_info_path) 266 | cache_data_dict[frmName] = img_prp_ftr_info 267 | 268 | tmp_bbx = copy.deepcopy(img_prp_ftr_info['rois'][:prp_range_num]) # to be modified 269 | tmp_info = img_prp_ftr_info['imFo'].squeeze() 270 | tmp_bbx[:, 0] = tmp_bbx[:, 0]/tmp_info[1] 271 | tmp_bbx[:, 2] = tmp_bbx[:, 2]/tmp_info[1] 272 | tmp_bbx[:, 1] = tmp_bbx[:, 1]/tmp_info[0] 273 | tmp_bbx[:, 3] = tmp_bbx[:, 3]/tmp_info[0] 274 | img_prp_res = tmp_bbx - bbox 275 | img_prp_res_sum = np.sum(img_prp_res, axis=1) 276 | for prpId in range(prp_range_num): 277 | if(abs(img_prp_res_sum[prpId])<0.00001): 278 | tube_prp_map.append(prpId) 279 | break 280 | #assert("fail to find proposals") 281 | if (len(tube_prp_map)!=len(tube)): 282 | pdb.set_trace() 283 | assert(len(tube_prp_map)==len(tube)) 284 | 285 | tube_to_prp_idx.append(tube_prp_map) 286 | 287 | # extract features 288 | tmp_tube_embedding = np.zeros((maxTubelegth, self.tube_ftr_dim), dtype=np.float32) 289 | for segId in range(maxTubelegth): 290 | start_id = segId*seg_length 291 | end_id = (segId+1)*seg_length 292 | if end_id>frmNum and frmNumfrmNum and frmNum num_tube_ftr and num_tube_ftr < maxTubelegth: 403 | break 404 | end_id = min((segId+1)*(seg_length), num_tube_ftr) 405 | tmp_rgb_tube_embedding[segId, :] = np.mean(rgb_tube_ftr[start_id:end_id], axis=0) 406 | tmp_flow_tube_embedding[segId, :] = np.mean(flow_tube_ftr[start_id:end_id], axis=0) 407 | 408 | rgb_tube_embedding[tube_id, ...] = tmp_rgb_tube_embedding 409 | flow_tube_embedding[tube_id, ...] = tmp_flow_tube_embedding 410 | 411 | tube_embedding = np.concatenate((rgb_tube_embedding, flow_tube_embedding), axis=2) 412 | return tube_embedding 413 | 414 | 415 | def get_context_embedding_i3d(self, index, maxTubelegth, out_cached_folder = ''): 416 | rgb_tube_embedding = np.zeros((maxTubelegth, self.tube_ftr_dim_i3d), dtype=np.float32) 417 | flow_tube_embedding = np.zeros((maxTubelegth, self.tube_ftr_dim_i3d), dtype=np.float32) 418 | set_name = self.set_name 419 | i3d_ftr_path = os.path.join(self.i3d_ftr_path, 'context/vid',set_name, str(index) +'.h5') 420 | if i3d_ftr_path in self.online_cache.keys() and self.i3d_cache_flag: 421 | tube_embedding = self.online_cache[i3d_ftr_path] 422 | return tube_embedding 423 | h5_handle = h5py.File(i3d_ftr_path, 'r') 424 | for tube_id in range(1): 425 | rgb_tube_ftr = h5_handle[str(tube_id)]['rgb_feature'][()].squeeze() 426 | flow_tube_ftr = h5_handle[str(tube_id)]['flow_feature'][()].squeeze() 427 | num_tube_ftr = h5_handle[str(tube_id)]['num_feature'][()].squeeze() 428 | seg_length = max(int(round(num_tube_ftr/maxTubelegth)), 1) 429 | tmp_rgb_tube_embedding = np.zeros((maxTubelegth, self.tube_ftr_dim_i3d), dtype=np.float32) 430 | tmp_flow_tube_embedding = np.zeros((maxTubelegth, self.tube_ftr_dim_i3d), dtype=np.float32) 431 | for segId in range(maxTubelegth): 432 | start_id = segId*seg_length 433 | end_id = (segId+1)*seg_length 434 | if end_id > num_tube_ftr and num_tube_ftr < maxTubelegth: 435 | break 436 | end_id = min((segId+1)*(seg_length), num_tube_ftr) 437 | tmp_rgb_tube_embedding[segId, :] = np.mean(rgb_tube_ftr[start_id:end_id], axis=0) 438 | tmp_flow_tube_embedding[segId, :] = np.mean(flow_tube_ftr[start_id:end_id], axis=0) 439 | 440 | rgb_tube_embedding = tmp_rgb_tube_embedding 441 | flow_tube_embedding = tmp_flow_tube_embedding 442 | 443 | tube_embedding = np.concatenate((rgb_tube_embedding, flow_tube_embedding), axis=1) 444 | self.online_cache[i3d_ftr_path] = tube_embedding 445 | return tube_embedding 446 | 447 | def get_tube_pos_embedding(self, tubeInfo, tube_length, feat_dim=64, feat_type='aiayn'): 448 | tube_list, frame_list = tubeInfo 449 | position_mat_raw = torch.zeros((1, self.rpNum, tube_length, 4)) 450 | if feat_type=='aiayn': 451 | bSize = 1 452 | prpSize = self.rpNum 453 | kNN = tube_length 454 | for tubeId, tube in enumerate(tube_list[0]): 455 | if tubeId>=self.rpNum: 456 | break 457 | tube_length_ori = len(tube) 458 | tube_seg_length = max(int(tube_length_ori/tube_length), 1) 459 | 460 | for tube_seg_id in range(0, tube_length): 461 | tube_seg_id_st = tube_seg_id*tube_seg_length 462 | tube_seg_id_end = min((tube_seg_id+1)*tube_seg_length, tube_length_ori) 463 | if(tube_seg_id_st)>=tube_length_ori: 464 | position_mat_raw[0, tubeId, tube_seg_id, :] = position_mat_raw[0, tubeId, tube_seg_id-1, :] 465 | continue 466 | bbox_list = tube[tube_seg_id_st:tube_seg_id_end] 467 | box_np = np.concatenate(bbox_list, axis=0) 468 | box_tf = torch.FloatTensor(box_np).view(-1, 4) 469 | position_mat_raw[0, tubeId, tube_seg_id, :]= box_tf.mean(0) 470 | position_mat_raw_v2 = copy.deepcopy(position_mat_raw) 471 | position_mat_raw_v2[:, 0] = (position_mat_raw[:, 0] + position_mat_raw[:, 2])/2 472 | position_mat_raw_v2[:, 1] = (position_mat_raw[:, 1] + position_mat_raw[:, 3])/2 473 | position_mat_raw_v2[:, 2] = position_mat_raw[:, 2] - position_mat_raw[:, 0] 474 | position_mat_raw_v2[:, 3] = position_mat_raw[:, 3] - position_mat_raw[:, 1] 475 | 476 | pos_emb = extract_position_embedding(position_mat_raw_v2, feat_dim, wave_length=1000) 477 | 478 | return pos_emb.squeeze(0) 479 | else: 480 | raise ValueError('%s is not implemented!' %(feat_type)) 481 | 482 | 483 | def get_visual_item(self, indexOri): 484 | index = self.use_key_index[indexOri] 485 | sumInd = 0 486 | tube_embedding = None 487 | cap_embedding = None 488 | cap_length_list = -1 489 | tAf = time.time() 490 | cap_embedding, cap_length_list, word_lbl_list = self.get_cap_emb(index, self.capNum) 491 | tBf = time.time() 492 | 493 | cache_str_shot_str = str(self.maxTubelegth) +'_' + str(index) 494 | tube_embedding_list = list() 495 | if cache_str_shot_str in self.cache_ftr_dict.keys(): 496 | if self.vis_ftr_type=='rgb'or self.vis_ftr_type=='rgb_i3d': 497 | tube_embedding, tubeInfo, tube_to_prp_idx = self.cache_ftr_dict[cache_str_shot_str] 498 | elif self.vis_ftr_type=='i3d': 499 | tube_embedding, tubeInfo = self.cache_ftr_dict[cache_str_shot_str] 500 | # get visual tube embedding 501 | else: 502 | if self.vis_ftr_type =='rgb'or self.vis_ftr_type=='rgb_i3d': 503 | tube_embedding, tubeInfo, tube_to_prp_idx = self.get_tube_embedding(index, self.maxTubelegth, self.out_cache_folder) 504 | if self.context_flag: 505 | tube_embedding_context = self.get_context_embedding(index, self.maxTubelegth, self.out_cache_folder) 506 | tube_embedding_context_exp = np.expand_dims(tube_embedding_context, axis=0).repeat(self.rpNum, axis=0) 507 | tube_embedding = np.concatenate([tube_embedding, tube_embedding_context_exp], axis=2) 508 | 509 | 510 | if self.vis_ftr_type=='rgb_i3d': 511 | tube_embedding_list.append(tube_embedding) 512 | tube_embedding = torch.FloatTensor(tube_embedding) 513 | 514 | if self.vis_ftr_type =='i3d' or self.vis_ftr_type=='rgb_i3d': 515 | tube_embedding = self.get_tube_embedding_i3d(index, self.maxTubelegth, self.out_cache_folder) 516 | if self.context_flag: 517 | tube_embedding_context = self.get_context_embedding_i3d(index, self.maxTubelegth, self.out_cache_folder) 518 | tube_embedding_context_exp = np.expand_dims(tube_embedding_context, axis=0).repeat(self.rpNum, axis=0) 519 | tube_embedding = np.concatenate([tube_embedding, tube_embedding_context_exp], axis=2) 520 | 521 | ins_ann, vd_name = self.vid_parser.get_shot_anno_from_index(index) 522 | tube_info_path = os.path.join(self.tubePath, self.set_name, self.prp_type, str(index)+'.pd') 523 | tubeInfo = pickleload(tube_info_path) 524 | if self.vis_ftr_type=='rgb_i3d': 525 | tube_embedding_list.append(tube_embedding) 526 | tube_embedding = torch.FloatTensor(tube_embedding) 527 | 528 | if self.vis_ftr_type=='rgb_i3d' and cache_str_shot_str not in self.cache_ftr_dict.keys(): 529 | tube_embedding = np.concatenate(tube_embedding_list, axis=2) 530 | tube_embedding = torch.FloatTensor(tube_embedding) 531 | 532 | # get position embedding 533 | if self.pos_type !='none': 534 | tp1 = time.time() 535 | tube_embedding_pos = self.get_tube_pos_embedding(tubeInfo, tube_length=self.maxTubelegth, \ 536 | feat_dim=self.pos_emb_dim, feat_type=self.pos_type) 537 | tp2 = time.time() 538 | tube_embedding = torch.cat((tube_embedding, tube_embedding_pos), dim=2) 539 | 540 | if self.cache_flag and self.vis_ftr_type=='rgb': 541 | self.cache_ftr_dict[cache_str_shot_str] = [tube_embedding, tubeInfo, tube_to_prp_idx] 542 | elif self.cache_flag and self.vis_ftr_type=='rgb_i3d': 543 | self.cache_ftr_dict[cache_str_shot_str] = [tube_embedding, tubeInfo, tube_to_prp_idx] 544 | elif self.cache_flag and self.vis_ftr_type=='i3d': 545 | self.cache_ftr_dict[cache_str_shot_str] = [tube_embedding, tubeInfo] 546 | tAf2 = time.time() 547 | vd_name, ins_in_vd = self.vid_parser.get_shot_info_from_index(index) 548 | 549 | return tube_embedding, cap_embedding, tubeInfo, index, cap_length_list, vd_name, word_lbl_list 550 | 551 | def get_tube_info(self, indexOri): 552 | index = self.use_key_index[indexOri] 553 | ins_ann, vd_name = self.vid_parser.get_shot_anno_from_index(index) 554 | tube_info_path = os.path.join(self.tubePath, self.set_name, self.prp_type, str(index)+'.pd') 555 | tubeInfo = pickleload(tube_info_path) 556 | return tubeInfo, index 557 | 558 | def get_tube_info_gt(self, indexOri): 559 | ''' 560 | get ground truth info 561 | ''' 562 | index = self.use_key_index[indexOri] 563 | ins_ann, vd_name = self.vid_parser.get_shot_anno_from_index(index) 564 | tube_info_path = os.path.join(self.tubePath, self.set_name, self.prp_type, str(index)+'.pd') 565 | tubeInfo = pickleload(tube_info_path) 566 | return ins_ann, index, vd_name 567 | 568 | def get_frm_embedding(self, index): 569 | set_name = self.set_name 570 | ins_ann, vd_name = self.vid_parser.get_shot_anno_from_index(index) 571 | tube_info_path = os.path.join(self.tubePath, set_name, self.prp_type, str(index)+'.pd') 572 | tubeInfo = pickleload(tube_info_path) 573 | tube_list, frame_list = tubeInfo 574 | frmNum = len(frame_list) 575 | rpNum = self.rpNum 576 | tube_to_prp_idx = list() 577 | ftr_tube_list = list() 578 | if self.frm_num>0: 579 | sample_index_list = random.sample(range(frmNum), self.frm_num) 580 | else: 581 | sample_index_list = list(range(frmNum)) 582 | 583 | frm_ftr_list = list() 584 | bbx_list = list() 585 | for i, frm_id in enumerate(sample_index_list): 586 | frmName = frame_list[frm_id] 587 | img_prp_ftr_info_path = os.path.join(self.ftrPath, self.set_name, vd_name, frmName+ '.pd') 588 | img_prp_ftr_info = pickleload(img_prp_ftr_info_path) 589 | tmp_frm_ftr = img_prp_ftr_info['roiFtr'][:rpNum] 590 | frm_ftr_list.append(np.expand_dims(tmp_frm_ftr, axis=0)) 591 | tmp_bbx = copy.deepcopy(img_prp_ftr_info['rois'][:rpNum]) # to be modified 592 | tmp_info = img_prp_ftr_info['imFo'].squeeze() 593 | tmp_bbx[:, 0] = tmp_bbx[:, 0]/tmp_info[1] 594 | tmp_bbx[:, 2] = tmp_bbx[:, 2]/tmp_info[1] 595 | tmp_bbx[:, 1] = tmp_bbx[:, 1]/tmp_info[0] 596 | tmp_bbx[:, 3] = tmp_bbx[:, 3]/tmp_info[0] 597 | bbx_list.append(tmp_bbx) 598 | frm_embedding = np.concatenate(frm_ftr_list, axis=0) 599 | return frm_embedding, tubeInfo, sample_index_list, bbx_list 600 | 601 | def get_visual_frm_item(self, indexOri): 602 | #pdb.set_trace() 603 | index = self.use_key_index[indexOri] 604 | 605 | # testing for certain sample: 606 | sumInd = 0 607 | tube_embedding = None 608 | cap_embedding = None 609 | cap_length_list = -1 610 | tAf = time.time() 611 | cap_embedding, cap_length_list, word_lbl_list = self.get_cap_emb(index, self.capNum) 612 | tBf = time.time() 613 | 614 | cache_str_shot_str = str(self.maxTubelegth) +'_' + str(index) 615 | frm_embedding_list = list() 616 | if self.vis_ftr_type =='rgb'or self.vis_ftr_type=='rgb_i3d': 617 | frm_embedding, tubeInfo, frm_idx, bbx_list = self.get_frm_embedding(index) 618 | #pdb.set_trace() 619 | if self.vis_ftr_type=='rgb_i3d': 620 | frm_embedding_list.append(frm_embedding) 621 | frm_embedding = np.concatenate(tube_embedding_list, axis=2) 622 | 623 | frm_embedding = torch.FloatTensor(frm_embedding) 624 | 625 | tAf2 = time.time() 626 | vd_name, ins_in_vd = self.vid_parser.get_shot_info_from_index(index) 627 | return frm_embedding, cap_embedding, tubeInfo, index, cap_length_list, vd_name, word_lbl_list, frm_idx, bbx_list 628 | 629 | def __getitem__(self, index): 630 | if not self.frm_level_flag: 631 | return self.get_visual_item(index) 632 | else: 633 | return self.get_visual_frm_item(index) 634 | 635 | def dis_collate_vid(batch): 636 | ftr_tube_list = list() 637 | ftr_cap_list = list() 638 | tube_info_list = list() 639 | cap_length_list = list() 640 | index_list = list() 641 | vd_name_list = list() 642 | word_lbl_list = list() 643 | max_length = 0 644 | frm_idx_list = list() 645 | bbx_list = list() 646 | region_gt_ori = list() 647 | for sample in batch: 648 | ftr_tube_list.append(sample[0]) 649 | ftr_cap_list.append(torch.FloatTensor(sample[1])) 650 | tube_info_list.append(sample[2]) 651 | index_list.append(sample[3]) 652 | vd_name_list.append(sample[5]) 653 | 654 | for tmp_length in sample[4]: 655 | if(tmp_length>max_length): 656 | max_length = tmp_length 657 | cap_length_list.append(tmp_length) 658 | word_lbl_list.append(sample[6]) 659 | if len(sample)>8: 660 | frm_idx_list.append(sample[7]) 661 | bbx_list.append(sample[8]) 662 | 663 | capMatrix = torch.stack(ftr_cap_list, 0) 664 | capMatrix = capMatrix[:, :, :max_length, :] 665 | if len(frm_idx_list)>0: 666 | return torch.stack(ftr_tube_list, 0), capMatrix, tube_info_list, index_list, cap_length_list, vd_name_list, word_lbl_list, frm_idx_list, bbx_list 667 | else: 668 | return torch.stack(ftr_tube_list, 0), capMatrix, tube_info_list, index_list, cap_length_list, vd_name_list, word_lbl_list 669 | 670 | 671 | if __name__=='__main__': 672 | from datasetLoader import build_dataloader 673 | opt = parse_args() 674 | opt.dbSet = 'vid' 675 | opt.set_name ='train' 676 | opt.batchSize = 4 677 | opt.num_workers = 0 678 | opt.rpNum =30 679 | opt.vis_ftr_type = 'i3d' 680 | data_loader, dataset = build_dataloader(opt) 681 | -------------------------------------------------------------------------------- /fun/vidDatasetParser.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | sys.path.append('../') 4 | sys.path.append('../util') 5 | from util.mytoolbox import * 6 | import pdb 7 | import h5py 8 | import csv 9 | import numpy as np 10 | from netUtil import * 11 | sys.path.append('../annotations') 12 | 13 | from itertools import izip 14 | import multiprocessing 15 | from multiprocessing import Pool 16 | import dill 17 | import matplotlib 18 | matplotlib.use('Agg') 19 | import matplotlib.pyplot as plt 20 | from wsParamParser import parse_args 21 | import random 22 | import operator 23 | from fun.datasetLoader import * 24 | import math 25 | 26 | class vidInfoParser(object): 27 | def __init__(self, set_name, annFd): 28 | self.tube_gt_path = os.path.join(annFd, 'Annotations/VID/tubeGt', set_name) 29 | self.tube_name_list_fn = os.path.join(annFd, 'Data/VID/annSamples/', set_name+'_valid_list.txt') 30 | self.jpg_folder = os.path.join(annFd, 'Data/VID/', set_name) 31 | self.info_lines = textread(self.tube_name_list_fn) 32 | self.set_name = set_name 33 | self.tube_ann_list_fn = os.path.join(annFd, 'Data/VID/annSamples/', set_name + '_ann_list_v2.txt') 34 | #pdb.set_trace() 35 | ins_lines = textread(self.tube_ann_list_fn) 36 | ann_dict_set_dict = {} 37 | for line in ins_lines: 38 | ins_id_str, caption = line.split(',', 1) 39 | ins_id = int(ins_id_str) 40 | if ins_id not in ann_dict_set_dict.keys(): 41 | ann_dict_set_dict[ins_id] = list() 42 | ann_dict_set_dict[ins_id].append(caption) 43 | self.tube_cap_dict = ann_dict_set_dict 44 | 45 | def get_length(self): 46 | return len(self.info_lines) 47 | 48 | def get_shot_info_from_index(self, index): 49 | info_Str = self.info_lines[index] 50 | vd_name, ins_id_str = info_Str.split(',') 51 | return vd_name, ins_id_str 52 | 53 | def get_shot_anno_from_index(self, index): 54 | vd_name, ins_id_str = self.get_shot_info_from_index(index) 55 | jsFn = os.path.join(self.tube_gt_path, vd_name + '.js') 56 | annDict = jsonload(jsFn) 57 | ann = None 58 | for ii, ann in enumerate(annDict['annotations']): 59 | track = ann['track'] 60 | trackId = ann['id'] 61 | if(trackId!=ins_id_str): 62 | continue 63 | break; 64 | return ann, vd_name 65 | 66 | def get_shot_frame_list_from_index(self, index): 67 | ann, vd_name = self.get_shot_anno_from_index(index) 68 | frm_list = list() 69 | track = ann['track'] 70 | trackId = ann['id'] 71 | frmNum = len(track) 72 | for iii in range(frmNum): 73 | vdFrmInfo = track[iii] 74 | imPath = '%06d' %(vdFrmInfo['frame']-1) 75 | frm_list.append(imPath) 76 | return frm_list, vd_name 77 | 78 | def proposal_path_set_up(self, prpPath): 79 | self.propsal_path = os.path.join(prpPath, self.set_name) 80 | 81 | 82 | def get_all_instance_frames(set_name, annFd): 83 | tubeGtPath = os.path.join(annFd, 'Annotations/VID/tubeGt', set_name) 84 | tube_name_list_fn = os.path.join(annFd, 'Data/VID/annSamples/', set_name+'_valid_list.txt') 85 | jpg_folder = os.path.join(annFd, 'Data/VID/', set_name) 86 | all_frm_path_list = list() 87 | 88 | info_lines = textread(tube_name_list_fn) 89 | for i, vd_info in enumerate(info_lines): 90 | #if(i!=300): 91 | # continue 92 | vd_name, ins_id_str = vd_info.split(',') 93 | jsFn = os.path.join(tubeGtPath, vd_name + '.js') 94 | annDict = jsonload(jsFn) 95 | for ii, ann in enumerate(annDict['annotations']): 96 | track = ann['track'] 97 | trackId = ann['id'] 98 | if(trackId!=ins_id_str): 99 | continue 100 | frmNum = len(track) 101 | for iii in range(frmNum): 102 | vdFrmInfo = track[iii] 103 | imPath = jpg_folder + '/' + vd_name + '/' + '%06d.JPEG' %(vdFrmInfo['frame']-1) 104 | all_frm_path_list.append(imPath) 105 | break; 106 | all_frm_path_list_unique = list(set(all_frm_path_list)) 107 | all_frm_path_list_unique.sort() 108 | return all_frm_path_list_unique 109 | 110 | def extract_shot_prp_list_from_pickle(vidParser, shot_index, prp_num=20,do_norm=1): 111 | frm_list, vd_name = vidParser.get_shot_frame_list_from_index(shot_index) 112 | prp_list = list() 113 | 114 | for i, frm_name in enumerate(frm_list): 115 | frm_name_raw = frm_name.split('.')[0] 116 | prp_path = os.path.join(vidParser.propsal_path, vd_name, frm_name_raw+'.pd') 117 | frm_prp_info = cPickleload(prp_path) 118 | tmp_bbx = frm_prp_info['rois'][:prp_num] 119 | tmp_score = frm_prp_info['roisS'][:prp_num] 120 | tmp_info = frm_prp_info['imFo'].squeeze() 121 | 122 | #pdb.set_trace() 123 | if do_norm==1: 124 | tmp_bbx[:, 0] = tmp_bbx[:, 0]/tmp_info[1] 125 | tmp_bbx[:, 2] = tmp_bbx[:, 2]/tmp_info[1] 126 | tmp_bbx[:, 1] = tmp_bbx[:, 1]/tmp_info[0] 127 | tmp_bbx[:, 3] = tmp_bbx[:, 3]/tmp_info[0] 128 | elif do_norm==2: 129 | tmp_bbx[:, 0] = tmp_bbx[:, 0]*tmp_info[2]/tmp_info[1] 130 | tmp_bbx[:, 2] = tmp_bbx[:, 2]*tmp_info[2]/tmp_info[1] 131 | tmp_bbx[:, 1] = tmp_bbx[:, 1]*tmp_info[2]/tmp_info[0] 132 | tmp_bbx[:, 3] = tmp_bbx[:, 3]*tmp_info[2]/tmp_info[0] 133 | else: 134 | tmp_bbx = tmp_bbx/tmp_info[2] 135 | tmp_score = np.expand_dims(tmp_score, axis=1 ) 136 | prp_list.append([tmp_score, tmp_bbx]) 137 | return prp_list, frm_list, vd_name 138 | 139 | def resize_tube_bbx(tube_vis, frmImList_vis): 140 | for prpId, frm in enumerate(tube_vis): 141 | h, w, c = frmImList_vis[prpId].shape 142 | tube_vis[prpId][0] = tube_vis[prpId][0]*w 143 | tube_vis[prpId][2] = tube_vis[prpId][2]*w 144 | tube_vis[prpId][1] = tube_vis[prpId][1]*h 145 | tube_vis[prpId][3] = tube_vis[prpId][3]*h 146 | return tube_vis 147 | 148 | def evaluate_tube_recall_vid(shot_proposals, vid_parser, tube_index, thre=0.5 ,topKOri=20, more_detailed_flag=False): 149 | #pdb.set_trace() 150 | topK = min(topKOri, len(shot_proposals[0][0])) 151 | recall_k = [0.0] * (topK + 1) 152 | iou_list = list() 153 | ann, vid_name = vid_parser.get_shot_anno_from_index(tube_index) 154 | boxes = {} 155 | for i, ann_frame in enumerate(ann['track']): 156 | frame_ind = ann_frame['frame'] 157 | box = ann_frame['bbox'] 158 | h, w = ann_frame['frame_size'] 159 | box[0] = box[0]*1.0/w 160 | box[2] = box[2]*1.0/w 161 | box[1] = box[1]*1.0/h 162 | box[3] = box[3]*1.0/h 163 | keyName = '%06d' %(frame_ind-1) 164 | boxes[keyName] = box 165 | 166 | # pdb.set_trace() 167 | tube_list, frame_list = shot_proposals 168 | assert(len(tube_list[0][0])== len(frame_list)) 169 | is_instance_annotated = False 170 | for i in range(topK): 171 | recall_k[i+1] = recall_k[i] 172 | if is_instance_annotated: 173 | continue 174 | curTubeOri = tube_list[0][i] 175 | tube_key_bbxList = {} 176 | for frame_ind, gt_box in boxes.iteritems(): 177 | try: 178 | index_tmp = frame_list.index(frame_ind) 179 | tube_key_bbxList[frame_ind] = curTubeOri[index_tmp] 180 | except: 181 | print('key %s do not exist in shot' %(frame_ind)) 182 | #pdb.set_trace() 183 | ol = compute_LS(tube_key_bbxList, boxes) 184 | if ol < thre: 185 | if more_detailed_flag: 186 | iou_list.append(ol) 187 | continue 188 | else: 189 | recall_k[i+1] += 1.0 190 | is_instance_annotated = True 191 | if more_detailed_flag: 192 | iou_list.append(ol) 193 | if more_detailed_flag: 194 | return recall_k, iou_list 195 | else: 196 | return recall_k 197 | 198 | def multi_process_connect_tubes(param_list): 199 | tube_index, tube_save_path, prp_num, tube_model_name, connect_w, set_name, annFd, vid_parser = param_list 200 | #vid_parser = vidInfoParser(set_name, annFd) 201 | print(tube_save_path) 202 | if os.path.isfile(tube_save_path): 203 | print('\n file exist\n') 204 | return 205 | if tube_model_name =='coco' : 206 | prp_list, frmList, vd_name = extract_shot_prp_list_from_pickle(vid_parser, tube_index, prp_num, do_norm=1) 207 | else: 208 | prp_list, frmList, vd_name = extract_shot_prp_list_from_pickle(vid_parser, tube_index, prp_num, do_norm=2) 209 | results = get_tubes(prp_list, connect_w) 210 | shot_proposals = [results, frmList] 211 | makedirs_if_missing(os.path.dirname(tube_save_path)) 212 | pickledump(tube_save_path, shot_proposals) 213 | 214 | def visual_tube_proposals(tube_save_path, vid_parser, tube_index, prp_num): 215 | 216 | topK = prp_num 217 | recall_k2_sum = np.array([0.0] * (topK + 1)) 218 | recall_k3_sum = np.array([0.0] * (topK + 1)) 219 | recall_k4_sum = np.array([0.0] * (topK + 1)) 220 | recall_k5_sum = np.array([0.0] * (topK + 1)) 221 | 222 | shot_proposals = pickleload(tube_save_path) 223 | results, frmList = shot_proposals 224 | vd_name, ins_id_str = vid_parser.get_shot_info_from_index(tube_index) 225 | 226 | recallK2 = evaluate_tube_recall_vid(shot_proposals, vid_parser, tube_index, 0.2 ,topKOri=prp_num) 227 | recallK3 = evaluate_tube_recall_vid(shot_proposals, vid_parser, tube_index, 0.3 ,topKOri=prp_num) 228 | recallK4 = evaluate_tube_recall_vid(shot_proposals, vid_parser, tube_index, 0.4 ,topKOri=prp_num) 229 | recallK5 = evaluate_tube_recall_vid(shot_proposals, vid_parser, tube_index, 0.5 ,topKOri=prp_num) 230 | recall_k2_sum += np.array(recallK2) 231 | recall_k3_sum += np.array(recallK3) 232 | recall_k4_sum += np.array(recallK4) 233 | recall_k5_sum += np.array(recallK5) 234 | 235 | print('%d/%d' %(tube_index, vid_parser.get_length())) 236 | print('thre: %f %f %f %f\n' %( 0.2,0.3,0.4,0.5)) 237 | print((recall_k2_sum)*1.0/(tube_index+1)) 238 | print((recall_k3_sum)*1.0/(tube_index+1)) 239 | print((recall_k4_sum)*1.0/(tube_index+1)) 240 | print((recall_k5_sum)*1.0/(tube_index+1)) 241 | 242 | #continue 243 | # visualization 244 | frmImNameList = [os.path.join(vid_parser.jpg_folder, vd_name, frame_name + '.JPEG') for frame_name in frmList] 245 | frmImList = list() 246 | for fId, imPath in enumerate(frmImNameList): 247 | img = cv2.imread(imPath) 248 | frmImList.append(img) 249 | vis_frame_num = 30 250 | visIner = int(len(frmImList) /vis_frame_num) 251 | #pdb.set_trace() 252 | for ii in range(len(results[0])): 253 | print('visualizing tube %d\n'%(ii)) 254 | tube = results[0][ii] 255 | frmImList_vis = [frmImList[iii] for iii in range(0, len(frmImList), visIner)] 256 | tube_vis = [tube[iii] for iii in range(0, len(frmImList), visIner)] 257 | tube_vis_resize = resize_tube_bbx(tube_vis, frmImList_vis) 258 | vd_name_raw = vd_name.split('/')[-1] 259 | visTube_from_image(copy.deepcopy(frmImList_vis), tube_vis_resize, 'sample/'+vd_name_raw+ '_'+str(prp_num) + str(ii)+'.gif') 260 | 261 | 262 | ############################################################################################### 263 | def get_recall_for_tube_proposals(tube_save_path, vid_parser, tube_index, prp_num, thre_list=[0.2, 0.3, 0.4, 0.5]): 264 | 265 | topK = prp_num 266 | 267 | shot_proposals = pickleload(tube_save_path) 268 | results, frmList = shot_proposals 269 | vd_name, ins_id_str = vid_parser.get_shot_info_from_index(tube_index) 270 | 271 | recall_list = list() 272 | for thre in thre_list: 273 | recallK = evaluate_tube_recall_vid(shot_proposals, vid_parser, tube_index, thre ,topKOri=prp_num) 274 | recall_list.append(np.array(recallK)) 275 | return recall_list 276 | 277 | ##################################################################################### 278 | def vis_im_prp(): 279 | #set_name = 'val' 280 | set_name = 'train' 281 | annFd = '/data1/zfchen/data/ILSVRC' 282 | #prpFd = '/data1/zfchen/data/ILSVRC/vid_prp_zf/Data/VID' 283 | prpFd = '/mnt/ceph_cv/aicv_image_data/forestlma/zfchen/vidPrp/Data/VID' 284 | #prpFd = '/data1/zfchen/data/ILSVRC/vid_prp_vg/Data/VID' 285 | tube_model_name = 'coco' 286 | #tube_model_name = 'vg' 287 | #tube_model_name = 'zf' 288 | vis_frame_num = 30 289 | prp_num = 30 290 | connect_w =0.2 291 | tube_index = 0 292 | thre = 0.5 293 | vis_flag = True 294 | vid_parser = vidInfoParser(set_name, annFd) 295 | vid_parser.proposal_path_set_up(prpFd) 296 | 297 | for tube_index in range(100, vid_parser.get_length()): 298 | #for tube_index in range(0, 2): 299 | tube_save_path = os.path.join(annFd, 'tubePrp' ,set_name, tube_model_name + '_' + str(prp_num) +'_' + str(int(10*connect_w)) , str(tube_index) + '.pd') 300 | 301 | if not vis_flag: 302 | continue 303 | visual_tube_proposals(tube_save_path, vid_parser, tube_index, prp_num) 304 | pdb.set_trace() 305 | 306 | ##################################################################################### 307 | def show_distribute_over_categories(recall_list, ann_list, thre_list): 308 | #pdb.set_trace() 309 | print('average instance level performance') 310 | recall_ins_sum = list() 311 | for ii, thre in enumerate(thre_list): 312 | tmp_ins_sum = np.array([0.0] * (recall_list[0][ii].shape[0])) 313 | for i, recall_ins in enumerate(recall_list) : 314 | tmp_ins_sum +=recall_ins[ii] 315 | recall_ins_sum.append( tmp_ins_sum /len(recall_list)) 316 | print('thre@ins@%f, %f\n' %(thre, recall_ins_sum[ii][-1])) 317 | print('top K, recall') 318 | print(recall_ins_sum[ii]) 319 | pdb.set_trace() 320 | 321 | print('showing recall distribution over categories') 322 | recall_k_categories_dict = {} 323 | for i, ann in enumerate(ann_list): 324 | class_id = str(ann['track'][0]['class']) 325 | if class_id in recall_k_categories_dict.keys(): 326 | recall_cat_list, ins_cat_num = recall_k_categories_dict[class_id] 327 | for ii, recall_thre in enumerate(recall_list[i]): 328 | recall_cat_list[ii] += recall_thre 329 | ins_cat_num +=1 330 | recall_k_categories_dict[class_id] =[recall_cat_list, ins_cat_num] 331 | else: 332 | ins_cat_num =1 333 | recall_k_categories_dict[class_id] = [recall_list[i], ins_cat_num] 334 | 335 | mean_cat_map = list() 336 | for i, thre in enumerate(thre_list): 337 | print('--------------------------------------------------------\n') 338 | print('recall@%f\n' %(thre)) 339 | recall_plot = list() 340 | for ii, cat_name in enumerate(recall_k_categories_dict.keys()): 341 | recall_cat_list, ins_num = recall_k_categories_dict[cat_name] 342 | recall_thre = recall_cat_list[i][-1]*1.0/ins_num 343 | print('%s: %f\n' %(cat_name, recall_thre)) 344 | recall_plot.append(recall_thre) 345 | cat_list = recall_k_categories_dict.keys() 346 | plt.close() 347 | fig, ax = plt.subplots(figsize=(12, 12)) 348 | plt.barh(range(len(cat_list)), recall_plot, tick_label=cat_list) 349 | plt.show() 350 | bar_name = './sample/vid_recall_train@%d.jpg' %(int(thre*10)) 351 | plt.savefig(bar_name) 352 | mean_cat_map.append(sum(recall_plot)/float(len(recall_plot))) 353 | for thre, mAp in zip(thre_list, mean_cat_map): 354 | print('thre@%f , map: %f\n' %(thre, mAp)) 355 | 356 | def caption_to_word_list(des_str): 357 | import string 358 | des_str = des_str.lower().replace('_', ' ').replace(',' , ' ').replace('-', ' ') 359 | for c in string.punctuation: 360 | des_str = des_str.replace(c, '') 361 | return split_carefully(des_str.lower().replace('_', ' ').replace('.', '').replace(',', '').replace("\'", '').replace('-', '').replace('\n', '').replace('\r', '').replace('\"', '').rstrip().replace("\\",'').replace('?', '').replace('/','').replace('#','').replace('(', '').replace(')','').replace(';','').replace('!', '').replace('/',''), ' ') 362 | 363 | def build_vid_word_list(): 364 | set_name_list = ['train', 'val', 'test'] 365 | ann_cap_path = '/data1/zfchen/data/ILSVRC/Data/VID/annSamples' 366 | word_list = list() 367 | for i, set_name in enumerate(set_name_list): 368 | #ann_cap_set_fn = os.path.join(ann_cap_path, set_name+'_ann_list.txt') 369 | ann_cap_set_fn = os.path.join(ann_cap_path, set_name+'_ann_list_v2.txt') 370 | cap_lines = textread(ann_cap_set_fn) 371 | for ii, line in enumerate(cap_lines): 372 | ins_id_str, caption = line.split(',', 1) 373 | word_list_tmp = caption_to_word_list(caption) 374 | word_list += word_list_tmp 375 | word_list= list(set(word_list)) 376 | return word_list 377 | 378 | 379 | 380 | def statistic_vid_word_list(): 381 | set_name_list = ['train', 'val', 'test'] 382 | ann_cap_path = '/data1/zfchen/data/ILSVRC/Data/VID/annSamples' 383 | word_list = list() 384 | cap_num = 0 385 | for i, set_name in enumerate(set_name_list): 386 | #ann_cap_set_fn = os.path.join(ann_cap_path, set_name+'_ann_list.txt') 387 | ann_cap_set_fn = os.path.join(ann_cap_path, set_name+'_ann_list_v2.txt') 388 | cap_lines = textread(ann_cap_set_fn) 389 | for ii, line in enumerate(cap_lines): 390 | ins_id_str, caption = line.split(',', 1) 391 | word_list_tmp = caption_to_word_list(caption) 392 | while '' in word_list_tmp: 393 | word_list_tmp.remove('') 394 | #pdb.set_trace() 395 | word_list += word_list_tmp 396 | cap_num +=1 397 | print('Average word length: %f\n'%(len(word_list)*1.0/cap_num)) 398 | print('total word number: %f\n'%(len(word_list))) 399 | word_dict = list(set(word_list)) 400 | print('word in dictionary: %f\n'%(len(word_dict)*1.0)) 401 | 402 | # get frequence 403 | word_to_dict ={} 404 | for i, word in enumerate(word_list): 405 | if word in word_to_dict.keys(): 406 | word_to_dict[word] +=1 407 | else: 408 | word_to_dict[word] =1 409 | sorted_word = sorted(word_to_dict.items(), key=operator.itemgetter(1)) 410 | 411 | sorted_word.reverse() 412 | 413 | 414 | topK = 30 415 | plot_data =[] 416 | cat_name = [] 417 | data_fn = 'word_noun.pdf' 418 | count_num = 0 419 | for i in range(len(sorted_word)): 420 | if sorted_word[i][0] not in stopwords.words("english"): 421 | print(sorted_word[i]) 422 | plot_data.append(sorted_word[i][1]) 423 | cat_name.append(sorted_word[i][0]) 424 | count_num +=1 425 | if count_num>=topK: 426 | break 427 | #pdb.set_trace() 428 | plot_data.reverse() 429 | cat_name.reverse() 430 | plot_distribution_word_ori(plot_data, cat_name, data_fn,rot=30, fsize=110) 431 | #pdb.set_trace() 432 | return word_list 433 | 434 | def get_h5_feature_dict(h5file_path): 435 | img_prp_reader = h5py.File(h5file_path, 'r') 436 | return img_prp_reader 437 | 438 | def draw_specific_tube_proposals(vid_parser, index, tube_id_list, tube_proposal_list, out_fd, color_list=None): 439 | 440 | load_image_flag = True 441 | lbl = index 442 | frmImList = list() 443 | tube_info_sub_prp, frm_info_list = tube_proposal_list 444 | if color_list is None: 445 | color_list =[(255, 0, 0), (0, 255, 0)] 446 | #color_list =[(0, 255, 0), (255, 0, 0)] 447 | dotted = False 448 | line_width = 6 449 | for ii, tube_id in enumerate(tube_id_list): 450 | if ii==1: 451 | dotted = True 452 | line_width =3 453 | tube = copy.deepcopy(tube_info_sub_prp[0][tube_id]) 454 | 455 | if load_image_flag: 456 | # visualize sample results 457 | vd_name, ins_id_str = vid_parser.get_shot_info_from_index(lbl) 458 | frmImNameList = [os.path.join(vid_parser.jpg_folder, vd_name, frame_name + '.JPEG') for frame_name in frm_info_list] 459 | for fId, imPath in enumerate(frmImNameList): 460 | img = cv2.imread(imPath) 461 | frmImList.append(img) 462 | vis_frame_num = 3000 463 | visIner =max(int(len(frmImList) /vis_frame_num), 1) 464 | load_image_flag = False 465 | 466 | frmImList_vis = [frmImList[iii] for iii in range(0, len(frmImList), visIner)] 467 | 468 | tube_vis = [tube[iii] for iii in range(0, len(frmImList), visIner)] 469 | print('visualizing tube %d\n'%(tube_id)) 470 | tube_vis_resize = resize_tube_bbx(tube_vis, frmImList_vis) 471 | frmImList_vis = vis_image_bbx(frmImList_vis, tube_vis_resize, color_list[ii], line_width, dotted) 472 | #frmImList_vis = vis_gray_but_bbx(frmImList_vis, tube_vis_resize) 473 | break 474 | 475 | out_fd_full = os.path.join(out_fd, vid_parser.set_name + str(lbl)) 476 | makedirs_if_missing(out_fd_full) 477 | frm_name_list = list() 478 | for i, idx in enumerate(range(0, len(frmImList), visIner)): 479 | out_fn_full = os.path.join(out_fd_full, frm_info_list[idx]+'.jpg') 480 | cv2.imwrite(out_fn_full, frmImList_vis[i]) 481 | frm_name_list.append(frm_info_list[idx]) 482 | return frmImList_vis, frm_name_list 483 | 484 | def get_gt_bbx(ins_ann): 485 | tube_length = len(ins_ann['track']) 486 | gt_bbx = list() 487 | for i in range(tube_length): 488 | tmp_bbx = copy.deepcopy(ins_ann['track'][i]['bbox']) 489 | h, w = ins_ann['track'][i]['frame_size'] 490 | tmp_bbx[0] = tmp_bbx[0]*1.0/w 491 | tmp_bbx[2] = tmp_bbx[2]*1.0/w 492 | tmp_bbx[1] = tmp_bbx[1]*1.0/h 493 | tmp_bbx[3] = tmp_bbx[3]*1.0/h 494 | gt_bbx.append(tmp_bbx) 495 | return gt_bbx 496 | 497 | 498 | if __name__ == '__main__': 499 | pdb.set_trace() 500 | 501 | 502 | -------------------------------------------------------------------------------- /fun/wsParamParser.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | from util.base_parser import BaseParser 4 | 5 | class wsParamParser(BaseParser): 6 | def __init__(self, *arg_list, **arg_dict): 7 | super(wsParamParser, self).__init__(*arg_list, **arg_dict) 8 | self.add_argument('--batchSize', default=16, type=int) 9 | self.add_argument('--dim_ftr', default=128, type=int) 10 | self.add_argument('--n_pairs', default=50, type=int) 11 | self.add_argument('--initmodel', default=None) 12 | self.add_argument('--resume', default='') 13 | self.add_argument('--gpu', default=1, type=int) 14 | self.add_argument('--decay', default=0.001, type=float) 15 | self.add_argument('--optimizer', default='sgd', type=str) 16 | self.add_argument('--margin', default=1, type=float) 17 | self.add_argument('--test', action='store_true', default=False) 18 | self.add_argument('--dbSet', default='otb', type=str) 19 | self.add_argument('--num_workers', default=8, type=int) 20 | self.add_argument('--visIter', default=5, type=int) 21 | self.add_argument('--evalIter', default=10, type=int) 22 | self.add_argument('--epSize', default=10000, type=int) 23 | self.add_argument('--logFd', default='../data/log/wsEmb', type=str) 24 | self.add_argument('--saveEp', default=10, type=int) 25 | self.add_argument('--outPre', default='../data/models/a2d_checkpoint_', type=str) 26 | self.add_argument('--biLoss', action='store_true', default=False) 27 | self.add_argument('--lossW', action='store_true', default=False) 28 | self.add_argument('--lamda', default=0.8, type=float) 29 | self.add_argument('--vwFlag', action='store_true', default=False) 30 | self.add_argument('--wsMode', default='rank', type=str) 31 | self.add_argument('--hdSize', default=128, type=int) 32 | self.add_argument('--vocaSize', default=1900, type=int) 33 | self.add_argument('--conSecFlag', action='store_true', default=False) 34 | self.add_argument('--conFrmNum', default=9, type=int) 35 | self.add_argument('--moduleNum', default=2, type=int) 36 | self.add_argument('--moduleHdSize', default=1024, type=int) 37 | self.add_argument('--stEp', default=0, type=int) 38 | self.add_argument('--keepKeyFrameOnly', action='store_true', default=False) 39 | self.add_argument('--visRsFd', default='../data/visResult/rank_', type=str) 40 | self.add_argument('--logFdTx', default='../logs/wsEmb', type=str) 41 | self.add_argument('--set_name', default='train', type=str) 42 | self.add_argument('--isParal', action='store_true', default=False) 43 | self.add_argument('--capNum', default=1, type=int) 44 | self.add_argument('--maxWordNum', default=20, type=int) 45 | self.add_argument('--rpNum', default=30, type=int) 46 | self.add_argument('--vis_dim', default=2048, type=int) 47 | self.add_argument('--vis_type', default='lstm', type=str) 48 | self.add_argument('--txt_type', default='lstm', type=str) 49 | self.add_argument('--pos_type', default='aiayn', type=str) 50 | self.add_argument('--pos_emb_dim', default=64, type=int) 51 | self.add_argument('--half_size', action='store_true', default=False) 52 | self.add_argument('--server_id', default=36, type=int) 53 | self.add_argument('--vis_ftr_type', default='rgb', type=str) 54 | self.add_argument('--struct_flag', action='store_true', default=False) 55 | self.add_argument('--struct_only', action='store_true', default=False) 56 | self.add_argument('--eval_val_flag', action='store_true', default=False) 57 | self.add_argument('--eval_test_flag', action='store_true', default=False) 58 | self.add_argument('--entropy_regu_flag', action='store_true', default=False) 59 | self.add_argument('--hidden_dim', default=128, type=int) 60 | self.add_argument('--centre_num', default=32, type=int) 61 | self.add_argument('--vlad_alpha', default=1.0, type=float) 62 | self.add_argument('--cache_flag', action='store_true', default=False) 63 | self.add_argument('--use_mean_cache_flag', action='store_true', default=False) 64 | self.add_argument('--batch_size', type=int, default=64) 65 | self.add_argument('--video_embedding_size', type=int, default=512) 66 | self.add_argument('--fc_feat_size', type=int, default=2048) 67 | self.add_argument('--word_embedding_size', type=int, default=512) 68 | self.add_argument('--lstm_hidden_size', type=int, default=512) 69 | self.add_argument('--att_hidden_size', type=int, default=512) 70 | self.add_argument('--word_cnt', type=int, default=20) 71 | self.add_argument('--context_flag', action='store_true', default=False) 72 | self.add_argument('--no_shuffle_flag', action='store_true', default=False) 73 | self.add_argument('--frm_level_flag', action='store_true', default=False) 74 | self.add_argument('--frm_num', type=int, default=1) 75 | self.add_argument('--att_exp', type=int, default=1) 76 | self.add_argument('--loss_type', default='triplet_mil', type=str) 77 | self.add_argument('--seed', default=0, type=int) 78 | self.add_argument('--update_iter', default=4, type=int) 79 | self.add_argument('--lamda2', default=1, type=float) 80 | self.add_argument('--video_time_step', type=int, default=20) 81 | self.add_argument('--caption_time_step', type=int, default=20) # tacos65 DDM 15 82 | self.add_argument('--dropout_prob', type=float, default=0.1) 83 | 84 | 85 | def parse_args(): 86 | parser = wsParamParser() 87 | args = parser.parse_args() 88 | half_size ='full' 89 | if args.half_size: 90 | half_size = 'half' 91 | struct_ann = '' 92 | if args.struct_flag: 93 | struct_ann = '_struct_ann_lamda_%d' %(int(args.lamda*10)) 94 | if args.struct_only: 95 | struct_ann = struct_ann + '_only' 96 | if args.entropy_regu_flag: 97 | struct_ann = struct_ann + '_lamda2_' + str(args.lamda2*10) 98 | 99 | struct_ann + args.loss_type 100 | 101 | struct_ann = struct_ann + '_margin_'+ str(args.margin*10)+ '_att_exp' + str(args.att_exp) 102 | 103 | if args.context_flag: 104 | struct_ann = struct_ann + '_context' 105 | 106 | if args.wsMode == 'coAtt': 107 | struct_ann = struct_ann + 'lstm_hd_' + str(args.lstm_hidden_size) 108 | 109 | if args.frm_level_flag: 110 | struct_ann = struct_ann + '_frm_level_' 111 | 112 | if args.lossW: 113 | struct_ann = struct_ann + 'weak_weight_'+str(args.lamda*10) 114 | 115 | struct_ann += '_seed_' + str(args.seed) 116 | 117 | args.logFd = args.logFd +'_bs_'+str(args.batchSize) + '_tn_' + str(args.rpNum) \ 118 | +'_wl_' +str(args.maxWordNum) + '_cn_' + str(args.capNum) +'_fd_'+ str(args.dim_ftr) \ 119 | + '_' + str(args.wsMode) +'_' +str(args.vis_type)+ '_' + str(args.pos_type) + \ 120 | '_' + half_size + '_txt_' + str(args.txt_type) + '_' + str(args.vis_ftr_type) \ 121 | + '_lr_' + str(args.lr*100000) + '_' + str(args.dbSet) + struct_ann 122 | 123 | args.outPre = args.outPre +'_bs_'+str(args.batchSize) + '_tn_' + str(args.rpNum) \ 124 | +'_wl_'+str(args.maxWordNum) + '_cn_' + str(args.capNum) +'_fd_'+ str(args.dim_ftr) \ 125 | + '_' + str(args.wsMode) + str(args.vis_type) + '_'+ str(args.pos_type) + \ 126 | '_'+ half_size+'_txt_'+str(args.txt_type)+ '_' +str(args.vis_ftr_type) + \ 127 | '_lr_' + str(args.lr*100000) + '_' + str(args.dbSet) + struct_ann +'/' 128 | 129 | args.logFdTx = args.logFdTx +'_bs_'+str(args.batchSize) + '_tn_' + str(args.rpNum) \ 130 | +'_wl_' +str(args.maxWordNum) + '_cn_' + str(args.capNum) +'_fd_'+ str(args.dim_ftr) \ 131 | + '_' + str(args.wsMode) +'_' +str(args.vis_type)+ '_' + str(args.pos_type) +'_' + \ 132 | half_size +'_txt_'+ str(args.txt_type) + '_' + str(args.vis_ftr_type) + '_lr_' + \ 133 | str(args.lr*100000) + '_' + str(args.dbSet) + struct_ann 134 | 135 | args.visRsFd = args.visRsFd + args.dbSet + '_' 136 | return args 137 | -------------------------------------------------------------------------------- /images/frm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zfchenUnique/WSSTG/ba19919b5010321651088afaecba877aebb4ed3a/images/frm.png -------------------------------------------------------------------------------- /images/task.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zfchenUnique/WSSTG/ba19919b5010321651088afaecba877aebb4ed3a/images/task.png -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # data preparation 2 | ## linking data and features to ./data 3 | ## ln -s ../../WSSTL/data/ILSVRC ./ILSVRC 4 | ## ln -s ../../WSSTL/data/vid ./vid 5 | -------------------------------------------------------------------------------- /scripts/test_video_emb_att.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python ../fun/eval.py --epSize 1000 \ 2 | --dbSet vid \ 3 | --maxWordNum 20 --num_workers 0 --batchSize 1 --logFd ../data/final_models/abs_ \ 4 | --outPre ../data/vd_model/ --biLoss \ 5 | --hdSize 300 --vwFlag --stEp 0 --logFdTx ../data/tensorBoardX/ \ 6 | --vis_dim 4096 --set_name val\ 7 | --rpNum 30 --saveEp 1 \ 8 | --txt_type gru --pos_type none \ 9 | --lr 0.001 \ 10 | --vis_ftr_type rgb_i3d \ 11 | --margin 1 \ 12 | --vis_type lstm \ 13 | --visIter 1\ 14 | --server_id 36\ 15 | --wsMode coAtt \ 16 | --fc_feat_size 4096 \ 17 | --dim_ftr 512 \ 18 | --no_shuffle_flag \ 19 | --eval_test_flag \ 20 | --initmodel ../data/models/_ep_29_itr_0.pth \ 21 | -------------------------------------------------------------------------------- /scripts/train_video_emb_att.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2 python ../fun/train.py --epSize 30 \ 2 | --seed 0\ 3 | --dbSet vid \ 4 | --maxWordNum 20 --num_workers 0 --batchSize 3 --logFd ../data/log/pm_resume_ \ 5 | --outPre ../data/acc_grad --biLoss \ 6 | --hdSize 300 --vwFlag --stEp 0 --logFdTx ../data/tensorBoardX/pm_resume_ \ 7 | --vis_dim 4096 --set_name train\ 8 | --rpNum 30 --saveEp 1 \ 9 | --txt_type lstm --pos_type none \ 10 | --lr 0 \ 11 | --vis_ftr_type rgb_i3d \ 12 | --margin 1 \ 13 | --vis_type lstm \ 14 | --visIter 5\ 15 | --fc_feat_size 4096 \ 16 | --dim_ftr 512 \ 17 | --server_id 36\ 18 | --stEp 0 \ 19 | --optimizer sgd \ 20 | --entropy_regu_flag --lamda2 1 \ 21 | --wsMode coAtt \ 22 | --update_iter 4 \ 23 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zfchenUnique/WSSTG/ba19919b5010321651088afaecba877aebb4ed3a/util/__init__.py -------------------------------------------------------------------------------- /util/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zfchenUnique/WSSTG/ba19919b5010321651088afaecba877aebb4ed3a/util/__init__.pyc -------------------------------------------------------------------------------- /util/base_parser.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | 4 | import argparse 5 | 6 | class BaseParser(argparse.ArgumentParser): 7 | def __init__(self, *arg_list, **arg_dict): 8 | super(BaseParser, self).__init__(*arg_list, **arg_dict) 9 | self.add_argument('--max_iter', default=1500, type=int) 10 | self.add_argument('--start_iter', default=0, type=int) 11 | self.add_argument('--val_interval', default=20, type=int) 12 | self.add_argument('--saving_interval', default=100, type=int) 13 | self.add_argument('--suffix', default='') 14 | self.add_argument('--lrdecay', default=500, type=float) 15 | self.add_argument('--clip_c', default=10., type=float) 16 | self.add_argument('--wo_early_stopping', 17 | dest='early_stopping', 18 | action='store_false', 19 | default=True) 20 | self.add_argument('--alpha', default=0.001, type=float) 21 | self.add_argument('--lr', default=0.001, type=float) 22 | self.add_argument('--momentum', default=0.9, type=float) 23 | 24 | def parse_args(self, *arg_list, **arg_dict): 25 | args = super(BaseParser, self).parse_args(*arg_list, **arg_dict) 26 | if len(args.suffix) > 0: 27 | args.suffix = '_' + args.suffix 28 | return args 29 | -------------------------------------------------------------------------------- /util/base_parser.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zfchenUnique/WSSTG/ba19919b5010321651088afaecba877aebb4ed3a/util/base_parser.pyc -------------------------------------------------------------------------------- /util/get_image_size.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from __future__ import print_function 4 | """ 5 | 6 | get_image_size.py 7 | ==================== 8 | 9 | :Name: get_image_size 10 | :Purpose: extract image dimensions given a file path 11 | 12 | :Author: Paulo Scardine (based on code from Emmanuel VAÏSSE) 13 | 14 | :Created: 26/09/2013 15 | :Copyright: (c) Paulo Scardine 2013 16 | :Licence: MIT 17 | 18 | """ 19 | import collections 20 | import json 21 | import os 22 | import struct 23 | 24 | FILE_UNKNOWN = "Sorry, don't know how to get size for this file." 25 | 26 | 27 | class UnknownImageFormat(Exception): 28 | pass 29 | 30 | 31 | types = collections.OrderedDict() 32 | BMP = types['BMP'] = 'BMP' 33 | GIF = types['GIF'] = 'GIF' 34 | ICO = types['ICO'] = 'ICO' 35 | JPEG = types['JPEG'] = 'JPEG' 36 | PNG = types['PNG'] = 'PNG' 37 | TIFF = types['TIFF'] = 'TIFF' 38 | 39 | image_fields = ['path', 'type', 'file_size', 'width', 'height'] 40 | 41 | 42 | class Image(collections.namedtuple('Image', image_fields)): 43 | 44 | def to_str_row(self): 45 | return ("%d\t%d\t%d\t%s\t%s" % ( 46 | self.width, 47 | self.height, 48 | self.file_size, 49 | self.type, 50 | self.path.replace('\t', '\\t'), 51 | )) 52 | 53 | def to_str_row_verbose(self): 54 | return ("%d\t%d\t%d\t%s\t%s\t##%s" % ( 55 | self.width, 56 | self.height, 57 | self.file_size, 58 | self.type, 59 | self.path.replace('\t', '\\t'), 60 | self)) 61 | 62 | def to_str_json(self, indent=None): 63 | return json.dumps(self._asdict(), indent=indent) 64 | 65 | 66 | def get_image_size(file_path): 67 | """ 68 | Return (width, height) for a given img file content - no external 69 | dependencies except the os and struct builtin modules 70 | """ 71 | img = get_image_metadata(file_path) 72 | return (img.width, img.height) 73 | 74 | 75 | def get_image_metadata(file_path): 76 | """ 77 | Return an `Image` object for a given img file content - no external 78 | dependencies except the os and struct builtin modules 79 | 80 | Args: 81 | file_path (str): path to an image file 82 | 83 | Returns: 84 | Image: (path, type, file_size, width, height) 85 | """ 86 | size = os.path.getsize(file_path) 87 | 88 | # be explicit with open arguments - we need binary mode 89 | with open(file_path, "rb") as input: 90 | height = -1 91 | width = -1 92 | data = input.read(26) 93 | msg = " raised while trying to decode as JPEG." 94 | 95 | if (size >= 10) and data[:6] in (b'GIF87a', b'GIF89a'): 96 | # GIFs 97 | imgtype = GIF 98 | w, h = struct.unpack("= 24) and data.startswith(b'\211PNG\r\n\032\n') 102 | and (data[12:16] == b'IHDR')): 103 | # PNGs 104 | imgtype = PNG 105 | w, h = struct.unpack(">LL", data[16:24]) 106 | width = int(w) 107 | height = int(h) 108 | elif (size >= 16) and data.startswith(b'\211PNG\r\n\032\n'): 109 | # older PNGs 110 | imgtype = PNG 111 | w, h = struct.unpack(">LL", data[8:16]) 112 | width = int(w) 113 | height = int(h) 114 | elif (size >= 2) and data.startswith(b'\377\330'): 115 | # JPEG 116 | imgtype = JPEG 117 | input.seek(0) 118 | input.read(2) 119 | b = input.read(1) 120 | try: 121 | while (b and ord(b) != 0xDA): 122 | while (ord(b) != 0xFF): 123 | b = input.read(1) 124 | while (ord(b) == 0xFF): 125 | b = input.read(1) 126 | if (ord(b) >= 0xC0 and ord(b) <= 0xC3): 127 | input.read(3) 128 | h, w = struct.unpack(">HH", input.read(4)) 129 | break 130 | else: 131 | input.read( 132 | int(struct.unpack(">H", input.read(2))[0]) - 2) 133 | b = input.read(1) 134 | width = int(w) 135 | height = int(h) 136 | except struct.error: 137 | raise UnknownImageFormat("StructError" + msg) 138 | except ValueError: 139 | raise UnknownImageFormat("ValueError" + msg) 140 | except Exception as e: 141 | raise UnknownImageFormat(e.__class__.__name__ + msg) 142 | elif (size >= 26) and data.startswith(b'BM'): 143 | # BMP 144 | imgtype = 'BMP' 145 | headersize = struct.unpack("= 40: 151 | w, h = struct.unpack("= 8) and data[:4] in (b"II\052\000", b"MM\000\052"): 160 | # Standard TIFF, big- or little-endian 161 | # BigTIFF and other different but TIFF-like formats are not 162 | # supported currently 163 | imgtype = TIFF 164 | byteOrder = data[:2] 165 | boChar = ">" if byteOrder == "MM" else "<" 166 | # maps TIFF type id to size (in bytes) 167 | # and python format char for struct 168 | tiffTypes = { 169 | 1: (1, boChar + "B"), # BYTE 170 | 2: (1, boChar + "c"), # ASCII 171 | 3: (2, boChar + "H"), # SHORT 172 | 4: (4, boChar + "L"), # LONG 173 | 5: (8, boChar + "LL"), # RATIONAL 174 | 6: (1, boChar + "b"), # SBYTE 175 | 7: (1, boChar + "c"), # UNDEFINED 176 | 8: (2, boChar + "h"), # SSHORT 177 | 9: (4, boChar + "l"), # SLONG 178 | 10: (8, boChar + "ll"), # SRATIONAL 179 | 11: (4, boChar + "f"), # FLOAT 180 | 12: (8, boChar + "d") # DOUBLE 181 | } 182 | ifdOffset = struct.unpack(boChar + "L", data[4:8])[0] 183 | try: 184 | countSize = 2 185 | input.seek(ifdOffset) 186 | ec = input.read(countSize) 187 | ifdEntryCount = struct.unpack(boChar + "H", ec)[0] 188 | # 2 bytes: TagId + 2 bytes: type + 4 bytes: count of values + 4 189 | # bytes: value offset 190 | ifdEntrySize = 12 191 | for i in range(ifdEntryCount): 192 | entryOffset = ifdOffset + countSize + i * ifdEntrySize 193 | input.seek(entryOffset) 194 | tag = input.read(2) 195 | tag = struct.unpack(boChar + "H", tag)[0] 196 | if(tag == 256 or tag == 257): 197 | # if type indicates that value fits into 4 bytes, value 198 | # offset is not an offset but value itself 199 | type = input.read(2) 200 | type = struct.unpack(boChar + "H", type)[0] 201 | if type not in tiffTypes: 202 | raise UnknownImageFormat( 203 | "Unkown TIFF field type:" + 204 | str(type)) 205 | typeSize = tiffTypes[type][0] 206 | typeChar = tiffTypes[type][1] 207 | input.seek(entryOffset + 8) 208 | value = input.read(typeSize) 209 | value = int(struct.unpack(typeChar, value)[0]) 210 | if tag == 256: 211 | width = value 212 | else: 213 | height = value 214 | if width > -1 and height > -1: 215 | break 216 | except Exception as e: 217 | raise UnknownImageFormat(str(e)) 218 | elif size >= 2: 219 | # see http://en.wikipedia.org/wiki/ICO_(file_format) 220 | imgtype = 'ICO' 221 | input.seek(0) 222 | reserved = input.read(2) 223 | if 0 != struct.unpack(" 1: 230 | import warnings 231 | warnings.warn("ICO File contains more than one image") 232 | # http://msdn.microsoft.com/en-us/library/ms997538.aspx 233 | w = input.read(1) 234 | h = input.read(1) 235 | width = ord(w) 236 | height = ord(h) 237 | else: 238 | raise UnknownImageFormat(FILE_UNKNOWN) 239 | 240 | return Image(path=file_path, 241 | type=imgtype, 242 | file_size=size, 243 | width=width, 244 | height=height) 245 | 246 | 247 | import unittest 248 | 249 | 250 | class Test_get_image_size(unittest.TestCase): 251 | data = [{ 252 | 'path': 'lookmanodeps.png', 253 | 'width': 251, 254 | 'height': 208, 255 | 'file_size': 22228, 256 | 'type': 'PNG'}] 257 | 258 | def setUp(self): 259 | pass 260 | 261 | def test_get_image_metadata(self): 262 | img = self.data[0] 263 | output = get_image_metadata(img['path']) 264 | self.assertTrue(output) 265 | self.assertEqual(output.path, img['path']) 266 | self.assertEqual(output.width, img['width']) 267 | self.assertEqual(output.height, img['height']) 268 | self.assertEqual(output.type, img['type']) 269 | self.assertEqual(output.file_size, img['file_size']) 270 | for field in image_fields: 271 | self.assertEqual(getattr(output, field), img[field]) 272 | 273 | def test_get_image_metadata__ENOENT_OSError(self): 274 | with self.assertRaises(OSError): 275 | get_image_metadata('THIS_DOES_NOT_EXIST') 276 | 277 | def test_get_image_metadata__not_an_image_UnknownImageFormat(self): 278 | with self.assertRaises(UnknownImageFormat): 279 | get_image_metadata('README.rst') 280 | 281 | def test_get_image_size(self): 282 | img = self.data[0] 283 | output = get_image_size(img['path']) 284 | self.assertTrue(output) 285 | self.assertEqual(output, 286 | (img['width'], 287 | img['height'])) 288 | 289 | def tearDown(self): 290 | pass 291 | 292 | 293 | def main(argv=None): 294 | """ 295 | Print image metadata fields for the given file path. 296 | 297 | Keyword Arguments: 298 | argv (list): commandline arguments (e.g. sys.argv[1:]) 299 | Returns: 300 | int: zero for OK 301 | """ 302 | import logging 303 | import optparse 304 | import sys 305 | 306 | prs = optparse.OptionParser( 307 | usage="%prog [-v|--verbose] [--json|--json-indent] []", 308 | description="Print metadata for the given image paths " 309 | "(without image library bindings).") 310 | 311 | prs.add_option('--json', 312 | dest='json', 313 | action='store_true') 314 | prs.add_option('--json-indent', 315 | dest='json_indent', 316 | action='store_true') 317 | 318 | prs.add_option('-v', '--verbose', 319 | dest='verbose', 320 | action='store_true',) 321 | prs.add_option('-q', '--quiet', 322 | dest='quiet', 323 | action='store_true',) 324 | prs.add_option('-t', '--test', 325 | dest='run_tests', 326 | action='store_true',) 327 | 328 | argv = list(argv) if argv is not None else sys.argv[1:] 329 | (opts, args) = prs.parse_args(args=argv) 330 | loglevel = logging.INFO 331 | if opts.verbose: 332 | loglevel = logging.DEBUG 333 | elif opts.quiet: 334 | loglevel = logging.ERROR 335 | logging.basicConfig(level=loglevel) 336 | log = logging.getLogger() 337 | log.debug('argv: %r', argv) 338 | log.debug('opts: %r', opts) 339 | log.debug('args: %r', args) 340 | 341 | if opts.run_tests: 342 | import sys 343 | sys.argv = [sys.argv[0]] + args 344 | import unittest 345 | return unittest.main() 346 | 347 | output_func = Image.to_str_row 348 | if opts.json_indent: 349 | import functools 350 | output_func = functools.partial(Image.to_str_json, indent=2) 351 | elif opts.json: 352 | output_func = Image.to_str_json 353 | elif opts.verbose: 354 | output_func = Image.to_str_row_verbose 355 | 356 | EX_OK = 0 357 | EX_NOT_OK = 2 358 | 359 | if len(args) < 1: 360 | prs.print_help() 361 | print('') 362 | prs.error("You must specify one or more paths to image files") 363 | 364 | errors = [] 365 | for path_arg in args: 366 | try: 367 | img = get_image_metadata(path_arg) 368 | print(output_func(img)) 369 | except KeyboardInterrupt: 370 | raise 371 | except OSError as e: 372 | log.error((path_arg, e)) 373 | errors.append((path_arg, e)) 374 | except Exception as e: 375 | log.exception(e) 376 | errors.append((path_arg, e)) 377 | pass 378 | if len(errors): 379 | import pprint 380 | print("ERRORS", file=sys.stderr) 381 | print("======", file=sys.stderr) 382 | print(pprint.pformat(errors, indent=2), file=sys.stderr) 383 | return EX_NOT_OK 384 | return EX_OK 385 | 386 | 387 | if __name__ == "__main__": 388 | import sys 389 | sys.exit(main(argv=sys.argv[1:])) 390 | -------------------------------------------------------------------------------- /util/mytoolbox.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | 4 | import commands 5 | import copy 6 | import cPickle 7 | import datetime 8 | import inspect 9 | import json 10 | import os 11 | import pickle 12 | import sys 13 | import time 14 | import scipy.io 15 | import cv2 16 | import pdb 17 | 18 | def set_debugger(): 19 | from IPython.core import ultratb 20 | sys.excepthook = ultratb.FormattedTB(call_pdb=True) 21 | 22 | class TimeReporter: 23 | def __init__(self, max_count, interval=1, moving_average=False): 24 | self.time = time.time 25 | self.start_time = time.time() 26 | self.max_count = max_count 27 | self.cur_count = 0 28 | self.prev_time = time.time() 29 | self.interval = interval 30 | self.moving_average = moving_average 31 | def report(self, cur_count=None, max_count=None, overwrite=True, prefix=None, postfix=None, interval=None): 32 | if cur_count is not None: 33 | self.cur_count = cur_count 34 | else: 35 | self.cur_count += 1 36 | if max_count is None: 37 | max_count = self.max_count 38 | cur_time = self.time() 39 | elapsed = cur_time - self.start_time 40 | if self.cur_count <= 0: 41 | ave_time = float('inf') 42 | elif self.moving_average and self.cur_count == 1: 43 | ave_time = float('inf') 44 | self.ma_prev_time = cur_time 45 | elif self.moving_average and self.cur_count == 2: 46 | self.ma_time = cur_time - self.ma_prev_time 47 | ave_time = self.ma_time 48 | self.ma_prev_time = cur_time 49 | elif self.moving_average: 50 | self.ma_time = self.ma_time * 0.95 + (cur_time - self.ma_prev_time) * 0.05 51 | ave_time = self.ma_time 52 | self.ma_prev_time = cur_time 53 | else: 54 | ave_time = elapsed / self.cur_count 55 | ETA = (max_count - self.cur_count) * ave_time 56 | print_str = 'count : %d / %d, elapsed time : %f, ETA : %f' % (self.cur_count, self.max_count, elapsed, ETA) 57 | if prefix is not None: 58 | print_str = str(prefix) + ' ' + print_str 59 | if postfix is not None: 60 | print_str += ' ' + str(postfix) 61 | this_interval = self.interval 62 | if interval is not None: 63 | this_interval = interval 64 | if cur_time - self.prev_time < this_interval: 65 | return 66 | if overwrite and self.cur_count != self.max_count: 67 | printr(print_str) 68 | self.prev_time = cur_time 69 | else: 70 | print print_str 71 | self.prev_time = cur_time 72 | 73 | def textread(path): 74 | f = open(path) 75 | lines = f.readlines() 76 | f.close() 77 | for i in range(len(lines)): 78 | lines[i] = lines[i].replace('\n', '').replace('\r', '') 79 | return lines 80 | 81 | def textdump(path, lines, need_asking=False): 82 | if os.path.exists(path) and need_asking: 83 | if 'n' == choosebyinput(['Y', 'n'], path + ' exists. Would you replace? [Y/n]'): 84 | return False 85 | f = open(path, 'w') 86 | for index, i in enumerate(lines): 87 | try: 88 | f.write(i.encode("utf-8") + '\n') 89 | except: 90 | print(index) 91 | pdb.set_trace() 92 | 93 | f.close() 94 | 95 | def pickleload(path): 96 | f = open(path) 97 | this_ans = pickle.load(f) 98 | f.close() 99 | return this_ans 100 | 101 | def pickledump(path, this_dic): 102 | f = open(path, 'w') 103 | this_ans = pickle.dump(this_dic, f) 104 | f.close() 105 | 106 | def cPickleload(path): 107 | f = open(path, 'rb') 108 | this_ans = cPickle.load(f) 109 | f.close() 110 | return this_ans 111 | 112 | def cPickledump(path, this_dic): 113 | f = open(path, 'wb') 114 | this_ans = cPickle.dump(this_dic, f, -1) 115 | f.close() 116 | 117 | def jsonload(path): 118 | f = open(path) 119 | this_ans = json.load(f) 120 | f.close() 121 | return this_ans 122 | 123 | def jsondump(path, this_dic): 124 | f = open(path, 'w') 125 | this_ans = json.dump(this_dic, f) 126 | f.close() 127 | 128 | def choosebyinput(cand, message=False): 129 | if not type(cand) == list and not type(cand) == int: 130 | print 'The type of cand_list has to be \'list\' or \'int\' .' 131 | return 132 | if type(cand) == int: 133 | cand_list = range(cand) 134 | if type(cand) == list: 135 | cand_list = cand 136 | int_cand_list = [] 137 | for i in cand_list: 138 | if type(i) == int: 139 | int_cand_list.append(str(i)) 140 | if message == False: 141 | message = 'choose by input [' 142 | for i in int_cand_list: 143 | message += i + ' / ' 144 | for i in cand_list: 145 | if not str(i) in int_cand_list: 146 | message += i + ' / ' 147 | message = message[:-3] + '] : ' 148 | while True: 149 | your_ans = raw_input(message) 150 | if your_ans in int_cand_list: 151 | return int(your_ans) 152 | break 153 | if your_ans in cand_list: 154 | return your_ans 155 | break 156 | 157 | def printr(*targ_str): 158 | str_to_print = '' 159 | for temp_str in targ_str: 160 | str_to_print += str(temp_str) + ' ' 161 | str_to_print = str_to_print[:-1] 162 | sys.stdout.write(str_to_print + '\r') 163 | sys.stdout.flush() 164 | 165 | def make_red(prt): 166 | return '\033[91m%s\033[00m' % prt 167 | 168 | def emphasize(*targ_str): 169 | str_to_print = '' 170 | for temp_str in targ_str: 171 | str_to_print += str(temp_str) + ' ' 172 | str_to_print = str_to_print[:-1] 173 | num_repeat = len(str_to_print) / 2 + 1 174 | print '_' + '人' * (num_repeat + 1) + '_' 175 | print '> %s <' % make_red(str_to_print) 176 | print ' ̄' + 'Y^' * num_repeat + 'Y ̄' 177 | 178 | def mkdir_if_missing(dir_path): 179 | if not os.path.exists(dir_path): 180 | os.mkdir(dir_path) 181 | 182 | def makedirs_if_missing(dir_path): 183 | if not os.path.exists(dir_path): 184 | os.makedirs(dir_path) 185 | 186 | def makebsdirs_if_missing(f_path): 187 | makedirs_if_missing(os.path.dirname(f_path) if '/' in f_path else f_path) 188 | 189 | def split_inds(all_num, split_num, split_targ): 190 | assert split_num >= 1 191 | assert split_targ >= 0 192 | assert split_targ < split_num 193 | part = all_num // split_num 194 | if not split_num == split_targ+1: 195 | return split_targ * part, (split_targ+1) * part 196 | else: 197 | return split_targ * part, all_num 198 | 199 | try: 200 | import numpy as np 201 | def are_same_vecs(vec_a, vec_b, this_eps1=1e-5, verbose=False): 202 | if not vec_a.ravel().shape == vec_b.ravel().shape: 203 | return False 204 | if np.linalg.norm(vec_a.ravel()) == 0: 205 | if not np.linalg.norm(vec_b.ravel()) == 0: 206 | if verbose: 207 | print 'assertion failed.' 208 | print 'diff norm : %f' % (np.linalg.norm(vec_a.ravel() - vec_b.ravel())) 209 | return False 210 | else: 211 | if not np.linalg.norm(vec_a.ravel() - vec_b.ravel()) / np.linalg.norm(vec_a.ravel()) < this_eps1: 212 | if verbose: 213 | print 'assertion failed.' 214 | print 'diff norm : %f' % (np.linalg.norm(vec_a.ravel() - vec_b.ravel()) / np.linalg.norm(vec_a.ravel())) 215 | return False 216 | return True 217 | def comp_vecs(vec_a, vec_b, this_eps1=1e-5): 218 | assert are_same_vecs(vec_a, vec_b, this_eps1, True) 219 | def arrayinfo(np_array): 220 | print 'max: %04f, min: %04f, abs_min: %04f, norm: %04f,' % (np_array.max(), np_array.min(), np.abs(np_array).min(), np.linalg.norm(np_array)), 221 | print 'dtype: %s,' % np_array.dtype, 222 | print 'shape: %s,' % str(np_array.shape), 223 | print 224 | except: 225 | def comp_vecs(*input1, **input2): 226 | print 'comp_vecs() cannot be loaded.' 227 | return 228 | def arrayinfo(*input1, **input2): 229 | print 'arrayinfo() cannot be loaded.' 230 | return 231 | 232 | try: 233 | import Levenshtein 234 | def search_nn_str(targ_str, str_lists): 235 | dist = float('inf') 236 | dist_str = None 237 | for i in sorted(str_lists): 238 | cur_dist = Levenshtein.distance(i, targ_str) 239 | if dist > cur_dist: 240 | dist = cur_dist 241 | dist_str = i 242 | return dist_str 243 | except: 244 | def search_nn_str(targ_str, str_lists): 245 | print 'search_nn_str() cannot be imported.' 246 | return 247 | 248 | def flatten(targ_list): 249 | new_list = copy.deepcopy(targ_list) 250 | for i in reversed(range(len(new_list))): 251 | if isinstance(new_list[i], list) or isinstance(new_list[i], tuple): 252 | new_list[i:i+1] = flatten(new_list[i]) 253 | return new_list 254 | 255 | def predict_charset(targ_str): 256 | targ_charsets = ['utf-8', 'cp932', 'euc-jp', 'iso-2022-jp'] 257 | for targ_charset in targ_charsets: 258 | try: 259 | targ_str.decode(targ_charset) 260 | return targ_charset 261 | except UnicodeDecodeError: 262 | pass 263 | return None 264 | 265 | def remove_non_ascii(targ_str, charset=None): 266 | if charset is not None: 267 | assert isinstance(targ_str, str) 268 | targ_str = targ_str.decode(charset) 269 | else: 270 | assert isinstance(targ_str, unicode) 271 | return ''.join([x for x in targ_str if ord(x) < 128]).encode('ascii') 272 | 273 | class StopWatch(object): 274 | def __init__(self): 275 | self._time = {} 276 | self._bef_time = {} 277 | def tic(self, name): 278 | self._bef_time[name] = time.time() 279 | def toc(self, name): 280 | self._time[name] = time.time() - self._bef_time[name] 281 | self._time[name] = time.time() - self._bef_time[name] 282 | return self._time[name] 283 | def show(self): 284 | show_str = '' 285 | for name, elp in self._time.iteritems(): 286 | show_str += '%s: %03.3f, ' % (name, elp) 287 | printr(show_str[:-2]) 288 | 289 | Timer = StopWatch # deprecated 290 | 291 | def get_free_gpu(default_gpu): 292 | FORMAT = '--format=csv,noheader' 293 | COM_GPU_UTIL = 'nvidia-smi --query-gpu=index,uuid ' + FORMAT 294 | COM_GPU_PROCESS = 'nvidia-smi --query-compute-apps=gpu_uuid ' + FORMAT 295 | uuid2id = {cur_line.split(',')[1].strip(): int(cur_line.split(',')[0]) 296 | for cur_line in commands.getoutput(COM_GPU_UTIL).split('\n')} 297 | used_gpus = set() 298 | for cur_line in commands.getoutput(COM_GPU_PROCESS).split('\n'): 299 | used_gpus.add(cur_line) 300 | if len(uuid2id) == len(used_gpus): 301 | return default_gpu 302 | elif os.uname()[1] == 'konoshiro': 303 | return str(1 - int(uuid2id[list(set(uuid2id.keys()) - used_gpus)[0]])) 304 | else: 305 | return uuid2id[list(set(uuid2id.keys()) - used_gpus)[0]] 306 | 307 | def get_abs_path(): 308 | return os.path.dirname(os.path.abspath(os.path.join(os.getcwd(), __file__))) 309 | 310 | def add_path(path): 311 | if path not in sys.path: 312 | sys.path.insert(0, path) 313 | 314 | def get_time_str(): 315 | return datetime.datetime.now().strftime('Y%yM%mD%dH%hM%M') 316 | 317 | def get_cur_time(): 318 | return str(datetime.datetime.now().strftime('%Y/%m/%d %H:%M:%S')) 319 | 320 | def split_carefully(text, splitter=',', delimiters=['"', "'"]): 321 | # assertion 322 | assert isinstance(splitter, str) 323 | assert not splitter in delimiters 324 | if not (isinstance(delimiters, list) or isinstance(delimiters, tuple)): 325 | delimiters = [delimiters] 326 | for cur_del in delimiters: 327 | assert len(cur_del) == 1 328 | 329 | cur_ind = 0 330 | prev_ind = 0 331 | splitted = [] 332 | is_in_delimiters = False 333 | cur_del = None 334 | while cur_ind < len(text): 335 | if text[cur_ind] in delimiters: 336 | if text[cur_ind] == cur_del: 337 | is_in_delimiters = False 338 | cur_del = None 339 | cur_ind += 1 340 | continue 341 | elif not is_in_delimiters: 342 | is_in_delimiters = True 343 | cur_del = text[cur_ind] 344 | cur_ind += 1 345 | continue 346 | if not is_in_delimiters and text[cur_ind] == splitter: 347 | splitted.append(text[prev_ind:cur_ind]) 348 | cur_ind += 1 349 | prev_ind = cur_ind 350 | continue 351 | cur_ind += 1 352 | splitted.append(text[prev_ind:cur_ind]) 353 | return splitted 354 | 355 | def full_listdir(dir_name): 356 | return [os.path.join(dir_name, i) for i in os.listdir(dir_name)] 357 | 358 | def get_list_dir(dir_name): 359 | fileFdList = full_listdir(dir_name) 360 | folder_list = list() 361 | for item in fileFdList: 362 | if os.path.isdir(item): 363 | folder_list.append(item) 364 | return folder_list 365 | 366 | 367 | class tictoc(object): 368 | def __init__(self, targ_list): 369 | self._targ_list = targ_list 370 | self._list_ind = -1 371 | self._TR = TimeReporter(len(targ_list)) 372 | def __iter__(self): 373 | return self 374 | def next(self): 375 | self._list_ind += 1 376 | if self._list_ind > 0: 377 | self._TR.report() 378 | if self._list_ind == len(self._targ_list): 379 | raise StopIteration() 380 | return self._targ_list[self._list_ind] 381 | 382 | ONCE_PRINTED = set() 383 | 384 | 385 | def print_once(*targ_str): 386 | frame = inspect.currentframe(1) 387 | fname = inspect.getfile(frame) 388 | cur_loc = frame.f_lineno 389 | cur_key = fname + str(cur_loc) 390 | if cur_key in ONCE_PRINTED: 391 | return 392 | else: 393 | ONCE_PRINTED.add(cur_key) 394 | str_to_print = '' 395 | for temp_str in targ_str: 396 | str_to_print += str(temp_str) + ' ' 397 | str_to_print = str_to_print[:-1] 398 | print str_to_print 399 | 400 | def get_specific_file_list_from_fd(dir_name, fileType, nameOnly=True): 401 | list_name = [] 402 | for fileTmp in os.listdir(dir_name): 403 | file_path = os.path.join(dir_name, fileTmp) 404 | if os.path.isdir(file_path): 405 | continue 406 | elif os.path.splitext(fileTmp)[1] == fileType: 407 | if nameOnly: 408 | list_name.append(os.path.splitext(fileTmp)[0]) 409 | else: 410 | list_name.append(file_path) 411 | return list_name 412 | 413 | # bugs on annotations 414 | def parse_mul_num_lines(fileName, toFloat=True, spliter=','): 415 | lineOut = [] 416 | lineList= textread(fileName) 417 | for lineTmp in lineList: 418 | splitedStr= split_carefully(lineTmp, spliter) 419 | if(len(splitedStr)<4): 420 | print('stange encoding for %s!' %(fileName)) 421 | splitedStr= split_carefully(lineTmp, '\t') 422 | if(len(splitedStr)<4): 423 | print('stange encoding for %s!' %(fileName)) 424 | splitedStr= split_carefully(lineTmp, ' ') 425 | 426 | if(toFloat): 427 | splitedTmp= [ float(ele) for ele in splitedStr] 428 | lineOut.append(splitedTmp) 429 | else: 430 | lineOut.append(splitedStr) 431 | return lineOut 432 | 433 | def pck2mat(pckFn, outFn): 434 | data = pickleload(pckFn) 435 | scipy.io.savemat(outFn, data) 436 | print('finish transformation') 437 | 438 | def putCapOnImage(imgVis, capList): 439 | if isinstance(capList, list): 440 | cap = '' 441 | for ele in capList: 442 | cap +=ele 443 | cap +=' ' 444 | else: 445 | cap = capList 446 | cv2.putText(imgVis, cap, 447 | (10, 50), 448 | cv2.FONT_HERSHEY_PLAIN, 449 | 2, (0, 0, 255), 450 | 2) 451 | return imgVis 452 | 453 | def get_all_file_list(dir_name): 454 | file_list=list() 455 | for fileTmp in os.listdir(dir_name): 456 | file_path = os.path.join(dir_name, fileTmp) 457 | if os.path.isdir(file_path): 458 | continue 459 | file_list.append(file_path) 460 | return file_list 461 | 462 | def resize_image_with_fixed_height(img, hSize =320): 463 | h, w, c = img.shape 464 | scl = hSize*1.0/h 465 | imgResize = cv2.resize(img, None, None, fx=scl, fy=scl) 466 | return imgResize, scl, h, w 467 | 468 | -------------------------------------------------------------------------------- /util/mytoolbox.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zfchenUnique/WSSTG/ba19919b5010321651088afaecba877aebb4ed3a/util/mytoolbox.pyc --------------------------------------------------------------------------------