├── README.md
├── datasets.py
├── evaluation.py
├── homogeneous_data.py
├── model.py
├── pre_transforms.py
├── query_dump.py
├── server.py
├── static
├── dataset
│ └── arch
│ │ └── annotations
│ │ └── dump_data_pair.py
└── web
│ ├── bootstrap.min.css
│ ├── index.css
│ ├── index.js
│ ├── jquery.min.js
│ └── jumbotron-narrow.css
├── templates
└── index.html
├── test.py
├── tools.py
├── train.py
├── utils.py
└── vocab.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 | ### 如何运行代码?
3 |
4 |
5 | ##### 1. 安装环境
6 |
7 | ```Python
8 | pip install http://download.pytorch.org/whl/cu75/torch-0.1.12.post2-cp27-none-linux_x86_64.whl
9 | pip install torchvision
10 | pip install gensim
11 | pip install hyperboard
12 | ```
13 |
14 |
15 | ##### 2. 运行项目代码
16 |
17 |
18 | ```Python
19 | python dump_data_pair.py
20 | python pre_transforms.py
21 | hyperboard-run --port 5020
22 | python test.py
23 | python query_dumpy.py
24 | python server.py
25 |
26 | ```
27 |
28 |
29 |
30 | ### 对代码的详细解释
31 |
32 | `dump_data_pair.py` 从原始数据 jianzhu_tag.json 中抽取特征 title, detail, image, (url),同时对数据做 shuffle.
33 |
34 | `pre_transforms.py` 将图片预处理成 2048 维的向量。文本使用 one-hot 编码。训练集得到图片向量 images_train 和图片描述 captions_train。验证集用于 Recall@K 的计算,需要去除重复的文本,并存储 caps_obj_id 和 imgs_obj_id 来判断图片和文本是否匹配。 caps_url, imgs_url, imgs_path 主要用于做补充信息的展示。测试集是训练集和验证集的合并,用于用户查询和寻找 good case。寻找 good case 时做模型效果验证(图片和文本是否匹配),所以也需要存储 caps_obj_id 和 imgs_obj_id。**预处理后需要对得到的文件做一些移动。**train 和 dev 数据全部放到 data 目录下相应位置,test 的 caps.txt 和 imgs.npy 放在 data 目录下相应位置,用来给 load_dataset 读取数据。 test 的 caps_url.json、imgs_url.json 和 imgs_path.json 放到 vse 目录下对应的 server 子目录中,用来做查询后的展示。
35 |
36 | `test.py` 训练好的模型和对应的超参数会被保存下来。这一步包含数据的读取和词典构造、处理不同长度的句子、计算 pairwise ranking loss、计算 Recall@K 等,这些会在后面的文档中详细进行说明。
37 |
38 | `query_dump.py` 读取当前最好的训练模型和对应的超参数,将数据集中对应的图片和文本转换成图片向量和文本向量,并保存在 vse 目录下对应的 server 子目录中,以备查询之用。这里重新保存了训练模型和超参数,表示图片向量和文本向量是使用这个模型和超参数得到的。
39 |
40 | `server.py` 搭建图文互搜网站,供用户输入建筑描述或建筑图片,返回相应的查询结果
41 |
42 |
43 | 这里我们重点关注 train.py,这份代码是图文互搜项目的核心代码。
44 |
45 | ##### 1. 处理不同长度的句子
46 |
47 | 读取训练集和验证集的数据,并利用两者的 caption 构造字典。build_dictionary 返回的是 worddict 和 wordcount。两者都按照word出现的次数做了排序。worddict 是单词以及它们的 id。wordcount是单词以及它们出现的次数。
48 |
49 |
50 | HomogeneousData 返回 Batch,**且每个 Batch 的文本长度都相同**。`prepare()` 统计每个长度下有多少句子以及每个句子的位置。`reset()` 对句子长度做乱序,对同一句子长度中的句子顺序做乱序。len_curr_counts 存储着每个长度下还有多少句子没有被使用。与之对应的是,len_indices_pos 存储着每个长度下访问到了哪个句子。
51 |
52 | `next()` 如果当前句子长度是否还有句子没有被访问,那么跳出 while 循环去访问该句子长度下还未被访问的句子,然后跳到下一个句子长度 (注意这里不是把某个长度下的所有句子都访问完后,再访问下一个句子长度的句子)。否则,查看下一个句子长度是否还有句子未被访问,然后继续上面的操作。如果所有句子长度下的所有句子都被访问了。那么一个 epoch 就结束了,调用 reset 重置相关变量。
53 |
54 | 在访问某句子长度下还未被访问的句子时,首先通过 len_curr_counts 确定该长度下还有多少句子未被访问。然后和 batch_size 取一个较小值,并命名为 curr_batch_size。然后通过 len_indices_pos 得到当前长度访问到了哪个句子,从该句子开始访问 curr_batch_size 个句子,并通过 len_indices 得到这些句子的位置。更新 len_indices_pos 和 len_curr_counts。最后返回对应位置的句子和图片。
55 |
56 | `prepare_data()` 输入的是 batch,里面包含文本 caps 和图片特征 features。对 caps 做分词并通过worddict 转换为单词 id。抛弃到长度大于 maxlen 的句子。然后将文本向量和图片特征向量从 list 转成 numpy。
57 |
58 | `encode_sentences()` 对验证集中的句子进行向量编码。ds 是一个可以按照长度访问句子的字典。[minibatch::numbatches] 的意思是从 minibatch 开始,每 numbatches 个取一个。然后将单词转换为单词 id,最后将沿着 f_senc ==> build_sentence_encoder 得到文本向量。在构造batch的时候句子的序号是按照相同长度被打乱的,但是到 features 的时候又根据句子的 id 进行重新排位,这时图片和文本又能够对应上了。
59 |
60 |
61 | ##### 2. PairwiseRankingLoss 的计算
62 |
63 | PairWiseRanking 的输入是图片 (batch_size, dim) 和文本 (batch_size, dim)。将图片矩乘以文本矩阵得到相似度矩阵。相似度矩阵的对角线是图文对的相似度。每一行是图片和其它文本 (包括匹配文本)的相似度,每一列是文本和其它图片 (包括匹配图片)的相似度。有了这些之后,我们就可以计算 pairwise ranking loss 了!!!!
64 |
65 | #### 3. Recall@K 的计算
66 |
67 | arch 数据集没有公开数据集那么规范,每张图片都有5个对应描述,然后使用天然的 index 得到图片和文本是否匹配。Recall@K 只用到验证集的数据,对训练集没有影响。在作验证集数据处理的时候,读入的每个样例包含图文对以及它们所属的 obj_id。然后将所有文本做去重处理,得到 captions_dev。而与之对应的数组 caps_obj_id ,标识了每个 caption 对应的 obj_id。类似的,images_dev 也有 imgs_obj_id 来标识每个 image 对应的 obj_id。因此,在计算 Recall@K 时,我们通过图片和文本的 obj_id 来判断两者是否匹配。
68 |
69 | i2t_arch() 计算以图搜文的 Recall@K。输入是全部的图片和全部文本向量。之前我们已经提到,在文本到文本向量的转换中,其顺序并没有改变。图片到图片向量的过程也是。遍历所有图片,依次和所有文本计算相似度。inds 按相似度排序的文本的序号。因为之前的处理中文本顺序并没有改变,所以我们可以通过序号直接找到其对应的 obj_id。caps_obj_id[inds] 得到按相似度排序的文本的 obj_id,numpy.where 得到与图片对应的 obj_id 在 caps_obj_id[inds] 中出现的位置。t2i_arch() 的过程也是类似的,这里不再赘述。
70 |
71 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | """
2 | Dataset loading
3 | """
4 | import numpy
5 |
6 | path_to_data = 'data/'
7 |
8 | def load_dataset(name='f8k', load_test=False):
9 | """
10 | Load captions and image features
11 | """
12 | loc = path_to_data + name + '/'
13 |
14 | if load_test:
15 | # Captions
16 | test_caps = []
17 | with open(loc+name+'_test_caps.txt', 'rb') as f:
18 | for line in f:
19 | test_caps.append(line.strip())
20 | # Image features
21 | test_ims = numpy.load(loc+name+'_test_ims.npy')
22 | return (test_caps, test_ims)
23 | else:
24 | # Captions
25 | train_caps, dev_caps = [], []
26 | with open(loc+name+'_train_caps.txt', 'rb') as f:
27 | for line in f:
28 | train_caps.append(line.strip())
29 |
30 | with open(loc+name+'_dev_caps.txt', 'rb') as f:
31 | for line in f:
32 | dev_caps.append(line.strip())
33 |
34 | # Image features
35 | train_ims = numpy.load(loc+name+'_train_ims.npy')
36 | dev_ims = numpy.load(loc+name+'_dev_ims.npy')
37 |
38 | return (train_caps, train_ims), (dev_caps, dev_ims)
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | import numpy
2 | import torch
3 | from datasets import load_dataset
4 | from tools import encode_sentences, encode_images
5 | import json
6 |
7 | def evalrank(model, data, split='dev'):
8 | """
9 | Evaluate a trained model on either dev ortest
10 | """
11 |
12 | print 'Loading dataset'
13 | if split == 'dev':
14 | X = load_dataset(data)[1]
15 | else:
16 | X = load_dataset(data, load_test=True)
17 |
18 |
19 | print 'Computing results...'
20 | ls = encode_sentences(model, X[0])
21 | lim = encode_images(model, X[1])
22 |
23 | if data == 'arch':
24 | # Find the good case in test dataset
25 | (r1, r5, r10, medr) = i2t_arch_case(lim, ls, X[0])
26 | print "Image to text: %.1f, %.1f, %.1f, %.1f" % (r1, r5, r10, medr)
27 | (r1i, r5i, r10i, medri) = t2i_arch_case(lim, ls, X[0])
28 | print "Text to image: %.1f, %.1f, %.1f, %.1f" % (r1i, r5i, r10i, medri)
29 | else:
30 | (r1, r5, r10, medr) = i2t(lim, ls)
31 | print "Image to text: %.1f, %.1f, %.1f, %.1f" % (r1, r5, r10, medr)
32 | (r1i, r5i, r10i, medri) = t2i(lim, ls)
33 | print "Text to image: %.1f, %.1f, %.1f, %.1f" % (r1i, r5i, r10i, medri)
34 |
35 |
36 | def i2t(images, captions, npts=None):
37 | """
38 | Images->Text (Image Annotation)
39 | Images: (5N, K) matrix of images
40 | Captions: (5N, K) matrix of captions
41 | """
42 | if npts == None:
43 | npts = images.size()[0] / 5
44 |
45 | ranks = numpy.zeros(npts)
46 | for index in range(npts):
47 |
48 | # Get query image
49 | im = images[5 * index].unsqueeze(0)
50 |
51 | # Compute scores
52 | d = torch.mm(im, captions.t())
53 | d_sorted, inds = torch.sort(d, descending=True)
54 | inds = inds.data.squeeze(0).cpu().numpy()
55 |
56 | # Score
57 | rank = 1e20
58 | # find the highest ranking
59 | for i in range(5*index, 5*index + 5, 1):
60 | tmp = numpy.where(inds == i)[0][0]
61 | if tmp < rank:
62 | rank = tmp
63 | ranks[index] = rank
64 |
65 | # Compute metrics
66 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks)
67 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks)
68 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks)
69 | medr = numpy.floor(numpy.median(ranks)) + 1
70 | return (r1, r5, r10, medr)
71 |
72 |
73 | def t2i(images, captions, npts=None, data='f8k'):
74 | """
75 | Text->Images (Image Search)
76 | Images: (5N, K) matrix of images
77 | Captions: (5N, K) matrix of captions
78 | """
79 | if npts == None:
80 | npts = images.size()[0] / 5
81 |
82 | ims = torch.cat([images[i].unsqueeze(0) for i in range(0, len(images), 5)])
83 |
84 | ranks = numpy.zeros(5 * npts)
85 | for index in range(npts):
86 |
87 | # Get query captions
88 | queries = captions[5*index : 5*index + 5]
89 |
90 | # Compute scores
91 | d = torch.mm(queries, ims.t())
92 | for i in range(d.size()[0]):
93 | d_sorted, inds = torch.sort(d[i], descending=True)
94 | inds = inds.data.squeeze(0).cpu().numpy()
95 | ranks[5 * index + i] = numpy.where(inds == index)[0][0]
96 |
97 | # Compute metrics
98 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks)
99 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks)
100 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks)
101 | medr = numpy.floor(numpy.median(ranks)) + 1
102 | return (r1, r5, r10, medr)
103 |
104 |
105 | def i2t_arch(images, captions):
106 | npts = images.size()[0]
107 | ranks = numpy.zeros(npts)
108 | caps_obj_id = numpy.load(open('data/arch/arch_dev_caps_id.npy'))
109 | imgs_obj_id = numpy.load(open('data/arch/arch_dev_imgs_id.npy'))
110 | for index in range(npts):
111 | # Get query image
112 | im = images[index:index+1]
113 | # Compute scores
114 | d = torch.mm(im, captions.t())
115 | d_sorted, inds = torch.sort(d, descending=True)
116 | inds = inds.data.squeeze(0).cpu().numpy()
117 | ranks[index] = numpy.where(caps_obj_id[inds] == imgs_obj_id[index])[0][0]
118 |
119 | # Compute metrics
120 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks)
121 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks)
122 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks)
123 | medr = numpy.floor(numpy.median(ranks)) + 1
124 | return (r1, r5, r10, medr)
125 |
126 |
127 | def t2i_arch(images, captions):
128 | npts = captions.size()[0]
129 | ranks = numpy.zeros(npts)
130 | caps_obj_id = numpy.load(open('data/arch/arch_dev_caps_id.npy'))
131 | imgs_obj_id = numpy.load(open('data/arch/arch_dev_imgs_id.npy'))
132 | for index in range(npts):
133 | # Get query caption
134 | cap = captions[index:index+1]
135 | # Compute scores
136 | d = torch.mm(cap, images.t())
137 | d_sorted, inds = torch.sort(d, descending=True)
138 | inds = inds.data.squeeze(0).cpu().numpy()
139 | ranks[index] = numpy.where(imgs_obj_id[inds] == caps_obj_id[index])[0][0]
140 |
141 | # Compute metrics
142 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks)
143 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks)
144 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks)
145 | medr = numpy.floor(numpy.median(ranks)) + 1
146 | return (r1, r5, r10, medr)
147 |
148 | def i2t_arch_case(images, captions, caps_orig):
149 | npts = images.size()[0]
150 | ranks = numpy.zeros(npts)
151 | caps_obj_id = numpy.load(open('data/arch/arch_test_caps_id.npy'))
152 | imgs_obj_id = numpy.load(open('data/arch/arch_test_imgs_id.npy'))
153 | imgs_url = json.load(open('data/arch/arch_test_imgs_url.json'))
154 |
155 | print_num = 10
156 | for index in range(npts):
157 | # Get query image
158 | im = images[index:index+1]
159 | # Compute scores
160 | d = torch.mm(im, captions.t())
161 | d_sorted, inds = torch.sort(d, descending=True)
162 | inds = inds.data.squeeze(0).cpu().numpy()
163 | ranks[index] = numpy.where(caps_obj_id[inds] == imgs_obj_id[index])[0][0]
164 | temp_rank = int(ranks[index])
165 | if temp_rank == 0 and print_num > 0:
166 | print 'i2t: %d' %(10-print_num)
167 | print 'image_url: ', imgs_url[index]
168 | print 'captions ', caps_orig[inds[0]]
169 | print '\n\n'
170 | print_num -= 1
171 |
172 | # Compute metrics
173 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks)
174 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks)
175 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks)
176 | medr = numpy.floor(numpy.median(ranks)) + 1
177 | return (r1, r5, r10, medr)
178 |
179 |
180 | def t2i_arch_case(images, captions, caps_orig):
181 | npts = captions.size()[0]
182 | ranks = numpy.zeros(npts)
183 | caps_obj_id = numpy.load(open('data/arch/arch_test_caps_id.npy'))
184 | imgs_obj_id = numpy.load(open('data/arch/arch_test_imgs_id.npy'))
185 | imgs_url = json.load(open('data/arch/arch_test_imgs_url.json'))
186 | print_num = 10
187 | for index in range(npts):
188 | # Get query caption
189 | cap = captions[index:index+1]
190 | # Compute scores
191 | d = torch.mm(cap, images.t())
192 | d_sorted, inds = torch.sort(d, descending=True)
193 | inds = inds.data.squeeze(0).cpu().numpy()
194 | ranks[index] = numpy.where(imgs_obj_id[inds] == caps_obj_id[index])[0][0]
195 | temp_rank = int(ranks[index])
196 | if temp_rank == 0 and print_num > 0:
197 | print 't2i: %d' %(10-print_num)
198 | print 'caption: ', caps_orig[index]
199 | print 'img_url: ', imgs_url[inds[0]]
200 | print '\n\n'
201 | print_num -= 1
202 |
203 | # Compute metrics
204 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks)
205 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks)
206 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks)
207 | medr = numpy.floor(numpy.median(ranks)) + 1
208 | return (r1, r5, r10, medr)
209 |
--------------------------------------------------------------------------------
/homogeneous_data.py:
--------------------------------------------------------------------------------
1 | import numpy
2 | import copy
3 | import sys
4 |
5 |
6 | class HomogeneousData():
7 |
8 | def __init__(self, data, batch_size=128, maxlen=None):
9 | self.data = data
10 | self.batch_size = batch_size
11 | self.maxlen = maxlen
12 |
13 | self.prepare()
14 | self.reset()
15 |
16 | def prepare(self):
17 | # self.caps = [cap[:self.maxlen] for cap in self.data[0]]
18 | self.caps = self.data[0]
19 | self.feats = self.data[1]
20 |
21 | # find the unique lengths
22 | self.lengths = [len(cc.split()) for cc in self.caps]
23 | self.len_unique = numpy.unique(self.lengths)
24 | # remove any overly long sentences
25 | if self.maxlen:
26 | self.len_unique = [ll for ll in self.len_unique if ll <= self.maxlen]
27 |
28 | # indices of unique lengths
29 | self.len_indices = dict()
30 | self.len_counts = dict()
31 | for ll in self.len_unique:
32 | self.len_indices[ll] = numpy.where(self.lengths == ll)[0]
33 | self.len_counts[ll] = len(self.len_indices[ll])
34 |
35 | # current counter
36 | self.len_curr_counts = copy.copy(self.len_counts)
37 |
38 | def reset(self):
39 | self.len_curr_counts = copy.copy(self.len_counts)
40 | self.len_unique = numpy.random.permutation(self.len_unique)
41 | self.len_indices_pos = dict()
42 | for ll in self.len_unique:
43 | self.len_indices_pos[ll] = 0
44 | self.len_indices[ll] = numpy.random.permutation(self.len_indices[ll])
45 | self.len_idx = -1
46 |
47 | def next(self):
48 | count = 0
49 | while True:
50 | self.len_idx = numpy.mod(self.len_idx+1, len(self.len_unique))
51 | if self.len_curr_counts[self.len_unique[self.len_idx]] > 0:
52 | break
53 | count += 1
54 | if count >= len(self.len_unique):
55 | break
56 | if count >= len(self.len_unique):
57 | self.reset()
58 | raise StopIteration()
59 |
60 | # get the batch size
61 | curr_batch_size = numpy.minimum(self.batch_size, self.len_curr_counts[self.len_unique[self.len_idx]])
62 | curr_pos = self.len_indices_pos[self.len_unique[self.len_idx]]
63 | # get the indices for the current batch
64 | curr_indices = self.len_indices[self.len_unique[self.len_idx]][curr_pos:curr_pos+curr_batch_size]
65 | self.len_indices_pos[self.len_unique[self.len_idx]] += curr_batch_size
66 | self.len_curr_counts[self.len_unique[self.len_idx]] -= curr_batch_size
67 |
68 | caps = [self.caps[ii] for ii in curr_indices]
69 | feats = [self.feats[ii] for ii in curr_indices]
70 |
71 | return caps, feats
72 |
73 | def __iter__(self):
74 | return self
75 |
76 |
77 | def prepare_data(caps, features, worddict, maxlen=None, n_words=10000):
78 | """
79 | Put data into format useable by the model
80 | """
81 | seqs = []
82 | feat_list = []
83 | for i, cc in enumerate(caps):
84 | seqs.append([worddict[w] if worddict[w] < n_words else 1 for w in cc.split()])
85 | feat_list.append(features[i])
86 |
87 | lengths = [len(s) for s in seqs]
88 |
89 | if maxlen != None and numpy.max(lengths) >= maxlen:
90 | new_seqs = []
91 | new_feat_list = []
92 | new_lengths = []
93 | for l, s, y in zip(lengths, seqs, feat_list):
94 | if l < maxlen:
95 | new_seqs.append(s)
96 | new_feat_list.append(y)
97 | new_lengths.append(l)
98 | lengths = new_lengths
99 | feat_list = new_feat_list
100 | seqs = new_seqs
101 |
102 | if len(lengths) < 1:
103 | return None, None
104 |
105 | # Why not use the following code?
106 | # y_np = numpy.asarray(feat_list, dtype=numpy.float32)
107 | y = numpy.zeros((len(feat_list), len(feat_list[0]))).astype('float32')
108 | for idx, ff in enumerate(feat_list):
109 | y[idx,:] = ff
110 |
111 | n_samples = len(seqs)
112 | maxlen = numpy.max(lengths)+1
113 |
114 | x = numpy.zeros((maxlen, n_samples)).astype('int64')
115 | for idx, s in enumerate(seqs):
116 | x[:lengths[idx],idx] = s
117 |
118 | return x, y
119 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | import torch
3 | from utils import l2norm, xavier_weight
4 | from torch.autograd import Variable
5 | import torch.nn.init as init
6 | from gensim.models.word2vec import Word2Vec
7 | import numpy
8 |
9 | wvModel = Word2Vec.load('static/word2vec-chi/word2vec_news.model')
10 |
11 | class ImgSenRanking(torch.nn.Module):
12 | def __init__(self, model_options):
13 | super(ImgSenRanking, self).__init__()
14 | self.linear = torch.nn.Linear(model_options['dim_image'], model_options['dim'])
15 | self.lstm = torch.nn.LSTM(model_options['dim_word'], model_options['dim'], 1)
16 | self.embedding = torch.nn.Embedding(model_options['n_words'], model_options['dim_word'])
17 | self.model_options = model_options
18 | self.init_weights()
19 |
20 | def init_weights(self):
21 | xavier_weight(self.linear.weight)
22 | # init.xavier_normal(self.linear.weight)
23 | self.linear.bias.data.fill_(0)
24 |
25 | def forward(self, x_id, im, x):
26 | x_id_emb = self.embedding(x_id)
27 | im = self.linear(im)
28 |
29 | x_w2v = torch.zeros(*x_id_emb.size())
30 | x_cat = None
31 | if self.model_options['concat']:
32 | for i, text in enumerate(x):
33 | for j, word in enumerate(text.split()):
34 | try:
35 | x_w2v[j, i] = torch.from_numpy(wvModel[word.decode('utf8')])
36 | except KeyError:
37 | pass
38 | x_w2v = Variable(x_w2v.cuda())
39 | x_cat = torch.cat([x_id_emb, x_w2v])
40 | else:
41 | x_cat = x_id_emb
42 |
43 |
44 | if self.model_options['encoder'] == 'bow':
45 | x_cat = x_cat.sum(0).squeeze(0)
46 | else:
47 | _, (x_cat, _) = self.lstm(x_cat)
48 | x_cat = x_cat.squeeze(0)
49 |
50 | return l2norm(x_cat), l2norm(im)
51 |
52 | def forward_sens(self, x_id, x):
53 | x_id_emb = self.embedding(x_id)
54 |
55 | x_w2v = torch.zeros(*x_id_emb.size())
56 | x_cat = None
57 | if self.model_options['concat']:
58 | for i, text in enumerate(x):
59 | for j, word in enumerate(text):
60 | try:
61 | x_w2v[j, i] = torch.from_numpy(wvModel[word.decode('utf8')])
62 | except KeyError:
63 | pass
64 |
65 | x_w2v = Variable(x_w2v.cuda())
66 | x_cat = torch.cat([x_id_emb, x_w2v])
67 | else:
68 | x_cat = x_id_emb
69 |
70 | if self.model_options['encoder'] == 'bow':
71 | x_cat = x_cat.sum(0).squeeze(0)
72 | else:
73 | _, (x_cat, _) = self.lstm(x_cat)
74 | x_cat = x_cat.squeeze(0)
75 | return l2norm(x_cat)
76 |
77 | def forward_imgs(self, im):
78 | im = self.linear(im)
79 | return l2norm(im)
80 |
81 | class PairwiseRankingLoss(torch.nn.Module):
82 |
83 | def __init__(self, margin=1.0):
84 | super(PairwiseRankingLoss, self).__init__()
85 | self.margin = margin
86 |
87 | def forward(self, im, s):
88 | margin = self.margin
89 | # compute image-sentence score matrix
90 | scores = torch.mm(im, s.transpose(1, 0))
91 | diagonal = scores.diag()
92 |
93 | # compare every diagonal score to scores in its column (i.e, all contrastive images for each sentence)
94 | cost_s = torch.max(Variable(torch.zeros(scores.size()[0], scores.size()[1]).cuda()), (margin-diagonal).expand_as(scores)+scores)
95 | # compare every diagonal score to scores in its row (i.e, all contrastive sentences for each image)
96 | cost_im = torch.max(Variable(torch.zeros(scores.size()[0], scores.size()[1]).cuda()), (margin-diagonal).expand_as(scores).transpose(1, 0)+scores)
97 |
98 | for i in xrange(scores.size()[0]):
99 | cost_s[i, i] = 0
100 | cost_im[i, i] = 0
101 |
102 | return cost_s.sum() + cost_im.sum()
103 |
--------------------------------------------------------------------------------
/pre_transforms.py:
--------------------------------------------------------------------------------
1 |
2 | # coding: utf-8
3 |
4 | # In[1]:
5 |
6 | import jieba.analyse
7 | jieba.analyse.set_stop_words('static/dataset/stopwords.txt')
8 |
9 | from torchvision import transforms
10 | import json, os
11 | import numpy as np
12 | import torch
13 | from PIL import Image, ImageFile
14 | import torchvision.models as models
15 | from torch.autograd import Variable
16 |
17 | dataset_dir = 'static/dataset/arch'
18 | max_num = 1000
19 |
20 | # In[2]:
21 |
22 | image_transform = transforms.Compose([
23 | transforms.Scale([224, 224]),
24 | transforms.ToTensor(),
25 | transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ],
26 | std = [ 0.229, 0.224, 0.225 ]),
27 | ])
28 |
29 | resnet = models.resnet152(pretrained=True)
30 | resnet.fc = torch.nn.Dropout(p=0)
31 | resnet = resnet.eval()
32 | resnet = resnet.cuda()
33 |
34 |
35 | def normalize(v):
36 | norm=np.linalg.norm(v)
37 | if norm==0:
38 | return v
39 | return v/norm
40 |
41 |
42 | def pre_transforms():
43 |
44 | ImageFile.LOAD_TRUNCATED_IMAGES = True
45 |
46 | print 'Pre-transforming train ...'
47 |
48 | json_file = 'annotations/data_pair_train.json'
49 | train_data_pair = json.load(open(os.path.join(dataset_dir, json_file), 'r'))
50 |
51 | captions_train, images_train = [], []
52 | for k, item in enumerate(train_data_pair):
53 | if k % 100 == 0:
54 | print 'Processing %d/%d' %(k, len(train_data_pair))
55 | caption = ' '.join(jieba.analyse.extract_tags(item['caption'], topK=20, withWeight=False, allowPOS=()))
56 | if len(caption) == 0:
57 | continue
58 |
59 | captions_train.append(caption.encode('utf8')+'\n')
60 |
61 | img_vec = image_transform(Image.open(item['img_path']).convert('RGB')).unsqueeze(0)
62 | img_vec = resnet(Variable(img_vec.cuda())).data.squeeze(0).cpu().numpy()
63 | images_train.append(img_vec)
64 |
65 | if k > max_num:
66 | break
67 |
68 | with open('arch_train_caps.txt', 'w') as f_write:
69 | f_write.writelines(captions_train)
70 |
71 | images_train = np.asarray(images_train, dtype=np.float32)
72 | images_train = normalize(images_train)
73 | np.save('arch_train_ims.npy', images_train)
74 |
75 | print 'Pre-transforming train Done'
76 |
77 | print 'Pre-transforming dev ...'
78 |
79 | json_file = 'annotations/data_pair_val.json'
80 | dev_data_pair = json.load(open(os.path.join(dataset_dir, json_file), 'r'))
81 |
82 | captions_dev, images_dev = [], []
83 | caps_obj_id, imgs_obj_id = [], []
84 | caps_url, imgs_url, imgs_path = [], [], []
85 | for k, item in enumerate(dev_data_pair):
86 | if k % 100 == 0:
87 | print 'Processing %d/%d' %(k, len(dev_data_pair))
88 | if item['obj_id'] not in caps_obj_id:
89 | caption = ' '.join(jieba.analyse.extract_tags(item['caption'], topK=20, withWeight=False, allowPOS=()))
90 | if len(caption) == 0:
91 | continue
92 |
93 | captions_dev.append(caption.encode('utf8')+'\n')
94 | caps_obj_id.append(item['obj_id'])
95 | caps_url.append(item['url'])
96 |
97 | img_vec = image_transform(Image.open(item['img_path']).convert('RGB')).unsqueeze(0)
98 | img_vec = resnet(Variable(img_vec.cuda())).data.squeeze(0).cpu().numpy()
99 | images_dev.append(img_vec)
100 | imgs_obj_id.append(item['obj_id'])
101 | imgs_url.append(item['url'])
102 | imgs_path.append(item['img_path'])
103 |
104 | if k > max_num:
105 | break
106 |
107 | with open('arch_dev_caps.txt', 'w') as f_write:
108 | f_write.writelines(captions_dev)
109 |
110 | json.dump(caps_url, open('arch_dev_caps_url.json', 'w'))
111 | json.dump(imgs_url, open('arch_dev_imgs_url.json', 'w'))
112 | json.dump(imgs_path, open('arch_dev_imgs_path.json', 'w'))
113 |
114 | images_dev = np.asarray(images_dev, dtype=np.float32)
115 | images_dev = normalize(images_dev)
116 | np.save('arch_dev_ims.npy', images_dev)
117 |
118 | caps_obj_id = np.asarray(caps_obj_id, dtype=np.float32)
119 | imgs_obj_id = np.asarray(imgs_obj_id, dtype=np.float32)
120 | np.save('arch_dev_caps_id.npy', caps_obj_id)
121 | np.save('arch_dev_imgs_id.npy', imgs_obj_id)
122 |
123 | print 'Pre-transforming dev Done'
124 |
125 | print 'Pre-transforming test ...'
126 |
127 | test_data_pair = train_data_pair + dev_data_pair
128 |
129 | captions_test, images_test = [], []
130 | caps_obj_id_test, imgs_obj_id_test = [], []
131 | caps_url_test, imgs_url_test, imgs_path_test = [], [], []
132 | for k, item in enumerate(test_data_pair):
133 | if k % 100 == 0:
134 | print 'Processing %d/%d' %(k, len(test_data_pair))
135 | if item['obj_id'] not in caps_obj_id_test:
136 | caption = ' '.join(jieba.analyse.extract_tags(item['caption'], topK=20, withWeight=False, allowPOS=()))
137 | if len(caption) == 0:
138 | continue
139 |
140 | captions_test.append(caption.encode('utf8')+'\n')
141 | caps_obj_id_test.append(item['obj_id'])
142 | caps_url_test.append(item['url'])
143 |
144 | img_vec = image_transform(Image.open(item['img_path']).convert('RGB')).unsqueeze(0)
145 | img_vec = resnet(Variable(img_vec.cuda())).data.squeeze(0).cpu().numpy()
146 | images_test.append(img_vec)
147 | imgs_obj_id_test.append(item['obj_id'])
148 | imgs_url_test.append(item['url'])
149 | imgs_path_test.append(item['img_path'])
150 |
151 | if k > max_num:
152 | break
153 |
154 | with open('arch_test_caps.txt', 'w') as f_write:
155 | f_write.writelines(captions_test)
156 |
157 | json.dump(caps_url_test, open('arch_test_caps_url.json', 'w'))
158 | json.dump(imgs_url_test, open('arch_test_imgs_url.json', 'w'))
159 | json.dump(imgs_path_test, open('arch_test_imgs_path.json', 'w'))
160 |
161 | images_test = np.asarray(images_test, dtype=np.float32)
162 | images_test = normalize(images_test)
163 | np.save('arch_test_ims.npy', images_test)
164 |
165 | caps_obj_id_test = np.asarray(caps_obj_id_test, dtype=np.float32)
166 | imgs_obj_id_test = np.asarray(imgs_obj_id_test, dtype=np.float32)
167 | np.save('arch_test_caps_id.npy', caps_obj_id_test)
168 | np.save('arch_test_imgs_id.npy', imgs_obj_id_test)
169 |
170 | print 'Pre-transforming test Done'
171 |
172 | if __name__ == "__main__":
173 | pre_transforms()
174 |
--------------------------------------------------------------------------------
/query_dump.py:
--------------------------------------------------------------------------------
1 | from model import ImgSenRanking
2 | import cPickle as pkl
3 | import torch
4 | from datasets import load_dataset
5 | import numpy as np
6 | from tools import encode_sentences, encode_images
7 | import json
8 |
9 | data = 'arch'
10 | loadfrom = 'vse/' + data
11 | saveto = 'vse/%s_server/%s' %(data, data)
12 | hyper_params = '%s_params.pkl' % loadfrom
13 | model_params = '%s_model.pkl' % loadfrom
14 |
15 | print 'Building model ... ',
16 | model_options = pkl.load(open(hyper_params, 'r'))
17 | model = ImgSenRanking(model_options).cuda()
18 | model.load_state_dict(torch.load(model_params))
19 | print 'Done'
20 |
21 | test = load_dataset(data, load_test=True)
22 |
23 | print 'Dumping data ... '
24 |
25 | curr_model = {}
26 | curr_model['options'] = model_options
27 | curr_model['worddict'] = model_options['worddict']
28 | curr_model['word_idict'] = model_options['word_idict']
29 | curr_model['img_sen_model'] = model
30 |
31 | ls, lim = encode_sentences(curr_model, test[0]), encode_images(curr_model, test[1])
32 |
33 | # save the using params and model when dumping data
34 | torch.save(ls, '%s_ls.pkl'%saveto)
35 | torch.save(lim, '%s_lim.pkl'%saveto)
36 | pkl.dump(model_options, open('%s_params_dump.pkl'%saveto, 'wb'))
37 | torch.save(model.state_dict(), '%s_model_dump.pkl'%saveto)
38 | json.dump(test[0], open('%s_caps.json'%saveto, 'w'))
39 |
40 | print 'ls: ', ls.data.size()
41 | print 'lim: ', lim.data.size()
42 |
43 |
--------------------------------------------------------------------------------
/server.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import json
4 | import os
5 | import numpy as np
6 | import torch.nn
7 | from torch.autograd import Variable
8 | from model import ImgSenRanking
9 | from PIL import Image, ImageFile
10 | from flask import Flask, request, render_template, jsonify
11 | from tools import encode_sentences, encode_images
12 | from pre_transforms import image_transform, resnet
13 | import cPickle as pkl
14 | import torch
15 | # TODO: Defind text_transforms in pre_transforms.py
16 | import jieba.analyse
17 | jieba.analyse.set_stop_words('static/dataset/stopwords.txt')
18 |
19 | app = Flask(__name__)
20 |
21 |
22 | ImageFile.LOAD_TRUNCATED_IMAGES = True
23 |
24 | UPLOAD_FOLDER = 'static/upload/'
25 | dump_path = 'vse/arch_server/'
26 |
27 | app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
28 |
29 | print 'loading image_dump.json'
30 | images_dump = torch.load(os.path.join(dump_path, 'arch_lim.pkl'))
31 | images_path = json.load(open(os.path.join(dump_path, 'arch_test_imgs_path.json')))
32 | images_url = json.load(open(os.path.join(dump_path, 'arch_test_imgs_url.json')))
33 |
34 | print 'loading text_dump.json'
35 | texts_dump = torch.load(os.path.join(dump_path, 'arch_ls.pkl'))
36 | texts_orig = json.load(open(os.path.join(dump_path, 'arch_caps.json')))
37 | texts_url = json.load(open(os.path.join(dump_path, 'arch_test_caps_url.json')))
38 |
39 | print 'loading jianzhu model'
40 | model_options = pkl.load(open(os.path.join(dump_path, 'arch_params_dump.pkl')))
41 | model = ImgSenRanking(model_options).cuda()
42 | model.load_state_dict(torch.load(os.path.join(dump_path, 'arch_model_dump.pkl')))
43 |
44 | curr_model = {}
45 | curr_model['options'] = model_options
46 | curr_model['worddict'] = model_options['worddict']
47 | curr_model['word_idict'] = model_options['word_idict']
48 | curr_model['img_sen_model'] = model
49 |
50 |
51 | @app.route('/')
52 | def index():
53 | return render_template('index.html')
54 |
55 | @app.route('/query', methods=['POST'])
56 | def query():
57 | query_sen = request.form.get('query_sentence', '')
58 | k_input = int(request.form.get('k_input', ''))
59 | query_img = request.files['query_image']
60 | img_name = query_img.filename
61 | upload_img = os.path.join(app.config['UPLOAD_FOLDER'], img_name)
62 | sim_images, sim_images_url = [], []
63 | sim_texts, sim_texts_url = [], []
64 | if img_name:
65 | query_img.save(upload_img)
66 | img_vec = image_transform(Image.open(upload_img).convert('RGB')).unsqueeze(0)
67 | image_emb = encode_images(curr_model, resnet(Variable(img_vec.cuda())).data.cpu().numpy())
68 | d = torch.mm(image_emb, texts_dump.t())
69 | d_sorted, inds = torch.sort(d, descending=True)
70 | inds = inds.data.squeeze(0).cpu().numpy()
71 | # sim_text_degree = 1-distance[0][:k_input]/distance[0][-1]
72 | sim_texts = np.array(texts_orig)[inds[:k_input]]
73 | sim_texts_url = np.array(texts_url)[inds[:k_input]]
74 | # sim_texts, sim_text_degree = sim_texts.tolist(), sim_text_degree.tolist()
75 | sim_texts, sim_texts_url = sim_texts.tolist(), sim_texts_url.tolist()
76 | if query_sen:
77 | query_sen = ' '.join(jieba.analyse.extract_tags(query_sen, topK=100, withWeight=False, allowPOS=()))
78 | query_sen = [query_sen.encode('utf8')]
79 | sentence = encode_sentences(curr_model, query_sen)
80 | d = torch.mm(sentence, images_dump.t())
81 | d_sorted, inds = torch.sort(d, descending=True)
82 | inds = inds.data.squeeze(0).cpu().numpy()
83 | # sim_image_degree = 1-distance[0][:k_input]/distance[0][-1]
84 | sim_images = np.array(images_path)[inds[:k_input]]
85 | sim_images_url = np.array(images_url)[inds[:k_input]]
86 | # sim_images, sim_image_degree = sim_images.tolist(), sim_image_degree.tolist()
87 | sim_images, sim_images_url = sim_images.tolist(), sim_images_url.tolist()
88 |
89 | upload_img = upload_img if img_name else 'no_upload_img'
90 | return jsonify(sim_images=sim_images, sim_images_url=sim_images_url,
91 | upload_img=upload_img, sim_texts=sim_texts, sim_texts_url=sim_texts_url)
92 |
93 |
94 | if __name__ == "__main__":
95 | app.run(host='0.0.0.0', port=2333)
96 |
--------------------------------------------------------------------------------
/static/dataset/arch/annotations/dump_data_pair.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import json
4 | import random
5 |
6 | dev_num = 10000
7 |
8 | data_pair = []
9 |
10 | # Keep pace with coco dataset in
11 | # - json field : caption, img_path, obj_id
12 | # - dataset split : train/val=2/1
13 |
14 | # Feel free to choose title or detail as caption
15 |
16 | with open('jianzhu_tag.json') as f_read:
17 | source_data = map(json.loads, f_read.readlines())
18 |
19 | for k, line in enumerate(source_data):
20 | if k % 1000 == 0:
21 | print 'Processing %d / %d' %(k, len(source_data))
22 | detail = line['detail'].strip()
23 | title, tag = line['title'].strip(), line['tag'].strip()
24 | title = title if title != '' else tag
25 | url = line['other']['url'] if line['other'].has_key('url') else ''
26 | poster = line['poster'].replace('/data/crawler/', '')
27 | if poster == '':
28 | continue
29 | poster = os.path.join('../../', poster)
30 |
31 | if os.path.exists(poster):
32 | for img_name in os.listdir(poster):
33 | poster = poster.replace('../../', 'static/dataset/')
34 | img_path = os.path.join(poster, img_name)
35 | data_pair.append({'caption': detail, 'img_path': img_path,
36 | 'obj_id': k, 'url': url})
37 |
38 | random.shuffle(data_pair)
39 | print 'Remaining data_pair: ', len(data_pair)
40 |
41 | json.dump(data_pair[:dev_num], open('data_pair_val.json', 'w'))
42 | json.dump(data_pair[dev_num:], open('data_pair_train.json', 'w'))
43 |
--------------------------------------------------------------------------------
/static/web/index.css:
--------------------------------------------------------------------------------
1 | div.header {
2 | text-align: center;
3 | }
4 |
5 | @media (min-width: 768px) {
6 | .container {
7 | max-width: 100000px;
8 | }
9 | }
10 |
11 | div.query_input {
12 | overflow: hidden;
13 | margin-bottom: 50px;
14 | }
15 |
16 |
17 | div.left {
18 | float: left;
19 | width: 50%;
20 | padding-right: 80px;
21 | }
22 |
23 | div.left textarea#query_sentence {
24 | float: right;
25 | }
26 |
27 | div.right {
28 | float: right;
29 | width: 50%;
30 | padding-left: 80px;
31 | }
32 |
33 | #query_sentence {
34 | padding: 20px;
35 | }
36 |
37 | img {
38 | height: 200px;
39 | margin: 10px;
40 | }
41 |
42 | div.query_other {
43 | text-align: center;
44 | }
45 |
46 | div.query_other input {
47 | width: 80px;
48 | padding: 3px;
49 | margin-right: 20px;
50 | }
51 |
52 | div.query_other button {
53 | margin-right: 20px;
54 | }
55 |
56 | form {
57 | margin-bottom: 50px;
58 | }
59 |
60 | div#img_result {
61 | margin-bottom: 50px;
62 | }
63 |
64 | div#text_result {
65 | margin-bottom: 50px;
66 | }
67 |
68 | div#text_result div {
69 | margin: 10px;
70 | overflow: hidden;
71 | }
72 |
73 | div#text_result div a {
74 | float: left;
75 | margin: 10px;
76 | }
77 |
78 |
--------------------------------------------------------------------------------
/static/web/index.js:
--------------------------------------------------------------------------------
1 | /**
2 | * Created by lindayong on 17-5-9.
3 | */
4 |
5 | $(function () {
6 | $('#query_button').click(query);
7 | $('#reset_button').click(function () {
8 | $('form')[0].reset();
9 | $('#upload_img').remove();
10 | });
11 |
12 | function query() {
13 | // var formData = new FormData(document.querySelector('form'));
14 | var formData = new FormData($('form')[0]);
15 | var query_button = $('#query_button');
16 | query_button.attr('disabled', 'true');
17 | query_button.text('正在查询中,请稍后 ...');
18 | $.ajax({
19 | url: '/query',
20 | data: formData,
21 | type: 'POST',
22 | processData: false,
23 | contentType: false,
24 | success: function (response) {
25 | $('#img_result').empty();
26 | $('#upload_img').remove();
27 | $('#text_result').empty();
28 | query_button.removeAttr('disabled');
29 | query_button.text('查询');
30 | console.log(response);
31 | var imgs_dir = response['sim_images'];
32 | var upload_img = response['upload_img'];
33 | var texts = response['sim_texts'];
34 | var imgs_url = response['sim_images_url'];
35 | var texts_url = response['sim_texts_url'];
36 | if (upload_img !== 'no_upload_img') {
37 | upload_img = '';
38 | $('#query_image').after(upload_img)
39 | }
40 | for (num in imgs_dir) {
41 | var img_dir = '../'+imgs_dir[num].replace(/%/g, '%25');
42 | var img_html = '
';
43 | var img_url = '' + img_html + '';
44 | $('#img_result').append(img_url);
45 | }
46 | for (num in texts) {
47 | var text_html = '