├── model ├── __init__.py ├── prior │ ├── __init__.py │ ├── object_prior.py │ ├── object_prior_example.py │ ├── relation_prior_example.py │ └── relation_prior.py ├── keyframe │ ├── __init__.py │ ├── keyframe_extracion.py │ └── keyframe_extracion_example.py ├── settings.py ├── SGGenModel.py ├── interpret.py └── vis_tuning.py ├── options ├── data ├── svg └── visual_genome ├── vis_result └── fig │ ├── result.png │ └── 3dsg_readme_fig.png ├── requirements.txt ├── .gitmodules ├── etc └── jpg2avi.py ├── .gitignore ├── object_prior_extraction.py ├── README.md ├── visualize_FactorizableNet.py ├── scene_graph_tuning.py └── relation_prior_extraction.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/prior/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/keyframe/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /options: -------------------------------------------------------------------------------- 1 | ./FactorizableNet/options/ -------------------------------------------------------------------------------- /data/svg: -------------------------------------------------------------------------------- 1 | ../FactorizableNet/data/svg/ -------------------------------------------------------------------------------- /data/visual_genome: -------------------------------------------------------------------------------- 1 | ../FactorizableNet/data/visual_genome/ -------------------------------------------------------------------------------- /vis_result/fig/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Uehwan/3-D-Scene-Graph/HEAD/vis_result/fig/result.png -------------------------------------------------------------------------------- /vis_result/fig/3dsg_readme_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Uehwan/3-D-Scene-Graph/HEAD/vis_result/fig/3dsg_readme_fig.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchtext==0.2.3 2 | setuptools 3 | git+https://github.com/chickenbestlover/ColorHistogram.git 4 | pyyaml 5 | graphviz 6 | webcolors 7 | pandas 8 | matplotlib 9 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "sort"] 2 | path = sort 3 | url = https://github.com/abewley/sort.git 4 | [submodule "FactorizableNet"] 5 | path = FactorizableNet 6 | url = https://github.com/yikang-li/FactorizableNet 7 | -------------------------------------------------------------------------------- /etc/jpg2avi.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import cv2 3 | import os 4 | from cv2 import VideoWriter, VideoWriter_fourcc 5 | fps=30 6 | format = "MJPG" 7 | fourcc = VideoWriter_fourcc(*format) 8 | RESULT_PWD = '/media/mil2/HDD/mil2/scannet/ScanNet/SensReader/python/exported/object_detection' 9 | if not os.path.exists(RESULT_PWD): 10 | os.mkdir(RESULT_PWD) 11 | vid = VideoWriter(osp.join(RESULT_PWD,'object_detection_result4.avi'), fourcc, float(fps), (1296,968), True) 12 | 13 | imageFileList = sorted(os.listdir(RESULT_PWD)) 14 | for idx in range(len(imageFileList)-1): 15 | print(idx) 16 | image = cv2.imread(osp.join(RESULT_PWD, str(idx) + '.jpg')) 17 | #vid.write(image.astype('uint8')) 18 | #vid.write(image) 19 | 20 | cv2.imshow('frame',image) 21 | if cv2.waitKey(50) & 0xFF == ord('q'): 22 | break 23 | vid.release() 24 | cv2.destroyAllWindows() -------------------------------------------------------------------------------- /model/prior/object_prior.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | 4 | 5 | def load_obj(filename): 6 | with open(filename + '.pkl', 'rb') as f: 7 | return pickle.load(f) 8 | 9 | 10 | # p(x, y) 11 | def cal_p_xy_joint(x_class, y_class, key2ind_pair, joint_prob): 12 | p_xy = joint_probability[key2ind_pair[x_class], key2ind_pair[y_class]] / np.sum(joint_probability) 13 | return p_xy 14 | 15 | 16 | # p(x|y) 17 | def cal_p_x_given_y(x_class, y_class, key2ind_pair, joint_prob): 18 | single_prob = np.sum(joint_probability, axis=1) 19 | p_y = single_prob[key2ind_pair[y_class]] 20 | p_xy = joint_probability[key2ind_pair[x_class], key2ind_pair[y_class]] 21 | return p_xy / p_y 22 | 23 | 24 | # p(x|y,z) approximated 25 | def cal_p_x_given_xy(x_class, y_class, z_class, key2ind_pair, joint_prob): 26 | p_x_given_y = cal_p_x_given_y(x_class, y_class, key2ind_pair, joint_prob) 27 | p_x_given_z = cal_p_x_given_y(x_class, z_class, key2ind_pair, joint_prob) 28 | return min(p_x_given_y, p_x_given_z) 29 | 30 | if __name__=="__main__": 31 | 32 | key2ind = load_obj("object_prior_key2ind") 33 | joint_probability = load_obj("object_prior_prob") 34 | print(cal_p_x_given_y('floor', 'rock', key2ind, joint_probability)) 35 | 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | demo.py 10 | *.o 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | env/ 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *,cover 50 | .hypothesis/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv/ 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | 93 | # Rope project settings 94 | .ropeproject 95 | 96 | .idea 97 | extension-ffi 98 | demo_mot.py 99 | 100 | 101 | # defined by yikang 102 | data/ 103 | data 104 | output 105 | output/ 106 | log/ 107 | *.mat 108 | Debug_Code.ipynb 109 | eval/*.json 110 | 111 | pretrained_models/ 112 | vis_result/ 113 | *.pdf 114 | *.pkl 115 | .vector_cache 116 | data 117 | *.json 118 | -------------------------------------------------------------------------------- /model/prior/object_prior_example.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | 4 | 5 | def load_obj(filename): 6 | with open(filename + '.pkl', 'rb') as f: 7 | return pickle.load(f) 8 | 9 | 10 | # p(x, y) 11 | def cal_p_xy_joint(x_class, y_class, key2ind_pair, joint_prob): 12 | p_xy = joint_probability[key2ind_pair[x_class], key2ind_pair[y_class]] / np.sum(joint_probability) 13 | return p_xy 14 | 15 | 16 | # p(x|y) 17 | def cal_p_x_given_y(x_class, y_class, key2ind_pair, joint_prob): 18 | single_prob = np.sum(joint_probability, axis=1) 19 | p_y = single_prob[key2ind_pair[y_class]] 20 | p_xy = joint_probability[key2ind_pair[x_class], key2ind_pair[y_class]] 21 | return p_xy / p_y 22 | 23 | 24 | # p(x|y,z) approximated 25 | def cal_p_x_given_yz(x_class, y_class, z_class, key2ind_pair, joint_prob): 26 | p_x_given_y = cal_p_x_given_y(x_class, y_class, key2ind_pair, joint_prob) 27 | p_x_given_z = cal_p_x_given_y(x_class, z_class, key2ind_pair, joint_prob) 28 | return min(p_x_given_y, p_x_given_z) 29 | 30 | 31 | key2ind = load_obj("object_prior_key2ind") 32 | joint_probability = load_obj("object_prior_prob") 33 | 34 | candidates = ['bed','apple','ball','keyboard','table','desk','rug','pillow'] 35 | filtered_candidates = [] 36 | for i,candidate in enumerate(candidates): 37 | ''' condition check''' 38 | pillow_and_floor = cal_p_xy_joint('pillow','floor',key2ind,joint_probability) 39 | print('p(x=pillow,z=floor) = ',pillow_and_floor) 40 | pillow_and_bed = cal_p_xy_joint('pillow',candidate,key2ind,joint_probability) 41 | print('p(x=pillow,y={c}) = {p}'.format(c=candidate,p=pillow_and_bed)) 42 | bed_and_floor = cal_p_xy_joint(candidate,'floor',key2ind,joint_probability) 43 | print('p(y={c},z=floor) = {p}'.format(c=candidate,p=bed_and_floor)) 44 | print('p(x=pillow,z=floor)^2 = {p}'.format(p=pillow_and_floor**2)) 45 | print('p(x=pillow,y={c})*p(y={c},z=floor) = {p}'.format(c=candidate,p=pillow_and_bed*pillow_and_floor)) 46 | condition_check = pillow_and_floor**2 max_prob: 38 | max_prob = curr_total_prob 39 | max_pred = item 40 | return max_pred, max_prob / float(sum_count) 41 | 42 | 43 | def most_probable_relation_for_unpaired(list_of_two_objs, relation_stat, dist): 44 | """ 45 | returns 46 | 1. most probable relation 47 | 2. corresponding probability 48 | 3. direction (subject, object) 49 | """ 50 | candidate1, prob1 = most_probable_relation_for_paired((list_of_two_objs[1], list_of_two_objs[0]), relation_stat, dist) 51 | candidate2, prob2 = most_probable_relation_for_paired((list_of_two_objs[0], list_of_two_objs[1]), relation_stat, dist) 52 | if prob1 > prob2: 53 | return candidate1, prob1, (list_of_two_objs[1], list_of_two_objs[0]) 54 | else: 55 | return candidate2, prob2, (list_of_two_objs[0], list_of_two_objs[1]) 56 | 57 | def triplet_prob_from_statistics(triplet, relation_stat, dist): 58 | """ 59 | args 60 | 1. triplet: (sbj,pred,obj). ex ('desk','on','floor') 61 | 2. relation_stat: statistics dataset 62 | 3. dist: pixel_wise distance (scalar). ex 50. 63 | returns 64 | 2. corresponding probability 65 | """ 66 | pair_of_objs, predicate = (triplet[0], triplet[2]), triplet[1] 67 | max_prob = -1 68 | max_pred = "" 69 | sum_count = 0 70 | if pair_of_objs in relation_stat and predicate in relation_stat[pair_of_objs]: 71 | keys =relation_stat[pair_of_objs].keys() 72 | dist_probs = [] 73 | total_probs = {} 74 | for key in keys: 75 | triplet_stat = relation_stat[pair_of_objs][key] 76 | curr_count = triplet_stat['count'] 77 | curr_mean = triplet_stat['mean'] 78 | curr_var = triplet_stat['var'] 79 | curr_dist_prob = cal_normal_prob(curr_mean, curr_var, dist) 80 | dist_probs.append(curr_dist_prob) 81 | total_probs[key]=curr_count * curr_dist_prob 82 | sum_count += curr_count 83 | 84 | 85 | return total_probs[predicate]/float(sum_count) 86 | else: 87 | return None, 0.0 88 | 89 | 90 | if __name__=="__main__": 91 | relation_statistics = load_obj("relation_prior_prob") 92 | print(most_probable_relation_for_unpaired(['floor', 'rock'], relation_statistics, 50)) 93 | 94 | -------------------------------------------------------------------------------- /model/prior/relation_prior.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | 4 | 5 | def load_obj(filename): 6 | with open(filename + '.pkl', 'rb') as f: 7 | return pickle.load(f) 8 | 9 | 10 | def cal_normal_prob(mean, var, val): 11 | prob = math.exp(-0.5*((val - mean)/math.sqrt(var+0.0000000001))**2) 12 | return prob 13 | 14 | 15 | def most_probable_relation_for_paired(pair_of_objs, relation_stat, dist): 16 | """ 17 | returns 18 | 1. most probable relation 19 | 2. corresponding probability 20 | """ 21 | if not pair_of_objs in relation_stat: 22 | return None, 0.0 23 | 24 | one_case = relation_stat[pair_of_objs] 25 | items = list(one_case) 26 | max_prob = -1 27 | max_pred = "" 28 | sum_count = 0 29 | 30 | for item in items: 31 | curr_count = one_case[item]['count'] 32 | curr_mean = one_case[item]['mean'] 33 | curr_var = one_case[item]['var'] 34 | curr_dist_prob = cal_normal_prob(curr_mean, curr_var, dist) 35 | curr_total_prob = curr_count * curr_dist_prob 36 | sum_count += curr_count 37 | if curr_total_prob > max_prob: 38 | max_prob = curr_total_prob 39 | max_pred = item 40 | return max_pred, max_prob / float(sum_count) 41 | 42 | 43 | def most_probable_relation_for_unpaired(list_of_two_objs, relation_stat, dist): 44 | """ 45 | returns 46 | 1. most probable relation 47 | 2. corresponding probability 48 | 3. direction (subject, object) 49 | """ 50 | candidate1, prob1 = most_probable_relation_for_paired((list_of_two_objs[1], list_of_two_objs[0]), relation_stat, dist) 51 | candidate2, prob2 = most_probable_relation_for_paired((list_of_two_objs[0], list_of_two_objs[1]), relation_stat, dist) 52 | 53 | if prob1 > prob2: 54 | return candidate1, prob1, (list_of_two_objs[1], list_of_two_objs[0]) 55 | else: 56 | return candidate2, prob2, (list_of_two_objs[0], list_of_two_objs[1]) 57 | 58 | def most_probable_relation_for_unpaired2(list_of_two_objs, relation_stat, dist): 59 | """ 60 | returns 61 | 1. most probable relation 62 | 2. corresponding probability 63 | 3. direction (subject, object) 64 | """ 65 | candidate1, prob1 = most_probable_relation_for_paired((list_of_two_objs[1], list_of_two_objs[0]), relation_stat, dist) 66 | candidate2, prob2 = most_probable_relation_for_paired((list_of_two_objs[0], list_of_two_objs[1]), relation_stat, dist) 67 | 68 | if prob1 > prob2: 69 | return candidate1, prob1, prob1 > prob2 70 | else: 71 | return candidate2, prob2, prob1 > prob2 72 | 73 | def triplet_prob_from_statistics(triplet, relation_stat, dist): 74 | """ 75 | args 76 | 1. triplet: (sbj,pred,obj). ex ('desk','on','floor') 77 | 2. relation_stat: statistics dataset 78 | 3. dist: pixel_wise distance (scalar). ex 50. 79 | returns 80 | 2. corresponding probability 81 | """ 82 | pair_of_objs, predicate = (triplet[0], triplet[2]), triplet[1] 83 | max_prob = -1 84 | max_pred = "" 85 | sum_count = 0 86 | if pair_of_objs in relation_stat and predicate in relation_stat[pair_of_objs]: 87 | keys =relation_stat[pair_of_objs].keys() 88 | dist_probs = [] 89 | total_probs = {} 90 | for key in keys: 91 | triplet_stat = relation_stat[pair_of_objs][key] 92 | curr_count = triplet_stat['count'] 93 | curr_mean = triplet_stat['mean'] 94 | curr_var = triplet_stat['var'] 95 | curr_dist_prob = cal_normal_prob(curr_mean, curr_var, dist) 96 | dist_probs.append(curr_dist_prob) 97 | total_probs[key]=curr_count * curr_dist_prob 98 | sum_count += curr_count 99 | return total_probs[predicate]/float(sum_count) 100 | else: 101 | return 0.0 102 | 103 | 104 | if __name__=="__main__": 105 | relation_statistics = load_obj("relation_prior_prob") 106 | print(most_probable_relation_for_unpaired(['floor', 'rock'], relation_statistics, 50)) 107 | 108 | -------------------------------------------------------------------------------- /object_prior_extraction.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import pickle 4 | import copy 5 | import os 6 | import os.path as osp 7 | 8 | 9 | def save_obj(obj, filename): 10 | with open(filename + '.pkl', 'wb') as f: 11 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 12 | 13 | 14 | def load_obj(filename): 15 | with open(filename + '.pkl', 'rb') as f: 16 | return pickle.load(f) 17 | 18 | 19 | relevant_classes = [u'__background__',u'field', u'zebra', u'sky', u'track', u'train', u'window', u'pole', u'windshield', u'background', u'tree', u'door', u'sheep', u'paint', u'grass', u'baby', u'ear', u'leg', u'eye', u'tail', u'head', u'nose', u'skateboarder', u'arm', u'foot', u'skateboard', u'wheel', u'hand', u'ramp', u'man', u'jeans', u'shirt', u'sneaker', u'writing', u'hydrant', u'cap', u'chain', u'sidewalk', u'curb', u'road', u'line', u'bush', u'sign', u'people', u'car', u'edge', u'bus', u'tire', u'lady', u'letter', u'leaf', u'boy', u'pocket', u'backpack', u'bottle', u'suitcase', u'word', u'ground', u'handle', u'strap', u'jacket', u'motorcycle', u'bicycle', u'truck', u'cloud', u'kite', u'pants', u'beach', u'woman', u'rock', u'dress', u'dog', u'building', u'frisbee', u'shoe', u'plant', u'pot', u'hair', u'face', u'shorts', u'stripe', u'bench', u'flower', u'cat', u'post', u'container', u'house', u'ceiling', u'seat', u'back', u'graffiti', u'paper', u'hat', u'tennisracket', u'tennisplayer', u'wall', u'logo', u'girl', u'clock', u'brick', u'white', u'elephant', u'mirror', u'bird', u'glove', u'oven', u'area', u'sticker', u'flag', u'surfboard', u'wetsuit', u'shadow', u'sleeve', u'tenniscourt', u'surface', u'finger', u'string', u'plane', u'wing', u'umbrella', u'snow', u'sunglasses', u'boot', u'coat', u'skipole', u'ski', u'skier', u'black', u'player', u'sock', u'racket', u'wrist', u'band', u'ball', u'light', u'shelf', u'stand', u'vase', u'horse', u'number', u'rug', u'goggles', u'snowboard', u'computer', u'screen', u'button', u'glass', u'bracelet', u'cellphone', u'mountain', u'phone', u'hill', u'fence', u'stone', u'cow', u'tag', u'bear', u'table', u'water', u'ocean', u'trashcan', u'circle', u'river', u'railing', u'design', u'bowl', u'food', u'spoon', u'tablecloth', u'plate', u'bread', u'tomato', u'kid', u'sand', u'dirt', u'mouth', u'hole', u'air', u'distance', u'board', u'feet', u'suit', u'wave', u'guy', u'reflection', u'bathroom', u'toilet', u'sink', u'faucet', u'floor', u'toiletpaper', u'towel', u'sandwich', u'knife', u'bolt', u'boat', u'engine', u'trafficlight', u'wine', u'cup', u'stem', u'base', u'top', u'bottom', u'sofa', u'counter', u'photo', u'frame', u'side', u'paw', u'branch', u'fur', u'forest', u'wire', u'headlight', u'rail', u'front', u'green', u'helmet', u'whiskers', u'pen', u'neck', u'net', u'necklace', u'duck', u'sweater', u'chair', u'horn', u'giraffe', u'spot', u'mane', u'airplane', u'beard', u'speaker', u'sun', u'shore', u'pillar', u'tower', u'jet', u'gravel', u'sauce', u'fork', u'tray', u'awning', u'tent', u'bun', u'teeth', u'camera', u'tile', u'lid', u'kitchen', u'curtain', u'drawer', u'knob', u'box', u'outlet', u'remote', u'couch', u'tie', u'book', u'ring', u'toothbrush', u'balcony', u'stairs', u'doorway', u'stopsign', u'bed', u'pillow', u'corner', u'trim', u'vegetable', u'orange', u'broccoli', u'rope', u'streetlight', u'name', u'pitcher', u'uniform', u'body', u'mouse', u'keyboard', u'desk', u'monitor', u'statue', u'collar', u'candle', u'animal', u'tv', u'donut', u'apple', u'child', u'licenseplate', u'catcher', u'umpire', u'banner', u'bat', u'batter', u'part', u'hotdog', u'object', u'cake', u'bridge', u'patch', u'belt', u'park', u'stick', u'bucket', u'runway', u'lamp', u'tip', u'carpet', u'blanket', u'cover', u'napkin', u'theoutdoors', u'stove', u'pizza', u'cheese', u'crust', u'van', u'beak', u'cord', u'poster', u'purse', u'laptop', u'shoulder', u'dish', u'can', u'pipe', u'key', u'arrow', u'surfer', u'controller', u'blinds', u'bluesky', u'whiteclouds', u'luggage', u'vehicle', u'streetsign', u'pan', u'baseball', u'baseballplayer', u'jersey', u'rack', u'cabinet', u'meat', u'watch', u'refrigerator', u'vest', u'skirt', u'hoof', u'label', u'teddybear', u'fridge', u'snowboarder', u'scarf', u'basket', u'cloth', u'shade', u'blue', u'spectator', u'knee', u'column', u'metal', u'steps', u'firehydrant', u'platform', u'jar', u'fruit', u'hood', u't-shirt', u'cone', u'weeds', u'treetrunk', u'room', u'red', u'television', u'scissors', u'gate', u'tennisball', u'court', u'log', u'star', u'lettuce', u'traincar', u'microwave', u'pepperoni', u'onion', u'chimney', u'concrete', u'mug', u'carrot', u'banana', u'cart', u'wood', u'bar', u'ripples', u'holder', u'pepper', u'tusk'] 20 | relevant_classes_ = copy.copy(relevant_classes) 21 | print('Before filtering: ' + str(len(relevant_classes))) 22 | print('After filtering: ' + str(len(relevant_classes))) 23 | 24 | ind2key = {idx: item for idx, item in enumerate(relevant_classes)} 25 | key2ind = {item: idx for idx, item in enumerate(relevant_classes)} 26 | try: 27 | os.makedirs(osp.join('model', 'prior', 'preprocessed')) 28 | except: 29 | pass 30 | file = open('model/prior/raw/relationships.json').read() 31 | 32 | data = json.loads(file) 33 | print("Reading JSON completed!!") 34 | 35 | joint_probability = np.zeros((len(relevant_classes), len(relevant_classes))) 36 | 37 | i = 0 38 | for datum in data: 39 | i += 1 40 | if i % 1000 == 0: 41 | print(str(i) + "th point processing") 42 | 43 | relations = datum['relationships'] 44 | for rel in relations: 45 | obj, sub = rel['object']['name'], rel['subject']['name'] 46 | 47 | if obj in relevant_classes and sub in relevant_classes: 48 | joint_probability[key2ind[obj], key2ind[sub]] += 1 49 | joint_probability[key2ind[sub], key2ind[obj]] += 1 50 | 51 | save_obj(joint_probability, "model/prior/preprocessed/object_prior_prob") 52 | save_obj(ind2key, "model/prior/preprocessed/object_prior_ind2key") 53 | save_obj(key2ind, "model/prior/preprocessed/object_prior_key2ind") 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 3D-Scene-Graph: *A Sparse and Semantic Representation of Physical Environments for Intelligent Agents* 2 | This work is based on our paper (IEEE Transactions on Cybernetics 2019, accepted). We proposed a new concept called 3D scene graph and its construction framework. Our work is based on [FactorizableNet](https://github.com/yikang-li/FactorizableNet), implemented in Pytorch. 3 | 4 | ## 3D Scene Graph Construction Framework 5 | The proposed 3D scene graph construction framework extracts relevant semantics within environments such as object categories and relations between objects as well as physical attributes 6 | such as 3D positions and major colors in the process of generating 3D scene graphs for the given environments. The framework receives a sequence of observations regarding the environments in the form of RGB-D image frames. 7 | For robust performance, the framework filters out unstable observations(i.e., blurry images) 8 | using the proposed adaptive blurry image detection algorithm. 9 | Then, the framework factors out keyframe groups to avoid redundant processing of the same information. 10 | Keyframe groups contain reasonably-overlapping frames. Next, the framework extracts semantics and physical attributes 11 | within the environments through recognition modules. 12 | During the recognition processes, spurious detections get rejected and missing entities are supplemented. 13 | Finally, the gathered information gets fused into 3D scene graph and the graph gets updated upon new observations. 14 | 15 | 16 | 17 | 18 | ## Requirements 19 | * Ubuntu 16.04+ 20 | * Python 2.7 21 | * [Pytorch 0.3.1](https://pytorch.org/get-started/previous-versions/) 22 | * torchtext 0.2.3 23 | * [FactorizableNet](https://github.com/yikang-li/FactorizableNet) 24 | * [ScanNet](http://www.scan-net.org) dataset (optional, used for test. an RGBD video from ScanNet is enough.) 25 | 26 | ## Installation 27 | 28 | Install [Pytorch 0.3.1](https://pytorch.org/get-started/previous-versions/). The code has been tested only with Python 2.7, CUDA 9.0 on Ubuntu 16.04. 29 | You will need to modify a significant amount of code if you want to run in a different environment (Python 3+ or Pytorch 0.4+). 30 | 31 | 32 | 1. Download 3D-Scene-Graph repository 33 | 34 | ``` 35 | git clone --recurse-submodules https://github.com/Uehwan/3D-Scene-Graph.git 36 | ``` 37 | 2. Install FactorizableNet 38 | ``` 39 | cd 3D-Scene-Graph/FactorizableNet 40 | ``` 41 | Please follow the installation instructions in [FactorizableNet](https://github.com/yikang-li/FactorizableNet) repository. 42 | Follow steps 1 through 6. You can skip step 7. Download VG-DR-Net in step 8. You do not need to download other models. 43 | 44 | 3. Install 3D-Scene-Graph 45 | ``` 46 | cd 3D-Scene-Graph 47 | touch FactorizableNet/__init__.py 48 | ln -s ./FactorizableNet/options/ options 49 | mkdir data 50 | ln -s ./FactorizableNet/data/svg data/svg 51 | ln -s ./FactorizableNet/data/visual_genome data/visual_genome 52 | 53 | pip install torchtext==0.2.3 54 | pip install setuptools pyyaml graphviz webcolors pandas matplotlib 55 | pip install git+https://github.com/chickenbestlover/ColorHistogram.git 56 | ``` 57 | 58 | An Alternative: use installation script 59 | ``` 60 | ./build.sh 61 | ``` 62 | 63 | 4. Download [ScanNet](http://www.scan-net.org) dataset 64 | 65 | In order to use ScanNet dataset, you need to fill out an agreement to toe ScanNet Terms of Use and send it to the ScanNet team at scannet@googlegroups.com. 66 | If the process was successful, they will send you a script downloading ScanNet dataset. 67 | 68 | To download a specific scan (e.g. `scene0000_00`) using the script (the script only runs on Python 2.7): 69 | ``` 70 | download-scannet.py -o [directory in which to download] --id scene0000_00 71 | (then press Enter twice) 72 | ``` 73 | After the download is finished, the scan is located in a new folder `scene0000_00`. 74 | In the folder, `*.sens` file contains the RGBD Video with camera pose. 75 | To extract them, we use SensReader, an extraction tool provided by [ScanNet git repo](https://github.com/ScanNet/ScanNet). 76 | 77 | ``` 78 | git clone https://github.com/ScanNet/ScanNet.git 79 | cd ScanNet/SensReader/python/ 80 | python reader.py \ 81 | --filename [your .sens filepath] \ 82 | --output_path [ROOT of 3D-Scene-Graph]/data/scene0000_00/ \ 83 | --export_depth_images \ 84 | --export_color_images \ 85 | --export_poses \ 86 | --export_intrinsics 87 | 88 | ``` 89 | 90 | 91 | ## Example of usage 92 | 93 | ``` 94 | python scene_graph_tuning.py \ 95 | --scannet_path data/scene0000_00/\ 96 | --obj_thres 0.23\ 97 | --thres_key 0.2\ 98 | --thres_anchor 0.68 \ 99 | --visualize \ 100 | --frame_start 800 \ 101 | --plot_graph \ 102 | --disable_spurious \ 103 | --gain 10 \ 104 | --detect_cnt_thres 2 \ 105 | --triplet_thres 0.065 106 | ``` 107 | 108 | ### Core hyper-parameters # 109 | 110 | Data settings: 111 | * *--dataset * : choose dataset, default='scannet'. 112 | * *--scannet_path * : scannet scan filepath , default='./data/scene0507/'. 113 | * *--frame_start * : idx of frame to start , default=0. 114 | * *--frame_end * : idx of frame to end , default=5000. 115 | 116 | FactorizableNet Output Filtering Settings: 117 | * *--obj_thres * : object recognition threshold score , default=0.25. 118 | * *--triplet_thres * : triplet recognition threshold score , default=0.08. 119 | * *--nms * : NMS threshold for post object NMS (negative means not NMS) , default=0.2. 120 | * *--triplet_nms * : Triplet NMS threshold for post object NMS (negative means not NMS) , default=0.4. 121 | 122 | 123 | Key-frame Extraction Settings: 124 | * *--thres_key * : keyframe threshold score , default=0.1. 125 | * *--thres_anchor * : achorframe threshold score , default=0.65. 126 | * *--alpha * : weight for Exponentially Weighted Summation , default=0.4. 127 | * *--gain * : gain for adaptive thresholding in blurry image detection , default=25. 128 | * *--offset * : offset for adaptive thresholding in blurry image detection , default=1. 129 | 130 | Visualization Settings: 131 | * *--pause_time * : a pause interval (sec) for every detection , default=1. 132 | * *--plot_graph * : plot 3D Scene Graph if true. 133 | * *--visualize * : enable visualization if ture. 134 | * *--format * : resulting image format, pdf or png, default='png'. 135 | * *--draw_color * : draw color node in 3D scene graph if true. 136 | * *--save_image * : save detection result image if true. 137 | 138 | 139 | 140 | 141 | ## Result 142 | 143 | ![scores1](./vis_result/fig/result.png) 144 | 145 | 146 | ## Demo Video 147 | 148 | [![Video Label](http://img.youtube.com/vi/RF4bf7ZlAX4/0.jpg)](https://www.youtube.com/watch?v=RF4bf7ZlAX4) 149 | 150 | ## Citations 151 | 152 | Please consider citing this project in your publications if it helps your research. 153 | The following is a BibTeX reference. 154 | 155 | ``` 156 | @article{kim2019graph3d, 157 | title={3D-Scene-Graph: A Sparse and Semantic Representation of Physical Environments for Intelligent Agents}, 158 | author={Kim, Ue-Hwan and Park, Jin-Man and Song, Taek-Jin and Kim, Jong-Hwan}, 159 | journal={IEEE Cybernetics}, 160 | year={2019} 161 | } 162 | ``` 163 | 164 | ## Acknowledgement 165 | This work was supported by the ICT R&D program 166 | of MSIP/IITP. [2016-0-00563, Research on Adaptive Machine 167 | Learning Technology Development for Intelligent Autonomous 168 | Digital Companion] 169 | -------------------------------------------------------------------------------- /visualize_FactorizableNet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./FactorizableNet') 3 | import random 4 | import numpy as np 5 | import argparse 6 | import yaml 7 | from pprint import pprint 8 | import cv2 9 | import torch 10 | from torch.autograd import Variable 11 | from lib import network 12 | import lib.datasets as datasets 13 | import lib.utils.general_utils as utils 14 | import models as models 15 | from models.HDN_v2.utils import interpret_relationships 16 | import warnings 17 | 18 | 19 | parser = argparse.ArgumentParser('Options for training Hierarchical Descriptive Model in pytorch') 20 | 21 | parser.add_argument('--path_opt', default='options/models/VG-DR-Net.yaml', type=str, 22 | help='path to a yaml options file, VG-DR-Net.yaml or VG-MSDN.yaml') 23 | parser.add_argument('--dataset_option', type=str, default='normal', help='data split selection [small | fat | normal]') 24 | parser.add_argument('--batch_size', type=int, help='#images per batch') 25 | parser.add_argument('--workers', type=int, default=4, help='#idataloader workers') 26 | # model init 27 | parser.add_argument('--pretrained_model', type=str, default = 'FactorizableNet/output/trained_models/Model-VG-DR-Net.h5', 28 | help='path to pretrained_model, Model-VG-DR-Net.h5 or Model-VG-MSDN.h5') 29 | 30 | # structure settings 31 | # Environment Settings 32 | parser.add_argument('--seed', type=int, default=1, help='set seed to some constant value to reproduce experiments') 33 | parser.add_argument('--nms', type=float, default=0.3, help='NMS threshold for post object NMS (negative means not NMS)') 34 | parser.add_argument('--triplet_nms', type=float, default=0.01, help='Triplet NMS threshold for post object NMS (negative means not NMS)') 35 | # testing settings 36 | parser.add_argument('--use_gt_boxes', action='store_true', help='Use ground truth bounding boxes for evaluation') 37 | args = parser.parse_args() 38 | 39 | 40 | def get_class_string(class_index, score, dataset): 41 | class_text = dataset[class_index] if dataset is not None else \ 42 | 'id{:d}'.format(class_index) 43 | return class_text + ' {:0.2f}'.format(score).lstrip('0') 44 | 45 | 46 | def vis_object_detection(image_scene, test_set, 47 | obj_inds, obj_boxes, obj_scores): 48 | for i, obj_ind in enumerate(obj_inds): 49 | cv2.rectangle(image_scene, 50 | (int(obj_boxes[i][0]), int(obj_boxes[i][1])), 51 | (int(obj_boxes[i][2]), int(obj_boxes[i][3])), 52 | colorlist[i], 53 | 2) 54 | font_scale = 0.5 55 | txt = str(i) + '. ' + get_class_string(obj_ind, obj_scores[i], test_set.object_classes) 56 | ((txt_w, txt_h), _) = cv2.getTextSize(txt, cv2.FONT_HERSHEY_SIMPLEX, font_scale, 1) 57 | # Place text background. 58 | x0, y0 = int(obj_boxes[i][0]), int(obj_boxes[i][3]) 59 | back_tl = x0, y0 - int(1.3 * txt_h) 60 | back_br = x0 + txt_w, y0 61 | cv2.rectangle(image_scene, back_tl, back_br, colorlist[i], -1) 62 | cv2.putText(image_scene, 63 | txt, 64 | (x0, y0 - 2), 65 | cv2.FONT_HERSHEY_SIMPLEX, 66 | font_scale, 67 | (255, 255, 255), 68 | 1) 69 | 70 | return image_scene 71 | 72 | 73 | if __name__ == '__main__': 74 | # Set options 75 | options = { 76 | 'data':{ 77 | 'dataset_option': args.dataset_option, 78 | 'batch_size': args.batch_size, 79 | }, 80 | } 81 | with open(args.path_opt, 'r') as handle: 82 | options_yaml = yaml.load(handle) 83 | options = utils.update_values(options, options_yaml) 84 | with open(options['data']['opts'], 'r') as f: 85 | data_opts = yaml.load(f) 86 | options['data']['dataset_version'] = data_opts.get('dataset_version', None) 87 | options['opts'] = data_opts 88 | 89 | print '## args' 90 | pprint(vars(args)) 91 | print '## options' 92 | pprint(options) 93 | # To set the random seed 94 | random.seed(args.seed) 95 | torch.manual_seed(args.seed + 1) 96 | torch.cuda.manual_seed(args.seed + 2) 97 | 98 | print("Loading training set and testing set..."), 99 | test_set = getattr(datasets, options['data']['dataset'])(data_opts, 'test', 100 | dataset_option=options['data'].get('dataset_option', None), 101 | batch_size=options['data']['batch_size'], 102 | use_region=options['data'].get('use_region', False)) 103 | print("Done") 104 | 105 | # Model declaration 106 | model = getattr(models, options['model']['arch'])(test_set, opts = options['model']) 107 | print("Done.") 108 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=options['data']['batch_size'], 109 | shuffle=False, num_workers=args.workers, 110 | pin_memory=True, 111 | collate_fn=getattr(datasets, options['data']['dataset']).collate) 112 | 113 | network.set_trainable(model, False) 114 | print('Loading pretrained model: {}'.format(args.pretrained_model)) 115 | args.train_all = True 116 | network.load_net(args.pretrained_model, model) 117 | # Setting the state of the training model 118 | model.cuda() 119 | model.eval() 120 | 121 | for i, sample in enumerate(test_loader): # (im_data, im_info, gt_objects, gt_relationships) 122 | if i < 500: continue 123 | im_data = Variable(sample['visual'].cuda(), volatile=True) 124 | gt_objects = sample['objects'][0] 125 | gt_relationships = sample['relations'][0] 126 | im_info = sample['image_info'] 127 | 128 | 129 | object_result, predicate_result = model.forward_eval(im_data, im_info, ) 130 | cls_prob_object, bbox_object, object_rois, reranked_score = object_result[:4] 131 | cls_prob_predicate, mat_phrase = predicate_result[:2] 132 | region_rois_num = predicate_result[2] 133 | # interpret the model output 134 | obj_boxes, obj_scores, obj_cls, subject_inds, object_inds, \ 135 | subject_boxes, object_boxes, predicate_inds, \ 136 | sub_assignment, obj_assignment, total_score = \ 137 | interpret_relationships(cls_prob_object, bbox_object, object_rois, 138 | cls_prob_predicate, mat_phrase, im_info, 139 | nms=0.3, top_N=1,topk=1, 140 | use_gt_boxes=False, 141 | triplet_nms=0.01, 142 | reranked_score=reranked_score) 143 | 144 | # filter out who has low obj_score 145 | keep_obj=np.where(obj_scores>=0.2)[0] 146 | if keep_obj.size==0: 147 | warnings.warn("no object detected ... continue to the next image") 148 | continue 149 | cutline_idx = max(keep_obj) 150 | obj_scores = obj_scores[:cutline_idx+1] 151 | obj_boxes = obj_boxes[:cutline_idx+1] 152 | obj_cls = obj_cls[:cutline_idx+1] 153 | 154 | relationships = np.array(zip(sub_assignment, obj_assignment, predicate_inds, total_score)) 155 | keep_sub_assign = np.where(relationships[:,0]<=cutline_idx)[0] 156 | relationships = relationships[keep_sub_assign] 157 | keep_obj_assign = np.where(relationships[:,1]<=cutline_idx)[0] 158 | relationships = relationships[keep_obj_assign] 159 | 160 | # filter out who has low total_score 161 | keep_rel = np.where(relationships[:,3]>=0.03)[0] 162 | if keep_rel.size == 0: 163 | warnings.warn("no relation detected ... continue to the next image") 164 | continue 165 | cutline_idx = max(keep_rel) 166 | relationships = relationships[:cutline_idx+1] 167 | 168 | print('-------Subject-------|------Predicate-----|--------Object---------|--Score-') 169 | for relation in relationships: 170 | if relation[2] > 0: # '0' is the class 'irrelevant' 171 | print('{sbj_cls:9} {sbj_ID:4} {sbj_score:1.2f} | ' 172 | '{pred_cls:11} {pred_score:1.2f} | ' 173 | '{obj_cls:9} {obj_ID:4} {obj_score:1.2f} | ' 174 | '{triplet_score:1.3f}'.format( 175 | sbj_cls=test_set.object_classes[int(obj_cls[int(relation[0])])], sbj_score=obj_scores[int(relation[0])], 176 | sbj_ID=str(int(relation[0])), 177 | pred_cls=test_set.predicate_classes[int(relation[2])], pred_score=relation[3]/obj_scores[int(relation[0])]/obj_scores[int(relation[1])], 178 | obj_cls=test_set.object_classes[int(obj_cls[int(relation[1])])], obj_score=obj_scores[int(relation[1])], 179 | obj_ID=str(int(relation[1])), 180 | triplet_score=relation[3])) 181 | 182 | 183 | sample_img_path = './data/svg/images/' + test_set.annotations[i]['path'] 184 | img_scene = cv2.imread(sample_img_path) 185 | colorlist = [(random.randint(0, 210), random.randint(0, 210), random.randint(0, 210)) for i in 186 | range(10000)] 187 | img_scene = vis_object_detection(img_scene, test_set, obj_cls, obj_boxes, obj_scores) 188 | cv2.imshow('sample', img_scene) 189 | cv2.waitKey(0) 190 | cv2.destroyAllWindows() 191 | 192 | 193 | result = {'objects': { 194 | 'bbox': obj_boxes, 195 | 'scores': obj_scores, 196 | 'class': obj_cls, }, 197 | 'relationships': relationships 198 | } -------------------------------------------------------------------------------- /scene_graph_tuning.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./FactorizableNet') 3 | import random 4 | import numpy.random as npr 5 | import numpy as np 6 | import yaml 7 | from pprint import pprint 8 | import cv2 9 | import torch 10 | from torch.autograd import Variable 11 | from lib import network 12 | import lib.datasets as datasets 13 | import lib.utils.general_utils as utils 14 | from model.settings import parse_args, testImageLoader 15 | from PIL import Image 16 | import os.path as osp 17 | import os 18 | from model import interpret, vis_tuning 19 | from model.vis_tuning import tools_for_visualizing 20 | from model.keyframe.keyframe_extracion import keyframe_checker 21 | from model.SGGenModel import SGGen_MSDN, SGGen_DR_NET 22 | 23 | args = parse_args() 24 | # Set the random seed 25 | random.seed(args.seed) 26 | torch.manual_seed(args.seed + 1) 27 | torch.cuda.manual_seed(args.seed + 2) 28 | colorlist = [(random.randint(0,230),random.randint(0,230),random.randint(0,230)) for i in range(10000)] 29 | 30 | # Set options 31 | options = { 32 | 'data': { 33 | 'dataset_option': args.dataset_option, 34 | 'batch_size': args.batch_size, 35 | }, 36 | } 37 | with open(args.path_opt, 'r') as handle: 38 | options_yaml = yaml.load(handle) 39 | options = utils.update_values(options, options_yaml) 40 | with open(options['data']['opts'], 'r') as f: 41 | data_opts = yaml.load(f) 42 | options['data']['dataset_version'] = data_opts.get('dataset_version', None) 43 | options['opts'] = data_opts 44 | 45 | print '## args' 46 | pprint(vars(args)) 47 | print '## options' 48 | pprint(options) 49 | # To set the random seed 50 | random.seed(args.seed) 51 | torch.manual_seed(args.seed + 1) 52 | torch.cuda.manual_seed(args.seed + 2) 53 | 54 | print("Loading training set and testing set..."), 55 | test_set = getattr(datasets, options['data']['dataset'])(data_opts, 'test', 56 | dataset_option=options['data'].get('dataset_option', None), 57 | batch_size=options['data']['batch_size'], 58 | use_region=options['data'].get('use_region', False)) 59 | print("Done") 60 | 61 | # Model declaration 62 | #model = getattr(models, options['model']['arch'])(test_set, opts=options['model']) 63 | if args.path_opt.split('/')[-1].strip() == 'VG-DR-Net.yaml': 64 | model = SGGen_DR_NET(args, test_set, opts=options['model']) 65 | elif args.path_opt.split('/')[-1].strip() == 'VG-MSDN.yaml': 66 | model = SGGen_MSDN(args, test_set, opts=options['model']) 67 | else: 68 | raise NotImplementedError 69 | print("Done.") 70 | network.set_trainable(model, False) 71 | print('Loading pretrained model: {}'.format(args.pretrained_model)) 72 | args.train_all = True 73 | network.load_net(args.pretrained_model, model) 74 | # Setting the state of the training model 75 | model.cuda() 76 | model.eval() 77 | 78 | print('--------------------------------------------------------------------------') 79 | print('3D-Scene-Graph-Generator Demo: Object detection and Scene Graph Generation') 80 | print('--------------------------------------------------------------------------') 81 | imgLoader = testImageLoader(args) 82 | # Initial Sort tracker 83 | interpreter = interpret.interpreter(args, test_set, ENABLE_TRACKING=False) 84 | scene_graph = vis_tuning.scene_graph(args) 85 | keyframe_extractor = keyframe_checker(args, 86 | thresh_key=args.thres_key, 87 | thresh_anchor=args.thres_anchor, 88 | max_group_len=args.max_group_len, 89 | intrinsic_depth=imgLoader.intrinsic_color, 90 | alpha=args.alpha, 91 | blurry_gain=args.gain, 92 | blurry_offset=args.offset, 93 | depth_shape=(480,640)) 94 | 95 | for idx in range(imgLoader.num_frames)[args.frame_start:args.frame_end]: 96 | print('...........................................................................') 97 | ''' 1. Load an color/depth image and camera parameter''' 98 | if args.dataset == 'scannet': 99 | image_scene, depth_img, pix_depth, inv_p_matrix, inv_R, Trans, camera_pose = imgLoader.load_image(frame_idx=idx) 100 | else: 101 | image_scene = imgLoader.load_image(frame_idx=idx) 102 | if type(image_scene) !=np.ndarray: 103 | continue 104 | depth_img, pix_depth, inv_p_matrix, inv_R, Trans, camera_pose = None, None, None, None, None, None 105 | img_original_shape = image_scene.shape 106 | 107 | 108 | ''' 3. Pre-processing: Rescale & Normalization ''' 109 | # Resize the image to target scale 110 | if args.dataset == 'scannet': 111 | image_scene= cv2.resize(image_scene, (depth_img.size), interpolation= cv2.INTER_AREA) 112 | target_scale = test_set.opts[test_set.cfg_key]['SCALES'][npr.randint(0, high=len(test_set.opts[test_set.cfg_key]['SCALES']))] 113 | im_data, im_scale = test_set._image_resize(image_scene, target_scale, test_set.opts[test_set.cfg_key]['MAX_SIZE']) 114 | # restore the [image_height, image_width, scale_factor, max_size] 115 | im_info = np.array([[im_data.shape[0], im_data.shape[1], im_scale, 116 | img_original_shape[0], img_original_shape[1]]], dtype=np.float) 117 | im_data = Image.fromarray(im_data) 118 | im_data = test_set.transform(im_data) # normalize the image with the pre-defined min/std. 119 | im_data = Variable(im_data.cuda(), volatile=True).unsqueeze(0) 120 | 121 | 122 | ''' 2. Key-frame Extraction: Check if this frame is key-frame or anchor-frame''' 123 | IS_KEY_OR_ANCHOR, sharp_score, sharp_thres = keyframe_extractor.check_frame(image_scene, depth_img, camera_pose) 124 | winname_scene = '{idx:0004}. sharp_score: {score:05.1f}, sharp_thres: {thres:05.1f}, KEY_OR_ANCHOR: {flag:5}' \ 125 | .format(idx=idx, score=sharp_score, thres=sharp_thres, flag=str(IS_KEY_OR_ANCHOR)) 126 | print(winname_scene) 127 | image_original = image_scene.copy() 128 | if args.visualize: 129 | cv2.namedWindow('detection') # Create a named window 130 | cv2.moveWindow('detection', 1400, 10) 131 | cv2.putText(image_original, 132 | winname_scene, 133 | (1, 11), 134 | cv2.FONT_HERSHEY_SIMPLEX, 135 | 0.5, 136 | # (150, 150, 150), 137 | (30, 30, 200), 138 | 2) 139 | cv2.imshow('detection', image_original) 140 | cv2.waitKey(1) 141 | if not IS_KEY_OR_ANCHOR: 142 | continue 143 | 144 | ''' 3. Object Detection & Scene Graph Generation from the Pre-trained MSDN Model ''' 145 | object_result, predicate_result = model.forward_eval(im_data, im_info, ) 146 | 147 | 148 | ''' 4. Post-processing: Interpret the Model Output ''' 149 | # interpret the model output 150 | obj_boxes, obj_scores, obj_cls, \ 151 | subject_inds, object_inds, \ 152 | subject_boxes, object_boxes, \ 153 | subject_IDs, object_IDs, \ 154 | predicate_inds, triplet_scores, relationships = \ 155 | interpreter.interpret_graph(object_result, predicate_result,im_info) 156 | 157 | 158 | ''' 5. Print 2D Object Detection ''' 159 | # original image_scene 160 | img_obj_detected = tools_for_visualizing.vis_object_detection(image_scene.copy(), test_set, obj_cls[:, 0], obj_boxes, obj_scores[:, 0]) 161 | 162 | if args.visualize: 163 | cv2.namedWindow('detection') # Create a named window 164 | cv2.moveWindow('detection', 1400, 10) 165 | cv2.putText(img_obj_detected, 166 | winname_scene, 167 | (1, 11), 168 | cv2.FONT_HERSHEY_SIMPLEX, 169 | 0.5, 170 | (30, 200, 30), 171 | 2) 172 | cv2.imshow('detection', image_original) 173 | 174 | if args.save_image: 175 | scene_name = args.scannet_path.split('/')[-1] 176 | try: os.makedirs(osp.join(args.vis_result_path, scene_name,'original')) 177 | except: pass 178 | try: os.makedirs(osp.join(args.vis_result_path, scene_name,'detection')) 179 | except: pass 180 | cv2.imwrite(osp.join(args.vis_result_path, scene_name,'original',str(idx) + '.jpg'), image_scene) 181 | cv2.imwrite(osp.join(args.vis_result_path, scene_name,'detection',str(idx) + '.jpg'), img_obj_detected) 182 | 183 | ''' 6. Merge Relations into 3D Scene Graph''' 184 | updated_image_scene = scene_graph.vis_scene_graph(image_scene.copy(), camera_pose, idx, test_set, 185 | obj_cls, obj_boxes, obj_scores, 186 | subject_inds, predicate_inds, object_inds, 187 | subject_IDs, object_IDs, triplet_scores,relationships, 188 | pix_depth, inv_p_matrix, inv_R, Trans, dataset=args.dataset) 189 | 190 | if args.visualize: 191 | cv2.namedWindow('updated') # Create a named window 192 | cv2.moveWindow('updated', 1400, 520) 193 | cv2.putText(updated_image_scene, 194 | winname_scene, 195 | (1, 11), 196 | cv2.FONT_HERSHEY_SIMPLEX, 197 | 0.5, 198 | (30, 200, 30), 199 | 2) 200 | cv2.imshow('updated', updated_image_scene) 201 | 202 | if args.save_image: 203 | scene_name = args.scannet_path.split('/')[-1] 204 | try: os.makedirs(osp.join(args.vis_result_path, scene_name,'updated')) 205 | except: pass 206 | cv2.imwrite(osp.join(args.vis_result_path, scene_name,'updated','updated'+str(idx) +'.jpg'), updated_image_scene) 207 | 208 | if args.visualize: 209 | cv2.waitKey(args.pause_time) 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | -------------------------------------------------------------------------------- /model/settings.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./FactorizableNet') 3 | import argparse 4 | import numpy as np 5 | import os.path as osp 6 | import os 7 | from PIL import Image 8 | import cv2 9 | import yaml 10 | import lib.utils.general_utils as utils 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser('Options for Running 3D-Scene-Graph-Generator in pytorch') 15 | 16 | '''Hyper-params in FactorizableNet''' 17 | # Pre-trained Model Settings 18 | parser.add_argument('--pretrained_model', type=str, 19 | default='FactorizableNet/output/trained_models/Model-VG-DR-Net.h5', 20 | help='path to pretrained_model, Model-VG-DR-Net.h5 or Model-VG-MSDN.h5') 21 | parser.add_argument('--path_opt', default='options/models/VG-DR-Net.yaml', type=str, 22 | help='path to a yaml options file, VG-DR-Net.yaml or VG-MSDN.yaml') 23 | # Data loader Settings 24 | parser.add_argument('--dataset_option', type=str, default='normal', 25 | help='data split selection [small | fat | normal]') 26 | parser.add_argument('--batch_size', type=int, help='#images per batch') 27 | parser.add_argument('--workers', type=int, default=4, help='#idataloader workers') 28 | # Environment Settings 29 | parser.add_argument('--seed', type=int, default=1, help='set seed to some constant value to reproduce experiments') 30 | parser.add_argument('--nms', type=float, default=0.2, 31 | help='NMS threshold for post object NMS (negative means not NMS)') 32 | parser.add_argument('--triplet_nms', type=float, default=0.4, 33 | help='Triplet NMS threshold for post object NMS (negative means not NMS)') 34 | # testing settings 35 | parser.add_argument('--use_gt_boxes', action='store_true', help='Use ground truth bounding boxes for evaluation') 36 | 37 | '''Demo Settings in 3D-Scene-Graph''' 38 | # Data loader Settings 39 | parser.add_argument('--dataset' ,type=str, default='scannet', 40 | help='choose a dataset. Example: "visual_genome", "scannet","ETH-Pedcross2", "ETH-Sunnyday"') 41 | parser.add_argument('--scannet_path', type=str, 42 | default='./data/scene0507/', help='scene0507') 43 | parser.add_argument('--mot_benchmark_path', type=str, 44 | default='./data/mot_benchmark/') 45 | parser.add_argument('--vis_result_path',type=str,default='./vis_result') 46 | # FactorizableNet Output Filtering Settings 47 | parser.add_argument('--obj_thres', type=float, default=0.25, 48 | help='object recognition threshold score') 49 | parser.add_argument('--triplet_thres', type=float, default=0.08, 50 | help='Triplet recognition threshold score ') 51 | # Key-frame Extractor Settings 52 | parser.add_argument('--thres_key', type=float, default=0.1, 53 | help=' ') 54 | parser.add_argument('--thres_anchor', type=float, default=0.65, 55 | help=' ') 56 | parser.add_argument('--max_group_len', type=float, default=10, 57 | help=' ') 58 | parser.add_argument('--alpha', type=float, default=0.4, 59 | help='weight for Exponentially Weighted Summation') 60 | parser.add_argument('--gain', type=float, default=25, 61 | help='gain for adaptive thresholding in blurry image detection. scene0507:45 scene0000:25') 62 | parser.add_argument('--offset', type=float, default=1, 63 | help='offset for adaptive thresholding in blurry image detection') 64 | parser.add_argument('--detect_cnt_thres', type=float, default=2, 65 | help='scene graph threshold') 66 | parser.add_argument('--frame_start', type=int, default=0, 67 | help='frame_start') 68 | parser.add_argument('--frame_end', type=int, default=5000, 69 | help='frame_end') 70 | parser.add_argument('--disable_keyframe', action='store_true', 71 | help='disable keyframe extraction if true') 72 | parser.add_argument('--disable_spurious', action='store_true', 73 | help='disable spurious rejection if true') 74 | parser.add_argument('--disable_samenode', action='store_true', 75 | help='disable same node detection if true') 76 | parser.add_argument('--pause_time', type=int, default=1, 77 | help='pause time') 78 | parser.add_argument('--plot_graph', action='store_true', 79 | help='plot graph if true') 80 | parser.add_argument('--visualize', action='store_true', 81 | help='enable visualization') 82 | parser.add_argument('--format', type=str, default='png', 83 | help='scene graph image format, pdf or png') 84 | parser.add_argument('--draw_color', action='store_true', 85 | help='draw color node in scene graph if true') 86 | parser.add_argument('--save_image', action='store_true', 87 | help='save detection result image if true') 88 | 89 | args = parser.parse_args() 90 | 91 | return args 92 | 93 | 94 | class testImageLoader(object): 95 | """ 96 | Description 97 | - Loads images for test 98 | - Specify which dataset to use for test 99 | Functions 100 | - load_image: loads one image of given frame id 101 | """ 102 | def __init__(self,args): 103 | self.args = args 104 | self.mot_benchmark_train = ['ADL-Rundle-6', 'ADL-Rundle-8','ETH-Bahnhof','ETH-Pedcross2','ETH-Sunnyday', 105 | 'KITTI-13', 'KITTI-17','PETS09-S2L1','TUD-Campus','TUD-Stadtmitte','Venice-2'] 106 | self.mot_benchmark_test = ['ADL-Rundle-1', 'ADL-Rundle-3','AVG-TownCentre','ETH-Crossing','ETH-Jelmoli', 107 | 'ETH-Linthescher', 'KITTI-16', 'KITTI-19','PETS09-S2L2','TUD-Crossing','Venice-1'] 108 | if self.args.dataset == 'scannet': 109 | self.scannet_img_path = osp.join(args.scannet_path, 'color') 110 | self.scannet_depth_path = osp.join(args.scannet_path, 'depth') 111 | self.scannet_intrinsic_path = osp.join(args.scannet_path, 'intrinsic') 112 | self.scannet_pose_path = osp.join(args.scannet_path, 'pose') 113 | # Load Camera intrinsic parameter 114 | self.intrinsic_color = open(osp.join(self.scannet_intrinsic_path, 'intrinsic_color.txt')).read() 115 | self.intrinsic_depth = open(osp.join(self.scannet_intrinsic_path, 'intrinsic_depth.txt')).read() 116 | self.intrinsic_color = [item.split() for item in self.intrinsic_color.split('\n')[:-1]] 117 | self.intrinsic_depth = [item.split() for item in self.intrinsic_depth.split('\n')[:-1]] 118 | self.intrinsic_depth = np.matrix(self.intrinsic_depth, dtype='float') 119 | self.img_folder_path = osp.join(args.scannet_path, 'color') 120 | elif self.args.dataset == 'visual_genome': 121 | self.img_folder_path = 'data/visual_genome/images' 122 | self.intrinsic_depth = None 123 | elif self.args.dataset in self.mot_benchmark_train: 124 | mot_train_path = osp.join(args.mot_benchmark_path, 'train') 125 | self.img_folder_path = osp.join(mot_train_path,self.args.dataset,'img1') 126 | self.intrinsic_depth = None 127 | elif self.args.dataset in self.mot_benchmark_test: 128 | mot_test_path = osp.join(args.mot_benchmark_path, 'test') 129 | self.img_folder_path = osp.join(mot_test_path, self.args.dataset,'img1') 130 | self.intrinsic_depth = None 131 | else: 132 | raise NotImplementedError 133 | 134 | self.num_frames = len(os.listdir(self.img_folder_path)) 135 | 136 | def load_image(self,frame_idx): 137 | if self.args.dataset == 'scannet': 138 | # Load an image from ScanNet Dataset 139 | img_path = osp.join(self.img_folder_path, str(frame_idx+1) + '.jpg') 140 | camera_pose = open(osp.join(self.scannet_pose_path, str(frame_idx+1) + '.txt')).read() 141 | depth_img = Image.open(osp.join(self.scannet_depth_path,str(frame_idx+1) + '.png')) 142 | # Preprocess loaded camera parameter and depth info 143 | depth_pix = depth_img.load() 144 | pix_depth = [] 145 | for ii in range(depth_img.size[1]): 146 | pix_row = [] 147 | for jj in range(depth_img.size[0]): 148 | pix_row.append(depth_pix[jj, ii]) 149 | pix_depth.append(pix_row) 150 | 151 | camera_pose = [item.split() for item in camera_pose.split('\n')[:-1]] 152 | camera_pose = np.array(camera_pose,dtype=float) 153 | p_matrix = [ self.intrinsic_color[0][:], self.intrinsic_color[1][:], self.intrinsic_color[2][:]] 154 | p_matrix = np.matrix(p_matrix, dtype='float') 155 | inv_p_matrix = np.linalg.pinv(p_matrix) 156 | R = np.matrix([camera_pose[0][0:3], camera_pose[1][0:3], camera_pose[2][0:3]], dtype='float') 157 | inv_R = np.linalg.inv(R) 158 | Trans = np.matrix([camera_pose[0][3], camera_pose[1][3], camera_pose[2][3]], dtype='float') 159 | img_scene = cv2.imread(img_path) 160 | return img_scene, depth_img, pix_depth, inv_p_matrix, inv_R, Trans, camera_pose 161 | 162 | elif self.args.dataset == 'visual_genome': 163 | # Load an image from Visual Genome Dataset 164 | frame_idx += 1 165 | img_path = osp.join(self.img_folder_path, str(frame_idx)+'.jpg') 166 | elif self.args.dataset in self.mot_benchmark_train or self.args.dataset in self.mot_benchmark_test: 167 | frame_idx += 1 168 | img_path = osp.join(self.img_folder_path, '%06d.jpg' % (frame_idx)) 169 | else: 170 | raise NotImplementedError 171 | img_scene = cv2.imread(img_path) 172 | return img_scene 173 | 174 | 175 | def set_options(args): 176 | # Set options 177 | options = { 178 | 'data': { 179 | 'dataset_option': args.dataset_option, 180 | 'batch_size': args.batch_size, 181 | }, 182 | } 183 | with open(args.path_opt, 'r') as handle: 184 | options_yaml = yaml.load(handle) 185 | options = utils.update_values(options, options_yaml) 186 | with open(options['data']['opts'], 'r') as f: 187 | data_opts = yaml.load(f) 188 | options['data']['dataset_version'] = data_opts.get('dataset_version', None) 189 | options['opts'] = data_opts 190 | 191 | return options 192 | -------------------------------------------------------------------------------- /relation_prior_extraction.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./FactorizableNet') 3 | import json 4 | import math 5 | import pickle 6 | import numpy as np 7 | import yaml 8 | import lib.datasets as datasets 9 | import lib.utils.general_utils as utils 10 | import os 11 | import os.path as osp 12 | from model.settings import parse_args 13 | from model.SGGenModel import VG_DR_NET_PRED_IGNORES 14 | 15 | 16 | def distance(o_x, o_y, s_x, s_y): 17 | return math.sqrt((o_x - s_x) ** 2 + (o_y - s_y) ** 2) 18 | 19 | def update_normal(prev_mean, prev_var, prev_num, new_val): 20 | total = prev_mean * prev_num + new_val 21 | new_mean = total / (prev_num + 1) 22 | prev_square_mean = prev_var + prev_mean ** 2 23 | new_square_mean = (prev_square_mean * prev_num + new_val ** 2) / (prev_num + 1) 24 | new_var = new_square_mean - new_mean ** 2 25 | return new_mean, new_var 26 | 27 | def save_obj(obj, filename): 28 | with open(filename + '.pkl', 'wb') as f: 29 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 30 | 31 | def load_obj(filename): 32 | with open(filename + '.pkl', 'rb') as f: 33 | return pickle.load(f) 34 | 35 | 36 | if __name__=="__main__": 37 | 38 | args = parse_args() 39 | # Set options 40 | options = { 41 | 'data': { 42 | 'dataset_option': args.dataset_option, 43 | 'batch_size': args.batch_size, 44 | }, 45 | } 46 | with open(args.path_opt, 'r') as handle: 47 | options_yaml = yaml.load(handle) 48 | options = utils.update_values(options, options_yaml) 49 | with open(options['data']['opts'], 'r') as f: 50 | data_opts = yaml.load(f) 51 | options['data']['dataset_version'] = data_opts.get('dataset_version', None) 52 | options['opts'] = data_opts 53 | print("Loading training set and testing set..."), 54 | test_set = getattr(datasets, options['data']['dataset'])(data_opts, 'test', 55 | dataset_option='normal', 56 | batch_size=1, 57 | use_region=False) 58 | print("Done") 59 | 60 | try: os.makedirs(osp.join('model','prior', 'preprocessed')) 61 | except: pass 62 | '''0. Choose OBJECTS & PREDICTES to be extracted''' 63 | relevant_classes = sorted([u'field', u'zebra', u'sky', u'track', u'train', u'window', u'pole', u'windshield', u'background', u'tree', u'door', u'sheep', u'paint', u'grass', u'baby', u'ear', u'leg', u'eye', u'tail', u'head', u'nose', u'skateboarder', u'arm', u'foot', u'skateboard', u'wheel', u'hand', u'ramp', u'man', u'jeans', u'shirt', u'sneaker', u'writing', u'hydrant', u'cap', u'chain', u'sidewalk', u'curb', u'road', u'line', u'bush', u'sign', u'people', u'car', u'edge', u'bus', u'tire', u'lady', u'letter', u'leaf', u'boy', u'pocket', u'backpack', u'bottle', u'suitcase', u'word', u'ground', u'handle', u'strap', u'jacket', u'motorcycle', u'bicycle', u'truck', u'cloud', u'kite', u'pants', u'beach', u'woman', u'rock', u'dress', u'dog', u'building', u'frisbee', u'shoe', u'plant', u'pot', u'hair', u'face', u'shorts', u'stripe', u'bench', u'flower', u'cat', u'post', u'container', u'house', u'ceiling', u'seat', u'back', u'graffiti', u'paper', u'hat', u'tennisracket', u'tennisplayer', u'wall', u'logo', u'girl', u'clock', u'brick', u'white', u'elephant', u'mirror', u'bird', u'glove', u'oven', u'area', u'sticker', u'flag', u'surfboard', u'wetsuit', u'shadow', u'sleeve', u'tenniscourt', u'surface', u'finger', u'string', u'plane', u'wing', u'umbrella', u'snow', u'sunglasses', u'boot', u'coat', u'skipole', u'ski', u'skier', u'black', u'player', u'sock', u'racket', u'wrist', u'band', u'ball', u'light', u'shelf', u'stand', u'vase', u'horse', u'number', u'rug', u'goggles', u'snowboard', u'computer', u'screen', u'button', u'glass', u'bracelet', u'cellphone', u'mountain', u'phone', u'hill', u'fence', u'stone', u'cow', u'tag', u'bear', u'table', u'water', u'ocean', u'trashcan', u'circle', u'river', u'railing', u'design', u'bowl', u'food', u'spoon', u'tablecloth', u'plate', u'bread', u'tomato', u'kid', u'sand', u'dirt', u'mouth', u'hole', u'air', u'distance', u'board', u'feet', u'suit', u'wave', u'guy', u'reflection', u'bathroom', u'toilet', u'sink', u'faucet', u'floor', u'toiletpaper', u'towel', u'sandwich', u'knife', u'bolt', u'boat', u'engine', u'trafficlight', u'wine', u'cup', u'stem', u'base', u'top', u'bottom', u'sofa', u'counter', u'photo', u'frame', u'side', u'paw', u'branch', u'fur', u'forest', u'wire', u'headlight', u'rail', u'front', u'green', u'helmet', u'whiskers', u'pen', u'neck', u'net', u'necklace', u'duck', u'sweater', u'chair', u'horn', u'giraffe', u'spot', u'mane', u'airplane', u'beard', u'speaker', u'sun', u'shore', u'pillar', u'tower', u'jet', u'gravel', u'sauce', u'fork', u'tray', u'awning', u'tent', u'bun', u'teeth', u'camera', u'tile', u'lid', u'kitchen', u'curtain', u'drawer', u'knob', u'box', u'outlet', u'remote', u'couch', u'tie', u'book', u'ring', u'toothbrush', u'balcony', u'stairs', u'doorway', u'stopsign', u'bed', u'pillow', u'corner', u'trim', u'vegetable', u'orange', u'broccoli', u'rope', u'streetlight', u'name', u'pitcher', u'uniform', u'body', u'mouse', u'keyboard', u'desk', u'monitor', u'statue', u'collar', u'candle', u'animal', u'tv', u'donut', u'apple', u'child', u'licenseplate', u'catcher', u'umpire', u'banner', u'bat', u'batter', u'part', u'hotdog', u'object', u'cake', u'bridge', u'patch', u'belt', u'park', u'stick', u'bucket', u'runway', u'lamp', u'tip', u'carpet', u'blanket', u'cover', u'napkin', u'theoutdoors', u'stove', u'pizza', u'cheese', u'crust', u'van', u'beak', u'cord', u'poster', u'purse', u'laptop', u'shoulder', u'dish', u'can', u'pipe', u'key', u'arrow', u'surfer', u'controller', u'blinds', u'bluesky', u'whiteclouds', u'luggage', u'vehicle', u'streetsign', u'pan', u'baseball', u'baseballplayer', u'jersey', u'rack', u'cabinet', u'meat', u'watch', u'refrigerator', u'vest', u'skirt', u'hoof', u'label', u'teddybear', u'fridge', u'snowboarder', u'scarf', u'basket', u'cloth', u'shade', u'blue', u'spectator', u'knee', u'column', u'metal', u'steps', u'firehydrant', u'platform', u'jar', u'fruit', u'hood', u't-shirt', u'cone', u'weeds', u'treetrunk', u'room', u'red', u'television', u'scissors', u'gate', u'tennisball', u'court', u'log', u'star', u'lettuce', u'traincar', u'microwave', u'pepperoni', u'onion', u'chimney', u'concrete', u'mug', u'carrot', u'banana', u'cart', u'wood', u'bar', u'ripples', u'holder', u'pepper', u'tusk']) 64 | print('Before OBJECT filtering: ' + str(len(relevant_classes))) 65 | to_exclude = ['1', 'field', '2', 'zebra', '3', 'sky', '4', 'track', '5', 'train', '9', 'background', '12', 'sheep', '14', 'grass', '15', 'baby', '16', 'ear', '17', 'leg', '18', 'eye', '19', 'tail', '20', 'head', '21', 'nose', '22', 'skateboarder', '23', 'arm', '24', 'foot', '25', 'skateboard', '27', 'hand', '29', 'man', '34', 'hydrant', '37', 'sidewalk', '38', 'curb', '39', 'road', '43', 'people', '44', 'car', '46', 'bus', '47', 'tire', '48', 'lady', '49', 'letter', '50', 'leaf', '51', 'boy', '64', 'cloud', '65', 'kite', '66', 'pants', '67', 'beach', '68', 'woman', '71', 'dog', '72', 'building', '73', 'frisbee', '77', 'hair', '78', 'face', '79', 'shorts', '83', 'cat', '94', 'tennisplayer', '97', 'girl', '101', 'elephant', '103', 'bird', '111', 'shadow', '112', 'sleeve', '113', 'tenniscourt', '114', 'surface', '115', 'finger', '120', 'snow', '121', 'sunglasses', '126', 'skier', '128', 'player', '131', 'wrist', '138', 'horse', '151', 'hill', '152', 'fence', '154', 'cow', '156', 'bear', '162', 'river', '163', 'railing', '172', 'kid', '175', 'mouth', '177', 'air', '178', 'distance', '180', 'feet', '182', 'wave', '183', 'guy', '184', 'reflection', '200', 'stem', '209', 'paw', '210', 'branch', '212', 'forest', '215', 'rail', '219', 'whiskers', '221', 'neck', '223', 'necklace', '224', 'duck', '228', 'giraffe', '230', 'mane', '232', 'beard', '234', 'sun', '235', 'shore', '237', 'tower', '239', 'gravel', '243', 'awning', '244', 'tent', '246', 'teeth', '276', 'pitcher', '277', 'uniform', '278', 'body', '286', 'animal', '290', 'child', '291', 'licenseplate', '292', 'catcher', '293', 'umpire', '296', 'batter', '301', 'bridge', '304', 'park', '307', 'runway', '314', 'theoutdoors', '319', 'van', '320', 'beak', '325', 'shoulder', '331', 'surfer', '334', 'bluesky', '335', 'whiteclouds', '337', 'vehicle', '340', 'baseball', '341', 'baseballplayer', '350', 'hoof', '354', 'snowboarder', '358', 'shade', '360', 'spectator', '361', 'knee', '366', 'platform', '372', 'weeds', '373', 'treetrunk', '396', 'ripples', '399', 'tusk'] 66 | for ex in to_exclude: 67 | if ex in relevant_classes: 68 | relevant_classes.remove(ex) 69 | obj_ind2key = {idx: item for idx, item in enumerate(relevant_classes)} 70 | obj_key2ind = {item: idx for idx, item in enumerate(relevant_classes)} 71 | save_obj(obj_ind2key,filename='model/prior/preprocessed/object_prior_ind2key') 72 | save_obj(obj_key2ind,filename='model/prior/preprocessed/object_prior_key2ind') 73 | print('After OBJECT filtering: ' + str(len(relevant_classes))) 74 | 75 | print('Before PREDICATE filtering: ' + str(len(test_set.predicate_classes))) 76 | pred_cls = range(len(test_set.predicate_classes)) 77 | pred_cls = [p for p in pred_cls if not p in VG_DR_NET_PRED_IGNORES] 78 | pred_cls = [test_set.predicate_classes[x] for x in pred_cls] 79 | pred_ind2key = {idx: item for idx, item in enumerate(pred_cls)} 80 | pred_key2ind = {item: idx for idx, item in enumerate(pred_cls)} 81 | save_obj(pred_ind2key,filename='model/prior/preprocessed/pred_prior_ind2key.pkl') 82 | save_obj(pred_key2ind,filename='model/prior/preprocessed/pred_prior_key2ind.pkl') 83 | print('After PREDICATE filtering: ' + str(len(pred_cls))) 84 | 85 | '''1. start constructing relation prior''' 86 | extract_statistics = True 87 | if extract_statistics: 88 | file = open('model/prior/raw/relationships.json').read() 89 | data = json.loads(file) 90 | print("Reading JSON completed!!") 91 | 92 | statistics = {} 93 | predicates = {} 94 | i = 0 95 | for datum in data: 96 | i += 1 97 | if i % 1000 == 0: 98 | print(str(i) + "th point processing") 99 | 100 | relations = datum['relationships'] 101 | 102 | for rel in relations: 103 | pred, obj, sub = rel['predicate'], rel['object']['name'], rel['subject']['name'] 104 | pred = pred.lower() 105 | obj_x, obj_y, sub_x, sub_y = rel['object']['x'], rel['object']['y'], rel['subject']['x'], rel['subject']['y'] 106 | if obj in relevant_classes and sub in relevant_classes and pred in pred_cls: 107 | curr_dist = distance(obj_x, obj_y, sub_x, sub_y) 108 | if (sub, obj) in statistics: 109 | if pred in statistics[(sub, obj)]: 110 | statistics[(sub, obj)][pred]['count'] += 1 111 | statistics[(sub, obj)][pred]['dist'].append(curr_dist) 112 | else: 113 | statistics[(sub, obj)][pred] = {'count': 1, 'dist': [curr_dist]} 114 | else: 115 | statistics[(sub, obj)] = {} 116 | statistics[(sub, obj)][pred] = {'count': 1, 'dist': [curr_dist]} 117 | if pred in predicates: 118 | predicates[pred] += 1 119 | else: 120 | predicates[pred] = 1 121 | save_obj(statistics, "model/prior/preprocessed/relation_prior") 122 | 123 | '''2. processing relation prior''' 124 | statistics = load_obj("model/prior/preprocessed/relation_prior") 125 | keys = list(statistics) 126 | for item in keys: 127 | for predicate in statistics[item]: 128 | statistics[item][predicate]['mean'] = np.mean(statistics[item][predicate]['dist']) 129 | statistics[item][predicate]['var'] = np.var(statistics[item][predicate]['dist']) 130 | del statistics[item][predicate]['dist'] 131 | save_obj(statistics, 'model/prior/preprocessed/relation_prior_prob') 132 | -------------------------------------------------------------------------------- /model/SGGenModel.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./FactorizableNet') 3 | import models as models 4 | import numpy as np 5 | import torch.nn.functional as F 6 | import torch 7 | 8 | 9 | VG_DR_NET_PRED_IGNORES=(0,6,10,18,19,20,22,23,24) 10 | VG_DR_NET_OBJ_IGNORES = (0,1,2,3,4,5,9,12,14,15,16,17,18,19,20,21,22,23,24,25,27,29,34,37,38,39,43,44,46,47,48,51,64,65,\ 11 | 67,68,71,72,73,77,78,79,83,94,97,103,111,112,113,114,115,120,121,126,127,128,131,138,151,152,154,156,162,163,172,175,177,\ 12 | 178,180,182,183,184,200,209,210,212,215,219,221,223,224,228,230,232,234,235,237,239,243,244,246,276,277,278,\ 13 | 286,290,291,292,293,296,301,304,307,314,319,\ 14 | 320,325,331,334,335,337,340,341,350,354,358,360,361,366,372,373,380,384,396,399,93,124,338,370,379 ) 15 | ''' 16 | objects_ignored (VG-DR-NET) 17 | 0 __background__ 18 | 1 field 2 zebra 3 sky 4 track 5 train 9 background 19 | 12 sheep 14 grass 15 baby 16 ear 17 leg 18 eye 20 | 19 tail 20 head 21 nose 22 skateboarder 23 arm 24 foot 21 | 25 skateboard 27 hand 29 man 34 hydrant 37 sidewalk 38 curb 22 | 39 road 43 people 44 car 46 bus 47 tire 48 lady 23 | 49 letter 50 leaf 51 boy 64 cloud 65 kite 66 pants 24 | 67 beach 68 woman 71 dog 72 building 73 frisbee 77 hair 25 | 78 face 79 shorts 83 cat 94 tennisplayer 97 girl 101 elephant 26 | 103 bird 111 shadow 112 sleeve 113 tenniscourt 114 surface 115 finger 27 | 120 snow 121 sunglasses 126 skier 128 player 131 wrist 138 horse 28 | 151 hill 152 fence 154 cow 156 bear 162 river 163 railing 29 | 172 kid 175 mouth 177 air 178 distance 180 feet 182 wave 30 | 183 guy 184 reflection 200 stem 209 paw 210 branch 212 forest 31 | 215 rail 219 whiskers 221 neck 223 necklace 224 duck 228 giraffe 32 | 230 mane 232 beard 234 sun 235 shore 237 tower 239 gravel 33 | 243 awning 244 tent 246 teeth 276 pitcher 277 uniform 278 body 34 | 286 animal 290 child 291 licenseplate 292 catcher 293 umpire 296 batter 35 | 301 bridge 304 park 307 runway 314 theoutdoors 319 van 320 beak 36 | 325 shoulder 331 surfer 334 bluesky 335 whiteclouds 337 vehicle 340 baseball 37 | 341 baseballplayer 350 hoof 354 snowboarder 358 shade 360 spectator 361 knee 38 | 366 platform 372 weeds 373 treetrunk 396 ripples 399 tusk 39 | 93 tennisracket 124 skipole 338 streetsign 370 t-shirt 379 tennisball 40 | ''' 41 | 42 | VG_MSDN_PRED_IGNORES=(9,12, 18,20,22, 27, 28, 30, 31, 32, 35, 48) 43 | VG_MSDN_OBJ_IGNORES=(0,2,11,18,20,22,24,25,26,27,29,30,31,34,35, 37, 39, 40, 41, 43, 44, 45, 49,50,52,54, 55,\ 44 | 56, 58, 60, 67, 68, 72, 78, 79, 80, 81, 83, 84, 85, 88, 90, 92, 93, 95, 96, 97, 100, 103, 104,\ 45 | 107, 113, 115, 118, 119, 121, 127, 128, 129, 130, 133, 135, 136, 142, 143, 145, 147, 150) 46 | ''' 47 | objects_ignored (VG-MSDN) 48 | 0 __background__ 2 kite 11 sky 18 hill 20 woman 22 animal 49 | 24 bear 25 wave 26 giraffe 27 background 29 foot 30 shadow 50 | 31 lady 34 sand 35 nose 37 sidewalk 39 fence 43 hair 51 | 44 street 45 zebra 49 girl 50 arm 52 leaf 54 dirt 52 | 55 boat 56 bird 58 leg 60 surfer 67 boy 68 cow 53 | 72 road 78 cloud 79 sheep 80 horse 81 eye 83 neck 54 | 84 tail 85 vehicle 88 head 90 bus 92 train 93 child 55 | 95 ear 96 reflection 97 car 100 cat 103 grass 104 toilet 56 | 107 ocean 113 snow 115 field 118 branch 119 elephant 121 beach 57 | 127 mountain 128 track 129 hand 130 plane 133 skier 135 man 58 | 136 building 142 dog 143 face 145 person 147 truck 150 wing 59 | ''' 60 | 61 | 62 | class SGGen_DR_NET(models.HDN_v2.factorizable_network_v4s.Factorizable_network): 63 | """ 64 | Description 65 | - Detect and recognize objects with relations 66 | - DR_net extends factorizable network 67 | Functions 68 | - forward_eval: generate detection & recognition results for the given image 69 | """ 70 | def __init__(self,args, trainset, opts, 71 | ): 72 | super(SGGen_DR_NET,self).__init__(trainset,opts) 73 | print(args.path_opt.split('/')[-1]) 74 | if args.path_opt.split('/')[-1].strip() == 'VG-DR-Net.yaml' and args.dataset =='scannet': 75 | predicates_ignored = VG_DR_NET_PRED_IGNORES 76 | self.predicates_mask = torch.ByteTensor(40000,25).fill_(0).cuda() 77 | self.predicates_mask[:,predicates_ignored] = 1 78 | objects_ignored=VG_DR_NET_OBJ_IGNORES 79 | self.objects_mask = torch.ByteTensor(200,400).fill_(0) 80 | self.objects_mask[:,objects_ignored] = 1 81 | self.objects_mask = self.objects_mask.cuda() 82 | 83 | else: 84 | self.predicates_mask = torch.ByteTensor(40000, 25).fill_(0).cuda() 85 | self.objects_mask = torch.ByteTensor(200, 400).fill_(0).cuda() 86 | 87 | self.object_class_filter =[] 88 | self.predicate_class_filter =[] 89 | 90 | def forward_eval(self, im_data, im_info, gt_objects=None): 91 | # Currently, RPN support batch but not for MSDN 92 | features, object_rois = self.rpn(im_data, im_info) 93 | if gt_objects is not None: 94 | gt_rois = np.concatenate([np.zeros((gt_objects.shape[0], 1)), 95 | gt_objects[:, :4], 96 | np.ones((gt_objects.shape[0], 1))], 1) 97 | else: 98 | gt_rois = None 99 | object_rois, region_rois, mat_object, mat_phrase, mat_region = self.graph_construction(object_rois, gt_rois=gt_rois) 100 | # roi pool 101 | pooled_object_features = self.roi_pool_object(features, object_rois).view(len(object_rois), -1) 102 | pooled_object_features = self.fc_obj(pooled_object_features) 103 | pooled_region_features = self.roi_pool_region(features, region_rois) 104 | pooled_region_features = self.fc_region(pooled_region_features) 105 | bbox_object = self.bbox_obj(F.relu(pooled_object_features)) 106 | 107 | for i, mps in enumerate(self.mps_list): 108 | pooled_object_features, pooled_region_features = \ 109 | mps(pooled_object_features, pooled_region_features, mat_object, mat_region, object_rois, region_rois) 110 | 111 | pooled_phrase_features = self.phrase_inference(pooled_object_features, pooled_region_features, mat_phrase) 112 | pooled_object_features = F.relu(pooled_object_features) 113 | pooled_phrase_features = F.relu(pooled_phrase_features) 114 | 115 | cls_score_object = self.score_obj(pooled_object_features) 116 | cls_score_object.data.masked_fill_(self.objects_mask, -float('inf')) 117 | cls_prob_object = F.softmax(cls_score_object, dim=1) 118 | cls_score_predicate = self.score_pred(pooled_phrase_features) 119 | cls_score_predicate.data.masked_fill_(self.predicates_mask, -float('inf')) 120 | cls_prob_predicate = F.softmax(cls_score_predicate, dim=1) 121 | 122 | if self.learnable_nms: 123 | selected_prob, _ = cls_prob_object[:, 1:].max(dim=1, keepdim=False) 124 | reranked_score = self.nms(pooled_object_features, selected_prob, object_rois) 125 | else: 126 | reranked_score = None 127 | 128 | return (cls_prob_object, bbox_object, object_rois, reranked_score), \ 129 | (cls_prob_predicate, mat_phrase, region_rois.size(0)), 130 | 131 | 132 | class SGGen_MSDN(models.HDN_v2.factorizable_network_v4.Factorizable_network): 133 | """ 134 | Description 135 | - Detect and recognize objects with relations 136 | - MSDN extends factorizable network 137 | Functions 138 | - forward_eval: generate detection & recognition results for the given image 139 | """ 140 | def __init__(self,args, trainset, opts, 141 | ): 142 | super(SGGen_MSDN,self).__init__(trainset,opts) 143 | print(args.path_opt.split('/')[-1]) 144 | if args.path_opt.split('/')[-1].strip() == 'VG-MSDN.yaml' and args.dataset =='scannet': 145 | predicates_ignored = VG_MSDN_PRED_IGNORES 146 | self.predicates_mask = torch.ByteTensor(40000,51).fill_(0).cuda() 147 | self.predicates_mask[:,predicates_ignored] = 1 148 | objects_ignored=VG_MSDN_OBJ_IGNORES 149 | self.objects_mask = torch.ByteTensor(200,151).fill_(0) 150 | self.objects_mask[:,objects_ignored] = 1 151 | self.objects_mask = self.objects_mask.cuda() 152 | 153 | else: 154 | self.predicates_mask = torch.ByteTensor(40000, 51).fill_(0).cuda() 155 | self.objects_mask = torch.ByteTensor(200, 151).fill_(0).cuda() 156 | 157 | def forward_eval(self, im_data, im_info, gt_objects=None): 158 | # Currently, RPN support batch but not for MSDN 159 | features, object_rois = self.rpn(im_data, im_info) 160 | if gt_objects is not None: 161 | gt_rois = np.concatenate([np.zeros((gt_objects.shape[0], 1)), 162 | gt_objects[:, :4], 163 | np.ones((gt_objects.shape[0], 1))], 1) 164 | else: 165 | gt_rois = None 166 | object_rois, region_rois, mat_object, mat_phrase, mat_region = self.graph_construction(object_rois, gt_rois=gt_rois) 167 | # roi pool 168 | pooled_object_features = self.roi_pool_object(features, object_rois).view(len(object_rois), -1) 169 | pooled_object_features = self.fc_obj(pooled_object_features) 170 | pooled_region_features = self.roi_pool_region(features, region_rois) 171 | pooled_region_features = self.fc_region(pooled_region_features) 172 | bbox_object = self.bbox_obj(F.relu(pooled_object_features)) 173 | 174 | for i, mps in enumerate(self.mps_list): 175 | pooled_object_features, pooled_region_features = \ 176 | mps(pooled_object_features, pooled_region_features, mat_object, mat_region) 177 | 178 | pooled_phrase_features = self.phrase_inference(pooled_object_features, pooled_region_features, mat_phrase) 179 | pooled_object_features = F.relu(pooled_object_features) 180 | pooled_phrase_features = F.relu(pooled_phrase_features) 181 | 182 | cls_score_object = self.score_obj(pooled_object_features) 183 | cls_score_object.data.masked_fill_(self.objects_mask, -float('inf')) 184 | cls_prob_object = F.softmax(cls_score_object, dim=1) 185 | cls_score_predicate = self.score_pred(pooled_phrase_features) 186 | cls_score_predicate.data.masked_fill_(self.predicates_mask, -float('inf')) 187 | cls_prob_predicate = F.softmax(cls_score_predicate, dim=1) 188 | 189 | if self.learnable_nms: 190 | selected_prob, _ = cls_prob_object[:, 1:].max(dim=1, keepdim=False) 191 | reranked_score = self.nms(pooled_object_features, selected_prob, object_rois) 192 | else: 193 | reranked_score = None 194 | 195 | 196 | return (cls_prob_object, bbox_object, object_rois, reranked_score), \ 197 | (cls_prob_predicate, mat_phrase, region_rois.size(0)), 198 | -------------------------------------------------------------------------------- /model/keyframe/keyframe_extracion.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import os 4 | import os.path as osp 5 | import cv2 6 | import math 7 | import time 8 | import matplotlib.pyplot as plt 9 | import random 10 | 11 | PATH_IMG = '/media/mil2/HDD/mil2/scannet/ScanNet/SensReader/python/exported/color/' 12 | PATH_DEPTH = '/media/mil2/HDD/mil2/scannet/ScanNet/SensReader/python/exported/depth/' 13 | PATH_POSE = '/media/mil2/HDD/mil2/scannet/ScanNet/SensReader/python/exported/pose/' 14 | PATH_INTRINSIC = '/media/mil2/HDD/mil2/scannet/ScanNet/SensReader/python/exported/intrinsic/intrinsic_depth.txt' 15 | 16 | 17 | def read_files(path): 18 | """ 19 | Description 20 | - Read files in the given directory and return the list of the paths of the files 21 | 22 | Parameter 23 | - path: path to the directory to be read 24 | 25 | Return 26 | - file_name: list of file paths 27 | 28 | """ 29 | file_name = os.listdir(path) 30 | file_name = sorted(file_name, key=lambda x: int(x.split('.')[0])) 31 | file_name = [path + fn for fn in file_name] 32 | return file_name 33 | 34 | 35 | def blurryness(image): 36 | return cv2.Laplacian(image, cv2.CV_64F).var() 37 | 38 | 39 | def warp_image(target, depth, rel_pose, intrinsic): 40 | height, width, _ = target.shape 41 | output = np.zeros((height, width, 3), np.uint8) 42 | 43 | for i in range(height): 44 | for j in range(width): 45 | temp = np.dot(np.linalg.inv(intrinsic), np.array([i, j, 1], dtype=float).reshape(3, 1)) 46 | temp = (depth[i, j]) * temp 47 | temp = np.dot(rel_pose, np.append(temp, [1]).reshape(4, 1)) 48 | temp = temp / temp[3] 49 | temp = np.dot(intrinsic, temp[:3]) 50 | temp = temp / temp[2] 51 | x, y = int(round(temp[0])), int(round(temp[1])) 52 | if x >= 0 and x < height and y >= 0 and y < width: 53 | output[i, j, :] = target[x, y, :] 54 | return output 55 | 56 | 57 | def calculate_overlap(depth_org, rel_pose, intrinsic_org, pixel_coordinates,num_coordinate_samples=1000): 58 | """ 59 | Description 60 | - Calculate overlap between two images based on projection 61 | - Projection of img2 to img1 62 | - p' = K * T_21 * depth * K^(-1) * p 63 | 64 | Parameter 65 | - depth: information on depth image 66 | - pose: relative pose to the reference image 67 | - intrinsic: camera intrinsic 68 | 69 | Return 70 | - amount_overlap: estimated amount of overlap in percentage 71 | 72 | """ 73 | 74 | ## Pixel coordinates (p in the above eq.) 75 | intrinsic = np.copy(intrinsic_org) 76 | x_ratio = 0.1 77 | y_ratio = 0.1 78 | intrinsic[0] *= x_ratio 79 | intrinsic[1] *= y_ratio 80 | depth = cv2.resize(depth_org, None, fx=x_ratio, fy=y_ratio) 81 | height, width = depth.shape 82 | 83 | ## Calculate the amount of the overlapping area 84 | num_total = height * width 85 | num_overlap = 0 86 | 87 | for i in range(height): # y-direction 88 | for j in range(width): # x-direction 89 | temp = np.dot(np.linalg.inv(intrinsic), np.array([j, i, 1], dtype=float).reshape(3, 1)) # temp = (X, Y, 1) 90 | temp = (depth[i, j]) * temp # temp = (X', Y', Z') 91 | temp = np.dot(rel_pose, np.append(temp, [1]).reshape(4, 1)) 92 | temp = temp / float(temp[3] + 1e-10) 93 | temp = np.dot(intrinsic, temp[:3]) 94 | temp = temp / float(temp[2] + 1e-10) 95 | x, y = int(temp[0]), int(temp[1]) 96 | if x >= 0 and x < width and y >= 0 and y < height: 97 | num_overlap += 1 98 | 99 | overlapping_area = num_overlap / num_total 100 | 101 | return overlapping_area 102 | 103 | 104 | def relative_pose(pose1, pose2): 105 | """ 106 | Description 107 | - Calculate relative pose between a pair of poses 108 | - To avoid calculating matrix inverse, the calculation is based on 109 | - P_12 = [R_2^(-1) R_2^(-1)(t_1 - t_2); 0, 0, 0, 1], 110 | - where R_2^(-1) = R_2.T 111 | 112 | Parameter 113 | - pose1, pose2: 4 x 4 pose matrix 114 | 115 | Return 116 | - p_2_to_1 (relative_pose): estimated relative pose 117 | 118 | """ 119 | """ 120 | R_1, R_2 = pose1[:3, :3], pose2[:3, :3] 121 | t_1, t_2 = pose1[:, -1][:-1], pose2[:, -1][:-1] 122 | R = np.dot(R_2.T, R_1) 123 | T = np.dot(R_2.T, t_1 - t_2) 124 | p_1_to_2 = np.zeros((4, 4)) 125 | p_1_to_2[:3, :3] = R 126 | p_1_to_2[:3, -1] = T 127 | p_1_to_2[-1, -1] = 1 128 | """ 129 | p_2_to_1 = np.dot(np.linalg.inv(pose2), pose1) 130 | return p_2_to_1 131 | 132 | 133 | def read_matrix_from_txt(matrix_file): 134 | """ 135 | Description 136 | - Read a matrix from .txt file 137 | 138 | Parameter 139 | - matrix_file: .txt file containing n x m matrix 140 | 141 | Return 142 | - matrix_array: numpy array of (n, m) shape 143 | 144 | """ 145 | f = open(matrix_file).readlines() 146 | matrix_array = [row.split() for row in f] 147 | matrix_array = np.array(matrix_array, dtype=float) 148 | return matrix_array 149 | 150 | 151 | def key_frame_extractor(file_name_img, file_name_depth, file_name_pose, intrinsic): 152 | """ 153 | Description 154 | - Extract keyframe groups by calculating overlapping areas 155 | 156 | Parameter 157 | - file_name_depth: list of depth paths 158 | - file_name_pose: list of pose paths 159 | - cam_intrinsic: scaled camera intrinsic matrix 160 | 161 | Return 162 | - key_frame_groups: extracted key frame groups 163 | 164 | """ 165 | # assert len(file_name_img) == len(file_name_depth), "Number of image != number of depth" 166 | # assert len(file_name_img) == len(file_name_pose), "Number of image != number of pose" 167 | # assert len(file_name_depth) == len(file_name_pose), "Number of depth != number of pose" 168 | 169 | # initialize variables 170 | key_frame_groups, curr_key_frame_group = [], [0] 171 | key_pose, anchor_pose = [read_matrix_from_txt(file_name_pose[0])] * 2 172 | thresh_key, thresh_anchor = 0.1, 0.65 173 | average_of_blurryness = blurryness(cv2.imread(file_name_img[0])) 174 | alpha = 0.7 175 | height, width = 480, 640 176 | pixel_coordinates = np.array([[x, y, 1] for x in np.arange(height) for y in np.arange(width)]) 177 | pixel_coordinates = np.swapaxes(pixel_coordinates, 0, 1) 178 | 179 | for ind, (pose_path, depth_path) in enumerate(zip(file_name_pose, file_name_depth)): 180 | print('FrameNum: {ind:4}'.format(ind=ind)) 181 | # 1. reject blurry images 182 | time_start = time.time() 183 | rgb_image =cv2.imread(file_name_img[ind]) 184 | depth = cv2.imread(depth_path, -1) 185 | # hi, wi, _ = rgb_image.shape 186 | hd, wd = depth.shape 187 | # 188 | # x_ratio = wd / wi 189 | # y_ratio = hd / hi 190 | # 191 | # intrinsic[0] *= x_ratio 192 | # intrinsic[1] *= y_ratio 193 | # 194 | #rgb_image = cv2.resize(rgb_image, (wd,hd), interpolation= cv2.INTER_AREA) 195 | #cv2.imshow('sample',rgb_image) 196 | #cv2.waitKey(1) 197 | 198 | curr_blurry = blurryness(rgb_image) 199 | average_of_blurryness = alpha * average_of_blurryness + (1 - alpha) * curr_blurry 200 | threshold = 25 * math.log(average_of_blurryness) + 25 201 | print('Blurry Score (the higher the better): {score:3.1f} , Threshold: {threshold:3.1f}' 202 | .format(score=curr_blurry, threshold=threshold)) 203 | if curr_blurry < threshold: 204 | continue 205 | ckpt_blurry = time.time() 206 | 207 | # 2. calculate the relative pose to key & anchor frames 208 | curr_pose = read_matrix_from_txt(pose_path) 209 | rel_pose_to_key = relative_pose(curr_pose, key_pose) 210 | rel_pose_to_anchor = relative_pose(curr_pose, anchor_pose) 211 | ckpt_calc_pose = time.time() 212 | 213 | # 3. calculate the ratio of the overlapping area 214 | 215 | overlap_with_key = calculate_overlap(depth, rel_pose_to_key, intrinsic, pixel_coordinates) 216 | overlap_with_anchor = calculate_overlap(depth, rel_pose_to_anchor, intrinsic, pixel_coordinates) 217 | ckpt_calc_overlap = time.time() 218 | 219 | # 4. update anchor and key frames 220 | if overlap_with_anchor < thresh_anchor: 221 | curr_key_frame_group.append(ind) 222 | anchor_pose = curr_pose 223 | 224 | if overlap_with_key < thresh_key or len(curr_key_frame_group) > 10: 225 | key_frame_groups.append(curr_key_frame_group) 226 | curr_key_frame_group = [] 227 | key_pose, anchor_pose = [curr_pose] * 2 228 | 229 | ckpt_update = time.time() 230 | 231 | print('overlap_with_key: {ov1:1.2f} overlap_with_anchor: {ov2:1.2f} ' 232 | .format(ov1=overlap_with_key, ov2=overlap_with_anchor)) 233 | 234 | # print('--------------Elapsed Time---------------') 235 | print('blurry pose overlap update') 236 | print('{blurry:1.2f} {pose:1.2f} {overlap:1.2f} {update:1.2f} \n'.format( 237 | blurry=ckpt_blurry - time_start, pose=ckpt_calc_pose - ckpt_blurry, 238 | overlap=ckpt_calc_overlap - ckpt_calc_pose, update=ckpt_update - ckpt_calc_overlap)) 239 | 240 | # cv2.namedWindow('sample') # Create a named window 241 | # cv2.moveWindow('sample', 700, 10) 242 | # cv2.imshow('sample', rgb_image) 243 | # cv2.waitKey(1) 244 | 245 | 246 | 247 | return key_frame_groups 248 | 249 | 250 | class keyframe_checker(object): 251 | def __init__(self,args, 252 | intrinsic_depth = None, 253 | thresh_key=0.1, 254 | thresh_anchor=0.65, 255 | max_group_len = 10, 256 | blurry_gain=30, 257 | blurry_offset=10, 258 | alpha=0.4, depth_shape=(480, 640) , 259 | num_coordinate_samples=1000, 260 | BLURRY_REJECTION_ONLY = None): 261 | self.args = args 262 | self.frame_num = 0 263 | if BLURRY_REJECTION_ONLY == None: 264 | self.BLURRY_REJECTION_ONLY = True if self.args.dataset != 'scannet' else False 265 | else: self.BLURRY_REJECTION_ONLY=BLURRY_REJECTION_ONLY 266 | 267 | # Blurry Image Rejection: Hyper parameters 268 | self.blurry_gain = blurry_gain 269 | self.blurry_offset = blurry_offset 270 | self.alpha = alpha 271 | 272 | # Key frame/ anchor frame Selection: Hyper parameters 273 | if not self.BLURRY_REJECTION_ONLY: 274 | self.intrinsic_depth = np.array(intrinsic_depth[:3,:3]) 275 | self.thresh_key = thresh_key 276 | self.thresh_anchor = thresh_anchor 277 | self.max_group_len = max_group_len 278 | self.depth_shape = depth_shape 279 | pixel_coordinates = np.array([[x, y, 1] for x in np.arange(depth_shape[0]) for y in np.arange(depth_shape[1])]) 280 | self.pixel_coordinates = np.swapaxes(pixel_coordinates, 0, 1) 281 | self.key_frame_groups, self.curr_key_frame_group = [], [0] 282 | self.num_cooridnate_samples = num_coordinate_samples 283 | 284 | 285 | 286 | def check_frame(self,img, depth, pose): 287 | if self.args.disable_keyframe: return True, 0.0, 0.0 288 | if self.frame_num == 0: 289 | self.average_of_blurryness = blurryness(img) 290 | self.key_pose, self.anchor_pose, = [pose] * 2 291 | 292 | # 1. reject blurry images 293 | curr_blurry = blurryness(img) 294 | self.average_of_blurryness = self.alpha * self.average_of_blurryness + (1 - self.alpha) * curr_blurry 295 | threshold = self.blurry_gain * math.log(self.average_of_blurryness) + self.blurry_offset 296 | #threshold = self.blurry_gain * self.average_of_blurryness + self.blurry_offset 297 | if curr_blurry < threshold: return False, curr_blurry, threshold 298 | #if self.BLURRY_REJECTION_ONLY: return curr_blurry > threshold, curr_blurry, threshold 299 | 300 | # 2. calculate the relative pose to key & anchor frames 301 | rel_pose_to_key = relative_pose(pose, self.key_pose) 302 | rel_pose_to_anchor = relative_pose(pose, self.anchor_pose) 303 | 304 | # 3. calculate the ratio of the overlapping area 305 | depth = np.asarray(depth,dtype='uint16') 306 | 307 | overlap_with_key = calculate_overlap(depth, rel_pose_to_key, self.intrinsic_depth, 308 | self.pixel_coordinates,self.num_cooridnate_samples) 309 | overlap_with_anchor = calculate_overlap(depth, rel_pose_to_anchor, self.intrinsic_depth, 310 | self.pixel_coordinates,self.num_cooridnate_samples) 311 | 312 | # 4. update anchor and key frames 313 | if overlap_with_anchor < self.thresh_anchor: 314 | self.curr_key_frame_group.append(self.frame_num) 315 | self.anchor_pose = pose 316 | IS_ANCHOR = True 317 | else: IS_ANCHOR = False 318 | 319 | if overlap_with_key < self.thresh_key or len(self.curr_key_frame_group) > self.max_group_len: 320 | self.key_frame_groups.append(self.curr_key_frame_group) 321 | self.curr_key_frame_group = [] 322 | self.key_pose, self.anchor_pose = [pose] * 2 323 | IS_KEY = True 324 | else: IS_KEY = False 325 | 326 | self.frame_num +=1 327 | 328 | return IS_KEY or IS_ANCHOR, curr_blurry, threshold 329 | 330 | 331 | 332 | 333 | if __name__=="__main__": 334 | file_name_img, file_name_depth, file_name_pose = read_files(PATH_IMG), read_files(PATH_DEPTH), read_files(PATH_POSE) 335 | intrinsic = read_matrix_from_txt(PATH_INTRINSIC)[:3, :3] # 3x3 matrix 336 | keyframe_groups = key_frame_extractor(file_name_img[:1000], file_name_depth[:1000], file_name_pose[:1000], intrinsic) 337 | -------------------------------------------------------------------------------- /model/keyframe/keyframe_extracion_example.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import os 4 | import cv2 5 | import math 6 | import time 7 | import matplotlib.pyplot as plt 8 | import random 9 | PATH_IMG = '/media/mil2/HDD/mil2/scannet/ScanNet/SensReader/python/exported/color/' 10 | PATH_DEPTH = '/media/mil2/HDD/mil2/scannet/ScanNet/SensReader/python/exported/depth/' 11 | PATH_POSE = '/media/mil2/HDD/mil2/scannet/ScanNet/SensReader/python/exported/pose/' 12 | PATH_INTRINSIC = '/media/mil2/HDD/mil2/scannet/ScanNet/SensReader/python/exported/intrinsic/intrinsic_color.txt' 13 | 14 | 15 | def read_files(path): 16 | """ 17 | Description 18 | - Read files in the given directory and return the list of the paths of the files 19 | 20 | Parameter 21 | - path: path to the directory to be read 22 | 23 | Return 24 | - file_name: list of file paths 25 | 26 | """ 27 | file_name = os.listdir(path) 28 | file_name = sorted(file_name, key=lambda x: int(x.split('.')[0])) 29 | file_name = [path + fn for fn in file_name] 30 | return file_name 31 | 32 | 33 | 34 | def blurryness(image): 35 | return cv2.Laplacian(image, cv2.CV_64F).var() 36 | 37 | def overlap(bb_test,bb_gt): 38 | """ 39 | Computes IUO between two bboxes in the form [x1,y1,x2,y2] 40 | """ 41 | xx1 = np.maximum(bb_test[0], bb_gt[0]) 42 | yy1 = np.maximum(bb_test[1], bb_gt[1]) 43 | xx2 = np.minimum(bb_test[2], bb_gt[2]) 44 | yy2 = np.minimum(bb_test[3], bb_gt[3]) 45 | w = np.maximum(0., xx2 - xx1) 46 | h = np.maximum(0., yy2 - yy1) 47 | wh = w * h 48 | o = wh / ((bb_test[2]-bb_test[0])*(bb_test[3]-bb_test[1]) 49 | + (bb_gt[2]-bb_gt[0])*(bb_gt[3]-bb_gt[1]) - wh) 50 | return(o) 51 | 52 | def key_frame_extractor(file_name_img, file_name_depth, file_name_pose, intrinsic): 53 | """ 54 | Description 55 | - Extract keyframe groups by calculating overlapping areas 56 | 57 | Parameter 58 | - file_name_depth: list of depth paths 59 | - file_name_pose: list of pose paths 60 | - cam_intrinsic: scaled camera intrinsic matrix 61 | 62 | Return 63 | - key_frame_groups: extracted key frame groups 64 | 65 | """ 66 | # assert len(file_name_img) == len(file_name_depth), "Number of image != number of depth" 67 | # assert len(file_name_img) == len(file_name_pose), "Number of image != number of pose" 68 | # assert len(file_name_depth) == len(file_name_pose), "Number of depth != number of pose" 69 | 70 | # initialize variables 71 | key_frame_groups, curr_key_frame_group = [], [0] 72 | key_pose, anchor_pose = [read_matrix_from_txt(file_name_pose[0])] * 2 73 | thresh_key, thresh_anchor = 0.1, 0.65 74 | average_of_blurryness = blurryness(cv2.imread(file_name_img[0])) 75 | alpha = 0.7 76 | height, width = 480, 640 77 | pixel_coordinates = np.array([[x, y, 1] for x in np.arange(height) for y in np.arange(width)]) 78 | pixel_coordinates = np.swapaxes(pixel_coordinates, 0, 1) 79 | 80 | for ind, (pose_path, depth_path) in enumerate(zip(file_name_pose, file_name_depth)): 81 | print('FrameNum: {ind:4}'.format(ind=ind)) 82 | # 1. reject blurry images 83 | time_start = time.time() 84 | curr_blurry = blurryness(cv2.imread(file_name_img[ind])) 85 | average_of_blurryness = alpha * average_of_blurryness + (1 - alpha) * curr_blurry 86 | threshold = 25 * math.log(average_of_blurryness) + 25 87 | print('Blurry Score (the higher the better): {score:3.1f} , Threshold: {threshold:3.1f}' 88 | .format(score=curr_blurry,threshold=threshold)) 89 | #if curr_blurry < threshold: 90 | # continue 91 | ckpt_blurry = time.time() 92 | 93 | # 2. calculate the relative pose to key & anchor frames 94 | curr_pose = read_matrix_from_txt(pose_path) 95 | rel_pose_to_key = relative_pose(curr_pose, key_pose) 96 | rel_pose_to_anchor = relative_pose(curr_pose, anchor_pose) 97 | ckpt_calc_pose = time.time() 98 | 99 | 100 | # 3. calculate the ratio of the overlapping area 101 | depth = cv2.imread(depth_path, -1) 102 | overlap_with_key,coordinates = calculate_overlap(depth, rel_pose_to_key, intrinsic, pixel_coordinates) 103 | overlap_with_anchor,coordinates = calculate_overlap(depth, rel_pose_to_anchor, intrinsic, pixel_coordinates) 104 | ckpt_calc_overlap = time.time() 105 | 106 | # 4. update anchor and key frames 107 | if overlap_with_anchor < thresh_anchor: 108 | curr_key_frame_group.append(ind) 109 | anchor_pose = curr_pose 110 | 111 | if overlap_with_key < thresh_key or len(curr_key_frame_group) > 10: 112 | key_frame_groups.append(curr_key_frame_group) 113 | curr_key_frame_group = [] 114 | key_pose, anchor_pose = [curr_pose] * 2 115 | 116 | ckpt_update = time.time() 117 | #print('--------------Elapsed Time---------------') 118 | print('blurry pose overlap update') 119 | print('{blurry:1.2f} {pose:1.2f} {overlap:1.2f} {update:1.2f} \n'.format( 120 | blurry=ckpt_blurry-time_start, pose=ckpt_calc_pose-ckpt_blurry, 121 | overlap = ckpt_calc_overlap - ckpt_calc_pose, update = ckpt_update-ckpt_calc_overlap)) 122 | 123 | return key_frame_groups 124 | 125 | 126 | def warp_image(target, depth, rel_pose, intrinsic): 127 | height, width, _ = target.shape 128 | output = np.zeros((height, width, 3), np.uint8) 129 | 130 | for i in range(height): 131 | for j in range(width): 132 | temp = np.dot(np.linalg.inv(intrinsic), np.array([i, j, 1], dtype=float).reshape(3, 1)) 133 | temp = (depth[i, j]) * temp 134 | temp = np.dot(rel_pose, np.append(temp, [1]).reshape(4, 1)) 135 | temp = temp / temp[3] 136 | temp = np.dot(intrinsic, temp[:3]) 137 | temp = temp / temp[2] 138 | x, y = int(round(temp[0])), int(round(temp[1])) 139 | if x >= 0 and x < height and y >= 0 and y < width: 140 | output[i, j, :] = target[x, y, :] 141 | return output 142 | 143 | 144 | def calculate_overlap(depth, pose, intrinsic,pixel_coordinates): 145 | """ 146 | Description 147 | - Calculate overlap between two images based on projection 148 | - Projection of img2 to img1 149 | - p' = K * T_21 * depth * K^(-1) * p 150 | 151 | Parameter 152 | - depth: information on depth image 153 | - pose: relative pose to the reference image 154 | - intrinsic: camera intrinsic 155 | 156 | Return 157 | - amount_overlap: estimated amount of overlap in percentage 158 | 159 | """ 160 | ## Step 1. Pixel coordinates (p in the above eq.) 161 | start_overlap = time.time() 162 | height, width = depth.shape 163 | #x, y = np.arange(width), np.arange(height) 164 | #X, Y = np.meshgrid(x, y) 165 | #pixel_coordinates = [list(zip(x, y)) for x, y in zip(X, Y)] 166 | #pixel_coordinates = [np.array([x, y, 1]).reshape(3, 1) for sub in pixel_coordinates for x, y in sub] 167 | #pixel_coordinates = np.array([[x, y, 1] for x in np.arange(height) for y in np.arange(width)]) 168 | # pixel_coordinates = np.array(pixel_coordinates).squeeze() 169 | #pixel_coordinates = np.swapaxes(pixel_coordinates, 0, 1) 170 | ckpt_step1 = time.time() 171 | 172 | 173 | ## Step 2. pixel coordinate => camera coordinate (real-world coordinate) 174 | intrinsic_inv = np.linalg.inv(intrinsic) 175 | coordinates = np.dot(intrinsic_inv, pixel_coordinates) 176 | coordinates = np.swapaxes(coordinates, 0, 1) 177 | # coordinates = [np.dot(intrinsic_inv, pixel_coord) for pixel_coord in pixel_coordinates] 178 | # coordinates = np.array(coordinates).squeeze() 179 | ckpt_step2 = time.time() 180 | 181 | 182 | ## Step 3. 3-D position reconstruction 183 | depth = depth.flatten().reshape(-1, 1) 184 | coordinates = np.multiply(coordinates, depth) 185 | ckpt_step3 = time.time() 186 | 187 | 188 | 189 | ## Step 4. Reprojection 190 | # homogeneous coordinate for 3-D points 191 | coordinates = np.hstack((coordinates, np.ones((coordinates.shape[0], 1)))) 192 | coordinates = np.swapaxes(coordinates, 0, 1) 193 | coordinates = np.swapaxes(np.dot(pose, coordinates), 0, 1) 194 | # normalization for 3-D points 195 | coordinates = coordinates[:, :3] / (coordinates[:, 3][:, None] + 1e-10) 196 | # reprojection 197 | coordinates = np.swapaxes(coordinates, 0, 1) 198 | coordinates = np.swapaxes(np.dot(intrinsic, coordinates), 0, 1) 199 | # normalization for 2-D points 200 | coordinates = coordinates[:, :2] / (coordinates[:, 2][:, None] + 1e-10) 201 | ckpt_step4 = time.time() 202 | 203 | 204 | ## Step 5. Calculate the amount of the overlapping area 205 | ov2_start = time.time() 206 | overlapping_area2 = sum((coordinates[:, 0] < width) & (coordinates[:, 1] < height) 207 | & (coordinates[:, 0] > 0) & (coordinates[:, 1] > 0)) 208 | overlapping_area2 /= (width * height) 209 | ov2_end = time.time() 210 | 211 | 212 | ov1_start = time.time() 213 | # Randomly sample 1000 points 214 | coordinates = random.sample(coordinates, 1000) 215 | coordinates = np.stack(coordinates) 216 | 217 | overlapping_points = coordinates[(coordinates[:, 0] < width) & (coordinates[:, 1] < height)\ 218 | & (coordinates[:, 0] > 0) & (coordinates[:, 1] > 0)] 219 | minX, minY = overlapping_points.min(axis=0) 220 | maxX, maxY = overlapping_points.max(axis=0) 221 | 222 | overlapping_area1 = (maxX-minX)*(maxY-minY)/(width*height) 223 | overlapping_area1 = min(overlapping_area1, 1.0) 224 | ov1_end = time.time() 225 | 226 | 227 | 228 | 229 | ckpt_step5 = time.time() 230 | 231 | #mean,std = coordinates.mean(axis=0), coordinates.std(axis=0) 232 | #coordinates = coordinates[coordinates[:,0] < mean[0] + 1 * std[0]] 233 | #coordinates = coordinates[coordinates[:,1] > mean[1] - 1 * std[1]] 234 | 235 | # p1 = coordinates[np.argmax(coordinates[:, 0])] 236 | # p2 = coordinates[np.argmax(coordinates[:, 1])] 237 | # p3 = coordinates[np.argmin(coordinates[:, 0])] 238 | # p4 = coordinates[np.argmin(coordinates[:, 1])] 239 | # 240 | # quadrangle = np.stack([p1,p2,p3,p4]) 241 | 242 | print('overlap1: {ov1:1.2f} overlap2: {ov2:1.2f} '.format(ov1=overlapping_area1, ov2= overlapping_area2)) 243 | 244 | print('--------------Elapsed Time---------------') 245 | print('method1: {ov1:1.2f} method2: {ov2:1.2f} '.format(ov1=ov1_end-ov1_start, ov2=ov2_end-ov2_start)) 246 | 247 | print('--------------Elapsed Time---------------') 248 | print('step1 step2 step3 step4 step5') 249 | print('{s1:1.3f} {s2:1.3f} {s3:1.3f} {s4:1.3f} {s5:1.3f} \n'.format( 250 | s1=ckpt_step1 - start_overlap, s2=ckpt_step2 - ckpt_step1, 251 | s3=ckpt_step3 - ckpt_step2, s4=ckpt_step4 - ckpt_step3,s5=ckpt_step5 - ckpt_step4)) 252 | 253 | 254 | # plt.clf() 255 | # plt.scatter(coordinates[:, 0].tolist(), coordinates[:, 1].tolist(),s=1,color='b') 256 | # plt.scatter(quadrangle[:,0].tolist(),quadrangle[:,1].tolist(),s=10,color='r') 257 | # #plt.xlim([-3000,3000]) 258 | # #plt.ylim([-3000,3000]) 259 | # #plt.show() 260 | # plt.draw() 261 | # plt.pause(0.0001) 262 | 263 | 264 | return overlapping_area2, coordinates 265 | 266 | 267 | def relative_pose(pose1, pose2): 268 | """ 269 | Description 270 | - Calculate relative pose between a pair of poses 271 | - To avoid calculating matrix inverse, the calculation is based on 272 | - P_12 = [R_2^(-1) R_2^(-1)(t_1 - t_2); 0, 0, 0, 1], 273 | - where R_2^(-1) = R_2.T 274 | 275 | Parameter 276 | - pose1, pose2: 4 x 4 pose matrix 277 | 278 | Return 279 | - p_2_to_1 (relative_pose): estimated relative pose 280 | 281 | """ 282 | """ 283 | R_1, R_2 = pose1[:3, :3], pose2[:3, :3] 284 | t_1, t_2 = pose1[:, -1][:-1], pose2[:, -1][:-1] 285 | R = np.dot(R_2.T, R_1) 286 | T = np.dot(R_2.T, t_1 - t_2) 287 | p_1_to_2 = np.zeros((4, 4)) 288 | p_1_to_2[:3, :3] = R 289 | p_1_to_2[:3, -1] = T 290 | p_1_to_2[-1, -1] = 1 291 | """ 292 | p_2_to_1 = np.dot(np.linalg.inv(pose2), pose1) 293 | return p_2_to_1 294 | 295 | 296 | def read_matrix_from_txt(matrix_file): 297 | """ 298 | Description 299 | - Read a matrix from .txt file 300 | 301 | Parameter 302 | - matrix_file: .txt file containing n x m matrix 303 | 304 | Return 305 | - matrix_array: numpy array of (n, m) shape 306 | 307 | """ 308 | f = open(matrix_file).readlines() 309 | matrix_array = [row.split() for row in f] 310 | matrix_array = np.array(matrix_array, dtype=float) 311 | return matrix_array 312 | 313 | 314 | file_name_img, file_name_depth, file_name_pose = read_files(PATH_IMG), read_files(PATH_DEPTH), read_files(PATH_POSE) 315 | # 316 | # ref_idx = 10 317 | intrinsic = read_matrix_from_txt(PATH_INTRINSIC)[:3, :3] # 3x3 matrix 318 | # pose2 = read_matrix_from_txt(file_name_pose[ref_idx]) # 4x4 matrix 319 | # img_test = cv2.imread(file_name_img[ref_idx]) # 480x640x3 RGB image 320 | # test_depth= cv2.imread(file_name_depth[ref_idx], -1) # 480x640 Depth image 321 | # hi, wi, _ = img_test.shape 322 | # hd, wd = test_depth.shape 323 | # 324 | # x_ratio = wd / wi 325 | # y_ratio = hd / hi 326 | # 327 | # intrinsic[0] *= x_ratio 328 | # intrinsic[1] *= y_ratio 329 | # 330 | # img_test = cv2.resize(img_test, (wd, hd)) 331 | # 332 | # test_idx = 300 333 | # img_test2 = cv2.imread(file_name_img[test_idx]) 334 | # depth, pose = cv2.imread(file_name_depth[test_idx], -1), read_matrix_from_txt(file_name_pose[test_idx]) 335 | # rel_pose = relative_pose(pose2, pose) # pose * rel_pose = pose2 336 | # test = calculate_overlap(depth, rel_pose, intrinsic) 337 | # rel_pose2 = relative_pose(pose, pose2) 338 | # 339 | # img_test2 = cv2.resize(img_test2, (wd, hd)) 340 | # 341 | # weird_pose = read_matrix_from_txt(file_name_pose[1000]) 342 | # rel_pose3 = relative_pose(pose, weird_pose) 343 | 344 | # out_img1 = warp_image(img_test2, test_depth, rel_pose, intrinsic) 345 | # out_img2 = warp_image(img_test2, test_depth, rel_pose2, intrinsic) 346 | 347 | # cv2.imshow('target_image', img_test2) 348 | # cv2.imshow('source_image', img_test) 349 | # cv2.imshow('warped_image1 with rel_pose', out_img1) 350 | # cv2.imshow('warped_image1 with rel_pose2', out_img2) 351 | # cv2.waitKey(0) 352 | # cv2.destroyAllWindows() 353 | 354 | ''' 355 | for test_idx in range(1, 5570, 10): 356 | depth, pose = cv2.imread(file_name_depth[test_idx], -1), read_matrix_from_txt(file_name_pose[test_idx]) 357 | rel_pose = relative_pose(pose, pose2) # pose * rel_pose = pose2 358 | test = calculate_overlap(depth, rel_pose, intrinsic) 359 | print('test_idx: {}, overlap: {}'.format(test_idx, test)) 360 | ''' 361 | 362 | plt.axis([-3000, 3000, -3000, 3000]) 363 | plt.ion() 364 | plt.show() 365 | keyframe_groups = key_frame_extractor(file_name_img, file_name_depth, file_name_pose, intrinsic) 366 | -------------------------------------------------------------------------------- /model/interpret.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./FactorizableNet') 3 | from lib.fast_rcnn.nms_wrapper import nms 4 | from lib.fast_rcnn.bbox_transform import bbox_transform_inv_hdn, clip_boxes 5 | from lib.utils.nms import triplet_nms as triplet_nms_py 6 | from sort.sort import Sort, iou 7 | import numpy as np 8 | from torch.autograd import Variable 9 | import torchtext 10 | import torch 11 | from prior import relation_prior 12 | from SGGenModel import VG_DR_NET_OBJ_IGNORES 13 | from torch.nn.functional import cosine_similarity 14 | 15 | 16 | def filter_untracted(ref_bbox, tobefiltered_bbox): 17 | keep = [] 18 | for bbox in ref_bbox: 19 | ious = [iou(bbox[:4], obj_box) for obj_box in tobefiltered_bbox] 20 | keep.append(np.argmax(ious)) 21 | return keep 22 | 23 | 24 | def nms_detections(pred_boxes, scores, nms_thresh, inds=None): 25 | keep = range(scores.shape[0]) 26 | keep, scores, pred_boxes = zip(*sorted(zip(keep, scores, pred_boxes), key=lambda x: x[1][0])[::-1]) 27 | keep, scores, pred_boxes = np.array(keep), np.array(scores), np.array(pred_boxes) 28 | dets = np.hstack((pred_boxes, scores[:,0][:, np.newaxis])).astype(np.float32) 29 | keep_keep = nms(dets, nms_thresh) 30 | keep_keep = keep_keep[:min(100, len(keep_keep))] 31 | keep = keep[keep_keep] 32 | if inds is None: 33 | return pred_boxes[keep_keep], scores[keep_keep], keep 34 | return pred_boxes[keep_keep], scores[keep_keep], inds[keep], keep 35 | 36 | 37 | class interpreter(object): 38 | """ 39 | Description 40 | - Interpret and analyze the results of recognition module 41 | - Statistical and semantic priors are used to filter out spurious detections 42 | - Statistical prior: generated from visual genome dataset 43 | - Semantic prior: generated from word2vec 44 | - Relations are interpreted (inferred and rejected) and missing objects are inferred 45 | Major Functions 46 | - functions for calculating proabilities: cal_p_xy_joint, cal_p_x_given_y, cal_p_x_given_yz, check_prob_condition 47 | - spurious_relation_rejection: spurious detections get rejected based on priors 48 | - missing_object_inference: missing objects are inferred and added to 3D scene graph 49 | - missing_relation_inference: missing relations are inferred and added to 3D scene graph 50 | """ 51 | def __init__(self,args, data_set,ENABLE_TRACKING=None): 52 | self.tracker = Sort() 53 | self.args = args 54 | self.nms_thres = args.nms 55 | self.triplet_nms_thres =args.triplet_nms 56 | self.obj_thres = args.obj_thres 57 | self.triplet_thres = args.triplet_thres 58 | self.tobefiltered_objects = [26, 53, 134, 247, 179, 74, 226, 135, 145, 300, 253, 95, 11, 102,87] 59 | # 26: wheel, 53: backpack, 143:light, 247:camera, 179:board 60 | # 74:shoe, 226:chair, 135:shelf, 145:button, 300:cake, 253:knob, 95:wall, 11:door, 102:mirror,87:ceiling 61 | if ENABLE_TRACKING == None: 62 | self.ENABLE_TRACKING = False if self.args.dataset == 'visual_genome' else True 63 | else: 64 | self.ENABLE_TRACKING = ENABLE_TRACKING 65 | if self.ENABLE_TRACKING and self.args.path_opt.split('/')[-1] == 'VG-DR-Net.yaml': 66 | self.tobefiltered_predicates = [0,6,10,18,19,20,22,23,24] 67 | # 0:backgrounds, 6:eat,10:wear, 18:ride, 19:watch, 20:play, 22:enjoy, 23:read, 24:cut 68 | 69 | elif self.ENABLE_TRACKING and self.args.path_opt.split('/')[-1] == 'VG-MSDN.yaml': 70 | self.tobefiltered_predicates = [12, 18, 27, 28, 30, 31, 32, 35] 71 | else: 72 | self.tobefiltered_predicates = [] 73 | 74 | # Params for Statistics Based Scene Graph Inference 75 | self.relation_statistics = relation_prior.load_obj("model/prior/preprocessed/relation_prior_prob") 76 | self.joint_probability = relation_prior.load_obj("model/prior/preprocessed/object_prior_prob") 77 | self.spurious_rel_thres = 0.07 78 | self.rel_infer_thres = 0.9 79 | self.obj_infer_thres = 0.001 80 | self.data_set = data_set 81 | self.detected_obj_set = set() 82 | self.fasttext = torchtext.vocab.FastText() 83 | self.word_vecs, self.word_itos,self.word_stoi = self.prepare_wordvecs(num_vocabs=400,ignores=VG_DR_NET_OBJ_IGNORES) 84 | self.pred_stoi = {self.data_set.predicate_classes[i]: i for i in range(len(self.data_set.predicate_classes))} 85 | 86 | # p(x, y) 87 | def cal_p_xy_joint(self,x_ind, y_ind): 88 | p_xy = self.joint_probability[x_ind, y_ind] / np.sum(self.joint_probability) 89 | return p_xy 90 | 91 | # p(x|y) 92 | def cal_p_x_given_y(self,x_ind, y_ind): 93 | single_prob = np.sum(self.joint_probability, axis=1) 94 | p_y = single_prob[y_ind] 95 | p_xy = self.joint_probability[x_ind, y_ind] 96 | return p_xy / p_y 97 | 98 | # p(x|y,z) approximated 99 | def cal_p_x_given_yz(self,x_ind, y_ind, z_ind): 100 | p_x_given_y = self.cal_p_x_given_y(x_ind, y_ind) 101 | p_x_given_z = self.cal_p_x_given_y(x_ind, z_ind) 102 | return min(p_x_given_y, p_x_given_z) 103 | 104 | # True if p(x, z)^2 < p(x,y)*p(y,z) 105 | def check_prob_condition(self,x_ind,y_ind,z_ind): 106 | p_xz = self.cal_p_xy_joint(x_ind,z_ind) 107 | p_xy = self.cal_p_xy_joint(x_ind,y_ind) 108 | p_yz = self.cal_p_xy_joint(y_ind,z_ind) 109 | return p_xz**2 < p_xy*p_yz 110 | 111 | def prepare_wordvecs(self,num_vocabs = 400, ignores = VG_DR_NET_OBJ_IGNORES): 112 | word_inds = range(num_vocabs) 113 | word_inds = [x for x in word_inds if x not in ignores] 114 | word_txts = [self.data_set.object_classes[x] for x in word_inds] 115 | self.word_ind2vec = {ind:self.fasttext.vectors[self.fasttext.stoi[x]] for ind,x in zip(word_inds,word_txts)} 116 | 117 | word_vecs = torch.stack([self.fasttext.vectors[self.fasttext.stoi[x]] for x in word_txts]).cuda() 118 | word_itos = {i: self.data_set.object_classes[x] for i, x in enumerate(word_inds)} 119 | word_stoi = {self.data_set.object_classes[x]:i for i, x in enumerate(word_inds)} 120 | return word_vecs, word_itos, word_stoi 121 | 122 | def update_obj_set(self,obj_inds): 123 | for obj_ind in obj_inds[:,0]: self.detected_obj_set.add(obj_ind) 124 | 125 | def find_disconnected_pairs(self,obj_inds, relationships): 126 | connected_pairs = set(tuple(x) for x in relationships[:, :2].astype(int).tolist()) 127 | disconnected_pairs = set() 128 | for i in range(len(obj_inds)): 129 | for j in range(len(obj_inds)): 130 | if i == j: continue 131 | if (i,j) in connected_pairs or (j,i) in connected_pairs: continue 132 | disconnected_pairs.add((i,j)) 133 | return disconnected_pairs 134 | 135 | def missing_relation_inference(self,obj_inds,obj_boxes,disconnected_pairs): 136 | infered_relation=set() 137 | #print('discon:',disconnected_pairs) 138 | for i in range(len(disconnected_pairs)): 139 | pair = disconnected_pairs.pop() 140 | node1_box, node2_box = obj_boxes[pair[0]], obj_boxes[pair[1]] 141 | distance = self.distance_between_boxes(np.stack([node1_box, node2_box], axis=0))[0, 1] 142 | pair_txt = [self.data_set.object_classes[obj_inds[pair[0]][0]], 143 | self.data_set.object_classes[obj_inds[pair[1]][0]]] 144 | candidate, prob, direction = relation_prior.most_probable_relation_for_unpaired(pair_txt, self.relation_statistics, int(distance)) 145 | if candidate !=None and prob > self.rel_infer_thres: 146 | if not direction: pair = (pair[1],pair[0]) 147 | infered_relation.add((pair[0],pair[1],self.pred_stoi[candidate],prob)) 148 | pair_txt = [self.data_set.object_classes[obj_inds[pair[0]][0]], 149 | self.data_set.object_classes[obj_inds[pair[1]][0]]] 150 | infered_relation= np.array(list(infered_relation)).reshape(-1, 4) 151 | return infered_relation 152 | 153 | def missing_object_inference(self,obj_inds,disconnected_pairs): 154 | detected_obj_list = np.array(list(self.detected_obj_set)) 155 | candidate_searchspace = [self.word_ind2vec[x] for x in detected_obj_list] 156 | candidate_searchspace = torch.stack(candidate_searchspace,dim=0).cuda() 157 | search_size = candidate_searchspace.shape[0] 158 | infered_obj_list = [] 159 | 160 | for i in range(len(disconnected_pairs)): 161 | pair = disconnected_pairs.pop() 162 | ''' wordvec based candidate objects filtering ''' 163 | sbj_vec = self.word_ind2vec[obj_inds[pair[0]][0]].cuda() 164 | obj_vec = self.word_ind2vec[obj_inds[pair[1]][0]].cuda() 165 | sim_sbj_obj = cosine_similarity(sbj_vec,obj_vec,dim=0) 166 | 167 | sbj_vec = sbj_vec.expand_as(candidate_searchspace) 168 | obj_vec = obj_vec.expand_as(candidate_searchspace) 169 | sim_cans_sbj = cosine_similarity(candidate_searchspace,sbj_vec, dim=1) 170 | sim_cans_obj = cosine_similarity(candidate_searchspace,obj_vec, dim=1) 171 | sim_sbj_obj = sim_sbj_obj.expand_as(sim_cans_obj) 172 | keep = (sim_cans_sbj + sim_cans_obj > 2 * sim_sbj_obj).nonzero().view(-1).cpu().numpy() 173 | candidate_obj_list = detected_obj_list[keep] 174 | if len(candidate_obj_list) == 0: continue 175 | 176 | ''' statistics based candidate objects filtering ''' 177 | keep=[] 178 | for i,obj_ind in enumerate(candidate_obj_list): 179 | if self.check_prob_condition(obj_inds[pair[0]][0],obj_ind,obj_inds[pair[1]][0]): keep.append(i) 180 | candidate_obj_list = candidate_obj_list[keep] 181 | if len(candidate_obj_list) == 0: continue 182 | 183 | ''' choose a candidate with best score above threshold''' 184 | probs = [self.cal_p_x_given_yz(candidate, obj_inds[pair[0]][0], obj_inds[pair[1]][0]) for candidate in candidate_obj_list] 185 | chosen_obj = candidate_obj_list[(np.array(probs)).argmax()] 186 | infered_obj_list.append(chosen_obj) 187 | 188 | def get_box_centers(self,boxes): 189 | # Define bounding box info 190 | center_x = (boxes[:, 0] + boxes[:, 2]) / 2 191 | center_y = (boxes[:, 1] + boxes[:, 3]) / 2 192 | centers = np.concatenate([center_x.reshape(-1, 1), center_y.reshape(-1, 1)], axis=1) 193 | return centers 194 | 195 | def distance_between_boxes(self,boxes): 196 | ''' 197 | returns all possible distances between boxes 198 | 199 | :param boxes: 200 | :return: dist: distance between boxes[1] and boxes[2] ==> dist[1,2] 201 | ''' 202 | centers = self.get_box_centers(boxes) 203 | centers_axis1 = np.repeat(centers,centers.shape[0],axis=0).reshape(-1,2) 204 | centers_axis2 = np.stack([centers for _ in range(centers.shape[0])]).reshape(-1, 2) 205 | dist = np.linalg.norm(centers_axis1 - centers_axis2, axis=1).reshape(-1,centers.shape[0]) 206 | return dist 207 | 208 | def spurious_relation_rejection(self,obj_boxes,obj_cls,relationships): 209 | if self.args.disable_spurious: return range(len(relationships)) 210 | subject_inds = obj_cls[relationships.astype(int)[:,0]][:, 0] 211 | pred_inds = relationships.astype(int)[:, 2] 212 | object_inds = obj_cls[relationships.astype(int)[:, 1]][:, 0] 213 | 214 | subject_boxes = obj_boxes[relationships.astype(int)[:,0]] 215 | object_boxes = obj_boxes[relationships.astype(int)[:,1]] 216 | 217 | keep = [] 218 | for i, (sbj_ind, pred_ind, obj_ind, sbj_box, obj_box) in enumerate(zip(subject_inds,pred_inds,object_inds, 219 | subject_boxes,object_boxes)): 220 | relation_txt = [self.data_set.object_classes[sbj_ind], 221 | self.data_set.predicate_classes[pred_ind], 222 | self.data_set.object_classes[obj_ind]] 223 | distance = self.distance_between_boxes(np.stack([sbj_box,obj_box],axis=0))[0,1] 224 | prob = relation_prior.triplet_prob_from_statistics(relation_txt, self.relation_statistics, int(distance)) 225 | print('prob: {prob:3.2f} {sbj:15}{rel:15}{obj:15}'.format(prob=prob, 226 | sbj=relation_txt[0], 227 | rel=relation_txt[1], 228 | obj=relation_txt[2])) 229 | 230 | if prob > self.spurious_rel_thres: keep.append(i) 231 | 232 | return keep 233 | 234 | def interpret_graph(self,object_result, predicate_result,im_info): 235 | cls_prob_object, bbox_object, object_rois, reranked_score = object_result[:4] 236 | cls_prob_predicate, mat_phrase = predicate_result[:2] 237 | region_rois_num = predicate_result[2] 238 | 239 | obj_boxes, obj_scores, obj_cls, \ 240 | subject_inds, object_inds, \ 241 | subject_boxes, object_boxes, \ 242 | subject_IDs, object_IDs, \ 243 | predicate_inds, triplet_scores, relationships = \ 244 | self.interpret_graph_(cls_prob_object, bbox_object, object_rois, 245 | cls_prob_predicate, mat_phrase, im_info, 246 | reranked_score) 247 | 248 | ''' missing object inference ''' 249 | # self.update_obj_set(obj_cls) 250 | # disconnected_pairs = self.find_disconnected_pairs(obj_cls, relationships) 251 | # self.missing_object_inference(obj_cls,disconnected_pairs) 252 | ''' missing object infernce (end) ''' 253 | 254 | ''' missing relation inference ''' 255 | # infered_relations = self.missing_relation_inference(obj_cls,obj_boxes,disconnected_pairs) 256 | # print('size:',relationships.shape,infered_relations.shape) 257 | # 258 | # relationships = np.concatenate([relationships,infered_relations],axis=0) 259 | # 260 | # predicate_inds = relationships[:, 2].astype(int) 261 | # subject_boxes = obj_boxes[relationships[:, 0].astype(int)] 262 | # object_boxes = obj_boxes[relationships[:, 1].astype(int)] 263 | # subject_IDs = np.array([int(obj_boxes[int(relation[0])][4]) for relation in relationships]) 264 | # object_IDs = np.array([int(obj_boxes[int(relation[1])][4]) for relation in relationships]) 265 | # subject_inds = obj_cls[relationships[:, 0].astype(int)] 266 | # object_inds = obj_cls[relationships[:, 1].astype(int)] 267 | # subject_scores = [obj_scores[int(relation[0])] for relation in relationships] 268 | # pred_scores = [relation[3] / obj_scores[int(relation[0])] / obj_scores[int(relation[1])] for relation in 269 | # relationships] 270 | # object_scores = [obj_scores[int(relation[1])] for relation in relationships] 271 | # triplet_scores = np.array(zip(subject_scores, pred_scores, object_scores)) 272 | ''' missing relation inference (end) ''' 273 | 274 | 275 | keep = self.spurious_relation_rejection(obj_boxes, obj_cls, relationships) 276 | 277 | return obj_boxes, obj_scores, obj_cls, \ 278 | subject_inds[keep], object_inds[keep], \ 279 | subject_boxes[keep], object_boxes[keep], \ 280 | subject_IDs[keep], object_IDs[keep], \ 281 | predicate_inds[keep], triplet_scores[keep], relationships[keep] 282 | 283 | def interpret_graph_(self,cls_prob_object, bbox_object, object_rois, 284 | cls_prob_predicate, mat_phrase, im_info, 285 | reranked_score=None): 286 | 287 | obj_boxes, obj_scores, obj_cls, subject_inds, object_inds, \ 288 | subject_boxes, object_boxes, predicate_inds, \ 289 | sub_assignment, obj_assignment, total_score = \ 290 | self.interpret_relationships(cls_prob_object, bbox_object, object_rois, 291 | cls_prob_predicate, mat_phrase, im_info, 292 | nms=self.nms_thres, topk_pred=2, topk_obj=3, 293 | use_gt_boxes=False, 294 | triplet_nms=self.triplet_nms_thres, 295 | reranked_score=reranked_score) 296 | 297 | obj_boxes, obj_scores, obj_cls, \ 298 | subject_inds, object_inds, \ 299 | subject_boxes, object_boxes, \ 300 | subject_IDs, object_IDs, \ 301 | predicate_inds, triplet_scores, relationships = self.filter_and_tracking(obj_boxes, obj_scores, obj_cls, 302 | subject_inds, object_inds, 303 | subject_boxes, object_boxes, 304 | predicate_inds, 305 | sub_assignment, obj_assignment, 306 | total_score) 307 | 308 | return obj_boxes, obj_scores, obj_cls, \ 309 | subject_inds, object_inds, \ 310 | subject_boxes, object_boxes, \ 311 | subject_IDs, object_IDs, \ 312 | predicate_inds, triplet_scores, relationships 313 | 314 | def interpret_relationships(self, cls_prob, bbox_pred, rois, cls_prob_predicate, 315 | mat_phrase, im_info, nms=-1., clip=True, min_score=0.01, 316 | top_N=100, use_gt_boxes=False, triplet_nms=-1., topk_pred=2,topk_obj=3, 317 | reranked_score=None): 318 | 319 | scores, inds = cls_prob[:, 1:].data.topk(k=topk_obj,dim=1) 320 | if reranked_score is not None: 321 | if isinstance(reranked_score, Variable): 322 | reranked_score = reranked_score.data 323 | scores *= reranked_score 324 | inds += 1 325 | scores, inds = scores.cpu().numpy(), inds.cpu().numpy() 326 | # filter out objects with wrong class 327 | for i,ind in enumerate(inds): 328 | if ind[0] in self.tobefiltered_objects: 329 | scores[i].fill(0) 330 | 331 | 332 | predicate_scores, predicate_inds = cls_prob_predicate[:, 1:].data.topk(dim=1, k=topk_pred) 333 | predicate_inds += 1 334 | predicate_scores, predicate_inds = predicate_scores.cpu().numpy().reshape( 335 | -1), predicate_inds.cpu().numpy().reshape(-1) 336 | 337 | # Apply bounding-box regression deltas 338 | box_deltas = bbox_pred.data.cpu().numpy() 339 | box_deltas = np.asarray([ 340 | box_deltas[i, (inds[i][0] * 4): (inds[i][0] * 4 + 4)] for i in range(len(inds)) 341 | ], dtype=np.float) 342 | keep = range(scores.shape[0]) 343 | if use_gt_boxes: 344 | triplet_nms = -1. 345 | pred_boxes = rois.data.cpu().numpy()[:, 1:5] / im_info[0][2] 346 | else: 347 | pred_boxes = bbox_transform_inv_hdn(rois.data.cpu().numpy()[:, 1:5], box_deltas) / im_info[0][2] 348 | pred_boxes = clip_boxes(pred_boxes, im_info[0][:2] / im_info[0][2]) 349 | 350 | # nms 351 | if nms > 0. and pred_boxes.shape[0] > 0: 352 | assert nms < 1., 'Wrong nms parameters' 353 | pred_boxes, scores, inds, keep = nms_detections(pred_boxes, scores, nms, inds=inds) 354 | 355 | sub_list = np.array([], dtype=int) 356 | obj_list = np.array([], dtype=int) 357 | pred_list = np.array([], dtype=int) 358 | 359 | # mapping the object id 360 | mapping = np.ones(cls_prob.size(0), dtype=np.int64) * -1 361 | mapping[keep] = range(len(keep)) 362 | 363 | sub_list = mapping[mat_phrase[:, 0]] 364 | obj_list = mapping[mat_phrase[:, 1]] 365 | pred_remain = np.logical_and(sub_list >= 0, obj_list >= 0) 366 | pred_list = np.where(pred_remain)[0] 367 | sub_list = sub_list[pred_remain] 368 | obj_list = obj_list[pred_remain] 369 | 370 | # expand the sub/obj and pred list to k-column 371 | pred_list = np.vstack([pred_list * topk_pred + i for i in range(topk_pred)]).transpose().reshape(-1) 372 | sub_list = np.vstack([sub_list for i in range(topk_pred)]).transpose().reshape(-1) 373 | obj_list = np.vstack([obj_list for i in range(topk_pred)]).transpose().reshape(-1) 374 | 375 | if use_gt_boxes: 376 | total_scores = predicate_scores[pred_list] 377 | else: 378 | total_scores = predicate_scores[pred_list] * scores[sub_list][:,0] * scores[obj_list][:,0] 379 | 380 | top_N_list = total_scores.argsort()[::-1][:10000] 381 | total_scores = total_scores[top_N_list] 382 | pred_ids = predicate_inds[pred_list[top_N_list]] # category of predicates 383 | sub_assignment = sub_list[top_N_list] # subjects assignments 384 | obj_assignment = obj_list[top_N_list] # objects assignments 385 | sub_ids = inds[:,0][sub_assignment] # category of subjects 386 | obj_ids = inds[:,0][obj_assignment] # category of objects 387 | sub_boxes = pred_boxes[sub_assignment] # boxes of subjects 388 | obj_boxes = pred_boxes[obj_assignment] # boxes of objects 389 | 390 | if triplet_nms > 0.: 391 | sub_ids, obj_ids, pred_ids, sub_boxes, obj_boxes, keep = triplet_nms_py(sub_ids, obj_ids, pred_ids, 392 | sub_boxes, obj_boxes, triplet_nms) 393 | sub_assignment = sub_assignment[keep] 394 | obj_assignment = obj_assignment[keep] 395 | total_scores = total_scores[keep] 396 | if len(sub_list) == 0: 397 | print('No Relatinoship remains') 398 | # pdb.set_trace() 399 | 400 | return pred_boxes, scores, inds, sub_ids, obj_ids, sub_boxes, obj_boxes, pred_ids, sub_assignment, obj_assignment, total_scores 401 | 402 | def filter_and_tracking(self, obj_boxes, obj_scores, obj_cls, 403 | subject_inds, object_inds, 404 | subject_boxes, object_boxes, predicate_inds, 405 | sub_assignment, obj_assignment, total_score): 406 | 407 | relationships = np.array(zip(sub_assignment, obj_assignment, predicate_inds, total_score)) 408 | 409 | 410 | 411 | # filter out bboxes who has low obj_score 412 | keep_obj = np.where(obj_scores[:,0] >= self.obj_thres)[0] 413 | if keep_obj.size == 0: 414 | print("no object detected ...") 415 | keep_obj= [0] 416 | cutline_idx = max(keep_obj) 417 | obj_scores = obj_scores[:cutline_idx + 1] 418 | obj_boxes = obj_boxes[:cutline_idx + 1] 419 | obj_cls = obj_cls[:cutline_idx + 1] 420 | 421 | # filter out triplets whose obj/sbj have low obj_score 422 | if relationships.size > 0: 423 | keep_sub_assign = np.where(relationships[:, 0] <= cutline_idx)[0] 424 | relationships = relationships[keep_sub_assign] 425 | if relationships.size > 0: 426 | keep_obj_assign = np.where(relationships[:, 1] <= cutline_idx)[0] 427 | relationships = relationships[keep_obj_assign] 428 | 429 | # filter out triplets who have low total_score 430 | if relationships.size > 0: 431 | keep_rel = np.where(relationships[:, 3] >= self.triplet_thres)[0] # MSDN:0.02, DR-NET:0.03 432 | # if keep_rel.size > 0: 433 | # cutline_idx = max(keep_rel) 434 | # relationships = relationships[:cutline_idx + 1] 435 | relationships = relationships[keep_rel] 436 | 437 | # filter out triplets whose sub equal obj 438 | if relationships.size > 0: 439 | keep_rel = [] 440 | for i,relation in enumerate(relationships): 441 | if relation[0] != relation[1]: 442 | keep_rel.append(i) 443 | keep_rel = np.array(keep_rel).astype(int) 444 | relationships = relationships[keep_rel] 445 | 446 | # filter out triplets whose predicate is related to human behavior. 447 | if relationships.size > 0: 448 | keep_rel = [] 449 | for i,relation in enumerate(relationships): 450 | if int(relation[2]) not in self.tobefiltered_predicates: 451 | keep_rel.append(i) 452 | keep_rel = np.array(keep_rel).astype(int) 453 | relationships = relationships[keep_rel] 454 | 455 | # Object tracking 456 | # Filter out all un-tracked objects and triplets 457 | if self.ENABLE_TRACKING: 458 | print(obj_boxes.shape) 459 | tracking_input = np.concatenate((obj_boxes, obj_scores[:,0].reshape(len(obj_scores), 1)), axis=1) 460 | bboxes_and_uniqueIDs = self.tracker.update(tracking_input) 461 | keep = filter_untracted(bboxes_and_uniqueIDs, obj_boxes) 462 | print(relationships.shape) 463 | 464 | # filter out triplets whose obj/sbj is untracked. 465 | if relationships.size >0: 466 | keep_sub_assign = [np.where(relationships[:, 0] == keep_idx) for keep_idx in keep] 467 | if len(keep_sub_assign) > 0: 468 | keep_sub_assign = np.concatenate(keep_sub_assign, axis=1).flatten() 469 | relationships = relationships[keep_sub_assign] 470 | else: 471 | relationships = relationships[np.array([]).astype(int)] 472 | if relationships.size > 0: 473 | keep_obj_assign = [np.where(relationships[:, 1] == keep_idx) for keep_idx in keep] 474 | if len(keep_obj_assign) > 0: 475 | keep_obj_assign = np.concatenate(keep_obj_assign, axis=1).flatten() 476 | relationships = relationships[keep_obj_assign] 477 | else: 478 | relationships = relationships[np.array([]).astype(int)] 479 | # 480 | print('filter3') 481 | print(relationships.astype(int)) 482 | print(keep) 483 | rel = relationships.copy() 484 | for i, k in enumerate(keep): 485 | relationships[:,:2][rel[:,:2] == k] = i 486 | 487 | sorted = relationships[:,3].argsort()[::-1] 488 | relationships = relationships[sorted] 489 | 490 | subject_inds = obj_cls[relationships[:, 0].astype(int)] 491 | object_inds = obj_cls[relationships[:, 1].astype(int)] 492 | obj_boxes = np.concatenate([obj_boxes, np.zeros([obj_boxes.shape[0], 1])], axis=1) 493 | for i, keep_idx in enumerate(keep): 494 | obj_boxes[keep_idx] = bboxes_and_uniqueIDs[i] 495 | obj_scores = obj_scores[keep] 496 | obj_cls = obj_cls[keep] 497 | obj_boxes = obj_boxes[keep] 498 | 499 | print(obj_scores.shape) 500 | print(obj_cls.shape) 501 | print(obj_boxes.shape) 502 | print(relationships.shape) 503 | else: 504 | obj_boxes = np.concatenate([obj_boxes, np.zeros([obj_boxes.shape[0], 1])], axis=1) 505 | for i in range(len(obj_boxes)): 506 | obj_boxes[i][4] = i 507 | subject_inds = obj_cls[relationships[:, 0].astype(int)] 508 | object_inds = obj_cls[relationships[:, 1].astype(int)] 509 | 510 | predicate_inds = relationships[:, 2].astype(int) 511 | subject_boxes = obj_boxes[relationships[:, 0].astype(int)] 512 | object_boxes = obj_boxes[relationships[:, 1].astype(int)] 513 | subject_IDs = np.array([int(obj_boxes[int(relation[0])][4]) for relation in relationships]) 514 | object_IDs = np.array([int(obj_boxes[int(relation[1])][4]) for relation in relationships]) 515 | 516 | 517 | subject_scores = [obj_scores[int(relation[0])] for relation in relationships] 518 | pred_scores = [relation[3] / obj_scores[int(relation[0])] / obj_scores[int(relation[1])] for relation in 519 | relationships] 520 | object_scores = [obj_scores[int(relation[1])] for relation in relationships] 521 | triplet_scores = np.array(zip(subject_scores, pred_scores, object_scores)) 522 | 523 | return obj_boxes, obj_scores, obj_cls, \ 524 | subject_inds, object_inds, \ 525 | subject_boxes, object_boxes, \ 526 | subject_IDs, object_IDs, \ 527 | predicate_inds, triplet_scores, relationships -------------------------------------------------------------------------------- /model/vis_tuning.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | import numpy as np 4 | from pandas import DataFrame 5 | import pandas as pd 6 | from graphviz import Digraph 7 | import webcolors 8 | import pprint 9 | import math 10 | from scipy.stats import norm 11 | from color_histogram.core.hist_3d import Hist3D 12 | #import pcl # cd python-pcl -> python setup.py build-ext -i -> python setup.py install 13 | import matplotlib.pyplot as plt 14 | from mpl_toolkits.mplot3d import Axes3D 15 | import torchtext #0. install torchtext==0.2.3 (pip install torchtext==0.2.3) 16 | from torch.nn.functional import cosine_similarity 17 | from collections import Counter 18 | import pcl 19 | import os.path as osp 20 | import os 21 | fasttext = torchtext.vocab.FastText() 22 | _GRAY = (218, 227, 218) 23 | _GREEN = (18, 127, 15) 24 | _WHITE = (255, 255, 255) 25 | 26 | 27 | class same_node_detection(object): 28 | def __init__(self): 29 | self.compare_all = False 30 | self.class_weight = 10.0/20.0 31 | self.pose_weight = 8.0/20.0 32 | self.color_weight = 2.0/20.0 33 | 34 | def compare_class(self, curr_cls, prev_cls, cls_score ): 35 | similar_cls = False 36 | same_cls = 0 37 | score = 0 38 | for cls in curr_cls: 39 | if cls in prev_cls: 40 | similar_cls = True 41 | same_cls += 1 42 | 43 | if (similar_cls): 44 | #score = float(same_cls) * cls_score 45 | if (same_cls == 3): 46 | score = 1.0 47 | if (same_cls == 2): 48 | score = 0.9 49 | if (same_cls == 1): 50 | score = 0.8 51 | 52 | # max similarity = 1 (same word), min similarity = 0 (no relation) 53 | # similarity = cosine_similarity(fasttext.vectors[fasttext.stoi['hello']],fasttext.vectors[fasttext.stoi['hi']],dim=0) 54 | else: 55 | if self.compare_all: 56 | similarity = 0 57 | for i in range(3): 58 | for j in range(3): 59 | similarity += cosine_similarity(fasttext.vectors[fasttext.stoi[curr_cls[i]]].cuda(), 60 | fasttext.vectors[fasttext.stoi[prev_cls[j]]].cuda(),dim=0).cpu()[0] 61 | similarity /= 9.0 62 | #print(similarity) 63 | else: 64 | similarity = cosine_similarity(fasttext.vectors[fasttext.stoi[curr_cls[0]]].cuda(), 65 | fasttext.vectors[fasttext.stoi[prev_cls[0]]].cuda(), dim=0).cpu()[0] 66 | 67 | #score = similarity * cls_score 68 | score = 0. 69 | 70 | return score 71 | 72 | def compare_position(self, curr_mean, curr_var, prev_mean, prev_var, prev_pt_num, new_pt_num): 73 | I_x, I_y, I_z = TFCO.check_distance(curr_mean,curr_var, prev_mean, prev_var) 74 | #score = (I_x * I_y * I_z) 75 | score = (I_x/3.0) + (I_y/3.0) + (I_z/3.0) 76 | score = float(score) 77 | return score 78 | 79 | def compare_color(self, curr_hist, prev_hist): 80 | curr_rgb = webcolors.name_to_rgb(curr_hist[0][1]) 81 | prev_rgb = webcolors.name_to_rgb(prev_hist[0][1]) 82 | dist = np.sqrt(np.sum(np.power(np.subtract(curr_rgb, prev_rgb),2))) / (255*np.sqrt(3)) 83 | score = 1-dist 84 | return score 85 | 86 | def node_update(self, window_3d_pts, global_node, curr_mean, curr_var, curr_cls, cls_score, curr_color_hist,test_set ): 87 | try: 88 | new_pt_num = len(window_3d_pts) 89 | global_node_num = len(global_node) 90 | print(global_node_num) 91 | score = [] 92 | score_pose = [] 93 | cls_score = cls_score[0] 94 | w1, w2, w3 = self.class_weight, self.pose_weight, self.color_weight 95 | #print("current object : {cls:3}".format(cls=curr_cls[0])) 96 | for i in range(global_node_num): 97 | prev_cls = (-global_node.ix[i]["class"]).argsort()[:3] # choose top 3 index 98 | prev_cls = [test_set.object_classes[ind] for ind in prev_cls] # index to text 99 | #print("compare object : {comp_cls:3}".format(comp_cls=prev_cls[0])) 100 | prev_mean, prev_var, prev_pt_num = global_node.ix[i]["mean"], global_node.ix[i]["var"], global_node.ix[i]["pt_num"] 101 | prev_color_hist = global_node.ix[i]["color_hist"] 102 | cls_sc = SND.compare_class(curr_cls, prev_cls, cls_score) 103 | pos_sc = SND.compare_position(curr_mean,curr_var, prev_mean, prev_var, prev_pt_num, new_pt_num) 104 | col_sc = SND.compare_color(curr_color_hist, prev_color_hist) 105 | #print("class_score {cls_score:3.2f}".format(cls_score=cls_sc)) 106 | #print("pose_score {pos_score:3.2f}".format(pos_score=pos_sc)) 107 | #print("color_score {col_score:3.2f}".format(col_score=col_sc)) 108 | tot_sc = (w1 * cls_sc) + (w2 * pos_sc) + (w3 * col_sc) 109 | #print("total_score {tot_score:3.2f}".format(tot_score=tot_sc)) 110 | score.append(tot_sc) 111 | #score_pose.append(pos_sc) 112 | node_score = max(score) 113 | print("node_score {score:3.4f}".format(score=node_score)) 114 | max_score_index = score.index(max(score)) 115 | #node_score_pose = score_pose[max_score_index] 116 | #print("node_score_pose {score_pose:3.2f}".format(score_pose=node_score_pose)) 117 | return node_score, max_score_index 118 | except: 119 | return 0,0 120 | 121 | 122 | 123 | class find_objects_class_and_color(object): 124 | def __init__(self): 125 | self.power = 2 126 | 127 | def get_class_string(self, class_index, score, dataset): 128 | class_text = dataset[class_index] if dataset is not None else \ 129 | 'id{:d}'.format(class_index) 130 | return class_text + ' {:0.2f}'.format(score).lstrip('0') 131 | 132 | def closest_colour(self, requested_colour): 133 | min_colours = {} 134 | for key, name in webcolors.css3_hex_to_names.items(): 135 | r_c, g_c, b_c = webcolors.hex_to_rgb(key) 136 | rd = (r_c - requested_colour[0]) ** self.power 137 | gd = (g_c - requested_colour[1]) ** self.power 138 | bd = (b_c - requested_colour[2]) ** self.power 139 | min_colours[(rd + gd + bd)] = name 140 | return min_colours[min(min_colours.keys())] 141 | 142 | def get_colour_name(self, requested_colour): 143 | try: 144 | closest_name = actual_name = webcolors.rgb_to_name(requested_colour) 145 | except ValueError: 146 | closest_name = FOCC.closest_colour(requested_colour) 147 | actual_name = None 148 | return actual_name, closest_name 149 | 150 | 151 | 152 | 153 | class tools_for_compare_objects(object): 154 | def __init__(self): 155 | self.meter = 5000. 156 | self.th_x = 0.112 157 | self.th_y = 0.112 158 | self.th_z = 0.112 159 | 160 | def check_distance(self, x,curr_var, mean, var): 161 | Z_x = (x[0]-mean[0])/self.meter 162 | Z_y = (x[1]-mean[1])/self.meter 163 | Z_z = (x[2]-mean[2])/self.meter 164 | #Z_x = (x[0]-mean[0])/np.sqrt(curr_var[0]) 165 | #Z_y = (x[1]-mean[1])/np.sqrt(curr_var[1]) 166 | #Z_z = (x[2]-mean[2])/np.sqrt(curr_var[2]) 167 | # In Standardized normal gaussian distribution 168 | # Threshold : 0.9 --> -1.65 < Z < 1.65 169 | # : 0.8 --> -1.29 < Z < 1.29 170 | # : 0.7 --> -1.04 < Z < 1.04 171 | # : 0.6 --> -0.845 < Z < 0.845 172 | # : 0.5 --> -0.675 < Z < 0.675 173 | # : 0.4 --> -0.53 < Z < 0.53 174 | #print(" pos {pos_x:3.2f} {pose_y:3.2f} {pose_z:3.2f}".format(pos_x=Z_x, pose_y=Z_y, pose_z=Z_z)) 175 | #print("pos_y {pos_y:3.2f}".format(pos_y=Z_y)) 176 | #print("pos_z {pos_z:3.2f}".format(pos_z=Z_z)) 177 | #th_x = np.sqrt(np.abs(var[0])) *beta 178 | #th_y = np.sqrt(np.abs(var[1])) *beta 179 | #th_z = np.sqrt(np.abs(var[2])) *beta 180 | x_check = -self.th_x < Z_x < self.th_x 181 | y_check = -self.th_y < Z_y < self.th_y 182 | z_check = -self.th_z < Z_z < self.th_z 183 | 184 | if (x_check): 185 | I_x = 1.0 186 | else: 187 | #I_x = norm.cdf(-np.abs(Z_x)) / norm.cdf(-self.th_x) 188 | #I_x = (norm.cdf(self.th_x) - norm.cdf(-self.th_x)) / (norm.cdf(np.abs(Z_x)) - norm.cdf(-np.abs(Z_x))) 189 | I_x = self.th_x / np.abs(Z_x) 190 | # if (np.abs(self.th_x - Z_x)<1): 191 | # I_x = np.abs(self.th_x - Z_x) 192 | # else: 193 | # I_x = 1/np.abs(self.th_x-Z_x) 194 | if (y_check): 195 | I_y = 1.0 196 | else: 197 | #I_y = norm.cdf(-np.abs(Z_y)) / norm.cdf(-self.th_y) 198 | #I_y = (norm.cdf(self.th_y) - norm.cdf(-self.th_y)) / (norm.cdf(np.abs(Z_y)) - norm.cdf(-np.abs(Z_y))) 199 | I_y = self.th_y / np.abs(Z_y) 200 | # if (np.abs(self.th_y - Z_y)<1): 201 | # I_y = np.abs(self.th_y - Z_y) 202 | # else: 203 | # I_y = 1/np.abs(self.th_y-Z_y) 204 | if (z_check): 205 | I_z = 1.0 206 | else: 207 | #I_z = norm.cdf(-np.abs(Z_z)) / norm.cdf(-self.th_z) 208 | #I_z = (norm.cdf(self.th_z) - norm.cdf(-self.th_z)) / (norm.cdf(np.abs(Z_z)) - norm.cdf(-np.abs(Z_z))) 209 | I_z = self.th_z / np.abs(Z_z) 210 | # if (np.abs(self.th_z - Z_z)<1): 211 | # I_z = np.abs(self.th_z - Z_z) 212 | # else: 213 | # I_z = 1/np.abs(self.th_x-Z_z) 214 | 215 | #print(" score {score_x:3.2f} {score_y:3.2f} {score_z:3.2f}".format(score_x=I_x, score_y=I_y, score_z=I_z)) 216 | #print(" tot_score {score:3.2f} ".format(score=(I_x+I_y+I_z)/3.)) 217 | #print("pose_score_y {pos_score_y:3.2f}".format(pos_score_y=I_y)) 218 | #print("pose_score_z {pos_score_z:3.2f}".format(pos_score_z=I_z)) 219 | return I_x, I_y, I_z 220 | 221 | def Measure_new_Gaussian_distribution(self, new_pts): 222 | try: 223 | pt_num = len(new_pts) 224 | mu = np.sum(new_pts, axis=0)/pt_num 225 | mean = [int(mu[0]), int(mu[1]), int(mu[2])] 226 | var = np.sum(np.power(new_pts, 2), axis=0)/pt_num - np.power(mu,2) 227 | var = [int(var[0]), int(var[1]), int(var[2])] 228 | return pt_num, mean, var 229 | except: 230 | return 1, [0,0,0], [1,1,1] 231 | 232 | def Measure_added_Gaussian_distribution(self, new_pts, prev_mean, prev_var, prev_pt_num, new_pt_num): 233 | # update mean and variance 234 | pt_num = prev_pt_num + new_pt_num 235 | mu = np.sum(new_pts, axis=0) 236 | mean = np.divide((np.multiply(prev_mean,prev_pt_num) + mu),pt_num) 237 | mean = [int(mean[0]), int(mean[1]), int(mean[2])] 238 | var = np.subtract(np.divide((np.multiply((prev_var + np.power(prev_mean,2)),prev_pt_num) + np.sum(np.power(new_pts,2),axis=0)) ,pt_num), np.power(mean,2)) 239 | var = [int(var[0]), int(var[1]), int(var[2])] 240 | return pt_num, mean, var 241 | 242 | def get_color_hist(self, img): 243 | ''' 244 | # return color_hist 245 | # format: [[num_pixels1,color1],[num_pixels2,color2],...,[num_pixelsN,colorN]] 246 | # ex: [[362 ,'red' ],[2 ,'blue'],...,[3 ,'gray']] 247 | ''' 248 | 249 | img = img[..., ::-1] # BGR to RGB 250 | img = img.flatten().reshape(-1, 3).tolist() # shape: ((640x480)*3) 251 | 252 | color_hist = [] 253 | start = 0 254 | new_color = False 255 | actual_name, closest_name = FOCC.get_colour_name(img[0]) 256 | if (actual_name == None): 257 | color_hist.append([0,closest_name]) 258 | else: 259 | color_hist.append([0,actual_name]) 260 | 261 | for i in range(len(img)): 262 | actual_name, closest_name = FOCC.get_colour_name(img[i]) 263 | for k in range(len(color_hist)): 264 | if(color_hist[k][1] == actual_name or color_hist[k][1] == closest_name): 265 | color_hist[k][0]+=1 266 | new_color = False 267 | break 268 | else: 269 | new_color = True 270 | if (new_color == True): 271 | if (actual_name == None): 272 | color_hist.append([1, closest_name]) 273 | new_color = False 274 | else: 275 | color_hist.append([1, actual_name]) 276 | new_color = False 277 | color_hist = sorted(color_hist, reverse = True) 278 | return color_hist 279 | 280 | def get_color_hist2(self, img): 281 | ''' 282 | # return color_hist 283 | # format: [[density1,color1],[density2,color2],[density3,color3]] 284 | # ex: [[362 ,'red' ],[2 ,'blue'],[3 ,'gray']] 285 | ''' 286 | try: 287 | hist3D = Hist3D(img[..., ::-1], num_bins=8, color_space='rgb')# BGR to RGB 288 | # print('sffsd:', img.shape) 289 | # cv2.imshow('a',img) 290 | # cv2.waitKey(1) 291 | except: 292 | 293 | return TFCO.get_color_hist(img) 294 | else: 295 | densities = hist3D.colorDensities() 296 | order = densities.argsort()[::-1] 297 | densities = densities[order] 298 | colors = (255*hist3D.rgbColors()[order]).astype(int) 299 | color_hist = [] 300 | for density, color in zip(densities,colors)[:4]: 301 | actual_name, closest_name = FOCC.get_colour_name(color.tolist()) 302 | if (actual_name == None): 303 | color_hist.append([density, closest_name]) 304 | else: 305 | color_hist.append([density, actual_name]) 306 | 307 | return color_hist 308 | 309 | 310 | class resampling_boundingbox_size(object): 311 | def __init__(self): 312 | self.range = 10.0 313 | self.mean_k = 10 314 | self.thres = 1.0 315 | 316 | def isNoisyPoint(self, point): 317 | return -self.range< point[0]=min(cnt_thres,idx))] 451 | relation_list = [rel for rel in relation_list if (node_feature.loc[node_feature['idx'] == int(rel[2])]['detection_cnt'].item()>=min(cnt_thres,idx))] 452 | 453 | relation_set = set(relation_list) # remove duplicate relations 454 | 455 | repeated_idx = [] 456 | relation_array = np.array(list(relation_set)) 457 | for i in range(len(relation_array)): 458 | for j in range(len(relation_array)): 459 | res = relation_array[i] == relation_array[j] 460 | if res[0] and res[2]and i!=j: 461 | repeated_idx.append(i) 462 | repeated_idx = set(repeated_idx) 463 | repeated_idx = list(repeated_idx) 464 | if len(repeated_idx)>0: 465 | repeated = relation_array[repeated_idx] 466 | #print repeated.shape, repeated_idx 467 | for i, (pos, x, y) in enumerate(zip(repeated_idx, repeated[:, 0], repeated[:, 2])): 468 | position = np.where((x == repeated[:, 0]) & (y == repeated[:, 2]))[0] 469 | triplets = repeated[position].astype(int).tolist() 470 | preds = [t[1] for t in triplets] 471 | counted = Counter(preds) 472 | voted_pred = counted.most_common(1) 473 | #print(i, idx, triplets, voted_pred) 474 | relation_array[pos, 1] = voted_pred[0][0] 475 | 476 | relation_set =[tuple(rel)for rel in relation_array.astype(int).tolist()] 477 | relation_set = set(relation_set) 478 | #print(len(relation_set)) 479 | 480 | #pale_rgb = [152,251,152] 481 | pale_rgb = [112,191,64] 482 | pale_hex = webcolors.rgb_to_hex(pale_rgb) 483 | for rel_num in range(len(relation_set)): 484 | rel = relation_set.pop() 485 | tile = False 486 | handle = False 487 | for t_i in tile_idx: 488 | if (str(rel[0]) == t_i or str(rel[2]) == t_i): 489 | tile = True 490 | for h_i in handle_idx: 491 | if (str(rel[0]) == h_i or str(rel[2]) == h_i): 492 | handle = True 493 | if ( (not tile) and (not handle)): 494 | sg.node('rel'+str(rel_num), shape= 'box', style= 'filled, rounded', fillcolor= pale_hex, fontcolor= 'black', 495 | margin = '0.11, 0.0001', width = '0.11', height='0' , label= str(test_set.predicate_classes[rel[1]])) 496 | sg.edge('struct'+str(rel[0]), 'rel'+str(rel_num)) 497 | sg.edge('rel'+str(rel_num), 'struct'+str(rel[2])) 498 | 499 | 500 | if view and sg.format =='pdf': 501 | sg.render(osp.join(save_path,'scene_graph'+str(idx)), view=view) 502 | elif view and sg.format == 'png': 503 | sg.render(osp.join(save_path, 'scene_graph' + str(idx)), view=False) 504 | img = cv2.imread(osp.join(save_path, 'scene_graph' + str(idx)+'.png'),cv2.IMREAD_COLOR) 505 | resize_x = 0.65 506 | resize_y = 0.9 507 | if img.shape[1]0: 680 | # ax.scatter(-arr[:,0],-arr[:,1],-arr[:,2],) 681 | # #ax.set_xlim(-2000, 2000) 682 | # #ax.set_ylim(-2000, 2000) 683 | # #ax.set_zlim(-2000, 2000) 684 | # 685 | # self.fig.show() 686 | # plt.pause(0.01) 687 | # plt.hold(True) 688 | # cv2.waitKey(0) 689 | 690 | 691 | 692 | 693 | 694 | '''4. Get a 3D position of the Center Patch's Center point''' 695 | # find 3D point of the bounding box(the center patch)'s center 696 | curr_pt_num, curr_mean, curr_var = TFCO.Measure_new_Gaussian_distribution(window_3d_pts) 697 | # ex: np.matrix([[X_1],[Y_1],[Z_1]]) 698 | 699 | # get object class names as strings 700 | box_cls = [test_set.object_classes[obj_ind[0]], 701 | test_set.object_classes[obj_ind[1]], 702 | test_set.object_classes[obj_ind[2]]] 703 | # box_cls: ['pillow','bag','cat'] 704 | box_score = obj_scores[i] 705 | # box_score: [0.2,0.1,0.01] 706 | cls_scores = np.zeros(400) 707 | for cls_idx, cls_score in zip(obj_ind, obj_scores[i]): 708 | cls_scores[cls_idx] += cls_score # check 709 | 710 | 711 | '''5. Save Object Recognition Results in DataFrame Format''' 712 | if(self.img_count ==0): 713 | # first image -> make new node 714 | box_id = i 715 | self.pt_num, self.mean, self.var = TFCO.Measure_new_Gaussian_distribution(window_3d_pts) 716 | # check 717 | start_data = {"class":cls_scores, "idx":box_id, "score":box_score, 718 | "bounding_box":[box_center_x,box_center_y,width,height], 719 | "3d_pose": [int(self.mean[0]),int(self.mean[1]),int(self.mean[2])], 720 | "mean":self.mean, 721 | "var":self.var, 722 | "pt_num":self.pt_num, 723 | "color_hist":color_hist, 724 | "detection_cnt":1 725 | } 726 | obj_boxes[i][4] =box_id 727 | self.data.loc[len(self.data)] = start_data 728 | if (i==0): 729 | self.data.drop(self.data.index[0], inplace=True) 730 | self.data.rename(index={1:0}, inplace=True) 731 | #print(self.data) 732 | 733 | else: 734 | # get node similarity score 735 | node_score, max_score_index = SND.node_update(window_3d_pts, self.data, curr_mean,curr_var, 736 | box_cls, obj_scores[i], color_hist,test_set) 737 | threshold = 0.8127 738 | 739 | if node_score > threshold and not self.disable_samenode: 740 | # change value of global_node 741 | # change global_node[max_score_index] 742 | print("node updated!!!") 743 | for cls_idx,cls_score in zip(obj_ind,obj_scores[i]): 744 | self.data.at[max_score_index,'class'][cls_idx]+= cls_score # check 745 | 746 | #self.data.at[max_score_index, "class"] = box_cls 747 | self.data.at[max_score_index, "score"] = node_score 748 | self.pt_num, self.mean, self.var = TFCO.Measure_added_Gaussian_distribution(window_3d_pts, 749 | self.data.ix[max_score_index]["mean"], 750 | self.data.ix[max_score_index]["var"], 751 | self.data.ix[max_score_index]["pt_num"], 752 | len(window_3d_pts)) 753 | self.data.at[max_score_index, "mean"] = self.mean 754 | self.data.at[max_score_index, "var"] = self.var 755 | self.data.at[max_score_index, "pt_num"] = self.pt_num 756 | self.data.at[max_score_index, "color_hist"] = color_hist 757 | self.data.at[max_score_index, "detection_cnt"] = self.data.ix[max_score_index]["detection_cnt"]+1 758 | box_id = self.data.ix[max_score_index]["idx"] 759 | obj_boxes[i][4] = box_id 760 | else: 761 | # make new_node in global_node 762 | # [class, index, score, bounding_box, 3d_pose, mean, var, pt_number, color_hist] 763 | box_id = len(self.data)+1 764 | obj_boxes[i][4] = box_id 765 | self.pt_num, self.mean, self.var = TFCO.Measure_new_Gaussian_distribution(window_3d_pts) 766 | global_node_num = len(self.data) 767 | add_node_list = [cls_scores, box_id, box_score, [box_center_x, box_center_y, width, height], 768 | [self.mean[0], self.mean[1], self.mean[2]], 769 | self.mean, self.var, self.pt_num, color_hist,1] 770 | self.data.loc[len(self.data)] = add_node_list 771 | 772 | # if object index was changed, update relation's object index also 773 | 774 | 775 | '''6. Print object info''' 776 | print('{obj_ID:5} {obj_cls:15} {obj_score:4.2f} {object_3d_pose:20} {obj_var:20} {obj_color:15}' 777 | .format(obj_ID= box_id, 778 | obj_cls= box_cls[0], 779 | obj_score= box_score[0], 780 | object_3d_pose= [self.mean[0], self.mean[1], self.mean[2]], 781 | obj_var= self.var, 782 | obj_color = color_hist[0][1] )) 783 | 784 | 785 | else: # TODO: for visual_genome 786 | raise NotImplementedError 787 | 788 | '''7. Plot ''' 789 | # updated object_detection 790 | cv2.rectangle(updated_image_scene, 791 | (int(obj_boxes[i][0]), int(obj_boxes[i][1])), 792 | (int(obj_boxes[i][2]), int(obj_boxes[i][3])), 793 | colorlist[int(obj_boxes[i][4])], 794 | 2) 795 | font_scale=0.5 796 | txt = str(box_id) + '. ' + str(box_cls[0]) + ' ' + str(round(box_score[0],2)) 797 | ((txt_w, txt_h), _) = cv2.getTextSize(txt, cv2.FONT_HERSHEY_SIMPLEX, font_scale, 1) 798 | # Place text background. 799 | x0, y0 = int(obj_boxes[i][0]),int(obj_boxes[i][3]) 800 | back_tl = x0, y0 - int(1.3 * txt_h) 801 | back_br = x0 + txt_w, y0 802 | cv2.rectangle(updated_image_scene, back_tl, back_br, colorlist[int(obj_boxes[i][4])], -1) 803 | cv2.putText(updated_image_scene, 804 | txt, 805 | (x0,y0-2), 806 | cv2.FONT_HERSHEY_SIMPLEX, 807 | font_scale, 808 | (255,255,255), 809 | 1) 810 | 811 | # add ID per bbox 812 | 813 | rel_prev_num = len(self.rel_data) 814 | print('-------Subject--------|-------Predicate-----|--------Object---------|--Score-') 815 | for i, relation in enumerate(relationships): 816 | # update relation's class also 817 | 818 | # accumulate relation_list 819 | if str(int(obj_boxes[int(relation[0])][4])) != str(int(obj_boxes[int(relation[1])][4])): 820 | # filter out triplets whose sbj == obj 821 | self.rel_data.loc[len(self.rel_data)] = [[str(int(obj_boxes[int(relation[0])][4])), int(relation[2]), str(int(obj_boxes[int(relation[1])][4]))]] 822 | 823 | print('{sbj_cls:9} {sbj_ID:4} {sbj_score:1.3f} | ' 824 | '{pred_cls:11} {pred_score:1.3f} | ' 825 | '{obj_cls:9} {obj_ID:4} {obj_score:1.3f} | ' 826 | '{triplet_score:1.3f}'.format( 827 | sbj_cls = test_set.object_classes[obj_inds[:,0][int(relation[0])]], sbj_score = obj_scores[:,0][int(relation[0])], 828 | sbj_ID = str(int(obj_boxes[int(relation[0])][4])), 829 | pred_cls = test_set.predicate_classes[int(relation[2])] , pred_score = relation[3] / obj_scores[:,0][int(relation[0])] / obj_scores[:,0][int(relation[1])], 830 | obj_cls = test_set.object_classes[obj_inds[:,0][int(relation[1])]], obj_score = obj_scores[:,0][int(relation[1])], 831 | obj_ID = str(int(obj_boxes[int(relation[1])][4])), 832 | triplet_score = relation[3])) 833 | 834 | rel_new_num = len(self.rel_data) 835 | 836 | # Draw scene graph 837 | if ( rel_prev_num != rel_new_num): 838 | TFV.Draw_connected_scene_graph(self.data, self.rel_data, self.img_count, test_set, sg, idx, 839 | self.detect_cnt_thres,self.args.plot_graph,self.save_path) 840 | #sg.view() 841 | 842 | # it's help to select starting point of first image manually 843 | self.img_count+=1 844 | 845 | return updated_image_scene 846 | 847 | --------------------------------------------------------------------------------