├── .gitignore ├── Data ├── VisualGenome │ ├── LICENSE.txt │ ├── __init__.py │ ├── api.py │ ├── local.py │ ├── models.py │ └── utils.py └── __init__.py ├── FilesManager ├── FilesManager.py ├── LanguageModule │ ├── object_embeddings.p │ └── predicate_embeddings.p ├── __init__.py ├── files.yaml └── module │ ├── params.yaml │ ├── saver │ └── __init__.py │ └── tf_logs │ └── __init__.py ├── Module ├── Eval.py ├── Module.py ├── Train.py └── __init__.py ├── README.md ├── Run.py ├── Utils ├── Logger.py ├── Singleton.py └── __init__.py ├── __init__.py ├── qualitive_results_att_boxes.png ├── requirements.txt ├── sg_example_final.png └── sgp_arch_git.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | #additional 104 | FilesManager/module/tf_logs/ 105 | FilesManager/module/saver/ 106 | -------------------------------------------------------------------------------- /Data/VisualGenome/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 Ranjay Krishna 2 | 3 | 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | 13 | 14 | The above copyright notice and this permission notice shall be included in 15 | all copies or substantial portions of the Software. 16 | 17 | 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 25 | THE SOFTWARE. 26 | -------------------------------------------------------------------------------- /Data/VisualGenome/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shikorab/SceneGraph/d2dc960a0d3728e7d2b0eaceb9d0fdaedb2a47d5/Data/VisualGenome/__init__.py -------------------------------------------------------------------------------- /Data/VisualGenome/api.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import utils 3 | 4 | 5 | def GetAllImageIds(): 6 | """ 7 | Get all Image ids. 8 | """ 9 | page = 1 10 | next = '/api/v0/images/all?page=' + str(page) 11 | ids = [] 12 | while True: 13 | data = utils.RetrieveData(next) 14 | ids.extend(data['results']) 15 | if data['next'] is None: 16 | break 17 | page += 1 18 | next = '/api/v0/images/all?page=' + str(page) 19 | return ids 20 | 21 | 22 | def GetImageIdsInRange(startIndex=0, endIndex=99): 23 | """ 24 | Get Image ids from startIndex to endIndex. 25 | """ 26 | idsPerPage = 1000 27 | startPage = startIndex / idsPerPage + 1 28 | endPage = endIndex / idsPerPage + 1 29 | ids = [] 30 | for page in range(startPage, endPage + 1): 31 | data = utils.RetrieveData('/api/v0/images/all?page=' + str(page)) 32 | ids.extend(data['results']) 33 | ids = ids[startIndex % 100:] 34 | ids = ids[:endIndex - startIndex + 1] 35 | return ids 36 | 37 | 38 | def GetImageData(id=61512): 39 | """ 40 | Get data about an image. 41 | """ 42 | data = utils.RetrieveData('/api/v0/images/' + str(id)) 43 | if 'detail' in data and data['detail'] == 'Not found.': 44 | return None 45 | image = utils.ParseImageData(data) 46 | return image 47 | 48 | 49 | def GetRegionDescriptionsOfImage(id=61512): 50 | """ 51 | Get the region descriptions of an image. 52 | """ 53 | image = GetImageData(id=id) 54 | data = utils.RetrieveData('/api/v0/images/' + str(id) + '/regions') 55 | if 'detail' in data and data['detail'] == 'Not found.': 56 | return None 57 | return utils.ParseRegionDescriptions(data, image) 58 | 59 | 60 | def GetRegionGraphOfRegion(image_id=61512, region_id=1): 61 | """ 62 | Get Region Graph of a particular Region in an image. 63 | """ 64 | image = GetImageData(id=image_id) 65 | data = utils.RetrieveData('/api/v0/images/' + str(image_id) + '/regions/' + str(region_id)) 66 | if 'detail' in data and data['detail'] == 'Not found.': 67 | return None 68 | return utils.ParseGraph(data[0], image) 69 | 70 | 71 | def GetSceneGraphOfImage(id=61512): 72 | """ 73 | Get Scene Graph of an image. 74 | """ 75 | image = GetImageData(id=id) 76 | data = utils.RetrieveData('/api/v0/images/' + str(id) + '/graph') 77 | if 'detail' in data and data['detail'] == 'Not found.': 78 | return None 79 | return utils.ParseGraph(data, image) 80 | 81 | 82 | def GetAllQAs(qtotal=100): 83 | """ 84 | Gets all the QA from the dataset. 85 | qtotal int total number of QAs to return. Set to None if all QAs should be returned 86 | """ 87 | page = 1 88 | next = '/api/v0/qa/all?page=' + str(page) 89 | qas = [] 90 | image_map = {} 91 | while True: 92 | data = utils.RetrieveData(next) 93 | for d in data['results']: 94 | if d['image'] not in image_map: 95 | image_map[d['image']] = GetImageData(id=d['image']) 96 | qas.extend(utils.ParseQA(data['results'], image_map)) 97 | if qtotal is not None and len(qas) > qtotal: 98 | return qas 99 | if data['next'] is None: 100 | break 101 | page += 1 102 | next = '/api/v0/qa/all?page=' + str(page) 103 | return qas 104 | 105 | 106 | def GetQAofType(qtype='why', qtotal=100): 107 | """ 108 | Get all QA's of a particular type - example, 'why' 109 | qtype string possible values: what, where, when, why, who, how. 110 | qtotal int total number of QAs to return. Set to None if all QAs should be returned 111 | """ 112 | page = 1 113 | next = '/api/v0/qa/' + qtype + '?page=' + str(page) 114 | qas = [] 115 | image_map = {} 116 | while True: 117 | data = utils.RetrieveData(next) 118 | for d in data['results']: 119 | if d['image'] not in image_map: 120 | image_map[d['image']] = GetImageData(id=d['image']) 121 | qas.extend(utils.ParseQA(data['results'], image_map)) 122 | if qtotal is not None and len(qas) > qtotal: 123 | return qas 124 | if data['next'] is None: 125 | break 126 | page += 1 127 | next = '/api/v0/qa/' + qtype + '?page=' + str(page) 128 | return qas 129 | 130 | 131 | def GetQAofImage(id=61512): 132 | """ 133 | Get all QAs for a particular image. 134 | """ 135 | page = 1 136 | next = '/api/v0/image/' + str(id) + '/qa?page=' + str(page) 137 | qas = [] 138 | image_map = {} 139 | while True: 140 | data = utils.RetrieveData(next) 141 | for d in data['results']: 142 | if d['image'] not in image_map: 143 | image_map[d['image']] = GetImageData(id=d['image']) 144 | qas.extend(utils.ParseQA(data['results'], image_map)) 145 | if data['next'] is None: 146 | break 147 | page += 1 148 | next = '/api/v0/image/' + str(id) + '/qa?page=' + str(page) 149 | return qas 150 | 151 | -------------------------------------------------------------------------------- /Data/VisualGenome/local.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | import os 4 | 5 | import utils 6 | from models import Graph, Synset 7 | from models import Image, Object, Attribute, Relationship 8 | 9 | 10 | def GetAllImageData(dataDir=None): 11 | """ 12 | Get Image ids from startIndex to endIndex. 13 | """ 14 | 15 | if dataDir is None: 16 | dataDir = utils.GetDataDir() 17 | dataFile = os.path.join(dataDir, 'image_data.json') 18 | data = json.load(open(dataFile)) 19 | return [utils.ParseImageData(image) for image in data] 20 | 21 | 22 | def GetAllRegionDescriptions(dataDir=None): 23 | """ 24 | Get all region descriptions. 25 | """ 26 | if dataDir is None: 27 | dataDir = utils.GetDataDir() 28 | dataFile = os.path.join(dataDir, 'region_descriptions.json') 29 | imageData = GetAllImageData(dataDir) 30 | imageMap = {} 31 | for d in imageData: 32 | imageMap[d.id] = d 33 | images = json.load(open(dataFile)) 34 | output = [] 35 | for image in images: 36 | output.append(utils.ParseRegionDescriptions(image['regions'], imageMap[image['id']])) 37 | return output 38 | 39 | 40 | def GetAllQAs(dataDir=None): 41 | """ 42 | Get all question answers. 43 | """ 44 | if dataDir is None: 45 | dataDir = utils.GetDataDir() 46 | dataFile = os.path.join(dataDir, 'question_answers.json') 47 | imageData = GetAllImageData(dataDir) 48 | imageMap = {} 49 | for d in imageData: 50 | imageMap[d.id] = d 51 | images = json.load(open(dataFile)) 52 | output = [] 53 | for image in images: 54 | output.append(utils.ParseQA(image['qas'], imageMap)) 55 | return output 56 | 57 | 58 | # -------------------------------------------------------------------------------------------------- 59 | # GetSceneGraphs and sub-methods 60 | 61 | 62 | def GetSceneGraph(image_id, images='data/', imageDataDir='data/by-id/', synsetFile='data/synsets.json'): 63 | """ 64 | Load a single scene graph from a .json file. 65 | """ 66 | if type(images) is str: 67 | # Instead of a string, we can pass this dict as the argument `images` 68 | images = {img.id: img for img in GetAllImageData(images)} 69 | 70 | fname = str(image_id) + '.json' 71 | image = images[image_id] 72 | data = json.load(open(imageDataDir + fname, 'r')) 73 | 74 | scene_graph = ParseGraphLocal(data, image) 75 | scene_graph = InitSynsets(scene_graph, synsetFile) 76 | return scene_graph 77 | 78 | 79 | def GetSceneGraphs(startIndex=0, endIndex=-1, 80 | dataDir='data/', imageDataDir='data/by-id/', 81 | minRels=0, maxRels=100): 82 | """ 83 | Get scene graphs given locally stored .json files; requires `SaveSceneGraphsById`. 84 | 85 | startIndex, endIndex : get scene graphs listed by image, from startIndex through endIndex 86 | dataDir : directory with `image_data.json` and `synsets.json` 87 | imageDataDir : directory of scene graph jsons saved by image id (see `SaveSceneGraphsById`) 88 | minRels, maxRels: only get scene graphs with at least / less than this number of relationships 89 | """ 90 | images = {img.id: img for img in GetAllImageData(dataDir)} 91 | scene_graphs = [] 92 | 93 | img_fnames = os.listdir(imageDataDir) 94 | if endIndex < 1: 95 | endIndex = len(img_fnames) 96 | 97 | for fname in img_fnames[startIndex: endIndex]: 98 | image_id = int(fname.split('.')[0]) 99 | scene_graph = GetSceneGraph(image_id, images, imageDataDir, dataDir + 'synsets.json') 100 | n_rels = len(scene_graph.relationships) 101 | if minRels <= n_rels <= maxRels: 102 | scene_graphs.append(scene_graph) 103 | 104 | return scene_graphs 105 | 106 | 107 | def MapObject(object_map, obj): 108 | """ 109 | Use object ids as hashes to `src.models.Object` instances. If item not 110 | in table, create new `Object`. Used when building scene graphs from json. 111 | """ 112 | oid = obj['object_id'] 113 | obj['id'] = oid 114 | del obj['object_id'] 115 | 116 | if oid in object_map: 117 | object_ = object_map[oid] 118 | 119 | else: 120 | if 'attributes' in obj: 121 | attrs = obj['attributes'] 122 | del obj['attributes'] 123 | else: 124 | attrs = [] 125 | if 'w' in obj: 126 | obj['width'] = obj['w'] 127 | obj['height'] = obj['h'] 128 | del obj['w'], obj['h'] 129 | 130 | object_ = Object(**obj) 131 | 132 | object_.attributes = attrs 133 | object_map[oid] = object_ 134 | 135 | return object_map, object_ 136 | 137 | 138 | global count_skips 139 | count_skips = [0, 0] 140 | 141 | 142 | def ParseGraphLocal(data, image, verbose=False): 143 | """ 144 | Modified version of `utils.ParseGraph`. 145 | """ 146 | 147 | global count_skips 148 | objects = [] 149 | object_map = {} 150 | relationships = [] 151 | attributes = [] 152 | 153 | for obj in data['objects']: 154 | object_map, o_ = MapObject(object_map, obj) 155 | objects.append(o_) 156 | for rel in data['relationships']: 157 | if rel['subject_id'] in object_map and rel['object_id'] in object_map: 158 | object_map, s = MapObject(object_map, {'object_id': rel['subject_id']}) 159 | v = rel['predicate'] 160 | object_map, o = MapObject(object_map, {'object_id': rel['object_id']}) 161 | rid = rel['relationship_id'] 162 | relationships.append(Relationship(rid, s, v, o, rel['synsets'])) 163 | else: 164 | # Skip this relationship if we don't have the subject and object in 165 | # the object_map for this scene graph. Some data is missing in this way. 166 | count_skips[0] += 1 167 | if 'attributes' in data: 168 | for attr in data['attributes']: 169 | a = attr['attribute'] 170 | if a['object_id'] in object_map: 171 | attributes.append(Attribute(attr['attribute_id'], a['object_id'], a['names'], a['synsets'])) 172 | else: 173 | count_skips[1] += 1 174 | if verbose: 175 | print 'Skipped {} rels, {} attrs total'.format(*count_skips) 176 | return Graph(image, objects, relationships, attributes) 177 | 178 | 179 | def InitSynsets(scene_graph, synset_file): 180 | """ 181 | Convert synsets in a scene graph from strings to Synset objects. 182 | """ 183 | 184 | syn_data = json.load(open(synset_file, 'r')) 185 | syn_class = {s['synset_name']: Synset(s['synset_name'], s['synset_definition']) for s in syn_data} 186 | 187 | for obj in scene_graph.objects: 188 | obj.synsets = [syn_class[sn] for sn in obj.synsets] 189 | for rel in scene_graph.relationships: 190 | rel.synset = [syn_class[sn] for sn in rel.synset] 191 | for attr in scene_graph.attributes: 192 | obj.synset = [syn_class[sn] for sn in attr.synset] 193 | 194 | return scene_graph 195 | 196 | 197 | # -------------------------------------------------------------------------------------------------- 198 | # This is a pre-processing step that only needs to be executed once. 199 | # You can download .jsons segmented with these methods from: 200 | # https://drive.google.com/file/d/0Bygumy5BKFtcQ1JrcFpyQWdaQWM 201 | 202 | 203 | def SaveSceneGraphsById(dataDir='data/', imageDataDir='data/by-id/'): 204 | """ 205 | Save a separate .json file for each image id in `imageDataDir`. 206 | 207 | Notes 208 | ----- 209 | - If we don't save .json's by id, `scene_graphs.json` is >6G in RAM 210 | - Separated .json files are ~1.1G on disk 211 | - Run `AddAttrsToSceneGraphs` before `ParseGraphLocal` will work 212 | - Attributes are only present in objects, and do not have synset info 213 | 214 | Each output .json has the following keys: 215 | - "id" 216 | - "objects" 217 | - "relationships" 218 | """ 219 | 220 | if not os.path.exists(imageDataDir): 221 | os.mkdir(imageDataDir) 222 | 223 | all_data = json.load(open(os.path.join(dataDir, 'scene_graphs.json'))) 224 | for sg_data in all_data: 225 | img_fname = str(sg_data['image_id']) + '.json' 226 | with open(os.path.join(imageDataDir, img_fname), 'w') as f: 227 | json.dump(sg_data, f) 228 | 229 | del all_data 230 | gc.collect() # clear memory 231 | 232 | 233 | def AddAttrsToSceneGraphs(dataDir='data/'): 234 | """ 235 | Add attributes to `scene_graph.json`, extracted from `attributes.json`. 236 | 237 | This also adds a unique id to each attribute, and separates individual 238 | attibutes for each object (these are grouped in `attributes.json`). 239 | """ 240 | 241 | attr_data = json.load(open(os.path.join(dataDir, 'attributes.json'))) 242 | with open(os.path.join(dataDir, 'scene_graphs.json')) as f: 243 | sg_dict = {sg['image_id']: sg for sg in json.load(f)} 244 | 245 | id_count = 0 246 | for img_attrs in attr_data: 247 | attrs = [] 248 | for attribute in img_attrs['attributes']: 249 | a = img_attrs.copy() 250 | del a['attributes'] 251 | a['attribute'] = attribute 252 | a['attribute_id'] = id_count 253 | attrs.append(a) 254 | id_count += 1 255 | iid = img_attrs['image_id'] 256 | sg_dict[iid]['attributes'] = attrs 257 | 258 | with open(os.path.join(dataDir, 'scene_graphs.json'), 'w') as f: 259 | json.dump(sg_dict.values(), f) 260 | del attr_data, sg_dict 261 | gc.collect() 262 | 263 | 264 | # -------------------------------------------------------------------------------------------------- 265 | # For info on VRD dataset, see: 266 | # http://cs.stanford.edu/people/ranjaykrishna/vrd/ 267 | 268 | def GetSceneGraphsVRD(json_file='data/vrd/json/test.json'): 269 | """ 270 | Load VRD dataset into scene graph format. 271 | """ 272 | 273 | scene_graphs = [] 274 | with open(json_file, 'r') as f: 275 | D = json.load(f) 276 | 277 | scene_graphs = [ParseGraphVRD(d) for d in D] 278 | return scene_graphs 279 | 280 | 281 | def ParseGraphVRD(d): 282 | image = Image(d['photo_id'], d['filename'], d['width'], d['height'], '', '') 283 | 284 | id2obj = {} 285 | objs = [] 286 | rels = [] 287 | atrs = [] 288 | 289 | for i, o in enumerate(d['objects']): 290 | b = o['bbox'] 291 | obj = Object(i, b['x'], b['y'], b['w'], b['h'], o['names'], []) 292 | id2obj[i] = obj 293 | objs.append(obj) 294 | 295 | for j, a in enumerate(o['attributes']): 296 | atrs.append(Attribute(j, obj, a['attribute'], [])) 297 | 298 | for i, r in enumerate(d['relationships']): 299 | s = id2obj[r['objects'][0]] 300 | o = id2obj[r['objects'][1]] 301 | v = r['relationship'] 302 | rels.append(Relationship(i, s, v, o, [])) 303 | 304 | return Graph(image, objs, rels, atrs) 305 | 306 | 307 | if __name__ == '__main__': 308 | # AddAttrsToSceneGraphs() 309 | # SaveSceneGraphsById() 310 | image_data = GetAllImageData() 311 | region_interest = GetAllRegionDescriptions() 312 | qas = GetAllQAs() 313 | tt = GetSceneGraph(1) 314 | print("We got an image") 315 | -------------------------------------------------------------------------------- /Data/VisualGenome/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visual Genome Python API wrapper, models 3 | """ 4 | 5 | 6 | class Image: 7 | """ 8 | Image. 9 | ID int 10 | url hyperlink string 11 | width int 12 | height int 13 | """ 14 | 15 | def __init__(self, id, url, width, height, coco_id, flickr_id): 16 | """ 17 | 18 | :param id: ID int 19 | :param url: url hyperlink string 20 | :param width: width int 21 | :param height: height int 22 | :param coco_id: 23 | :param flickr_id: 24 | """ 25 | self.id = id 26 | self.url = url 27 | self.width = width 28 | self.height = height 29 | self.coco_id = coco_id 30 | self.flickr_id = flickr_id 31 | 32 | def __str__(self): 33 | return 'id: %d, coco_id: %d, flickr_id: %d, width: %d, url: %s' \ 34 | % ( 35 | self.id, -1 if self.coco_id is None else self.coco_id, 36 | -1 if self.flickr_id is None else self.flickr_id, 37 | self.width, self.url) 38 | 39 | def __repr__(self): 40 | return str(self) 41 | 42 | 43 | class Region: 44 | """ 45 | Region. 46 | image int 47 | phrase string 48 | x int 49 | y int 50 | width int 51 | height int 52 | """ 53 | 54 | def __init__(self, id, image, phrase, x, y, width, height): 55 | """ 56 | 57 | :param id: 58 | :param image: image int 59 | :param phrase: phrase string 60 | :param x: x int 61 | :param y: y int 62 | :param width: width int 63 | :param height: height int 64 | """ 65 | self.id = id 66 | self.image = image 67 | self.phrase = phrase 68 | self.x = x 69 | self.y = y 70 | self.width = width 71 | self.height = height 72 | 73 | def __str__(self): 74 | return 'id: %d, x: %d, y: %d, width: %d, height: %d, phrase: %s, image: %d' % \ 75 | (self.id, self.x, self.y, self.width, self.height, self.phrase, self.image.id) 76 | 77 | def __repr__(self): 78 | return str(self) 79 | 80 | 81 | class Graph: 82 | """ 83 | Graphs contain objects, relationships and attributes 84 | image Image 85 | bboxes Object array 86 | relationships Relationship array 87 | attributes Attribute array 88 | """ 89 | 90 | def __init__(self, image, objects, relationships, attributes): 91 | self.image = image 92 | self.objects = objects 93 | self.relationships = relationships 94 | self.attributes = attributes 95 | 96 | 97 | class Object: 98 | """ 99 | Objects. 100 | id int 101 | x int 102 | y int 103 | width int 104 | height int 105 | names string array 106 | synsets Synset array 107 | """ 108 | 109 | def __init__(self, id, x, y, width, height, names, synsets): 110 | self.id = id 111 | self.x = x 112 | self.y = y 113 | self.width = width 114 | self.height = height 115 | self.names = names 116 | self.synsets = synsets 117 | 118 | def __str__(self): 119 | name = self.names[0] if len(self.names) != 0 else 'None' 120 | return '%s' % (name) 121 | 122 | def __repr__(self): 123 | return str(self) 124 | 125 | 126 | class Relationship: 127 | """ 128 | Relationships. Ex, 'man - jumping over - fire hydrant'. 129 | subject int 130 | predicate string 131 | object int 132 | rel_canon Synset 133 | """ 134 | 135 | def __init__(self, id, subject, predicate, object, synset): 136 | self.id = id 137 | self.subject = subject 138 | self.predicate = predicate 139 | self.object = object 140 | self.synset = synset 141 | 142 | def __str__(self): 143 | return "%d: %s %s %s" % (self.id, self.subject, self.predicate, self.object) 144 | 145 | def __repr__(self): 146 | return str(self) 147 | 148 | 149 | class Attribute: 150 | """ 151 | Attributes. Ex, 'man - old'. 152 | subject Object 153 | attribute string 154 | synset Synset 155 | """ 156 | def __init__(self, id, subject, attribute, synset): 157 | self.id = id 158 | self.subject = subject 159 | self.attribute = attribute 160 | self.synset = synset 161 | 162 | def __str__(self): 163 | return "%d: %s is %s" % (self.id, self.subject, self.attribute) 164 | 165 | def __repr__(self): 166 | return str(self) 167 | 168 | 169 | class QA: 170 | """ 171 | Question Answer Pairs. 172 | ID int 173 | image int 174 | question string 175 | answer string 176 | q_objects QAObject array 177 | a_objects QAObject array 178 | """ 179 | 180 | def __init__(self, id, image, question, answer, question_objects, answer_objects): 181 | self.id = id 182 | self.image = image 183 | self.question = question 184 | self.answer = answer 185 | self.q_objects = question_objects 186 | self.a_objects = answer_objects 187 | 188 | def __str__(self): 189 | return 'id: %d, image: %d, question: %s, answer: %s' \ 190 | % (self.id, self.image.id, self.question, self.answer) 191 | 192 | def __repr__(self): 193 | return str(self) 194 | 195 | 196 | class QAObject: 197 | """ 198 | Question Answer Objects are localized in the image and refer to a part 199 | of the question text or the answer text. 200 | start_idx int 201 | end_idx int 202 | name string 203 | synset_name string 204 | synset_definition string 205 | """ 206 | 207 | def __init__(self, start_idx, end_idx, name, synset): 208 | self.start_idx = start_idx 209 | self.end_idx = end_idx 210 | self.name = name 211 | self.synset = synset 212 | 213 | def __repr__(self): 214 | return str(self) 215 | 216 | 217 | class Synset: 218 | """ 219 | Wordnet Synsets. 220 | name string 221 | definition string 222 | """ 223 | 224 | def __init__(self, name, definition): 225 | self.name = name 226 | self.definition = definition 227 | 228 | def __str__(self): 229 | return '{} - {}'.format(self.name, self.definition) 230 | 231 | def __repr__(self): 232 | return str(self) 233 | 234 | 235 | class ObjectMapping(object): 236 | """ 237 | This class inherits Object and adds a url field for mapping between an object and the img url 238 | Objects Mapping. 239 | id int 240 | x int 241 | y int 242 | width int 243 | height int 244 | names string array 245 | synsets Synset array 246 | url hyperlink string 247 | """ 248 | 249 | def __init__(self, id, x, y, width, height, names, synsets, url): 250 | self.url = url 251 | self.id = id 252 | self.x = x 253 | self.y = y 254 | self.width = width 255 | self.height = height 256 | self.names = names 257 | self.synsets = synsets 258 | 259 | def __str__(self): 260 | name = self.names[0] if len(self.names) != 0 else 'None' 261 | return '%s' % (name) 262 | 263 | def __repr__(self): 264 | return str(self) 265 | 266 | 267 | class RelationshipMapping(object): 268 | """ 269 | Relationships. Ex, 'man - jumping over - fire hydrant'. 270 | subject int 271 | predicate string 272 | object int 273 | rel_canon Synset 274 | """ 275 | 276 | def __init__(self, id, subject, predicate, object, synset, url, filtered_id): 277 | self.id = id 278 | self.subject = subject 279 | self.predicate = predicate 280 | self.object = object 281 | self.synset = synset 282 | self.url = url 283 | self.filtered_id = filtered_id 284 | 285 | def __str__(self): 286 | return "%d: %s %s %s" % (self.id, self.subject, self.predicate, self.object) 287 | 288 | def __repr__(self): 289 | return str(self) 290 | 291 | 292 | class ImageMapping(object): 293 | """ 294 | Image. 295 | ID int 296 | url hyperlink string 297 | width int 298 | height int 299 | """ 300 | 301 | def __init__(self, id, url, width, height, coco_id, flickr_id, width_org, height_org, img_ind): 302 | """ 303 | 304 | :param id: ID int 305 | :param url: url hyperlink string 306 | :param width: width from images_1024 int 307 | :param height: height from images_1024 int 308 | :param coco_id: 309 | :param flickr_id: 310 | """ 311 | 312 | self.img_ind = img_ind 313 | self.height_org = height_org 314 | self.width_org = width_org 315 | self.id = id 316 | self.url = url 317 | self.width = width 318 | self.height = height 319 | self.coco_id = coco_id 320 | self.flickr_id = flickr_id 321 | 322 | def __str__(self): 323 | return 'id: %d, coco_id: %d, flickr_id: %d, width: %d, url: %s' \ 324 | % ( 325 | self.id, -1 if self.coco_id is None else self.coco_id, 326 | -1 if self.flickr_id is None else self.flickr_id, 327 | self.width, self.url) 328 | 329 | def __repr__(self): 330 | return str(self) 331 | 332 | -------------------------------------------------------------------------------- /Data/VisualGenome/utils.py: -------------------------------------------------------------------------------- 1 | from models import Image, Object, Attribute, Relationship 2 | from models import Region, Graph, QA, QAObject, Synset 3 | import httplib 4 | import json 5 | 6 | 7 | def GetDataDir(): 8 | """ 9 | Get the local directory where the Visual Genome data is locally stored. 10 | """ 11 | from os.path import dirname, realpath, join 12 | dataDir = join(dirname(realpath('__file__')), 'data') 13 | return dataDir 14 | 15 | 16 | def RetrieveData(request): 17 | """ 18 | Helper Method used to get all data from request string. 19 | """ 20 | connection = httplib.HTTPSConnection("visualgenome.org", '443') 21 | connection.request("GET", request) 22 | response = connection.getresponse() 23 | jsonString = response.read() 24 | data = json.loads(jsonString) 25 | return data 26 | 27 | 28 | def ParseSynset(canon): 29 | """ 30 | Helper to Extract Synset from canon object. 31 | """ 32 | if len(canon) == 0: 33 | return None 34 | return Synset(canon[0]['synset_name'], canon[0]['synset_definition']) 35 | 36 | 37 | def ParseGraph(data, image): 38 | """ 39 | Helper to parse a Graph object from API data. 40 | """ 41 | 42 | objects = [] 43 | object_map = {} 44 | relationships = [] 45 | attributes = [] 46 | # Create the Objects 47 | for obj in data['bounding_boxes']: 48 | names = [] 49 | synsets = [] 50 | for s in obj['boxed_objects']: 51 | names.append(s['name']) 52 | synsets.append(ParseSynset(s['object_canon'])) 53 | object_ = Object(obj['id'], obj['x'], obj['y'], obj['width'], obj['height'], names, synsets) 54 | object_map[obj['id']] = object_ 55 | objects.append(object_) 56 | # Create the Relationships 57 | for rel in data['relationships']: 58 | relationships.append(Relationship(rel['id'], object_map[rel['subject']], \ 59 | rel['predicate'], object_map[rel['object']], 60 | ParseSynset(rel['relationship_canon']))) 61 | # Create the Attributes 62 | for atr in data['attributes']: 63 | attributes.append(Attribute(atr['id'], object_map[atr['subject']], \ 64 | atr['attribute'], ParseSynset(atr['attribute_canon']))) 65 | return Graph(image, objects, relationships, attributes) 66 | 67 | 68 | def ParseImageData(data): 69 | """ 70 | Helper to parse the image data for one image. 71 | """ 72 | 73 | img_id = data['id'] if 'id' in data else data['image_id'] 74 | url = data['url'] 75 | width = data['width'] 76 | height = data['height'] 77 | coco_id = data['coco_id'] 78 | flickr_id = data['flickr_id'] 79 | image = Image(img_id, url, width, height, coco_id, flickr_id) 80 | return image 81 | 82 | 83 | def ParseRegionDescriptions(data, image): 84 | """ 85 | Helper to parse region descriptions. 86 | """ 87 | 88 | regions = [] 89 | if data[0].has_key('region_id'): 90 | region_id_key = 'region_id' 91 | else: 92 | region_id_key = 'id' 93 | for d in data: 94 | regions.append(Region(d[region_id_key], image, d['phrase'], d['x'], d['y'], d['width'], d['height'])) 95 | return regions 96 | 97 | 98 | def ParseQA(data, image_map): 99 | """ 100 | Helper to parse a list of question answers. 101 | """ 102 | 103 | qas = [] 104 | for d in data: 105 | qos = [] 106 | aos = [] 107 | if 'question_objects' in d: 108 | for qo in d['question_objects']: 109 | synset = Synset(qo['synset_name'], qo['synset_definition']) 110 | qos.append(QAObject(qo['entity_idx_start'], qo['entity_idx_end'], qo['entity_name'], synset)) 111 | if 'answer_objects' in d: 112 | for ao in d['answer_objects']: 113 | synset = Synset(o['synset_name'], ao['synset_definition']) 114 | aos.append(QAObject(ao['entity_idx_start'], ao['entity_idx_end'], ao['entity_name'], synset)) 115 | qas.append(QA(d['qa_id'], image_map[d['image_id']], d['question'], d['answer'], qos, aos)) 116 | return qas 117 | -------------------------------------------------------------------------------- /Data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shikorab/SceneGraph/d2dc960a0d3728e7d2b0eaceb9d0fdaedb2a47d5/Data/__init__.py -------------------------------------------------------------------------------- /FilesManager/FilesManager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import sys 4 | 5 | sys.path.append("..") 6 | from Utils.Singleton import Singleton 7 | import cPickle 8 | import yaml 9 | import copy 10 | import json 11 | import h5py 12 | 13 | FILE_MANAGER_PATH = os.path.abspath(os.path.dirname(__file__)) 14 | FILE_MANAGER_FILENAME = "files.yaml" 15 | 16 | 17 | class FilesManager(object): 18 | """ 19 | Files Manager used to load and save any kind of files 20 | """ 21 | __metaclass__ = Singleton 22 | 23 | def __init__(self, overrides_filename=None): 24 | """ 25 | Constructor for FilesManager 26 | :param overrides_filename: "*.yaml file used to override paths to files 27 | """ 28 | # save input data 29 | self.overrides_filename = overrides_filename 30 | # prints 31 | self.log = lambda x: self.log_str(x) 32 | 33 | # load file paths 34 | stream = file(os.path.join(FILE_MANAGER_PATH, FILE_MANAGER_FILENAME), 'r') 35 | self.files = yaml.load(stream) 36 | 37 | # override file paths 38 | if overrides_filename != None: 39 | print("FilesManager: overrides files according to " + str(overrides_filename)) 40 | # read yaml file 41 | stream = file(os.path.join(FILE_MANAGER_PATH, overrides_filename), 'r') 42 | overrides = yaml.load(stream) 43 | # override 44 | self.override(self.files, overrides) 45 | 46 | def load_file(self, tokens, version=None): 47 | """ 48 | load file given file tokens 49 | :param tokens: tokens delimited with '.' (each toked is a level in files.yaml) 50 | :param version: useful if old version is required (according to numbers in files.yaml) 51 | :return: data read from file according to file type 52 | """ 53 | # get file path 54 | fileinfo = self.get_file_info(tokens, version) 55 | filetype = fileinfo["type"] 56 | filename = os.path.join(FILE_MANAGER_PATH, fileinfo["name"]) 57 | 58 | # load data per file type 59 | if filetype == "pickle": 60 | self.log("FilesManager: Load pickle file: %s=%s" % (tokens, filename)) 61 | picklefile = open(filename, "rb") 62 | 63 | # get number of objects stored in the pickle file 64 | nof_objects = 1 65 | if "nof_objects" in fileinfo: 66 | nof_objects = fileinfo["nof_objects"] 67 | 68 | # load data 69 | if nof_objects == 1: 70 | data = cPickle.load(picklefile) 71 | else: 72 | data = [] 73 | for i in range(nof_objects): 74 | data.append(cPickle.load(picklefile)) 75 | 76 | picklefile.close() 77 | return data 78 | elif filetype == "yaml": 79 | self.log("FilesManager: Load yaml file: %s=%s" % (tokens, filename)) 80 | stream = open(filename, 'r') 81 | data = yaml.load(stream) 82 | return data 83 | elif filetype == "json": 84 | self.log("FilesManager: Load json file: %s=%s" % (tokens, filename)) 85 | f = open(filename, 'r') 86 | data = json.load(f) 87 | return data 88 | elif filetype == "h5py": 89 | self.log("FilesManager: Load h5py file: %s=%s" % (tokens, filename)) 90 | data = h5py.File(filename, 'r') 91 | return data 92 | elif filetype == "text": 93 | self.log("FilesManager: Load text file: %s=%s" % (tokens, filename)) 94 | with open(filename) as f: 95 | lines = f.readlines() 96 | return lines 97 | 98 | def save_file(self, tokens, data, version=None): 99 | """ 100 | save file given tokens in pickle format 101 | :param tokens: tokens delimited with '.' (each toked is a level in files.yaml) 102 | :param version: useful if old version is required (according to numbers in files.yaml) 103 | :param data: data to save 104 | :return: void 105 | """ 106 | # get file path 107 | fileinfo = self.get_file_info(tokens, version) 108 | filename = os.path.join(FILE_MANAGER_PATH, fileinfo["name"]) 109 | 110 | self.log("FilesManager: Save pickle file: " + filename) 111 | 112 | # load data 113 | picklefile = open(filename, "wb") 114 | # get number of objects stored in the pickle file 115 | nof_objects = 1 116 | if "nof_objects" in fileinfo: 117 | nof_objects = fileinfo["nof_objects"] 118 | if nof_objects == 1: 119 | cPickle.dump(data, picklefile) 120 | else: 121 | for elem in data: 122 | cPickle.dump(elem, picklefile) 123 | 124 | picklefile.close() 125 | 126 | def file_exist(self, tokens, version=None): 127 | """ 128 | check if file exists given tokens 129 | :param tokens: tokens delimited with '.' (each toked is a level in files.yaml) 130 | :param version: useful if old version is required (according to numbers in files.yaml) 131 | :return: True if file exist 132 | """ 133 | # get file path 134 | fileinfo = self.get_file_info(tokens, version) 135 | filename = os.path.join(FILE_MANAGER_PATH, fileinfo["name"]) 136 | 137 | return os.path.exists(filename) 138 | 139 | def get_file_info(self, tokens, version=None): 140 | 141 | """ 142 | get file name given file tokens 143 | :param tokens: tokens delimited with '.' (each toked is a level in files.yaml) 144 | :param version: useful if old version is required (according to numbers in files.yaml) 145 | :return: dictionary with file info 146 | """ 147 | # get list of tokens 148 | tokens_lst = tokens.split(".") 149 | 150 | # get filename 151 | fileinfo = self.files 152 | for token in tokens_lst: 153 | if fileinfo.has_key(token): 154 | fileinfo = fileinfo[token] 155 | else: 156 | raise Exception("unknown name token {0} for name {1}".format(token, tokens)) 157 | 158 | # make sure fileinfo was extracted 159 | if not "name" in fileinfo: 160 | raise Exception("uncomplete file tokens", tokens) 161 | 162 | # handle versions 163 | if version is not None: 164 | if "versions" in fileinfo: 165 | versions = fileinfo["versions"] 166 | if version in versions: 167 | # deep copy to be able to override info 168 | fileinfo = copy.deepcopy(fileinfo) 169 | fileinfo["name"] = versions[version]["name"] 170 | self.log("FilesManager: %s - Use Old Version %d" % (tokens, version)) 171 | if "doc" in versions[version]: 172 | self.log("FilesManager: Version %d doc - %s" % (version, versions[version]["doc"])) 173 | else: 174 | raise Exception("FilesManager: %s version num %d wasn't found" % (tokens, version)) 175 | else: 176 | raise Exception("FilesManager: %s versions token wasn't found" % (tokens)) 177 | 178 | return fileinfo 179 | 180 | 181 | def get_file_path(self, tokens, version=None): 182 | """ 183 | return full path to file 184 | :param tokens: tokens delimited with '.' (each toked is a level in files.yaml) 185 | :param version: useful if old version is required (according to numbers in files.yaml) 186 | :return: path to file 187 | """ 188 | # get file path 189 | fileinfo = self.get_file_info(tokens, version) 190 | filename = os.path.join(FILE_MANAGER_PATH, fileinfo["name"]) 191 | 192 | self.log("FilesManager: file path : %s=%s" % (tokens, filename)) 193 | return filename 194 | 195 | def override(self, files_db, overrides): 196 | """ 197 | Overrides files data base according to overrides data base 198 | :param files_db: 199 | :param overrides: 200 | :return: void 201 | """ 202 | for elem in overrides: 203 | if elem in files_db: 204 | if type(overrides[elem]) == str: 205 | files_db[elem] = overrides[elem] 206 | 207 | else: 208 | self.override(files_db[elem], overrides[elem]) 209 | else: 210 | files_db[elem] = overrides[elem] 211 | 212 | def add_logger(self, logger): 213 | """ 214 | Log all prints to file 215 | :return: 216 | """ 217 | self.log = lambda x: logger.log(x) 218 | 219 | def log_str(self, x): 220 | print(x) 221 | -------------------------------------------------------------------------------- /FilesManager/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shikorab/SceneGraph/d2dc960a0d3728e7d2b0eaceb9d0fdaedb2a47d5/FilesManager/__init__.py -------------------------------------------------------------------------------- /FilesManager/files.yaml: -------------------------------------------------------------------------------- 1 | # the file holds the paths to default files that will be loaded or saved 2 | logs: 3 | name: "logs" 4 | type: "dir" 5 | doc: "default folder for logs" 6 | 7 | sg_module: 8 | train: 9 | tf_logs: 10 | name: "module/tf_logs" 11 | type: "dir" 12 | doc: "holds tensor flow log" 13 | saver: 14 | name: "module/saver" 15 | type: "dir" 16 | doc: "stores module weights" 17 | params: 18 | name: "module/params.yaml" 19 | type: "yaml" 20 | doc: "train parameters" 21 | 22 | language_module: 23 | word2vec: 24 | object_embeddings: 25 | name: "LanguageModule/object_embeddings.p" 26 | type: "pickle" 27 | doc: "Testing Language Model Word2vec trained" 28 | predicate_embeddings: 29 | name: "LanguageModule/predicate_embeddings.p" 30 | type: "pickle" 31 | doc: "Testing Language Model Word2vec trained" 32 | 33 | data: 34 | visual_genome: 35 | train: 36 | name: "../Module/data/train" 37 | type: "dir" 38 | doc: "dir with entities files. each file inclues 1000 entities" 39 | train_baseline: 40 | name: "../Module/data/train/baseline" 41 | type: "dir" 42 | doc: "dir with baseline confidence per each relation and entity" 43 | test: 44 | name: "../Module/data/test" 45 | type: "dir" 46 | doc: "dir with entities files. each file inclues 1000 entities" 47 | test_baseline: 48 | name: "../Module/data/test/baseline" 49 | type: "dir" 50 | doc: "dir with baseline confidence per each relation and entity" 51 | data: 52 | name: "../Module" 53 | type: "dir" 54 | doc: "dir with zip file and data folder" 55 | -------------------------------------------------------------------------------- /FilesManager/module/params.yaml: -------------------------------------------------------------------------------- 1 | name: gpi_ling_new 2 | nof_iterations: 10000 3 | learning_rate: 0.0001 4 | learning_rate_steps: 1000 5 | learning_rate_decay: 0.5 6 | load_model_name: "gpi_ling_new_best_module.ckpt" 7 | use_saved_model: False 8 | batch_size: 20 9 | predicate_pos_neg_ratio: 10 10 | lr_object_coeff: 4 11 | layers: [500, 500, 500] 12 | gpu: 0 13 | 14 | -------------------------------------------------------------------------------- /FilesManager/module/saver/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shikorab/SceneGraph/d2dc960a0d3728e7d2b0eaceb9d0fdaedb2a47d5/FilesManager/module/saver/__init__.py -------------------------------------------------------------------------------- /FilesManager/module/tf_logs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shikorab/SceneGraph/d2dc960a0d3728e7d2b0eaceb9d0fdaedb2a47d5/FilesManager/module/tf_logs/__init__.py -------------------------------------------------------------------------------- /Module/Eval.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import os 3 | import sys 4 | 5 | sys.path.append("..") 6 | 7 | from FilesManager.FilesManager import FilesManager 8 | from Module import Module 9 | from Utils.Logger import Logger 10 | import tensorflow as tf 11 | import numpy as np 12 | from Train import test, NOF_OBJECTS, NOF_PREDICATES 13 | import cPickle 14 | import matplotlib.pyplot 15 | matplotlib.pyplot.switch_backend('agg') 16 | import time 17 | 18 | def iou(box_a, box_b): 19 | union_area = (max(box_a[2], box_b[2]) - min(box_a[0], box_b[0]) + 1) * (max(box_a[3], box_b[3]) - min(box_a[1], box_b[1]) + 1) 20 | overlap_w = min(box_a[2], box_b[2]) - max(box_a[0], box_b[0]) + 1 21 | if overlap_w <= 0: 22 | return 0 23 | overlap_h = min(box_a[3], box_b[3]) - max(box_a[1], box_b[1]) + 1 24 | if overlap_h <= 0: 25 | return 0 26 | return float(overlap_w * overlap_h) / union_area 27 | 28 | def eval_image(entity, labels_relation, labels_entity, out_confidence_relation_val, out_confidence_entity_val, k=100): 29 | """ 30 | Scene Graph Classification - 31 | R@k metric (measures the fraction of ground truth relationships 32 | triplets that appear among the k most confident triplet prediction in an image) 33 | :param labels_relation: labels of image predicates (each one is one hot vector) - shape (N, N, NOF_PREDICATES) 34 | :param labels_entity: labels of image objects (each one is one hot vector) - shape (N, NOF_OBJECTS) 35 | :param out_confidence_relation_val: confidence of image predicates - shape (N, N, NOF_PREDICATES) 36 | :param out_confidence_entity_val: confidence of image objects - shape (N, NOF_OBJECTS) 37 | :param k: k most confident predictions to consider 38 | :return: image score, number of the gt triplets that appear in the k most confident predictions, 39 | number of the gt triplets 40 | """ 41 | # iterate over each relation to predict and find k highest predictions 42 | top_predictions = np.zeros((0, 12)) 43 | 44 | # results per relation 45 | per_relation_correct = np.zeros(NOF_PREDICATES) 46 | per_relation_total = np.zeros(NOF_PREDICATES) 47 | 48 | N = labels_entity.shape[0] 49 | if N == 1: 50 | return 0, 0, 0, per_relation_correct, per_relation_total 51 | 52 | relation_pred = np.argmax(out_confidence_relation_val[:, :, :NOF_PREDICATES - 1], axis=2) 53 | relation_scores = np.max(out_confidence_relation_val[:, :, :NOF_PREDICATES - 1], axis=2) 54 | entity_pred = np.argmax(out_confidence_entity_val, axis=1) 55 | entity_scores = np.max(out_confidence_entity_val, axis=1) 56 | 57 | # get list of the top k most confident triplets predictions 58 | for subject_index in range(N): 59 | for object_index in range(N): 60 | # filter if subject equals to object 61 | if subject_index == object_index: 62 | continue 63 | 64 | # create entry with the scores 65 | triplet_prediction = np.zeros((1, 12)) 66 | triplet_prediction[0][0] = entity.objects[subject_index].x 67 | triplet_prediction[0][1] = entity.objects[subject_index].y 68 | triplet_prediction[0][2] = entity.objects[subject_index].x + entity.objects[subject_index].width 69 | triplet_prediction[0][3] = entity.objects[subject_index].y + entity.objects[subject_index].height 70 | triplet_prediction[0][4] = entity.objects[object_index].x 71 | triplet_prediction[0][5] = entity.objects[object_index].y 72 | triplet_prediction[0][6] = entity.objects[object_index].x + entity.objects[object_index].width 73 | triplet_prediction[0][7] = entity.objects[object_index].y + entity.objects[object_index].height 74 | 75 | triplet_prediction[0][8] = entity_pred[subject_index] 76 | triplet_prediction[0][9] = relation_pred[subject_index][object_index] 77 | triplet_prediction[0][10] = entity_pred[object_index] 78 | triplet_prediction[0][11] = relation_scores[subject_index][object_index] * entity_scores[subject_index] * \ 79 | entity_scores[object_index] 80 | 81 | # append to the list of highest predictions 82 | top_predictions = np.concatenate((top_predictions, triplet_prediction)) 83 | 84 | # get k highest confidence 85 | top_k_indices = np.argsort(top_predictions[:, 11])[-k:] 86 | sub_boxes = top_predictions[top_k_indices, :4] 87 | obj_boxes = top_predictions[top_k_indices, 4:8] 88 | sub_pred = top_predictions[top_k_indices, 8] 89 | relation_pred = top_predictions[top_k_indices, 9] 90 | obj_pred = top_predictions[top_k_indices, 10] 91 | 92 | relations_gt = np.argmax(labels_relation, axis=2) 93 | entities_gt = np.argmax(labels_entity, axis=1) 94 | 95 | img_score = 0 96 | nof_pos_relationship = 0 97 | for subject_index in range(N): 98 | for object_index in range(N): 99 | # filter if subject equals to object 100 | if subject_index == object_index: 101 | continue 102 | # filter negative relationship 103 | if relations_gt[subject_index, object_index] == NOF_PREDICATES - 1: 104 | continue 105 | 106 | gt_sub_box = np.zeros((4)) 107 | gt_sub_box[0] = entity.objects[subject_index].x 108 | gt_sub_box[1] = entity.objects[subject_index].y 109 | gt_sub_box[2] = entity.objects[subject_index].x + entity.objects[subject_index].width 110 | gt_sub_box[3] = entity.objects[subject_index].y + entity.objects[subject_index].height 111 | 112 | gt_obj_box = np.zeros((4)) 113 | gt_obj_box[0] = entity.objects[object_index].x 114 | gt_obj_box[1] = entity.objects[object_index].y 115 | gt_obj_box[2] = entity.objects[object_index].x + entity.objects[object_index].width 116 | gt_obj_box[3] = entity.objects[object_index].y + entity.objects[object_index].height 117 | 118 | predicate_id = relations_gt[subject_index][object_index] 119 | sub_id = entities_gt[subject_index] 120 | obj_id = entities_gt[object_index] 121 | 122 | nof_pos_relationship += 1 123 | per_relation_total[predicate_id] += 1 124 | 125 | # filter according to iou 126 | found = False 127 | for top_k_i in range(k): 128 | if sub_id != sub_pred[top_k_i] or obj_id != obj_pred[top_k_i] or predicate_id !=relation_pred[top_k_i]: 129 | continue 130 | iou_sub_val = iou(gt_sub_box, sub_boxes[top_k_i]) 131 | if iou_sub_val < 0.5: 132 | continue 133 | iou_obj_val = iou(gt_obj_box, obj_boxes[top_k_i]) 134 | if iou_obj_val < 0.5: 135 | continue 136 | 137 | found = True 138 | break 139 | 140 | if found: 141 | img_score += 1 142 | per_relation_correct[predicate_id] += 1 143 | else: 144 | img_score = img_score 145 | 146 | if nof_pos_relationship != 0: 147 | img_score_percent = float(img_score) / float(nof_pos_relationship) 148 | else: 149 | img_score_percent = 0 150 | 151 | return img_score_percent, img_score, nof_pos_relationship, per_relation_correct, per_relation_total 152 | 153 | 154 | def eval(load_module_name=None, k=100, layers=[500, 500, 500], gpu=1): 155 | """ 156 | Evaluate module: 157 | - Scene Graph Classification - R@k metric (measures the fraction of ground truth relationships 158 | triplets that appear among the k most confident triplet prediction in an image) 159 | :param load_module_name: name of the module to load 160 | :param k: see description 161 | :param layers: hidden layers of relation and entity classifier 162 | :param gpu: gpu number to use 163 | :return: nothing - output to logger instead 164 | """ 165 | gpi_type = "Linguistic" 166 | k_recall = True 167 | filesmanager = FilesManager() 168 | # create logger 169 | logger = Logger() 170 | 171 | # print eval params 172 | frame = inspect.currentframe() 173 | args, _, _, values = inspect.getargvalues(frame) 174 | logger.log('function name "%s"' % inspect.getframeinfo(frame)[2]) 175 | for i in args: 176 | logger.log(" %s = %s" % (i, values[i])) 177 | 178 | # set gpu 179 | if gpu != None: 180 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu) 181 | logger.log("os.environ[\"CUDA_VISIBLE_DEVICES\"] = " + str(gpu)) 182 | 183 | # create module 184 | module = Module(gpi_type=gpi_type, nof_predicates=NOF_PREDICATES, nof_objects=NOF_OBJECTS, 185 | is_train=False, layers=layers, including_object=True) 186 | 187 | # get input place holders 188 | confidence_relation_ph, confidence_entity_ph, bb_ph, word_embed_relations_ph, word_embed_entities_ph = module.get_in_ph() 189 | # get module output 190 | out_relation_probes, out_entity_probes = module.get_output() 191 | 192 | # Initialize the Computational Graph 193 | init = tf.global_variables_initializer() 194 | # Add ops to save and restore all the variables. 195 | saver = tf.train.Saver() 196 | 197 | # read data 198 | entities_path = filesmanager.get_file_path("data.visual_genome.test") 199 | 200 | test_files_list = range(35) 201 | 202 | # embeddings 203 | embed_obj = FilesManager().load_file("language_module.word2vec.object_embeddings") 204 | embed_pred = FilesManager().load_file("language_module.word2vec.predicate_embeddings") 205 | embed_pred = np.concatenate((embed_pred, np.zeros(embed_pred[:1].shape)), axis=0) # concat negative represntation 206 | accum_results = None 207 | with tf.Session() as sess: 208 | if load_module_name is not None: 209 | # Restore variables from disk. 210 | if load_module_name=="gpi_linguistic_pretrained": 211 | module_path = os.path.join(filesmanager.get_file_path("data.visual_genome.data"), "data") 212 | else: 213 | module_path = filesmanager.get_file_path("sg_module.train.saver") 214 | module_path_load = os.path.join(module_path, load_module_name + "_module.ckpt") 215 | if os.path.exists(module_path_load + ".index"): 216 | saver.restore(sess, module_path_load) 217 | logger.log("Model restored.") 218 | else: 219 | raise Exception("Module not found") 220 | else: 221 | sess.run(init) 222 | # eval module 223 | 224 | nof = 0 225 | total = 0 226 | correct_all = 0 227 | total_all = 0 228 | 229 | # create one hot vector for null relation 230 | relation_neg = np.zeros(NOF_PREDICATES) 231 | relation_neg[NOF_PREDICATES - 1] = 1 232 | 233 | index = 0 234 | basline_path = filesmanager.get_file_path("data.visual_genome.test_baseline") 235 | for file_name in test_files_list: 236 | file_path = os.path.join(entities_path, str(file_name) + ".p") 237 | file_handle = open(file_path, "rb") 238 | test_entities = cPickle.load(file_handle) 239 | file_handle.close() 240 | 241 | for entity in test_entities: 242 | file_path = os.path.join(basline_path, str(entity.image.id) + ".p") 243 | if not os.path.exists(file_path): 244 | continue 245 | file_handle = open(file_path, "rb") 246 | detector_data = cPickle.load(file_handle) 247 | file_handle.close() 248 | 249 | entity.predicates_outputs_with_no_activation = detector_data["rel_dist_mapped"] 250 | entity.objects_outputs_with_no_activations = detector_data["obj_dist_mapped"] 251 | # set diagonal to be negative relation 252 | N = entity.predicates_outputs_with_no_activation.shape[0] 253 | indices = np.arange(N) 254 | entity.predicates_outputs_with_no_activation[indices, indices, :] = relation_neg 255 | entity.predicates_labels[indices, indices, :] = relation_neg 256 | 257 | # create bounding box info per object 258 | obj_bb = np.zeros((len(entity.objects), 14)) 259 | for obj_id in range(len(entity.objects)): 260 | obj_bb[obj_id][0] = entity.objects[obj_id].x / 1200.0 261 | obj_bb[obj_id][1] = entity.objects[obj_id].y / 1200.0 262 | obj_bb[obj_id][2] = (entity.objects[obj_id].x + entity.objects[obj_id].width) / 1200.0 263 | obj_bb[obj_id][3] = (entity.objects[obj_id].y + entity.objects[obj_id].height) / 1200.0 264 | obj_bb[obj_id][4] = entity.objects[obj_id].x 265 | obj_bb[obj_id][5] = -1 * entity.objects[obj_id].x 266 | obj_bb[obj_id][6] = entity.objects[obj_id].y 267 | obj_bb[obj_id][7] = -1 * entity.objects[obj_id].y 268 | obj_bb[obj_id][8] = entity.objects[obj_id].width * entity.objects[obj_id].height 269 | obj_bb[obj_id][9] = -1 * entity.objects[obj_id].width * entity.objects[obj_id].height 270 | obj_bb[:, 4] = np.argsort(obj_bb[:, 4]) 271 | obj_bb[:, 5] = np.argsort(obj_bb[:, 5]) 272 | obj_bb[:, 6] = np.argsort(obj_bb[:, 6]) 273 | obj_bb[:, 7] = np.argsort(obj_bb[:, 7]) 274 | obj_bb[:, 8] = np.argsort(obj_bb[:, 8]) 275 | obj_bb[:, 9] = np.argsort(obj_bb[:, 9]) 276 | obj_bb[:, 10] = np.argsort(np.max(entity.objects_outputs_with_no_activations, axis=1)) 277 | obj_bb[:, 11] = np.argsort(-1 * np.max(entity.objects_outputs_with_no_activations, axis=1)) 278 | obj_bb[:, 12] = np.arange(obj_bb.shape[0]) 279 | obj_bb[:, 13] = np.arange(obj_bb.shape[0], 0, -1) 280 | 281 | # filter images with no positive relations 282 | relations_neg_labels = entity.predicates_labels[:, :, NOF_PREDICATES - 1:] 283 | if np.sum(entity.predicates_labels[:, :, :NOF_PREDICATES - 1]) == 0: 284 | continue 285 | 286 | # use object class labels for pred class (multiply be some factor to convert to confidence) 287 | in_entity_confidence = entity.objects_outputs_with_no_activations 288 | 289 | # create the feed dictionary 290 | feed_dict = {confidence_relation_ph: entity.predicates_outputs_with_no_activation, 291 | confidence_entity_ph: in_entity_confidence, 292 | bb_ph: obj_bb, 293 | module.phase_ph: False, 294 | word_embed_entities_ph: embed_obj, word_embed_relations_ph: embed_pred} 295 | 296 | out_relation_probes_val, out_entity_probes_val = \ 297 | sess.run([out_relation_probes, out_entity_probes], 298 | feed_dict=feed_dict) 299 | 300 | out_relation_probes_val[indices, indices, :] = relation_neg 301 | 302 | results = test(entity.predicates_labels, entity.objects_labels, out_relation_probes_val, 303 | out_entity_probes_val) 304 | 305 | # accumulate results 306 | if accum_results is None: 307 | accum_results = results 308 | else: 309 | for key in results: 310 | accum_results[key] += results[key] 311 | 312 | # eval image 313 | k_metric_res, correct_image, total_image, img_per_relation_correct, img_per_relation_total = eval_image(entity, 314 | entity.predicates_labels, 315 | entity.objects_labels, out_relation_probes_val, out_entity_probes_val, k=min(k, N * N - N)) 316 | # filter images without positive relations 317 | if total_image == 0: 318 | continue 319 | 320 | nof += 1 321 | total += k_metric_res 322 | total_score = float(total) / nof 323 | correct_all += correct_image 324 | total_all += total_image 325 | logger.log("\rresult %d - %f (%d / %d) - total %f (%d)" % ( 326 | index, k_metric_res, correct_image, total_image, total_score, entity.image.id)) 327 | 328 | index += 1 329 | 330 | relation_accuracy = float(accum_results['entity_correct']) / accum_results['entity_total'] 331 | relation_pos_accuracy = float(accum_results['relations_pos_correct']) / accum_results[ 332 | 'relations_pos_total'] 333 | relationships_pos_accuracy = float(accum_results['relationships_pos_correct']) / accum_results[ 334 | 'relations_pos_total'] 335 | logger.log("entity %f - positive relation %f - positive triplet %f" % 336 | (relation_accuracy, relation_pos_accuracy, relationships_pos_accuracy)) 337 | 338 | time.sleep(3) 339 | 340 | 341 | logger.log("(%s) Final Result for k=%d - %f" % (load_module_name, k, total_score)) 342 | 343 | 344 | if __name__ == "__main__": 345 | k_recall = True 346 | gpu = 1 347 | layers = [500, 500, 500] 348 | 349 | load_module_name = "gpi_linguistic_pretrained" 350 | k = 100 351 | eval(load_module_name, k, layers, gpu) 352 | exit() 353 | -------------------------------------------------------------------------------- /Module/Module.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import sys 3 | 4 | sys.path.append("..") 5 | from Utils.Logger import Logger 6 | 7 | 8 | class Module(object): 9 | """ 10 | SGP Module which gets as an input the confidence of relations (predicates) and entities (objects) 11 | and outputs an improved confidence for predicates and objects 12 | """ 13 | 14 | def __init__(self, gpi_type="Linguistic", nof_predicates=51, nof_objects=150, rnn_steps=2, is_train=True, 15 | learning_rate=0.0001, 16 | learning_rate_steps=1000, learning_rate_decay=0.5, 17 | including_object=False, layers=[500, 500, 500], reg_factor=0.0, lr_object_coeff=4): 18 | """ 19 | Construct module: 20 | - create input placeholders 21 | - apply SGP rnn_steps times 22 | - create labels placeholders 23 | - create module loss and train_step 24 | 25 | :type gpi_type: "Linguistic", "FeatureAttention", "NeighbourAttention" 26 | :param nof_predicates: nof predicate labels 27 | :param nof_objects: nof object labels 28 | :param rnn_steps: number of time to apply SGP 29 | :param is_train: whether the module will be used to train or eval 30 | """ 31 | # save input 32 | self.learning_rate_decay = learning_rate_decay 33 | self.learning_rate_steps = learning_rate_steps 34 | self.learning_rate = learning_rate 35 | self.nof_predicates = nof_predicates 36 | self.nof_objects = nof_objects 37 | self.is_train = is_train 38 | self.rnn_steps = rnn_steps 39 | self.embed_size = 300 40 | self.gpi_type = gpi_type 41 | 42 | self.including_object = including_object 43 | self.lr_object_coeff = lr_object_coeff 44 | self.layers = layers 45 | self.reg_factor = reg_factor 46 | self.activation_fn = tf.nn.relu 47 | self.reuse = None 48 | # logging module 49 | logger = Logger() 50 | 51 | ## 52 | # module input 53 | self.phase_ph = tf.placeholder(tf.bool, name='phase') 54 | 55 | # confidence 56 | self.confidence_relation_ph = tf.placeholder(dtype=tf.float32, shape=(None, None, self.nof_predicates), 57 | name="confidence_relation") 58 | #self.confidence_relation_ph = tf.contrib.layers.dropout(self.confidence_relation_ph, keep_prob=0.9, is_training=self.phase_ph) 59 | self.confidence_entity_ph = tf.placeholder(dtype=tf.float32, shape=(None, self.nof_objects), 60 | name="confidence_entity") 61 | #self.confidence_entity_ph = tf.contrib.layers.dropout(self.confidence_entity_ph, keep_prob=0.9, is_training=self.phase_ph) 62 | # spatial features 63 | self.entity_bb_ph = tf.placeholder(dtype=tf.float32, shape=(None, 14), name="obj_bb") 64 | 65 | # word embeddings 66 | self.word_embed_entities_ph = tf.placeholder(dtype=tf.float32, shape=(self.nof_objects, self.embed_size), 67 | name="word_embed_objects") 68 | self.word_embed_relations_ph = tf.placeholder(dtype=tf.float32, shape=(self.nof_predicates, self.embed_size), 69 | name="word_embed_predicates") 70 | 71 | # labels 72 | if self.is_train: 73 | self.labels_relation_ph = tf.placeholder(dtype=tf.float32, shape=(None, None, self.nof_predicates), 74 | name="labels_predicate") 75 | self.labels_entity_ph = tf.placeholder(dtype=tf.float32, shape=(None, self.nof_objects), 76 | name="labels_object") 77 | self.labels_coeff_loss_ph = tf.placeholder(dtype=tf.float32, shape=(None), name="labels_coeff_loss") 78 | 79 | # store all the outputs of of rnn steps 80 | self.out_confidence_entity_lst = [] 81 | self.out_confidence_relation_lst = [] 82 | # rnn stage module 83 | confidence_relation = self.confidence_relation_ph 84 | confidence_entity = self.confidence_entity_ph 85 | 86 | # features msg 87 | for step in range(self.rnn_steps): 88 | confidence_relation, confidence_entity_temp = \ 89 | self.sgp(in_confidence_relation=confidence_relation, 90 | in_confidence_entity=confidence_entity, 91 | scope_name="deep_graph") 92 | # store the confidence 93 | self.out_confidence_relation_lst.append(confidence_relation) 94 | if self.including_object: 95 | confidence_entity = confidence_entity_temp 96 | # store the confidence 97 | self.out_confidence_entity_lst.append(confidence_entity_temp) 98 | self.reuse = True 99 | 100 | #confidence_entity = confidence_entity_temp 101 | self.out_confidence_relation = confidence_relation 102 | self.out_confidence_entity = confidence_entity 103 | reshaped_relation_confidence = tf.reshape(confidence_relation, (-1, self.nof_predicates)) 104 | self.reshaped_relation_probes = tf.nn.softmax(reshaped_relation_confidence) 105 | self.out_relation_probes = tf.reshape(self.reshaped_relation_probes, tf.shape(confidence_relation), 106 | name="out_relation_probes") 107 | self.out_entity_probes = tf.nn.softmax(confidence_entity, name="out_entity_probes") 108 | 109 | # loss 110 | if self.is_train: 111 | # Learning rate 112 | self.lr_ph = tf.placeholder(dtype=tf.float32, shape=[], name="lr_ph") 113 | 114 | self.loss, self.gradients, self.grad_placeholder, self.train_step = self.module_loss() 115 | 116 | def nn(self, features, layers, out, scope_name, seperated_layer=False, last_activation=None): 117 | """ 118 | simple nn to convert features to confidence 119 | :param features: list of features tensor 120 | :param layers: hidden layers 121 | :param seperated_layer: First run FC one each feature tensor seperately 122 | :param out: output shape (used to reshape to required output shape) 123 | :param scope_name: tensorflow scope name 124 | :param last_activation: activation function for the last layer (None means no activation) 125 | :return: confidence 126 | """ 127 | with tf.variable_scope(scope_name) as scopevar: 128 | 129 | # first layer each feature seperatly 130 | features_h_lst = [] 131 | index = 0 132 | for feature in features: 133 | if seperated_layer: 134 | in_size = feature.shape[-1]._value 135 | scope = str(index) 136 | h = tf.contrib.layers.fully_connected(feature, in_size, reuse=self.reuse, scope=scope, 137 | activation_fn=self.activation_fn) 138 | index += 1 139 | features_h_lst.append(h) 140 | else: 141 | features_h_lst.append(feature) 142 | 143 | h = tf.concat(features_h_lst, axis=-1) 144 | h = tf.contrib.layers.dropout(h, keep_prob=0.9, is_training=self.phase_ph) 145 | for layer in layers: 146 | scope = str(index) 147 | h = tf.contrib.layers.fully_connected(h, layer, reuse=self.reuse, scope=scope, 148 | activation_fn=self.activation_fn) 149 | h = tf.contrib.layers.dropout(h, keep_prob=0.9, is_training=self.phase_ph) 150 | index += 1 151 | 152 | scope = str(index) 153 | y = tf.contrib.layers.fully_connected(h, out, reuse=self.reuse, scope=scope, activation_fn=last_activation) 154 | return y 155 | 156 | def sgp(self, in_confidence_relation, in_confidence_entity, scope_name="rnn_cell"): 157 | """ 158 | SGP step - which get as an input a confidence of the predicates and objects and return an improved confidence of the predicates and the objects 159 | :return: 160 | :param in_confidence_relation: in relation confidence 161 | :param in_confidence_entity: in entity confidence 162 | :param scope_name: sgp step scope 163 | :return: improved relatiob probabilities, improved relation confidence, improved entity probabilities and improved entity confidence 164 | """ 165 | with tf.variable_scope(scope_name): 166 | 167 | # relation features normalization 168 | self.in_confidence_predicate_actual = in_confidence_relation 169 | relation_probes = tf.nn.softmax(in_confidence_relation) 170 | self.relation_probes = relation_probes 171 | relation_features = tf.log(relation_probes + tf.constant(1e-10)) 172 | 173 | # entity features normalization 174 | self.in_confidence_entity_actual = in_confidence_entity 175 | entity_probes = tf.nn.softmax(in_confidence_entity) 176 | self.entity_probes = entity_probes 177 | entity_features_conf = tf.log(entity_probes + tf.constant(1e-10)) 178 | entity_features = tf.concat((entity_features_conf, self.entity_bb_ph), axis=1) 179 | 180 | # word embeddings 181 | # expand object word embed 182 | N = tf.slice(tf.shape(self.confidence_relation_ph), [0], [1], name="N") 183 | if self.gpi_type == "Linguistic": 184 | self.entity_prediction = tf.argmax(self.entity_probes, axis=1) 185 | self.entity_prediction_val = tf.reduce_max(self.entity_probes, axis=1) 186 | self.embed_entities = tf.gather(self.word_embed_entities_ph, self.entity_prediction) 187 | self.embed_entities = tf.transpose( 188 | tf.multiply(tf.transpose(self.embed_entities), self.entity_prediction_val)) 189 | in_extended_confidence_embed_shape = tf.concat((N, tf.shape(self.embed_entities)), 0) 190 | entity_features = tf.concat((self.embed_entities, entity_features), axis=1) 191 | 192 | self.relation_prediction = tf.argmax(self.relation_probes[:, :, :self.nof_predicates - 1], axis=2) 193 | self.relation_prediction_val = tf.reduce_max(self.relation_probes[:, :, :self.nof_predicates - 1], axis=2) 194 | self.embed_relations = tf.gather(self.word_embed_relations_ph, tf.reshape(self.relation_prediction, [-1])) 195 | self.embed_relations = tf.transpose( 196 | tf.multiply(tf.transpose(self.embed_relations), tf.reshape(self.relation_prediction_val, [-1]))) 197 | self.embed_relations = tf.reshape(self.embed_relations, in_extended_confidence_embed_shape) 198 | relation_features = tf.concat((relation_features, self.embed_relations), axis=2) 199 | 200 | # append relations in both directions 201 | self.relation_features = tf.concat((relation_features, tf.transpose(relation_features, perm=[1, 0, 2])), axis=2) 202 | 203 | # expand object confidence 204 | self.extended_confidence_entity_shape = tf.concat((N, tf.shape(entity_features)), 0) 205 | self.expand_object_features = tf.add(tf.zeros(self.extended_confidence_entity_shape), 206 | entity_features, 207 | name="expand_object_features") 208 | # expand subject confidence 209 | self.expand_subject_features = tf.transpose(self.expand_object_features, perm=[1, 0, 2], 210 | name="expand_subject_features") 211 | 212 | ## 213 | # Node Neighbours 214 | self.object_ngbrs = [self.expand_object_features, self.expand_subject_features, relation_features] 215 | # apply phi 216 | self.object_ngbrs_phi = self.nn(features=self.object_ngbrs, layers=[], out=500, scope_name="nn_phi") 217 | # Attention mechanism 218 | if self.gpi_type == "FeatureAttention" or self.gpi_type == "Linguistic": 219 | self.object_ngbrs_scores = self.nn(features=self.object_ngbrs, layers=[], out=500, 220 | scope_name="nn_phi_atten") 221 | self.object_ngbrs_weights = tf.nn.softmax(self.object_ngbrs_scores, dim=1) 222 | self.object_ngbrs_phi_all = tf.reduce_sum(tf.multiply(self.object_ngbrs_phi, self.object_ngbrs_weights), 223 | axis=1) 224 | 225 | elif self.gpi_type == "NeighbourAttention": 226 | self.object_ngbrs_scores = self.nn(features=self.object_ngbrs, layers=[], out=1, 227 | scope_name="nn_phi_atten") 228 | self.object_ngbrs_weights = tf.nn.softmax(self.object_ngbrs_scores, dim=1) 229 | self.object_ngbrs_phi_all = tf.reduce_sum(tf.multiply(self.object_ngbrs_phi, self.object_ngbrs_weights), 230 | axis=1) 231 | else: 232 | self.object_ngbrs_phi_all = tf.reduce_mean(self.object_ngbrs_phi, axis=1) 233 | 234 | ## 235 | # Nodes 236 | self.object_ngbrs2 = [entity_features, self.object_ngbrs_phi_all] 237 | # apply alpha 238 | self.object_ngbrs2_alpha = self.nn(features=self.object_ngbrs2, layers=[], out=500, scope_name="nn_phi2") 239 | # Attention mechanism 240 | if self.gpi_type == "FeatureAttention" or self.gpi_type == "Linguistic": 241 | self.object_ngbrs2_scores = self.nn(features=self.object_ngbrs2, layers=[], out=500, 242 | scope_name="nn_phi2_atten") 243 | self.object_ngbrs2_weights = tf.nn.softmax(self.object_ngbrs2_scores, dim=0) 244 | self.object_ngbrs2_alpha_all = tf.reduce_sum( 245 | tf.multiply(self.object_ngbrs2_alpha, self.object_ngbrs2_weights), axis=0) 246 | elif self.gpi_type == "NeighbourAttention": 247 | self.object_ngbrs2_scores = self.nn(features=self.object_ngbrs2, layers=[], out=1, 248 | scope_name="nn_phi2_atten") 249 | self.object_ngbrs2_weights = tf.nn.softmax(self.object_ngbrs2_scores, dim=0) 250 | self.object_ngbrs2_alpha_all = tf.reduce_sum( 251 | tf.multiply(self.object_ngbrs2_alpha, self.object_ngbrs2_weights), axis=0) 252 | else: 253 | self.object_ngbrs2_alpha_all = tf.reduce_mean(self.object_ngbrs2_alpha, axis=0) 254 | 255 | expand_graph_shape = tf.concat((N, N, tf.shape(self.object_ngbrs2_alpha_all)), 0) 256 | expand_graph = tf.add(tf.zeros(expand_graph_shape), self.object_ngbrs2_alpha_all) 257 | 258 | ## 259 | # rho relation (relation prediction) 260 | # The input is object features, subject features, relation features and the representation of the graph 261 | self.expand_obj_ngbrs_phi_all = tf.add(tf.zeros_like(self.object_ngbrs_phi), self.object_ngbrs_phi_all) 262 | self.expand_sub_ngbrs_phi_all = tf.transpose(self.expand_obj_ngbrs_phi_all, perm=[1, 0, 2]) 263 | self.relation_all_features = [relation_features, self.expand_object_features, self.expand_subject_features, expand_graph] 264 | 265 | pred_delta = self.nn(features=self.relation_all_features, layers=self.layers, out=self.nof_predicates, 266 | scope_name="nn_pred") 267 | pred_forget_gate = self.nn(features=self.relation_all_features, layers=[], out=1, 268 | scope_name="nn_pred_forgate", last_activation=tf.nn.sigmoid) 269 | out_confidence_relation = pred_delta + pred_forget_gate * in_confidence_relation 270 | 271 | ## 272 | # rho entity (entity prediction) 273 | # The input is entity features, entity neighbour features and the representation of the graph 274 | if self.including_object: 275 | self.object_all_features = [entity_features, expand_graph[0], self.object_ngbrs_phi_all] 276 | obj_delta = self.nn(features=self.object_all_features, layers=self.layers, out=self.nof_objects, 277 | scope_name="nn_obj") 278 | obj_forget_gate = self.nn(features=self.object_all_features, layers=[], out=self.nof_objects, 279 | scope_name="nn_obj_forgate", last_activation=tf.nn.sigmoid) 280 | out_confidence_object = obj_delta + obj_forget_gate * in_confidence_entity 281 | else: 282 | out_confidence_object = in_confidence_entity 283 | 284 | return out_confidence_relation, out_confidence_object 285 | 286 | def module_loss(self, scope_name="loss"): 287 | """ 288 | SGP loss 289 | :param scope_name: tensor flow scope name 290 | :return: loss and train step 291 | """ 292 | with tf.variable_scope(scope_name): 293 | # reshape to batch like shape 294 | shaped_labels_predicate = tf.reshape(self.labels_relation_ph, (-1, self.nof_predicates)) 295 | 296 | # relation gt 297 | self.gt = tf.argmax(shaped_labels_predicate, axis=1) 298 | 299 | loss = 0 300 | 301 | for rnn_step in range(self.rnn_steps): 302 | 303 | shaped_confidence_predicate = tf.reshape(self.out_confidence_relation_lst[rnn_step], 304 | (-1, self.nof_predicates)) 305 | 306 | # set predicate loss 307 | self.relation_ce_loss = tf.nn.softmax_cross_entropy_with_logits(labels=shaped_labels_predicate, 308 | logits=shaped_confidence_predicate, 309 | name="relation_ce_loss") 310 | 311 | self.loss_relation = self.relation_ce_loss 312 | self.loss_relation_weighted = tf.multiply(self.loss_relation, self.labels_coeff_loss_ph) 313 | 314 | loss += tf.reduce_sum(self.loss_relation_weighted) 315 | 316 | # set object loss 317 | if self.including_object: 318 | self.object_ce_loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.labels_entity_ph, 319 | logits=self.out_confidence_entity_lst[ 320 | rnn_step], 321 | name="object_ce_loss") 322 | 323 | loss += self.lr_object_coeff * tf.reduce_sum(self.object_ce_loss) 324 | 325 | # reg 326 | trainable_vars = tf.trainable_variables() 327 | lossL2 = tf.add_n([tf.nn.l2_loss(v) for v in trainable_vars]) * self.reg_factor 328 | loss += lossL2 329 | 330 | # minimize 331 | #opt = tf.train.GradientDescentOptimizer(self.lr_ph) 332 | opt = tf.train.AdamOptimizer(self.lr_ph) 333 | # opt = tf.train.MomentumOptimizer(self.lr_ph, 0.9, use_nesterov=True) 334 | gradients = opt.compute_gradients(loss) 335 | # create placeholder to minimize in a batch 336 | grad_placeholder = [(tf.placeholder("float", shape=grad[0].get_shape()), grad[1]) for grad in gradients] 337 | 338 | train_step = opt.apply_gradients(grad_placeholder) 339 | return loss, gradients, grad_placeholder, train_step 340 | 341 | def get_in_ph(self): 342 | """ 343 | get input place holders 344 | """ 345 | return self.confidence_relation_ph, self.confidence_entity_ph, self.entity_bb_ph, self.word_embed_relations_ph, self.word_embed_entities_ph 346 | 347 | def get_output(self): 348 | """ 349 | get module output 350 | """ 351 | return self.out_relation_probes, self.out_entity_probes 352 | 353 | def get_labels_ph(self): 354 | """ 355 | get module labels ph (used for train) 356 | """ 357 | return self.labels_relation_ph, self.labels_entity_ph, self.labels_coeff_loss_ph 358 | 359 | def get_module_loss(self): 360 | """ 361 | get module loss and train step 362 | """ 363 | return self.loss, self.gradients, self.grad_placeholder, self.train_step 364 | -------------------------------------------------------------------------------- /Module/Train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | 4 | import math 5 | import inspect 6 | 7 | from Data.VisualGenome.models import * 8 | from FilesManager.FilesManager import FilesManager 9 | from Module import Module 10 | import tensorflow as tf 11 | import numpy as np 12 | import os 13 | import cPickle 14 | from random import shuffle 15 | 16 | from Utils.Logger import Logger 17 | 18 | # feature sizes 19 | NOF_PREDICATES = 51 20 | NOF_OBJECTS = 150 21 | 22 | # save model every number of iterations 23 | SAVE_MODEL_ITERATIONS = 5 24 | 25 | # test every number of iterations 26 | TEST_ITERATIONS = 1 27 | 28 | 29 | def test(labels_relation, labels_entity, out_confidence_relation_val, out_confidence_entity_val): 30 | """ 31 | returns a dictionary with statistics about object, predicate and relationship accuracy in this image 32 | :param labels_relation: labels of image predicates (each one is one hot vector) - shape (N, N, NOF_PREDICATES) 33 | :param labels_entity: labels of image objects (each one is one hot vector) - shape (N, NOF_OBJECTS) 34 | :param out_confidence_relation_val: confidence of image predicates - shape (N, N, NOF_PREDICATES) 35 | :param out_confidence_entity_val: confidence of image objects - shape (N, NOF_OBJECTS) 36 | :return: see description 37 | """ 38 | relation_gt = np.argmax(labels_relation, axis=2) 39 | entity_gt = np.argmax(labels_entity, axis=1) 40 | relation_pred = np.argmax(out_confidence_relation_val, axis=2) 41 | relations_pred_no_neg = np.argmax(out_confidence_relation_val[:, :, :NOF_PREDICATES - 1], axis=2) 42 | entities_pred = np.argmax(out_confidence_entity_val, axis=1) 43 | 44 | # noinspection PyDictCreation 45 | results = {} 46 | # number of objects 47 | results["entity_total"] = entity_gt.shape[0] 48 | # number of predicates / relationships 49 | results["relations_total"] = relation_gt.shape[0] * relation_gt.shape[1] 50 | # number of positive predicates / relationships 51 | pos_indices = np.where(relation_gt != NOF_PREDICATES - 1) 52 | results["relations_pos_total"] = pos_indices[0].shape[0] 53 | 54 | # number of object correct predictions 55 | results["entity_correct"] = np.sum(entity_gt == entities_pred) 56 | # number of correct predicate 57 | results["relations_correct"] = np.sum(relation_gt == relation_pred) 58 | # number of correct positive predicates 59 | relations_gt_pos = relation_gt[pos_indices] 60 | relations_pred_pos = relations_pred_no_neg[pos_indices] 61 | results["relations_pos_correct"] = np.sum(relations_gt_pos == relations_pred_pos) 62 | # number of correct relationships 63 | entity_true_indices = np.where(entity_gt == entities_pred) 64 | relations_gt_true = relation_gt[entity_true_indices[0], :][:, entity_true_indices[0]] 65 | relations_pred_true = relation_pred[entity_true_indices[0], :][:, entity_true_indices[0]] 66 | relations_pred_true_pos = relations_pred_no_neg[entity_true_indices[0], :][:, entity_true_indices[0]] 67 | results["relationships_correct"] = np.sum(relations_gt_true == relations_pred_true) 68 | # number of correct positive relationships 69 | pos_true_indices = np.where(relations_gt_true != NOF_PREDICATES - 1) 70 | relations_gt_pos_true = relations_gt_true[pos_true_indices] 71 | relations_pred_pos_true = relations_pred_true_pos[pos_true_indices] 72 | results["relationships_pos_correct"] = np.sum(relations_gt_pos_true == relations_pred_pos_true) 73 | 74 | return results 75 | 76 | 77 | def train(name="module", 78 | nof_iterations=100, 79 | learning_rate=0.0001, 80 | learning_rate_steps=1000, 81 | learning_rate_decay=0.5, 82 | load_module_name="module.ckpt", 83 | use_saved_module=False, 84 | batch_size=20, 85 | pred_pos_neg_ratio=10, 86 | lr_object_coeff=4, 87 | layers=[500, 500, 500], 88 | gpu=0): 89 | """ 90 | Train SGP module given train parameters and module hyper-parameters 91 | :param name: name of the train session 92 | :param nof_iterations: number of epochs 93 | :param learning_rate: 94 | :param learning_rate_steps: decay after number of steps 95 | :param learning_rate_decay: the factor to decay the learning rate 96 | :param load_module_name: name of already trained module weights to load 97 | :param use_saved_module: start from already train module 98 | :param batch_size: number of images in each mini-batch 99 | :param pred_pos_neg_ratio: Set the loss ratio between positive and negatives (not labeled) predicates 100 | :param lr_object_coeff: Set the loss ratio between objects and predicates 101 | :param layers: list of sizes of the hidden layer of the predicate and object classifier 102 | :param gpu: gpu number to use for the training 103 | :return: nothing 104 | """ 105 | gpi_type = "Linguistic" 106 | including_object = True 107 | # get filesmanager 108 | filesmanager = FilesManager() 109 | 110 | # create logger 111 | logger_path = filesmanager.get_file_path("logs") 112 | logger_path = os.path.join(logger_path, name) 113 | logger = Logger(name, logger_path) 114 | 115 | # print train params 116 | frame = inspect.currentframe() 117 | args, _, _, values = inspect.getargvalues(frame) 118 | logger.log('function name "%s"' % inspect.getframeinfo(frame)[2]) 119 | for i in args: 120 | logger.log(" %s = %s" % (i, values[i])) 121 | 122 | # set gpu 123 | if gpu != None: 124 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu) 125 | logger.log("os.environ[\"CUDA_VISIBLE_DEVICES\"] = " + str(gpu)) 126 | 127 | # create module 128 | module = Module(gpi_type=gpi_type, nof_predicates=NOF_PREDICATES, nof_objects=NOF_OBJECTS, 129 | is_train=True, 130 | learning_rate=learning_rate, learning_rate_steps=learning_rate_steps, 131 | learning_rate_decay=learning_rate_decay, 132 | lr_object_coeff=lr_object_coeff, 133 | including_object=including_object, 134 | layers=layers) 135 | 136 | ## 137 | # get module place holders 138 | # 139 | # get input place holders 140 | confidence_relation_ph, confidence_entity_ph, bb_ph, word_embed_relations_ph, word_embed_entities_ph = module.get_in_ph() 141 | # get labels place holders 142 | labels_relation_ph, labels_entity_ph, labels_coeff_loss_ph = module.get_labels_ph() 143 | # get loss and train step 144 | loss, gradients, grad_placeholder, train_step = module.get_module_loss() 145 | 146 | ## 147 | # get module output 148 | out_relation_probes, out_entity_probes = module.get_output() 149 | 150 | # Initialize the Computational Graph 151 | init = tf.global_variables_initializer() 152 | # Add ops to save and restore all the variables. 153 | variables = tf.contrib.slim.get_variables_to_restore() 154 | variables_to_restore = variables 155 | saver = tf.train.Saver(variables_to_restore) 156 | 157 | with tf.Session() as sess: 158 | # Restore variables from disk. 159 | module_path = filesmanager.get_file_path("sg_module.train.saver") 160 | module_path_load = os.path.join(module_path, load_module_name) 161 | if os.path.exists(module_path_load + ".index") and use_saved_module: 162 | saver.restore(sess, module_path_load) 163 | logger.log("Model restored.") 164 | else: 165 | sess.run(init) 166 | 167 | # train images 168 | vg_train_path = filesmanager.get_file_path("data.visual_genome.train") 169 | # list of train files 170 | train_files_list = range(2, 72) 171 | shuffle(train_files_list) 172 | 173 | # Actual validation is 5 files. 174 | # After tunning the hyper parameters, use just 2 files for early stopping. 175 | validation_files_list = range(2) 176 | 177 | # create one hot vector for predicate_negative (i.e. not labeled) 178 | relation_neg = np.zeros(NOF_PREDICATES) 179 | relation_neg[NOF_PREDICATES - 1] = 1 180 | 181 | # object embedding 182 | embed_obj = FilesManager().load_file("language_module.word2vec.object_embeddings") 183 | embed_pred = FilesManager().load_file("language_module.word2vec.predicate_embeddings") 184 | embed_pred = np.concatenate((embed_pred, np.zeros(embed_pred[:1].shape)), 185 | axis=0) # concat negative represntation 186 | 187 | # train module 188 | lr = learning_rate 189 | best_test_loss = -1 190 | baseline_path = filesmanager.get_file_path("data.visual_genome.train_baseline") 191 | for epoch in xrange(1, nof_iterations): 192 | accum_results = None 193 | total_loss = 0 194 | steps = [] 195 | # read data 196 | file_index = -1 197 | for file_name in train_files_list: 198 | 199 | file_index += 1 200 | 201 | # load data from file 202 | file_path = os.path.join(vg_train_path, str(file_name) + ".p") 203 | file_handle = open(file_path, "rb") 204 | train_images = cPickle.load(file_handle) 205 | file_handle.close() 206 | shuffle(train_images) 207 | 208 | for image in train_images: 209 | # load initial belief by baseline detector 210 | file_path = os.path.join(baseline_path, str(image.image.id) + ".p") 211 | if not os.path.exists(file_path): 212 | continue 213 | file_handle = open(file_path, "rb") 214 | decetctor_data = cPickle.load(file_handle) 215 | file_handle.close() 216 | image.predicates_outputs_with_no_activation = decetctor_data["rel_dist_mapped"] 217 | image.objects_outputs_with_no_activations = decetctor_data["obj_dist_mapped"] 218 | 219 | # set diagonal to be negative predicate (no relation for a single object) 220 | indices = np.arange(image.predicates_outputs_with_no_activation.shape[0]) 221 | image.predicates_outputs_with_no_activation[indices, indices, :] = relation_neg 222 | image.predicates_labels[indices, indices, :] = relation_neg 223 | 224 | # spatial features 225 | entity_bb = np.zeros((len(image.objects), 14)) 226 | for obj_id in range(len(image.objects)): 227 | entity_bb[obj_id][0] = image.objects[obj_id].x / 1200.0 228 | entity_bb[obj_id][1] = image.objects[obj_id].y / 1200.0 229 | entity_bb[obj_id][2] = (image.objects[obj_id].x + image.objects[obj_id].width) / 1200.0 230 | entity_bb[obj_id][3] = (image.objects[obj_id].y + image.objects[obj_id].height) / 1200.0 231 | entity_bb[obj_id][4] = image.objects[obj_id].x 232 | entity_bb[obj_id][5] = -1 * image.objects[obj_id].x 233 | entity_bb[obj_id][6] = image.objects[obj_id].y 234 | entity_bb[obj_id][7] = -1 * image.objects[obj_id].y 235 | entity_bb[obj_id][8] = image.objects[obj_id].width * image.objects[obj_id].height 236 | entity_bb[obj_id][9] = -1 * image.objects[obj_id].width * image.objects[obj_id].height 237 | entity_bb[:, 4] = np.argsort(entity_bb[:, 4]) 238 | entity_bb[:, 5] = np.argsort(entity_bb[:, 5]) 239 | entity_bb[:, 6] = np.argsort(entity_bb[:, 6]) 240 | entity_bb[:, 7] = np.argsort(entity_bb[:, 7]) 241 | entity_bb[:, 8] = np.argsort(entity_bb[:, 8]) 242 | entity_bb[:, 9] = np.argsort(entity_bb[:, 9]) 243 | entity_bb[:, 10] = np.argsort(np.max(image.objects_outputs_with_no_activations, axis=1)) 244 | entity_bb[:, 11] = np.argsort(-1 * np.max(image.objects_outputs_with_no_activations, axis=1)) 245 | entity_bb[:, 12] = np.arange(entity_bb.shape[0]) 246 | entity_bb[:, 13] = np.arange(entity_bb.shape[0], 0, -1) 247 | 248 | # filter non mixed cases 249 | relations_neg_labels = image.predicates_labels[:, :, NOF_PREDICATES - 1:] 250 | if np.sum(image.predicates_labels[:, :, :NOF_PREDICATES - 1]) == 0: 251 | continue 252 | 253 | if including_object: 254 | in_entity_confidence = image.objects_outputs_with_no_activations 255 | else: 256 | in_entity_confidence = image.objects_labels * 1000 257 | 258 | # give lower weight to negatives 259 | coeff_factor = np.ones(relations_neg_labels.shape) 260 | factor = float(np.sum(image.predicates_labels[:, :, :NOF_PREDICATES - 2])) / np.sum( 261 | relations_neg_labels) / pred_pos_neg_ratio 262 | coeff_factor[relations_neg_labels == 1] *= factor 263 | 264 | coeff_factor[indices, indices] = 0 265 | 266 | # create the feed dictionary 267 | feed_dict = {confidence_relation_ph: image.predicates_outputs_with_no_activation, 268 | confidence_entity_ph: in_entity_confidence, 269 | bb_ph : entity_bb, 270 | module.phase_ph: True, 271 | word_embed_entities_ph: embed_obj, word_embed_relations_ph: embed_pred, 272 | labels_relation_ph: image.predicates_labels, labels_entity_ph: image.objects_labels, 273 | labels_coeff_loss_ph: coeff_factor.reshape((-1)), module.lr_ph: lr} 274 | 275 | # run the network 276 | out_relation_probes_val, out_entity_probes_val, loss_val, gradients_val = \ 277 | sess.run([out_relation_probes, out_entity_probes, loss, gradients], 278 | feed_dict=feed_dict) 279 | if math.isnan(loss_val): 280 | print("NAN") 281 | continue 282 | 283 | # set diagonal to be neg (in order not to take into account in statistics) 284 | out_relation_probes_val[indices, indices, :] = relation_neg 285 | 286 | # append gradient to list (will be applied as a batch of entities) 287 | steps.append(gradients_val) 288 | 289 | # statistic 290 | total_loss += loss_val 291 | 292 | results = test(image.predicates_labels, image.objects_labels, out_relation_probes_val, 293 | out_entity_probes_val) 294 | 295 | # accumulate results 296 | if accum_results is None: 297 | accum_results = results 298 | else: 299 | for key in results: 300 | accum_results[key] += results[key] 301 | 302 | if len(steps) == batch_size: 303 | # apply steps 304 | step = steps[0] 305 | feed_grad_apply_dict = {grad_placeholder[j][0]: step[j][0] for j in 306 | xrange(len(grad_placeholder))} 307 | for i in xrange(1, len(steps)): 308 | step = steps[i] 309 | for j in xrange(len(grad_placeholder)): 310 | feed_grad_apply_dict[grad_placeholder[j][0]] += step[j][0] 311 | 312 | feed_grad_apply_dict[module.lr_ph] = lr 313 | sess.run([train_step], feed_dict=feed_grad_apply_dict) 314 | steps = [] 315 | # print stat - per file just for the first epoch - disabled!! 316 | if epoch == 1: 317 | obj_accuracy = float(accum_results['entity_correct']) / accum_results['entity_total'] 318 | predicate_pos_accuracy = float(accum_results['relations_pos_correct']) / accum_results[ 319 | 'relations_pos_total'] 320 | relationships_pos_accuracy = float(accum_results['relationships_pos_correct']) / accum_results[ 321 | 'relations_pos_total'] 322 | logger.log("iter %d.%d - obj %f - pred %f - relation %f" % 323 | (epoch, file_index, obj_accuracy, predicate_pos_accuracy, relationships_pos_accuracy)) 324 | 325 | # print stat per epoch 326 | obj_accuracy = float(accum_results['entity_correct']) / accum_results['entity_total'] 327 | predicate_pos_accuracy = float(accum_results['relations_pos_correct']) / accum_results[ 328 | 'relations_pos_total'] 329 | predicate_all_accuracy = float(accum_results['relations_correct']) / accum_results['relations_total'] 330 | relationships_pos_accuracy = float(accum_results['relationships_pos_correct']) / accum_results[ 331 | 'relations_pos_total'] 332 | relationships_all_accuracy = float(accum_results['relationships_correct']) / accum_results[ 333 | 'relations_total'] 334 | 335 | logger.log("iter %d - loss %f - obj %f - pred %f - rela %f - all_pred %f - all rela %f - lr %f" % 336 | (epoch, total_loss, obj_accuracy, predicate_pos_accuracy, relationships_pos_accuracy, 337 | predicate_all_accuracy, relationships_all_accuracy, lr)) 338 | 339 | # run validation 340 | if epoch % TEST_ITERATIONS == 0: 341 | total_test_loss = 0 342 | accum_test_results = None 343 | 344 | for file_name in validation_files_list: 345 | # load data from file 346 | file_path = os.path.join(vg_train_path, str(file_name) + ".p") 347 | file_handle = open(file_path, "rb") 348 | validation_images = cPickle.load(file_handle) 349 | file_handle.close() 350 | 351 | for image in validation_images: 352 | file_path = os.path.join(baseline_path, str(image.image.id) + ".p") 353 | if not os.path.exists(file_path): 354 | continue 355 | file_handle = open(file_path, "rb") 356 | detector_data = cPickle.load(file_handle) 357 | file_handle.close() 358 | 359 | image.predicates_outputs_with_no_activation = detector_data["rel_dist_mapped"] 360 | image.objects_outputs_with_no_activations = detector_data["obj_dist_mapped"] 361 | # set diagonal to be neg 362 | indices = np.arange(image.predicates_outputs_with_no_activation.shape[0]) 363 | image.predicates_outputs_with_no_activation[indices, indices, :] = relation_neg 364 | image.predicates_labels[indices, indices, :] = relation_neg 365 | 366 | # get shape of extended object to be used by the module 367 | extended_confidence_object_shape = np.asarray(image.predicates_outputs_with_no_activation.shape) 368 | extended_confidence_object_shape[2] = NOF_OBJECTS 369 | 370 | # spatial features 371 | entity_bb = np.zeros((len(image.objects), 14)) 372 | for obj_id in range(len(image.objects)): 373 | entity_bb[obj_id][0] = image.objects[obj_id].x / 1200.0 374 | entity_bb[obj_id][1] = image.objects[obj_id].y / 1200.0 375 | entity_bb[obj_id][2] = (image.objects[obj_id].x + image.objects[obj_id].width) / 1200.0 376 | entity_bb[obj_id][3] = (image.objects[obj_id].y + image.objects[obj_id].height) / 1200.0 377 | entity_bb[obj_id][4] = image.objects[obj_id].x 378 | entity_bb[obj_id][5] = -1 * image.objects[obj_id].x 379 | entity_bb[obj_id][6] = image.objects[obj_id].y 380 | entity_bb[obj_id][7] = -1 * image.objects[obj_id].y 381 | entity_bb[obj_id][8] = image.objects[obj_id].width * image.objects[obj_id].height 382 | entity_bb[obj_id][9] = -1 * image.objects[obj_id].width * image.objects[obj_id].height 383 | entity_bb[:, 4] = np.argsort(entity_bb[:, 4]) 384 | entity_bb[:, 5] = np.argsort(entity_bb[:, 5]) 385 | entity_bb[:, 6] = np.argsort(entity_bb[:, 6]) 386 | entity_bb[:, 7] = np.argsort(entity_bb[:, 7]) 387 | entity_bb[:, 8] = np.argsort(entity_bb[:, 8]) 388 | entity_bb[:, 9] = np.argsort(entity_bb[:, 9]) 389 | entity_bb[:, 10] = np.argsort(np.max(image.objects_outputs_with_no_activations, axis=1)) 390 | entity_bb[:, 11] = np.argsort(-1 * np.max(image.objects_outputs_with_no_activations, axis=1)) 391 | entity_bb[:, 12] = np.arange(entity_bb.shape[0]) 392 | entity_bb[:, 13] = np.arange(entity_bb.shape[0], 0, -1) 393 | 394 | # filter non mixed cases 395 | relations_neg_labels = image.predicates_labels[:, :, NOF_PREDICATES - 1:] 396 | if np.sum(image.predicates_labels[:, :, :NOF_PREDICATES - 1]) == 0: 397 | continue 398 | 399 | # give lower weight to negatives 400 | coeff_factor = np.ones(relations_neg_labels.shape) 401 | factor = float(np.sum(image.predicates_labels[:, :, :NOF_PREDICATES - 2])) / np.sum( 402 | relations_neg_labels) / pred_pos_neg_ratio 403 | coeff_factor[relations_neg_labels == 1] *= factor 404 | coeff_factor[indices, indices] = 0 405 | coeff_factor[relations_neg_labels == 1] = 0 406 | 407 | if including_object: 408 | in_entity_confidence = image.objects_outputs_with_no_activations 409 | else: 410 | in_entity_confidence = image.objects_labels * 1000 411 | 412 | # create the feed dictionary 413 | feed_dict = {confidence_relation_ph: image.predicates_outputs_with_no_activation, 414 | confidence_entity_ph: in_entity_confidence, 415 | module.entity_bb_ph: entity_bb, 416 | module.word_embed_entities_ph: embed_obj, 417 | module.phase_ph: False, 418 | module.word_embed_relations_ph: embed_pred, 419 | labels_relation_ph: image.predicates_labels, 420 | labels_entity_ph: image.objects_labels, 421 | labels_coeff_loss_ph: coeff_factor.reshape((-1))} 422 | 423 | # run the network 424 | out_relation_probes_val, out_entity_probes_val, loss_val = sess.run( 425 | [out_relation_probes, out_entity_probes, loss], 426 | feed_dict=feed_dict) 427 | 428 | # set diagonal to be neg (in order not to take into account in statistics) 429 | out_relation_probes_val[indices, indices, :] = relation_neg 430 | 431 | # statistic 432 | total_test_loss += loss_val 433 | 434 | # statistics 435 | results = test(image.predicates_labels, image.objects_labels, 436 | out_relation_probes_val, out_entity_probes_val) 437 | 438 | # accumulate results 439 | if accum_test_results is None: 440 | accum_test_results = results 441 | else: 442 | for key in results: 443 | accum_test_results[key] += results[key] 444 | 445 | 446 | # print stat 447 | obj_accuracy = float(accum_test_results['entity_correct']) / accum_test_results['entity_total'] 448 | predicate_pos_accuracy = float(accum_test_results['relations_pos_correct']) / accum_test_results[ 449 | 'relations_pos_total'] 450 | predicate_all_accuracy = float(accum_test_results['relations_correct']) / accum_test_results[ 451 | 'relations_total'] 452 | relationships_pos_accuracy = float(accum_test_results['relationships_pos_correct']) / \ 453 | accum_test_results[ 454 | 'relations_pos_total'] 455 | relationships_all_accuracy = float(accum_test_results['relationships_correct']) / accum_test_results[ 456 | 'relations_total'] 457 | 458 | logger.log("VALIDATION - loss %f - obj %f - pred %f - rela %f - all_pred %f - all rela %f" % 459 | (total_test_loss, obj_accuracy, predicate_pos_accuracy, relationships_pos_accuracy, 460 | predicate_all_accuracy, relationships_all_accuracy)) 461 | 462 | # save best module so far 463 | if best_test_loss == -1 or total_test_loss < best_test_loss: 464 | module_path_save = os.path.join(module_path, name + "_best_module.ckpt") 465 | save_path = saver.save(sess, module_path_save) 466 | logger.log("Model saved in file: %s" % save_path) 467 | best_test_loss = total_test_loss 468 | 469 | # save module 470 | if epoch % SAVE_MODEL_ITERATIONS == 0: 471 | module_path_save = os.path.join(module_path, name + "_module.ckpt") 472 | save_path = saver.save(sess, module_path_save) 473 | logger.log("Model saved in file: %s" % save_path) 474 | 475 | # learning rate decay 476 | if (epoch % learning_rate_steps) == 0: 477 | lr *= learning_rate_decay 478 | 479 | 480 | if __name__ == "__main__": 481 | filemanager = FilesManager() 482 | 483 | params = filemanager.load_file("sg_module.train.params") 484 | 485 | name = params["name"] 486 | learning_rate = params["learning_rate"] 487 | learning_rate_steps = params["learning_rate_steps"] 488 | learning_rate_decay = params["learning_rate_decay"] 489 | nof_iterations = params["nof_iterations"] 490 | load_model_name = params["load_model_name"] 491 | use_saved_model = params["use_saved_model"] 492 | batch_size = params["batch_size"] 493 | predicate_pos_neg_ratio = params["predicate_pos_neg_ratio"] 494 | lr_object_coeff = params["lr_object_coeff"] 495 | layers = params["layers"] 496 | gpu = params["gpu"] 497 | 498 | train(name, nof_iterations, learning_rate, learning_rate_steps, learning_rate_decay, load_model_name, 499 | use_saved_model, batch_size, predicate_pos_neg_ratio, lr_object_coeff, layers, 500 | gpu) 501 | -------------------------------------------------------------------------------- /Module/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shikorab/SceneGraph/d2dc960a0d3728e7d2b0eaceb9d0fdaedb2a47d5/Module/__init__.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mapping Images to Scene Graphs with Permutation-Invariant Structured Prediction 2 | #### [Roei Herzig*](https://roeiherz.github.io/) , Moshiko Raboh*, [Gal Chechik](https://chechiklab.biu.ac.il/~gal/), [Jonathan Berant](http://www.cs.tau.ac.il/~joberant/), [Amir Globerson](http://www.cs.tau.ac.il/~gamir/) 3 | 4 | 5 | ## Introduction 6 | Scene graph prediction is the task of mapping an image into a set of bounding boxes, along with their categories and relations (e.g., see [2, 3, 4, 5, 6]). 7 | 8 | 9 | In the paper [Scene Graphs with Permutation-Invariant Structured Prediction](https://arxiv.org/abs/1802.05451) (2018) [1] we present a new architecture for graph inference that has the **following structural property**: 10 | on the one hand, the architecture is invariant to input permutations; 11 | on the other hand, every permutation-invariant function can be implemented via this architecture. 12 | 13 | In this repository, we share our architecture implementation for the task of scene graph prediction. 14 | 15 | ## Model implementation 16 | **Scene Graph Predictor (SGP)** gets as an input inital confidience distributions per entity and relation and processes these to obtain new labels. SGP satisfies the graph permutation invariance property intoduced in the paper. 17 | The model is implemented in [TensorFlow](https://www.tensorflow.org/). 18 | For the initial confidence distributions per entity and relation, we simply re-use features learned by the baseline model from Zellers et al. (2017). (git repository https://github.com/rowanz/neural-motifs) 19 | 20 | 21 | 22 | ## SGP architecture 23 | Our SGP implementation is using an iteratively RNN to process predictions. Each step outputs an improved predictions. 24 | 25 | 26 | 27 | A schematic representation of the architecture. Given an image, a Label predictor outputs initial predictions ![equation](http://latex.codecogs.com/gif.latex?Z_%7Bi%7D%2C%20Z_%7Bi%2Cj%7D). Then, our SGP model, computes each ![equation](http://latex.codecogs.com/gif.latex?%5Cphi_%7Bi%2Cj%7D) element wise. Next, they are summed to create vector ![equation](http://latex.codecogs.com/gif.latex?S_%7Bi%7D), which is concatenated with ![equation](http://latex.codecogs.com/gif.latex?Z_%7Bi%7D). Then, ![equation](http://latex.codecogs.com/gif.latex?%5Calpha) is applied, and another summation creates the graph representation. Finally, ![equation](http://latex.codecogs.com/gif.latex?%5Crho_%7Bentity%7D) classifies objects and ![equation](http://latex.codecogs.com/gif.latex?%5Crho_%7Brelation%7D) classifies relation. The process of SGP could be repeated iteratively (in the paper we repeat it 3 times). 28 | 29 | For more information, please look at the code (Module/Module.py file) and the paper. 30 | 31 | 32 | ## Attention with SGP architecture 33 | Our SGP architecture uses attention at the feature-level for each node during inference. 34 | We weight the significance of each feature per node, such that the network can choose which features from adjacent nodes contributes the most information. 35 | 36 | 37 | 38 | An example of attention per entities and global attention over all nodes. The size and location of objects provide a key signal to the attention mechanism. The model assigns higher confidence for the label "tie" when the label "shirt" is detected (third panel from the left). Similarly, the model assigns a higher confidence for the label "eye" when it is located near "hair". 39 | 40 | ## Dependencies 41 | To get started with the framework, install the following dependencies: 42 | - [Python 2.7](https://www.python.org/) 43 | - [tensorflow-gpu 1.0.1](https://www.tensorflow.org/) 44 | - [matplotlib 2.0.2](http://matplotlib.org/) 45 | - [h5py 2.7.0](http://www.h5py.org/) 46 | - [numpy 1.12.1](http://www.numpy.org/) 47 | - [pyyaml 3.12](https://pypi.python.org/pypi/PyYAML) 48 | 49 | Run `"pip install -r requirements.txt"` - to install all the requirements. 50 | 51 | 52 | ## Usage 53 | 1. Run `"python Run.py download"` to download and extract train, validation and test data. The data already contains the result of applying the baseline detecor over the VisualGenome data. 54 | 2. Run `"python Run.py eval gpi_linguistic_pretrained "` to evaluate the pre-trained model of our best variant, linguistic with multi-head attention. (recall@100 SG Classification). 55 | 3. Run `"python Run.py train gpi_linguistic "` to train a new model (linguistic with multi-head attention). 56 | 4. Run `"python Run.py eval gpi_linguistic_best "` to evaluate the new model. (recall@100 SG Classification). 57 | 58 | 59 | ## About this repository 60 | This repository contains an implementation of our best variant (Linguistic with multi-head attention) of the Scene Graph Prediction (SGP) model introduced in the paper [Scene Graphs with Permutation-Invariant Structured Prediction](https://arxiv.org/abs/1802.05451). 61 | (The repsitory updated for version 1 of the paper - the results of latest version will be published ×in the future). 62 | Specifically, the repository allow to run scene-graph classification (recall@100) evaluation script on our pre-trained model or alternatively (1) train an SGP model (2) evaluate the trained model using scene-graph classification (recall@100) evaluation script. 63 | 64 | 65 | ## References 66 | [1] Roei Herzig, Moshiko Raboh, Gal Chechik, Jonathan Berant, Amir Globerson, [Mapping Images to Scene Graphs with Permutation-Invariant Structured Prediction](https://arxiv.org/abs/1802.05451), 2018. 67 | 68 | [2] Justin Johnson, Ranjay Krishna, Michael Stark, Li Jia Li, David A. Shamma, Michael S. Bernstein, Fei Fei Li, [Image Retrieval using Scene Graphs](http://hci.stanford.edu/publications/2015/scenegraphs/JohnsonCVPR2015.pdf), CVPR, 2015. 69 | 70 | [3] Cewu Lu, Ranjay Krishna, Michael S. Bernstein, Fei Fei Li, [Visual Relationship Detection with Language Priors](https://cs.stanford.edu/people/ranjaykrishna/vrd/vrd.pdf), ECCV, 2016. 71 | 72 | [4] Xu, Danfei and Zhu, Yuke and Choy, Christopher and Fei-Fei, Li, [Scene Graph Generation by Iterative Message Passing](https://arxiv.org/pdf/1701.02426.pdf), CVPR, 2017. 73 | 74 | [5] Alejandro Newell and Jia Deng, [Pixels to Graphs by Associative Embedding](https://papers.nips.cc/paper/6812-pixels-to-graphs-by-associative-embedding.pdf), NIPS, 2017. 75 | 76 | [6] Rowan Zellers, Mark Yatskar, Sam Thomson, Yejin Choi, [Neural Motifs: Scene Graph Parsing with Global Context](https://arxiv.org/pdf/1711.06640.pdf), CVPR, 2018. 77 | 78 | ## Cite 79 | Please cite our paper if you use this code in your own work: 80 | ``` 81 | @inproceedings{hr18perminvimg2sg, 82 | author = {Roei Herzig and 83 | Moshiko Raboh and 84 | Gal Chechik and 85 | Jonathan Berant and 86 | Amir Globerson}, 87 | title = {Mapping Images to Scene Graphs with Permutation-Invariant Structured 88 | Prediction}, 89 | booktitle = {Advances in Neural Information Processing Systems (NIPS)}, 90 | year = {2018} 91 | } 92 | ``` 93 | -------------------------------------------------------------------------------- /Run.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from Module.Train import train 3 | from Module.Eval import eval 4 | import zipfile 5 | import os 6 | from FilesManager.FilesManager import FilesManager 7 | from Utils.Logger import Logger 8 | import urllib 9 | 10 | if __name__ == "__main__": 11 | version_filename_flag = '.data_ver2' 12 | 13 | # create logger 14 | logger = Logger() 15 | 16 | application = None 17 | name = None 18 | gpu = None 19 | 20 | ## get input parameters 21 | # application 22 | if len(sys.argv) > 1: 23 | application = sys.argv[1] 24 | 25 | # module name 26 | if len(sys.argv) > 2: 27 | name = sys.argv[2] 28 | 29 | # gpu number 30 | if len(sys.argv) > 3: 31 | gpu = sys.argv[3] 32 | 33 | if application == "train": 34 | # check if requried data version downloaded 35 | if not os.path.isfile(version_filename_flag): 36 | print("Error: Data wasn't downloaded. Type python Run.py for instructions how to download\n\n") 37 | exit() 38 | logger.log("Command: Train(module_name=%s, gpu=%s" % (name, str(gpu))) 39 | train(name=name, gpu=gpu) 40 | 41 | elif application == "eval": 42 | # check if requried data version downloaded 43 | if not os.path.isfile(version_filename_flag): 44 | print("Error: Data wasn't downloaded. Type python Run.py for instructions how to download\n\n") 45 | exit() 46 | 47 | logger.log("Command: Eval(module_name=%s, gpu=%s" % (name, str(gpu))) 48 | eval(load_module_name=name, gpu=gpu) 49 | 50 | elif application == "download": 51 | logger.log("Command: Download()") 52 | 53 | filesmanager = FilesManager() 54 | path = filesmanager.get_file_path("data.visual_genome.data") 55 | file_name = os.path.join(path, "data.zip") 56 | 57 | # Download Data 58 | logger.log("Download Data ...") 59 | url = "http://www.cs.tau.ac.il/~taunlp/scene_graph/data.zip" 60 | urllib.urlretrieve(url, file_name) 61 | 62 | # Extract data 63 | logger.log("Extract ZIP file ...") 64 | zip_ref = zipfile.ZipFile(file_name, 'r') 65 | zip_ref.extractall(path) 66 | zip_ref.close() 67 | 68 | # mark data version downloaded 69 | open(version_filename_flag, "wb").close() 70 | 71 | else: 72 | # print usage 73 | print("Error: unexpected usage\n\n") 74 | print("SGP Runner") 75 | print("----------") 76 | print("Download data: \"python Run.py download\"") 77 | print(" Should be run just once, on the first time the module used") 78 | print("Train Module: \"python Run.py train \"") 79 | print(" Train lingustic SGP") 80 | print(" Module weights with the highest score over the validation set will be saved as \"_best\"") 81 | print(" Module weights of the last epoch will be saved as \"\"") 82 | print("Eval Module: \"python Run.py eval \"") 83 | print(" Scene graph classification (recall@100) evaluation for the trained module.") 84 | print(" Use 'gpi_ling_orig_best' for a pre-trained module") 85 | print(" Use \"_best\" for a self-trained module") 86 | -------------------------------------------------------------------------------- /Utils/Logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import sys 4 | from FilesManager.FilesManager import FilesManager 5 | import time 6 | sys.path.append("..") 7 | from Utils.Singleton import Singleton 8 | 9 | 10 | class Logger(object): 11 | """ 12 | Logger class 13 | """ 14 | __metaclass__ = Singleton 15 | 16 | def __init__(self, name="", path=None): 17 | """ 18 | Creating logger, the name of the logger will be printed at each line and log will be saved in path 19 | :param name: name of the logger 20 | :param path: path to log file 21 | """ 22 | self.name = name 23 | if self.name != "": 24 | self.prefix = self.name + ": " 25 | else: 26 | self.prefix = "" 27 | 28 | self.path = path 29 | 30 | # create dir 31 | if self.path is not None: 32 | if not os.path.exists(self.path): 33 | os.makedirs(self.path) 34 | 35 | # create file to log 36 | self.log_file = open(self.path + "/logger.log", "w") 37 | else: 38 | self.log_file = None 39 | 40 | # add logger to file manager 41 | filesmanager = FilesManager() 42 | filesmanager.add_logger(self) 43 | 44 | # default log file 45 | if self.log_file is None: 46 | self.path = filesmanager.get_file_path("logs") 47 | if not os.path.exists(self.path): 48 | os.makedirs(self.path) 49 | self.log_file = open(self.path + "/logger-%s.log" % time.strftime("%c"), "w") 50 | self.log("Start %s" % time.strftime("%c")) 51 | 52 | def log(self, str): 53 | """ 54 | This function write str to the logger 55 | :param str: a string to be written 56 | """ 57 | if self.log_file is not None: 58 | self.log_file.write(self.name + ": " + str + "\n") 59 | self.log_file.flush() 60 | print(self.prefix + str) 61 | 62 | def get_dir(self): 63 | """ 64 | Get the path of the dir 65 | :return: path of the dir 66 | """ 67 | return self.path 68 | 69 | @classmethod 70 | def get_logger(cls): 71 | """ 72 | This function returns the logger 73 | :return: return the cls logger 74 | """ 75 | return cls 76 | -------------------------------------------------------------------------------- /Utils/Singleton.py: -------------------------------------------------------------------------------- 1 | class Singleton(type): 2 | """ 3 | Singleton metaclass 4 | """ 5 | 6 | _instances = {} 7 | 8 | def __call__(cls, *args, **kwargs): 9 | if cls not in cls._instances: 10 | cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) 11 | return cls._instances[cls] 12 | 13 | 14 | class Visitor(object): 15 | """ 16 | This class is a base class for the visitor design pattern, 17 | a subclass can implement visit methods for specific types 18 | for example, given two classes A and B, a visitor need to create two functions 19 | visit_A and visit_B and dynamically the visit method will route to the correct visit method 20 | given a subject of type A or B 21 | """ 22 | 23 | def visit(self, subject): 24 | method_name = 'visit_' + type(subject).__name__ 25 | method = getattr(self, method_name) 26 | if method is None: 27 | method = self.generic_visit 28 | return method(subject) 29 | 30 | def generic_visit(self, subject): 31 | print('Class {0} does not have a visit function in {1}'.format(type(subject).__name__, 32 | type(self).__name__)) -------------------------------------------------------------------------------- /Utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shikorab/SceneGraph/d2dc960a0d3728e7d2b0eaceb9d0fdaedb2a47d5/Utils/__init__.py -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shikorab/SceneGraph/d2dc960a0d3728e7d2b0eaceb9d0fdaedb2a47d5/__init__.py -------------------------------------------------------------------------------- /qualitive_results_att_boxes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shikorab/SceneGraph/d2dc960a0d3728e7d2b0eaceb9d0fdaedb2a47d5/qualitive_results_att_boxes.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow_gpu==1.0.1 2 | matplotlib==2.0.2 3 | h5py==2.7.0 4 | numpy==1.13.1 5 | PyYAML==3.12 6 | -------------------------------------------------------------------------------- /sg_example_final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shikorab/SceneGraph/d2dc960a0d3728e7d2b0eaceb9d0fdaedb2a47d5/sg_example_final.png -------------------------------------------------------------------------------- /sgp_arch_git.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shikorab/SceneGraph/d2dc960a0d3728e7d2b0eaceb9d0fdaedb2a47d5/sgp_arch_git.png --------------------------------------------------------------------------------