├── README.md ├── data └── README.md ├── data_produce.py ├── images └── README.md ├── json └── README.md ├── requirements.txt ├── tools ├── README.md ├── data.py ├── mn.py └── train.py ├── top_fact.py ├── word2vec_model └── README.md └── word2vector.py /README.md: -------------------------------------------------------------------------------- 1 | # KB-REF_dataset (Knowledge Based Referring Expression) 2 | ## Description 3 | KB-REF dataset is a referring expression comprehension dataset. Different with other referring expression dataset, it requires that each referring expression must use at least one external knowledge (the information can not be got from the image). There are 31, 284 expressions with 9, 925 images in training set, 4, 000 expressions with 2, 290 images in validation set, and 8, 000 expressions with 4, 702 images in test set. Also the dataset contains a number of object categories. 4 | ## Download 5 | The dataset can be downloaded from [BaiduYun Drive (code: 3vze)](https://pan.baidu.com/s/1iC9SqkOSVu0XsNnP9-PKQg). The images of KB-REF are come from [VisualGenome](http://visualgenome.org/). It contains several files: 6 | * expression.json: The main part of our dataset. It is a dictionary file: the key in the file is composed of image id (before the '\_') and object id (after the '\_'), the value is composed of the referring expression (first) and the corresponding fact (second). 7 | * candidate.json: It is the ground truth objects for each image. For each image, we choose 10 ground truth objects as the candidate bounding box when the model is reasoning on the dataset. The key is the image id, and the value is the object id in the image. 8 | * image.json: It contains the width and height of each image. The key is image id, the value is width (first), height (second). 9 | * objects.json: It contains the specific information for each object instance. It is a two-tier dictionary file. The key for first tier is the image id. The key for second tier is the object instance id. The value contains: the object category, the object name, the x of the top left corner, the y of the top left corner, the width and the height of the bounding box. 10 | * train.json, val.json, test.json: We split dataset according to the image. These files descripe which pictures are for train, val, and test. 11 | * Vocabualry.json: The vocabulary file. 12 | * Wikipedia.json, ConceptNet.json, WebChild.json: The knowledge we collect. The key is the object category and the value is the corresponding facts. 13 | 14 | **If you have any question about this dataset and code, please email to dongyang.liu0705@qq.com directly. And I will response you as as soon as possible.** 15 | 16 | **The code of baseline model will coming soon. I'm busy preparing for graduation recently, please wait. Thank you!** 17 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | dataset 2 | -------------------------------------------------------------------------------- /data_produce.py: -------------------------------------------------------------------------------- 1 | from nltk.corpus import wordnet 2 | from nltk import word_tokenize, pos_tag 3 | import json 4 | import enchant 5 | from collections import defaultdict 6 | import numpy as np 7 | import time 8 | 9 | if __name__ == '__main__': 10 | d = enchant.Dict("en_US") 11 | with open('./data/KB-REF/Vocabulary.json') as file: 12 | word_dict = json.load(file) 13 | def get_wordnet_pos(treebank_tag): 14 | if treebank_tag.startswith('J'): 15 | return wordnet.ADJ 16 | elif treebank_tag.startswith('V'): 17 | return wordnet.VERB 18 | elif treebank_tag.startswith('N'): 19 | return wordnet.NOUN 20 | elif treebank_tag.startswith('R'): 21 | return wordnet.ADV 22 | else: 23 | return None 24 | 25 | def lemmatize_sentence(sentence, flag=True): 26 | res = [] 27 | for word, pos in pos_tag(word_tokenize(sentence)): 28 | if flag: 29 | res.append(word) 30 | else: 31 | wordnet_pos = get_wordnet_pos(pos) 32 | if wordnet_pos == wordnet.NOUN or wordnet_pos == wordnet.ADJ or wordnet_pos == wordnet.ADV: 33 | res.append(word) 34 | return res 35 | 36 | 37 | def sentence(sentence, flag=True): 38 | a = [] 39 | for e in lemmatize_sentence(sentence, flag=flag): 40 | if d.check(e): 41 | e = e.lower() 42 | if e in word_dict and word_dict[e] < len(word_dict)-1: 43 | a.append(word_dict[e]) 44 | else: 45 | a.append(len(word_dict)-1) 46 | if len(a) > 50: 47 | c = 50 48 | else: 49 | c = len(a) 50 | while len(a) < 50: 51 | a.append(len(word_dict)-1) 52 | return a[0:50], c 53 | 54 | 55 | with open('./data/KB-REF/expression.json') as file: 56 | data = json.load(file) 57 | with open('./data/KB-REF/candidate.json') as file: 58 | cand = json.load(file) 59 | with open('./json/top_facts.json') as file: 60 | facts = json.load(file) 61 | with open('./data/KB-REF/objects.json') as file: 62 | objects = json.load(file) 63 | with open('./data/KB-REF/image.json') as file: 64 | w_h = json.load(file) 65 | with open('./data/KB-REF/train.json') as file: 66 | train_set = json.load(file) 67 | with open('./data/KB-REF/val.json') as file: 68 | val_set = json.load(file) 69 | with open('./data/KB-REF/test.json') as file: 70 | test_set = json.load(file) 71 | train = [] 72 | val = [] 73 | test = [] 74 | length = [] 75 | print(len(data)) 76 | for k in data: 77 | try: 78 | start = time.time() 79 | label = cand[k.split('_')[0]].index(k.split('_')[1]) 80 | img = k.split('_')[0] 81 | expression, leng = sentence(data[k][0], flag=True) 82 | e_mask = leng 83 | bbox = [] 84 | final_f = [] 85 | length = [] 86 | c_mask = len(cand[k.split('_')[0]]) - cand[k.split('_')[0]].count('-1') 87 | for c in cand[k.split('_')[0]]: 88 | if c != '-1': 89 | lg = [] 90 | bbox.append( 91 | [objects[img][c][2], objects[img][c][3], objects[img][c][4], 92 | objects[img][c][5]]) 93 | try: 94 | fact = facts[k][c] 95 | f = [] 96 | if len(fact) >= 50: 97 | for i in range(50): 98 | f1, leng = sentence(fact[i]) 99 | f.append(f1) 100 | lg.append(leng) 101 | else: 102 | for i in range(len(fact)): 103 | f1, leng = sentence(fact[i]) 104 | f.append(f1) 105 | lg.append(leng) 106 | while len(f) < 50: 107 | a = np.ones(50) + len(word_dict)-2 108 | f.append(a.tolist()) 109 | lg.append(0) 110 | final_f.append(f) 111 | length.append(lg) 112 | except: 113 | f = [] 114 | while len(f) < 50: 115 | a = np.ones(50) + len(word_dict)-2 116 | f.append(a.tolist()) 117 | lg.append(0) 118 | final_f.append(f) 119 | length.append(lg) 120 | else: 121 | lg = [] 122 | bbox.append([0, 0, 0, 0]) 123 | f = [] 124 | while len(f) < 50: 125 | a = np.ones(50) + len(word_dict)-2 126 | f.append(a.tolist()) 127 | lg.append(0) 128 | final_f.append(f) 129 | length.append(lg) 130 | 131 | if img in train_set: 132 | train.append({'image': img, 133 | 'label': label, 134 | 'expression': expression, 135 | 'e_mask': e_mask, 136 | 'bbox': bbox, 137 | 'w_h': w_h[k.split('_')[0]], 138 | 'facts': final_f, 139 | 'mask': length, 140 | 'c_mask': c_mask}) 141 | elif img in val_set: 142 | val.append({'image': img, 143 | 'label': label, 144 | 'expression': expression, 145 | 'e_mask': e_mask, 146 | 'bbox': bbox, 147 | 'w_h': w_h[k.split('_')[0]], 148 | 'facts': final_f, 149 | 'mask': length, 150 | 'c_mask': c_mask}) 151 | else: 152 | test.append({'image': img, 153 | 'label': label, 154 | 'expression': expression, 155 | 'e_mask': e_mask, 156 | 'bbox': bbox, 157 | 'w_h': w_h[k.split('_')[0]], 158 | 'facts': final_f, 159 | 'mask': length, 160 | 'c_mask': c_mask}) 161 | print(time.time() - start) 162 | except: 163 | continue 164 | 165 | 166 | with open('./json/train.json', 'w') as file: 167 | json.dump(train, file) 168 | with open('./json/val.json', 'w') as file: 169 | json.dump(val, file) 170 | with open('./json/test.json', 'w') as file: 171 | json.dump(test, file) 172 | 173 | -------------------------------------------------------------------------------- /images/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /json/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.8.0 2 | boto==2.49.0 3 | boto3==1.12.39 4 | botocore==1.15.39 5 | certifi==2019.6.16 6 | chardet==3.0.4 7 | click==7.1.1 8 | cycler==0.10.0 9 | docutils==0.15.2 10 | gensim==3.8.2 11 | grpcio==1.23.0 12 | idna==2.9 13 | jmespath==0.9.5 14 | joblib==0.14.1 15 | kiwisolver==1.1.0 16 | Markdown==3.1.1 17 | matplotlib==3.1.1 18 | nltk==3.5 19 | numpy==1.17.2 20 | opencv-python==4.2.0.34 21 | Pillow==6.1.0 22 | protobuf==3.9.1 23 | pyenchant==3.0.1 24 | pyparsing==2.4.2 25 | python-dateutil==2.8.0 26 | regex==2020.4.4 27 | requests==2.23.0 28 | s3transfer==0.3.3 29 | scipy==1.4.1 30 | six==1.12.0 31 | smart-open==1.11.1 32 | tensorboard==2.0.0 33 | tensorboardX==1.8 34 | torch==1.0.0 35 | torchvision==0.2.1 36 | tqdm==4.45.0 37 | urllib3==1.25.8 38 | Werkzeug==0.16.0 39 | -------------------------------------------------------------------------------- /tools/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tools/data.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torch.utils.data import Dataset 3 | import torch 4 | import json 5 | import os 6 | import numpy as np 7 | from torchvision import transforms 8 | import random 9 | 10 | 11 | def dataLoader(path): 12 | return Image.open(path).convert("RGB") 13 | 14 | 15 | class data(Dataset): 16 | def __init__(self, jpg_path, label_path, data_transform=None, loader=dataLoader): 17 | super(data, self).__init__() 18 | jpg_name = [] 19 | jpg_label = [] 20 | f_label = [] 21 | expression = [] 22 | bbox = [] 23 | facts = [] 24 | w_h = [] 25 | mask = [] 26 | e_mask = [] 27 | c_mask = [] 28 | self.loader = dataLoader 29 | self.transform = data_transform 30 | with open(label_path) as file: 31 | Data = json.load(file) 32 | for x in Data: 33 | jpg_name.append(os.path.join(jpg_path, x['image'])+'.jpg') 34 | jpg_label.append(x['label']) 35 | f_label.append(0) 36 | expression.append(x['expression']) 37 | bbox.append(x['bbox']) 38 | facts.append(x['facts']) 39 | w_h.append(x['w_h']) 40 | mask.append(x['mask']) 41 | e_mask.append(x['e_mask']) 42 | c_mask.append(x['c_mask']) 43 | self.jpg_name = jpg_name 44 | self.jpg_label = jpg_label 45 | self.f_label = f_label 46 | self.expression = expression 47 | self.bbox = bbox 48 | self.facts = facts 49 | self.w_h = w_h 50 | self.mask = mask 51 | self.e_mask = e_mask 52 | self.c_mask = c_mask 53 | 54 | def __getitem__(self, item): 55 | jpg_name = self.jpg_name[item] 56 | jpg_label = self.jpg_label[item] 57 | jpg = self.loader(jpg_name) 58 | if self.transform is not None: 59 | jpg = self.transform(jpg) 60 | label = torch.LongTensor(1) 61 | label[0] = jpg_label 62 | f_label = self.f_label[item] 63 | f_label = torch.from_numpy(np.eye(500)[jpg_label*50+f_label].reshape(10, 50)).type(torch.FloatTensor) 64 | bboxs = self.bbox[item] 65 | local = [] 66 | locations = [] 67 | for bbox in bboxs: 68 | can = Image.open(jpg_name).crop([bbox[0], bbox[1], bbox[0]+bbox[2], bbox[1]+bbox[3]]).convert("RGB") 69 | locations.append([bbox[0] / self.w_h[item][0], bbox[1] /self.w_h[item][1], (bbox[0]+bbox[2]) / self.w_h[item][0], (bbox[1]+bbox[3]) / self.w_h[item][1], bbox[2]*bbox[3] / self.w_h[item][0] /self.w_h[item][1]]) 70 | if self.transform is not None: 71 | can = self.transform(can) 72 | #print(np.array(can).shape) 73 | local.append(np.array(can)) 74 | local = torch.from_numpy(np.array(local)).view(-1, 224, 224) 75 | expression = self.expression[item][:self.e_mask[item]] 76 | #print(expression, type(expression)) 77 | random.shuffle(expression) 78 | #print(np.array(expression)) 79 | expression = np.pad(np.array(expression), (0, 50-self.e_mask[item]), 'constant', constant_values=(0, 15732)) 80 | expression = torch.from_numpy(np.array(expression)).type(torch.LongTensor) 81 | #print(np.array(self.facts[item]).shape) 82 | facts = torch.from_numpy(np.array(self.facts[item])).type(torch.LongTensor).view(-1, 50) 83 | locations = torch.from_numpy(np.array(locations)).type(torch.FloatTensor) 84 | #print(self.mask[item]) 85 | #print(np.array(self.mask[item]).shape) 86 | mask = [] 87 | #print(len(self.mask[item][0])) 88 | #print(len(self.mask[item])) 89 | f_mask = [] 90 | ff_mask = [] 91 | for i in range(len(self.mask[item])): 92 | middle = [] 93 | l = 0 94 | #print(len(self.mask[item][i])) 95 | #print(self.mask[item][i]) 96 | for j in range(len(self.mask[item][0])):#len(self.mask[item][i])): 97 | if self.mask[item][i][j] > 0: 98 | middle.append(list(np.pad(np.ones(self.mask[item][i][j]), (0, 50-self.mask[item][i][j])) / self.mask[item][i][j])) 99 | #middle.append(list(np.eye(50)[self.mask[item][i][j]-1])) 100 | l += 1 101 | else: 102 | middle.append(list(np.zeros(50))) 103 | mask.append(middle) 104 | f_mask.append(list(np.pad(np.ones(l), (0, 50-l), 'constant'))) 105 | ff_mask.append(list(np.eye(50)[l-1])) 106 | mask = np.array(mask) 107 | #print(mask.shape) 108 | mask = torch.from_numpy(mask).type(torch.FloatTensor) 109 | f_mask = np.array(f_mask) 110 | f_mask = torch.from_numpy(f_mask).type(torch.FloatTensor) 111 | ff_mask = np.array(ff_mask) 112 | ff_mask = torch.from_numpy(ff_mask).type(torch.FloatTensor) 113 | #mask = torch.from_numpy(np.array(self.mask[item])).type(torch.FloatTensor) 114 | #e_mask = np.eye(50)[self.e_mask[item]-1] 115 | e_mask = np.pad(np.ones(self.e_mask[item]), (0, 50-self.e_mask[item]))# / self.e_mask[item] 116 | e_mask = torch.from_numpy(e_mask).type(torch.FloatTensor) 117 | #e_mask = torch.from_numpy(np.array(self.e_mask[item])).type(torch.FloatTensor) 118 | c_mask = np.pad(np.ones(self.c_mask[item]), (0, 10-self.c_mask[item]))# / self.e_mask[item] 119 | c_mask = torch.from_numpy(c_mask).type(torch.FloatTensor) 120 | return jpg, label, f_label, expression, e_mask, local, locations, facts, mask, f_mask, ff_mask, jpg_name, c_mask 121 | 122 | def __len__(self): 123 | return len(self.jpg_name) 124 | 125 | 126 | if __name__ == '__main__': 127 | test_transform = transforms.Compose([ 128 | transforms.Resize((224, 224)), 129 | transforms.ToTensor(), 130 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 131 | ]) 132 | -------------------------------------------------------------------------------- /tools/mn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | from torch.nn import functional as F 5 | import pdb 6 | import numpy as np 7 | from PIL import Image 8 | import matplotlib.pyplot as plt 9 | import time 10 | from torch.autograd import Variable 11 | import argparse 12 | 13 | 14 | class Basenet(nn.Module): 15 | def __init__(self, config): 16 | super(Basenet, self).__init__() 17 | self.w_embed_size = config.vocab_size 18 | self.text_len = config.text_len 19 | self.cand_len = config.cand_len 20 | self.fact_len = config.fact_len 21 | self.max_episodic = config.max_episodic 22 | self.q_lstm_dim = config.q_lstm_dim 23 | self.s_lstm_dim = config.s_lstm_dim 24 | self.g_dim = 512 25 | self.g_attnd_dim = 512 26 | self.o_dim = config.o_dim 27 | self.l_dim = config.l_dim 28 | self.s_attnd_dim = 512 29 | 30 | 31 | self.f_global = models.vgg16(pretrained=True).features 32 | #self.global = nn.Sequential(*list(models.resnet50(pretrained=True).children())[:-2]) 33 | 34 | self.w_embedding = nn.Embedding(self.w_embed_size, self.q_lstm_dim) 35 | 36 | self.f_expression = nn.LSTM(self.q_lstm_dim, self.q_lstm_dim, batch_first=True) 37 | self.exp_bn = nn.BatchNorm1d(self.q_lstm_dim) 38 | self.f_fact = nn.LSTM(self.q_lstm_dim, self.s_lstm_dim, batch_first=True) 39 | self.fact_bn = nn.BatchNorm1d(self.s_lstm_dim) 40 | 41 | self.g_attnd_q = nn.Linear(self.q_lstm_dim, self.g_attnd_dim) 42 | self.g_attnd_g = nn.Linear(self.g_dim, self.g_attnd_dim) 43 | self.g_fc = nn.Linear(self.g_attnd_dim, 1) 44 | 45 | self.f_local = nn.Sequential( 46 | nn.Linear(self.g_dim, self.o_dim), 47 | nn.BatchNorm1d(self.o_dim), 48 | nn.ReLU(inplace=True), 49 | ) 50 | 51 | self.f_location = nn.Sequential( 52 | nn.Linear(5, self.l_dim), 53 | nn.BatchNorm1d(self.l_dim), 54 | nn.ReLU(inplace=True), 55 | ) 56 | self.g1 = nn.Sequential( 57 | nn.Linear(self.s_lstm_dim*4, config.g1_dim), 58 | ) 59 | self.g2 = nn.Linear(config.g1_dim, 1) 60 | self.mn = nn.LSTMCell(self.s_lstm_dim, self.s_lstm_dim) 61 | self.m_Cell = nn.LSTMCell(self.s_lstm_dim, self.s_lstm_dim) 62 | self.f_final = nn.Sequential( 63 | nn.Linear(self.g_dim+self.s_lstm_dim+self.o_dim+self.l_dim, self.q_lstm_dim), 64 | ) 65 | 66 | def attention_image(self, whole, f_expression): 67 | with torch.no_grad(): 68 | f_global = self.f_global(whole).view(-1, self.g_dim, 7*7).permute(0, 2, 1).contiguous() 69 | g_attnd_g = self.g_attnd_g(f_global) 70 | g_attnd_q = self.g_attnd_q(f_expression).unsqueeze(1).expand_as(g_attnd_g) 71 | g_attnd = self.g_fc(F.tanh(g_attnd_g+g_attnd_q)) 72 | weight = F.softmax(g_attnd, dim=1).unsqueeze(1).squeeze(3) 73 | return torch.bmm(weight, f_global).squeeze(1) 74 | 75 | def candidate_visual(self, locals): 76 | bs = locals.size()[0] 77 | locals = locals.view(bs, -1, 3, 224, 224) 78 | locals = locals.view(-1, 3, 224, 224).contiguous() 79 | with torch.no_grad(): 80 | f_local = self.f_global(locals) 81 | f_local = F.avg_pool2d(f_local, kernel_size=(7, 7)).squeeze(2).squeeze(2) 82 | f_local = self.f_local(f_local).view(bs, -1, self.g_dim) 83 | return f_local 84 | 85 | def episodic_memory(self, f_facts, f_expression, m, f_mask, ff_mask): 86 | f_expression = f_expression.unsqueeze(1).expand_as(f_facts) 87 | m = m.unsqueeze(1).expand_as(f_facts) 88 | z = torch.cat([f_facts*f_expression, f_facts*m, torch.abs(f_facts-f_expression), torch.abs(f_facts-m)], dim=2).view(-1, self.s_lstm_dim*4) 89 | Z = self.g2(F.tanh(self.g1(z))).view(-1, self.fact_len).masked_fill_(f_mask, -9999999) 90 | weights = F.softmax(Z, 1) 91 | h_pre = Variable(torch.zeros(Z.size()[0], self.s_lstm_dim).cuda()) 92 | c_pre = Variable(torch.zeros(Z.size()[0], self.s_lstm_dim).cuda()) 93 | hs = Variable(torch.zeros(Z.size()[0], 1, self.s_lstm_dim).cuda()) 94 | cs = Variable(torch.zeros(Z.size()[0], 1, self.s_lstm_dim).cuda()) 95 | for i in range(self.fact_len): 96 | h, c = self.mn(f_facts[:, i, :].squeeze(1), (h_pre, c_pre)) 97 | h_pre = weights[:, i].unsqueeze(1) * h + (1-weights[:, i]).unsqueeze(1) * h_pre 98 | c_pre = c 99 | hs = torch.cat((hs, h_pre.unsqueeze(1)), 1) 100 | cs = torch.cat((cs, c_pre.unsqueeze(1)), 1) 101 | return torch.bmm(ff_mask.view(-1, self.fact_len).unsqueeze(1), hs[:,1:,:]).squeeze(1), torch.bmm(ff_mask.view(-1, self.fact_len).unsqueeze(1), cs[:,1:,:]).squeeze(1) 102 | 103 | def forward(self, whole, expression, e_mask, locals, locations, facts, mask, f_mask, ff_mask, c_mask): 104 | #f_expression 105 | bs = expression.size()[0] 106 | f_expression = F.dropout(self.w_embedding(expression), 0.1) 107 | self.f_expression.flatten_parameters() 108 | x, _ = self.f_expression(f_expression) 109 | f_expression = torch.bmm(e_mask.unsqueeze(1), x).squeeze(1) 110 | f_expression = self.exp_bn(f_expression) 111 | 112 | #top-down attention 113 | f_global = self.attention_image(whole, f_expression) 114 | 115 | #f_local 116 | f_local = self.candidate_visual(locals) 117 | 118 | #f_location 119 | f_location = self.f_location(locations.view(-1, 5)).view(bs, -1, 128) 120 | 121 | #f_facts 122 | f_facts = F.dropout(self.w_embedding(facts), 0.1) 123 | #pdb.set_trace() 124 | self.f_fact.flatten_parameters() 125 | x, _ = self.f_fact(f_facts.contiguous().view(-1, self.text_len, self.s_lstm_dim)) 126 | f_facts = torch.bmm(mask.view(-1, self.text_len).unsqueeze(1), x.view(-1, self.text_len, self.s_lstm_dim)) 127 | f_facts = self.fact_bn(f_facts.squeeze(1).contiguous()).view(-1, self.cand_len*self.fact_len, self.s_lstm_dim).view(-1, self.cand_len, self.fact_len, self.s_lstm_dim) 128 | 129 | #memory network 130 | f_m = f_expression.unsqueeze(1).expand(bs, self.cand_len, self.q_lstm_dim).contiguous().view(-1, self.q_lstm_dim) 131 | m = f_m 132 | f_mask = torch.eq(f_mask.view(-1, self.fact_len), 0) 133 | for i in range(self.max_episodic): 134 | h, c = self.episodic_memory(f_facts.view(-1, self.fact_len, self.s_lstm_dim), f_m, m, f_mask, ff_mask) 135 | m, _ = self.m_Cell(h, (m, c)) 136 | m = m.view(-1, self.cand_len, self.s_lstm_dim) 137 | 138 | #prediction 139 | f_global = f_global.unsqueeze(1).expand(bs, self.cand_len, self.g_dim) 140 | f = self.f_final(torch.cat((m, f_global, f_local, f_location), 2).contiguous().view(-1, self.g_dim+self.s_lstm_dim+self.o_dim+self.l_dim)).contiguous().view(bs, -1, self.q_lstm_dim) 141 | f_expression = f_expression.unsqueeze(1).expand(bs, self.cand_len, self.q_lstm_dim) 142 | scores = torch.sum(f_expression * f, dim=2) 143 | c_mask = torch.eq(c_mask.view(-1, self.cand_len), 0) 144 | scores = F.softmax(scores.masked_fill_(c_mask, -9999999), dim=1) 145 | return scores 146 | 147 | 148 | if __name__ == '__main__': 149 | argparser = argparse.ArgumentParser() 150 | argparser.add_argument('--vocab_size', type=int, default=15733) 151 | argparser.add_argument('--cand_len', type=int, default=10) 152 | argparser.add_argument('--fact_len', type=int, default=50) 153 | argparser.add_argument('--text_len', type=int, default=50) 154 | argparser.add_argument('--max_episodic', type=int, default=5) 155 | argparser.add_argument('--q_lstm_dim', type=int, default=2048) 156 | argparser.add_argument('--s_lstm_dim', type=int, default=2048) 157 | argparser.add_argument('--o_dim', type=int, default=512) 158 | argparser.add_argument('--l_dim', type=int, default=128) 159 | argparser.add_argument('--g1_dim', type=int, default=512) 160 | args = argparser.parse_args() 161 | net = Basenet(args) 162 | print(net) 163 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torch.autograd import Variable 4 | from torchvision import transforms 5 | from mn import Basenet 6 | from data import data 7 | from PIL import Image 8 | from tensorboardX import SummaryWriter 9 | import json 10 | import torch.nn.functional as F 11 | from matplotlib import pyplot as plt 12 | import time 13 | import pdb 14 | import numpy as np 15 | import torch.nn.functional as F 16 | import argparse 17 | 18 | 19 | argparser = argparse.ArgumentParser() 20 | argparser.add_argument('--vocab_size', type=int, default=15733) 21 | argparser.add_argument('--cand_len', type=int, default=10) 22 | argparser.add_argument('--fact_len', type=int, default=50) 23 | argparser.add_argument('--text_len', type=int, default=50) 24 | argparser.add_argument('--max_episodic', type=int, default=5) 25 | argparser.add_argument('--q_lstm_dim', type=int, default=2048) 26 | argparser.add_argument('--s_lstm_dim', type=int, default=2048) 27 | argparser.add_argument('--o_dim', type=int, default=512) 28 | argparser.add_argument('--l_dim', type=int, default=128) 29 | argparser.add_argument('--g1_dim', type=int, default=512) 30 | args = argparser.parse_args() 31 | 32 | 33 | 34 | def train(epoch, model, critertion, f_loss, optimizer, use_gpu): 35 | model.train() 36 | correct = 0 37 | train_loss = 0 38 | for batch_id, (jpg, label, f_label, expression, e_mask, locals, locations, facts, mask, f_mask, ff_mask, path, c_mask) in enumerate(trainDataloader): 39 | if use_gpu: 40 | jpg = Variable(jpg.cuda()) 41 | label = Variable(label.cuda()) 42 | f_label = Variable(f_label.cuda()) 43 | expression = Variable(expression.cuda()) 44 | e_mask = Variable(e_mask.cuda()) 45 | locals = Variable(locals.cuda()) 46 | locations = Variable(locations.cuda()) 47 | facts = Variable(facts.cuda()) 48 | mask = Variable(mask.cuda()) 49 | f_mask = Variable(f_mask.cuda()) 50 | ff_mask = Variable(ff_mask.cuda()) 51 | c_mask = Variable(c_mask.cuda()) 52 | else: 53 | jpg, label = Variable(jpg), Variable(label) 54 | f_label = Variable(f_label) 55 | expression = Variable(expression) 56 | e_mask = Variable(e_mask) 57 | locals = Variable(locals) 58 | locations = Variable(locations) 59 | facts = Variable(facts) 60 | mask = Variable(mask) 61 | f_mask = Variable(f_mask) 62 | ff_mask = Variable(ff_mask) 63 | c_mask = Variable(c_mask) 64 | if len(label.size()) == 2: 65 | label = label.squeeze(1) 66 | optimizer.zero_grad() 67 | output = model(jpg, expression, e_mask, locals, locations, facts, mask, f_mask, ff_mask, c_mask) 68 | #output, i_weight = model(jpg, expression, e_mask, locals, locations, facts, mask, f_mask, ff_mask, c_mask) 69 | #print(3, time.time()-start) 70 | #print(output.size(), label.size()) 71 | loss = critertion(output, label) 72 | #print(label) 73 | #print(f_weight) 74 | #print(output) 75 | train_loss += loss.item() 76 | writer.add_scalar('train_loss', loss.item(), (epoch-1)*len(trainDataloader)+batch_id) 77 | pred = output.data.max(1)[1] 78 | correct += pred.eq(label.data).cpu().sum() 79 | #print(correct) 80 | loss.backward() 81 | #pdb.set_trace() 82 | optimizer.step() 83 | if batch_id % 20 == 0: 84 | print('Train Epoch:{} [{}/{} ({:.0f}%)]\tLoss:{:.6f}'.format( 85 | epoch, batch_id * len(jpg), len(trainDataloader.dataset), 86 | 100. * batch_id / len(trainDataloader), loss.item() 87 | )) 88 | train_loss /= len(trainDataloader) 89 | writer.add_scalar('train_acc', correct / len(trainDataloader.dataset), epoch) 90 | print('\nTrain Set: Average loss:{:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 91 | train_loss, correct, len(trainDataloader.dataset), 92 | 100. * correct / len(trainDataloader.dataset) 93 | )) 94 | 95 | 96 | def test(epoch, model, critertion, f_loss, use_gpu, Dataset): 97 | model.eval() 98 | val_loss = 0 99 | correct = 0 100 | result = [] 101 | with torch.no_grad(): 102 | for batch_id, (jpg, label, f_label, expression, e_mask, locals, locations, facts, mask, f_mask, ff_mask, path, c_mask) in enumerate(Dataset): 103 | if use_gpu: 104 | jpg = Variable(jpg.cuda()) 105 | label = Variable(label.cuda()) 106 | f_label = Variable(f_label.cuda()) 107 | expression = Variable(expression.cuda()) 108 | e_mask = Variable(e_mask.cuda()) 109 | locals = Variable(locals.cuda()) 110 | locations = Variable(locations.cuda()) 111 | facts = Variable(facts.cuda()) 112 | mask = Variable(mask.cuda()) 113 | f_mask = Variable(f_mask.cuda()) 114 | ff_mask = Variable(ff_mask.cuda()) 115 | c_mask = Variable(c_mask.cuda()) 116 | else: 117 | jpg, label = Variable(jpg), Variable(label) 118 | f_label = Variable(f_label) 119 | expression = Variable(expression) 120 | e_mask = Variable(e_mask) 121 | locals = Variable(locals) 122 | locations = Variable(locations) 123 | facts = Variable(facts) 124 | mask = Variable(mask) 125 | f_mask = Variable(f_mask) 126 | ff_mask = Variable(ff_mask) 127 | c_mask = Variable(c_mask) 128 | if len(label.size()) == 2: 129 | label = label.squeeze(1) 130 | output = model(jpg, expression, e_mask, locals, locations, facts, mask, f_mask, ff_mask, c_mask) 131 | #print(label) 132 | #output, i_weight = model(jpg, expression, e_mask, locals, locations, facts, mask, f_mask, ff_mask, c_mask) 133 | #print(output.size(), label.size()) 134 | loss = critertion(output, label) #+ f_loss(f_weight, f_label) 135 | val_loss += loss.item() 136 | writer.add_scalar('val_loss', loss.item(), (epoch/5-1)*len(Dataset)+batch_id) 137 | pred = output.data.max(1)[1] 138 | correct += pred.eq(label.data).cpu().sum() 139 | #result.append({'image': path,'gt': label.data.cpu().numpy().tolist(),'pred': output.data.cpu().numpy().tolist(),'f_weight': f_weight.detach().cpu().numpy().tolist()}) 140 | result.append({'image': path,'gt': label.data.cpu().numpy().tolist(),'pred': output.data.cpu().numpy().tolist()}) 141 | 142 | with open('result_mn.json', 'w') as file: 143 | json.dump(result, file) 144 | #''' 145 | val_loss /= len(Dataset) 146 | writer.add_scalar('val_acc', correct / len(Dataset.dataset), epoch) 147 | print('\nTest Set: Average loss:{:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 148 | val_loss, correct, len(Dataset.dataset), 149 | 100. * correct / len(Dataset.dataset) 150 | )) 151 | return val_loss 152 | 153 | 154 | if __name__ == '__main__': 155 | writer = SummaryWriter('{}/{}'.format('runs', 'mn')) 156 | train_transform = transforms.Compose([ 157 | transforms.Resize((224, 224)), 158 | transforms.RandomHorizontalFlip(), 159 | transforms.ToTensor(), 160 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 161 | ]) 162 | test_transform = transforms.Compose([ 163 | transforms.Resize((224, 224)), 164 | transforms.ToTensor(), 165 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 166 | ]) 167 | net = Basenet(args) 168 | trainDataset = data('../image','../json/train.json', data_transform=train_transform) 169 | valDataset = data('../image', '../json/val.json', data_transform=test_transform) 170 | testDataset = data('../image','../json/test.json', data_transform=test_transform) 171 | trainDataloader = DataLoader(trainDataset, batch_size=16, shuffle=True, num_workers=4, drop_last=True) 172 | valDataloader = DataLoader(valDataset, batch_size=16, shuffle=True, num_workers=4, drop_last=True) 173 | testDataloader = DataLoader(testDataset, batch_size=16, shuffle=True, num_workers=4, drop_last=True) 174 | critertion = torch.nn.CrossEntropyLoss() 175 | f_loss = torch.nn.BCELoss() 176 | use_gpu = torch.cuda.is_available() 177 | no_params = list(map(id, net.f_global.parameters())) 178 | base_params = filter(lambda x: id(x) not in no_params, net.parameters()) 179 | optimizer = torch.optim.SGD([ 180 | {'params': base_params}, 181 | {'params': net.f_global.parameters(), 'lr': 0} 182 | ], lr=1e-4, momentum=0.9, weight_decay=1e-5) 183 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=1) 184 | use_gpu = torch.cuda.is_available() 185 | 186 | if use_gpu: 187 | net.cuda() 188 | net = torch.nn.DataParallel(net) 189 | for epoch in range(1, 41): 190 | train(epoch, net, critertion, f_loss, optimizer, use_gpu) 191 | val_loss = test(epoch, net, critertion, f_loss, use_gpu, valDataloader) 192 | scheduler.step(val_loss) 193 | if epoch % 10 == 0: 194 | torch.save(net, './train'+str(epoch)+'_mn.pth') 195 | test_loss = test(epoch, net, critertion, f_loss, use_gpu, testDataloader) 196 | #''' 197 | -------------------------------------------------------------------------------- /top_fact.py: -------------------------------------------------------------------------------- 1 | from gensim.models import word2vec 2 | import json 3 | from collections import defaultdict 4 | from nltk import sent_tokenize, word_tokenize 5 | import numpy as np 6 | import random 7 | import time 8 | 9 | 10 | if __name__ == '__main__': 11 | model = word2vec.Word2Vec.load('./word2vec_model/facts.model') 12 | with open('./data/KB-REF/expression.json') as file: 13 | expression = json.load(file) 14 | with open('./data/KB-REF/candidate.json') as file: 15 | cand = json.load(file) 16 | with open('./data/KB-REF/Wikipedia.json') as file: 17 | Wikipedia = json.load(file) 18 | with open('./data/KB-REF/ConceptNet.json') as file: 19 | ConceptNet = json.load(file) 20 | with open('./data/KB-REF/WebChild.json') as file: 21 | WebChild = json.load(file) 22 | with open('./data/KB-REF/objects.json') as file: 23 | objects = json.load(file) 24 | 25 | top_facts = {} 26 | for k in expression: 27 | start = time.time() 28 | middle = {} 29 | img = k.split('_')[0] 30 | e = expression[k][0] 31 | candidates = cand[img] 32 | j = 0 33 | em = np.zeros(300) 34 | for f in word_tokenize(e): 35 | try: 36 | em += np.array(model.wv.get_vector(f.lower())) 37 | j += 1 38 | except: 39 | continue 40 | em /= j 41 | for c in candidates: 42 | if c != '-1': 43 | sims = [] 44 | fs = [] 45 | final = [] 46 | o = objects[img][c][0].split('.')[0] 47 | try: 48 | facts = sent_tokenize(Wikipedia[o.lower()]) 49 | for fact in facts: 50 | j = 0 51 | nm = np.zeros(300) 52 | for f in word_tokenize(fact): 53 | try: 54 | nm += np.array(model.wv.get_vector(f.lower())) 55 | j += 1 56 | except: 57 | continue 58 | if j!= 0: 59 | nm /= j 60 | sim = np.dot(em, nm) / (np.linalg.norm(em) * np.linalg.norm(nm)) 61 | sim = 0.5 + 0.5 * sim 62 | fs.append(fact) 63 | sims.append(sim) 64 | except: 65 | continue 66 | try: 67 | facts = sent_tokenize(ConceptNet[o.lower()].replace('.', '. ').replace('has/have ', '')) 68 | for fact in facts: 69 | j = 0 70 | nm = np.zeros(300) 71 | for f in word_tokenize(fact): 72 | try: 73 | nm += np.array(model.wv.get_vector(f.lower())) 74 | j += 1 75 | except: 76 | continue 77 | if j!= 0: 78 | nm /= j 79 | sim = np.dot(em, nm) / (np.linalg.norm(em) * np.linalg.norm(nm)) 80 | sim = 0.5 + 0.5 * sim 81 | fs.append(fact) 82 | sims.append(sim) 83 | except: 84 | continue 85 | try: 86 | facts = sent_tokenize(WebChild[o.lower()]) 87 | for fact in facts: 88 | j = 0 89 | nm = np.zeros(300) 90 | for f in word_tokenize(fact): 91 | try: 92 | nm += np.array(model.wv.get_vector(f.lower())) 93 | j += 1 94 | except: 95 | continue 96 | if j!= 0: 97 | nm /= j 98 | sim = np.dot(em, nm) / (np.linalg.norm(em) * np.linalg.norm(nm)) 99 | sim = 0.5 + 0.5 * sim 100 | fs.append(fact) 101 | sims.append(sim) 102 | except: 103 | continue 104 | sims = np.array(sims) 105 | inxs = np.argsort(-sims)[0:50] 106 | for ix in inxs: 107 | final.append(fs[ix]) 108 | random.shuffle(final) 109 | middle = dict(middle, **{c: final}) 110 | 111 | top_facts = dict(top_facts, **{k: middle}) 112 | print(time.time()-start) 113 | 114 | with open('./json/top_facts.json', 'w') as file: 115 | json.dump(top_facts, file) 116 | 117 | -------------------------------------------------------------------------------- /word2vec_model/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /word2vector.py: -------------------------------------------------------------------------------- 1 | import json 2 | from gensim.models import word2vec 3 | from nltk import word_tokenize, sent_tokenize 4 | import enchant 5 | import time 6 | 7 | 8 | if __name__ == '__main__': 9 | with open('./data/KB-REF/Wikipedia.json') as file: 10 | Wikipedia = json.load(file) 11 | with open('./data/KB-REF/ConceptNet.json') as file: 12 | ConceptNet = json.load(file) 13 | with open('./data/KB-REF/WebChild.json') as file: 14 | WebChild = json.load(file) 15 | out = '' 16 | l = 0 17 | for k in Wikipedia: 18 | fact = sent_tokenize(Wikipedia[k]) 19 | l += len(fact) 20 | for f in fact: 21 | out += f 22 | out += '\n' 23 | for k in ConceptNet: 24 | fact = sent_tokenize(ConceptNet[k].replace('.', '. ').replace('has/have ', '')) 25 | l += len(fact) 26 | for f in fact: 27 | out += f 28 | out += '\n' 29 | for k in WebChild: 30 | fact = sent_tokenize(WebChild[k]) 31 | l += len(fact) 32 | for f in fact: 33 | out += f 34 | out += '\n' 35 | with open('./data/KB-REF/f.txt', 'w', encoding='utf-8') as file: 36 | file.write(out) 37 | file.close() 38 | 39 | start = time.time() 40 | sentences = word2vec.Text8Corpus('./data/KB-REF/f.txt') 41 | model = word2vec.Word2Vec(sentences, size=300, hs=1, sg=1, min_count=5, window=5, iter=100, workers=4) 42 | model.save('./word2vec_model/facts.model') 43 | print(time.time()-start) 44 | --------------------------------------------------------------------------------