├── .gitignore ├── LICENSE ├── README.md ├── code ├── eda.py ├── mapping.json ├── mapping.py ├── model_triplet.py ├── predict_model_triplet.py ├── search.py ├── search_by_image.py ├── search_by_keywords.py └── tsne_visualization.py ├── output ├── Figure_1.png └── Figure_2.png └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 CVxTz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Building a Deep Image Search Engine using tf.Keras 2 | 3 | ### Motivation : 4 | 5 | Imagine having a data collection of hundreds of thousands to millions of images 6 | without any metadata describing the content of each image. How can we build a 7 | system that is able to find a sub-set of those images that best answer a user’s 8 | search query ?
What we will basically need is a search engine that is able 9 | to rank image results given how well they correspond to the search query, which 10 | can be either expressed in a natural language or by another query image.
The 11 | way we will solve the problem in this post is by training a deep neural model 12 | that learns a fixed length representation (or embedding) of any input image and 13 | text and makes it so those representations are close in the euclidean space if 14 | the pairs text-image or image-image are “similar”. 15 | 16 | ### Data set : 17 | 18 | I could not find a data-set of search result ranking that is big enough but I 19 | was able to get this data-set : 20 | [http://jmcauley.ucsd.edu/data/amazon/](http://jmcauley.ucsd.edu/data/amazon/) 21 | which links E-commerce item images to their title and description. We will use 22 | this metadata as the supervision source to learn meaningful joined text-image 23 | representations. The experiments were limited to fashion (Clothing, Shoes and 24 | Jewelry) items and to 500,000 images in order to manage the computations and 25 | storage costs. 26 | 27 | ### Problem setting : 28 | 29 | The data-set we have links each image with a description written in natural 30 | language. So we define a task in which we want to learn a joined, fixed length 31 | representation for images and text so that each image representation is close to 32 | the representation of its description. 33 | 34 | ![](https://cdn-images-1.medium.com/max/800/1*_sP4W46aev9txzY4O3z9OA.png) 35 | 36 | ### Model : 37 | 38 | The model takes 3 inputs : The image (which is the anchor), the image 39 | title+description ( the positive example) and the third input is some randomly 40 | sampled text (the negative example).
Then we define two sub-models : 41 | 42 | * Image encoder : Resnet50 pre-trained on ImageNet+GlobalMaxpooling2D 43 | * Text encoder : GRU+GlobalMaxpooling1D 44 | 45 | The image sub-model produces the embedding for the Anchor **E_a **and the text 46 | sub-model outputs the embedding for the positive title+description **E_p** and 47 | the embedding for the negative text **E_n**. 48 | 49 | We then train by optimizing the following triplet loss: 50 | 51 | **L = max( d(E_a, E_p)-d(E_a, E_n)+alpha, 0)** 52 | 53 | Where d is the euclidean distance and alpha is a hyper parameter equal to 0.4 in 54 | this experiment. 55 | 56 | Basically what this loss allows to do is to make **d(E_a, E_p) **small and 57 | make** d(E_a, E_n) **large, so that each image embedding is close to the 58 | embedding of its description and far from the embedding of random text. 59 | 60 | ### Visualization Results : 61 | 62 | Once we learned the image embedding model and text embedding model we can 63 | visualize them by projecting them into two dimensions using tsne 64 | ([https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html](https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html) 65 | ). 66 | 67 | ![](https://cdn-images-1.medium.com/max/1200/1*8BU-K6uCnLCAgGu8ft64Hw.png) 68 | Test Images and their corresponding text description are linked by green lines 69 | 70 | We can see from the plot that generally, in the embedding space, images and 71 | their corresponding descriptions are close. Which is what we would expect given 72 | the training loss that was used. 73 | 74 | ### Text-image Search : 75 | 76 | Here we use few examples of text queries to search for the best matches in a set 77 | of 70,000 images. We compute the text embedding for the query and then the 78 | embedding for each image in the collection. We finally select the top 9 images 79 | which are the closest to the query in the embedding space. 80 | 81 | ![](https://cdn-images-1.medium.com/max/800/1*8LjufL4G3ekhtUfng9ww5w.png) 82 | 83 | ![](https://cdn-images-1.medium.com/max/800/1*FdzSeeHw6exPkyONJFczYg.png) 84 | 85 | These examples show that the embedding models are able to learn useful 86 | representations of images and embeddings of simple composition of words. 87 | 88 | ### Image-Image Search : 89 | 90 | Here we will use an image as a query and then search in the database of 70,000 91 | images for the examples that are most similar to it. The ranking is determined 92 | by how close each pair of images are in the embedding space using the euclidean 93 | distance. 94 | 95 | ![](https://cdn-images-1.medium.com/max/800/1*uIXdCz04c9gg86kkj71FzQ.png) 96 | 97 | ![](https://cdn-images-1.medium.com/max/800/1*fV5UIU79UiJr3xMd_nHJBg.png) 98 | 99 | The results illustrate that the embeddings generated are high level 100 | representations of images that capture the most important characteristics of the 101 | objects represented without being excessively influenced by the orientation, 102 | lighting or minor local details, without being trained explicitly to do so. 103 | 104 | ### Conclusion : 105 | 106 | In this project we worked on the Machine learning blocks that allow us to build 107 | a keyword and image based search engine applied to a collection of images. The 108 | basic idea is to learn a meaningful and joined embedding function for text and 109 | image and then use the distance between items in the embedding space to rank 110 | search results. 111 | 112 | **References :** 113 | 114 | * [Large Scale Online Learning of Image Similarity Through 115 | Ranking](http://www.jmlr.org/papers/volume11/chechik10a/chechik10a.pdf) 116 | * [Ups and downs: Modeling the visual evolution of fashion trends with one-class 117 | collaborative filtering](https://cseweb.ucsd.edu/~jmcauley/pdfs/www16a.pdf) 118 | * [https://github.com/KinWaiCheuk/Triplet-net-keras/blob/master/Triplet%20NN%20Test%20on%20MNIST.ipynb](https://github.com/KinWaiCheuk/Triplet-net-keras/blob/master/Triplet%20NN%20Test%20on%20MNIST.ipynb) -------------------------------------------------------------------------------- /code/eda.py: -------------------------------------------------------------------------------- 1 | import json 2 | from random import shuffle 3 | import pandas as pd 4 | import numpy as np 5 | from matplotlib import pyplot 6 | 7 | 8 | out_name = "../output/test_representations.json" 9 | 10 | data = pd.read_json(out_name) 11 | 12 | data = data.sample(n=10000) 13 | 14 | img_repr = data['image_repr'].tolist() 15 | img_repr_random = data['image_repr'].tolist() 16 | shuffle(img_repr_random) 17 | text_repr = data['text_repr'].tolist() 18 | 19 | target_distances = [] 20 | random_distances = [] 21 | 22 | for img, random_image, text in zip(img_repr, img_repr_random, text_repr): 23 | d_1 = np.linalg.norm(np.array(img)-np.array(text)) 24 | d_2 = np.linalg.norm(np.array(random_image)-np.array(text)) 25 | 26 | target_distances.append(d_1) 27 | random_distances.append(d_2) 28 | 29 | 30 | pyplot.hist(target_distances, bins=100, alpha=0.5, label='matched text') 31 | pyplot.hist(random_distances, bins=100, alpha=0.5, label='random text') 32 | pyplot.legend(loc='upper right') 33 | pyplot.show() -------------------------------------------------------------------------------- /code/mapping.py: -------------------------------------------------------------------------------- 1 | import json 2 | from nltk.tokenize import word_tokenize 3 | from collections import defaultdict 4 | import re 5 | 6 | 7 | def tokenize(x): 8 | x = re.sub('([\\\'".!?,-/])', r' \1 ', x) 9 | x = re.sub('(\d+)', r' \1 ', x) 10 | 11 | x = word_tokenize(x.lower()) 12 | 13 | return x 14 | 15 | 16 | UNK_TOKEN = "unk" 17 | BATCH_SIZE = 16 18 | 19 | 20 | def get_frequency_token_vocab(list_tokenized_sentences, vocab=defaultdict(int)): 21 | for sentence in list_tokenized_sentences: 22 | for token in sentence: 23 | vocab[token] += 1 24 | 25 | vocab[UNK_TOKEN] = 10000 26 | return vocab 27 | 28 | 29 | def get_mapping_dict(vocab, cutoff=10): 30 | i = 1 31 | word_freq = [(k, v) for k, v in vocab.items()] 32 | word_freq.sort(key = lambda x: -x[1]) 33 | mapping = {} 34 | for token, freq in word_freq: 35 | if vocab[token] >= cutoff: 36 | mapping[token] = i 37 | i += 1 38 | return mapping 39 | 40 | 41 | train = json.load(open("../input/filtred_train_data.json", 'r')) 42 | val = json.load(open("../input/filtred_val_data.json", 'r')) 43 | 44 | list_images_train, captions_train = list(zip(*train)) 45 | 46 | list_images_val, captions_val = list(zip(*val)) 47 | 48 | list_tok = [tokenize(x) for x in captions_train+captions_val] 49 | 50 | vocab = get_frequency_token_vocab(list_tok) 51 | mapping = get_mapping_dict(vocab) 52 | 53 | json.dump(mapping, open('mapping.json', 'w'), indent=4) -------------------------------------------------------------------------------- /code/model_triplet.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from random import choice, sample 4 | 5 | import cv2 6 | import numpy as np 7 | import tensorflow.keras.backend as K 8 | from nltk.tokenize import word_tokenize 9 | from tensorflow.keras.applications import ResNet50 10 | from tensorflow.keras.applications.imagenet_utils import preprocess_input 11 | from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau 12 | from tensorflow.keras.layers import Input, GlobalMaxPool2D, GlobalMaxPool1D, Dense, Embedding, GRU, \ 13 | Bidirectional, Concatenate, Lambda, SpatialDropout1D 14 | from tensorflow.keras.models import Model 15 | from tensorflow.keras.optimizers import Adam 16 | 17 | UNK_TOKEN = "unk" 18 | img_shape = (222, 171, 3) 19 | vec_dim = 50 20 | BATCH_SIZE = 32 21 | 22 | 23 | def tokenize(x): 24 | x = re.sub('([\\\'".!?,-/])', r' \1 ', x) 25 | x = re.sub('(\d+)', r' \1 ', x) 26 | 27 | x = word_tokenize(x.lower()) 28 | 29 | return x 30 | 31 | 32 | def triplet_loss(y_true, y_pred, alpha=0.4): 33 | """ 34 | https://github.com/KinWaiCheuk/Triplet-net-keras/blob/master/Triplet%20NN%20Test%20on%20MNIST.ipynb 35 | Implementation of the triplet loss function 36 | Arguments: 37 | y_true -- true labels, required when you define a loss in Keras, you don't need it in this function. 38 | y_pred -- python list containing three objects: 39 | anchor -- the encodings for the anchor data 40 | positive -- the encodings for the positive data (similar to anchor) 41 | negative -- the encodings for the negative data (different from anchor) 42 | Returns: 43 | loss -- real number, value of the loss 44 | """ 45 | 46 | total_lenght = y_pred.shape.as_list()[-1] 47 | anchor = y_pred[:, 0:int(total_lenght * 1 / 3)] 48 | positive = y_pred[:, int(total_lenght * 1 / 3):int(total_lenght * 2 / 3)] 49 | negative = y_pred[:, int(total_lenght * 2 / 3):int(total_lenght * 3 / 3)] 50 | 51 | # distance between the anchor and the positive 52 | pos_dist = K.sum(K.square(anchor - positive), axis=1) 53 | 54 | # distance between the anchor and the negative 55 | neg_dist = K.sum(K.square(anchor - negative), axis=1) 56 | 57 | # compute loss 58 | basic_loss = pos_dist - neg_dist + alpha 59 | loss = K.maximum(basic_loss, 0.0) 60 | 61 | return loss 62 | 63 | 64 | def map_sentence(tokenized_sentence, mapping): 65 | out_sentence = list(map(lambda x: mapping[x if x in mapping else UNK_TOKEN], tokenized_sentence)) 66 | return out_sentence 67 | 68 | 69 | def map_sentences(list_tokenized_sentences, mapping): 70 | mapped = [] 71 | for sentence in list_tokenized_sentences: 72 | out_sentence = map_sentence(sentence, mapping) 73 | mapped.append(out_sentence) 74 | 75 | return mapped 76 | 77 | 78 | def cap_sequence(seq, max_len, append): 79 | if len(seq) < max_len: 80 | if np.random.uniform(0, 1) < 0.5: 81 | return seq + [append] * (max_len - len(seq)) 82 | else: 83 | return [append] * (max_len - len(seq)) + seq 84 | else: 85 | if np.random.uniform(0, 1) < 0.5: 86 | 87 | return seq[:max_len] 88 | else: 89 | return seq[-max_len:] 90 | 91 | 92 | def cap_sequences(list_sequences, max_len, append): 93 | capped = [] 94 | for seq in list_sequences: 95 | out_seq = cap_sequence(seq, max_len, append) 96 | capped.append(out_seq) 97 | 98 | return capped 99 | 100 | 101 | def read_img(path, preprocess=True): 102 | img = cv2.imread(path) 103 | if img is None or img.size<10: 104 | img = np.zeros((222, 171)) 105 | img = cv2.resize(img, (171, 222)) 106 | if preprocess: 107 | img = preprocess_input(img) 108 | return img 109 | 110 | 111 | def gen(list_images, list_captions, batch_size=16, aug=False): 112 | indexes = list(range(len(list_images))) 113 | while True: 114 | batch_indexes = sample(indexes, batch_size) 115 | 116 | candidate_images = [list_images[i] for i in batch_indexes] 117 | captions_p = [list_captions[i] for i in batch_indexes] 118 | 119 | captions_n = [choice(list_captions) for _ in batch_indexes] 120 | 121 | X1 = np.array([read_img(x) for x in candidate_images]) 122 | if aug and np.random.uniform(0, 1)<0.5: 123 | X1 = X1[:, :, ::-1, :] 124 | X2 = np.array(captions_p) 125 | X3 = np.array(captions_n) 126 | 127 | yield [X1, X2, X3], np.zeros((batch_size, 3 * vec_dim)) 128 | 129 | 130 | def model(vocab_size, lr=0.0001): 131 | input_1 = Input(shape=(None, None, 3)) 132 | input_2 = Input(shape=(None,)) 133 | input_3 = Input(shape=(None,)) 134 | 135 | base_model = ResNet50(weights='imagenet', include_top=False) 136 | 137 | x1 = base_model(input_1) 138 | x1 = GlobalMaxPool2D()(x1) 139 | 140 | dense_1 = Dense(vec_dim, activation="linear", name="dense_image_1") 141 | 142 | x1 = dense_1(x1) 143 | 144 | embed = Embedding(vocab_size, 50, name="embed") 145 | 146 | gru = Bidirectional(GRU(256, return_sequences=True), name="gru_1") 147 | dense_2 = Dense(vec_dim, activation="linear", name="dense_text_1") 148 | 149 | x2 = embed(input_2) 150 | x2 = SpatialDropout1D(0.1)(x2) 151 | x2 = gru(x2) 152 | x2 = GlobalMaxPool1D()(x2) 153 | x2 = dense_2(x2) 154 | 155 | x3 = embed(input_3) 156 | x3 = SpatialDropout1D(0.1)(x3) 157 | x3 = gru(x3) 158 | x3 = GlobalMaxPool1D()(x3) 159 | x3 = dense_2(x3) 160 | 161 | _norm = Lambda(lambda x: K.l2_normalize(x, axis=-1)) 162 | 163 | x1 = _norm(x1) 164 | x2 = _norm(x2) 165 | x3 = _norm(x3) 166 | 167 | x = Concatenate(axis=-1)([x1, x2, x3]) 168 | 169 | model = Model([input_1, input_2, input_3], x) 170 | 171 | model.compile(loss=triplet_loss, optimizer=Adam(lr)) 172 | 173 | model.summary() 174 | 175 | return model 176 | 177 | 178 | def image_model(lr=0.0001): 179 | input_1 = Input(shape=(None, None, 3)) 180 | 181 | base_model = ResNet50(weights='imagenet', include_top=False) 182 | 183 | x1 = base_model(input_1) 184 | x1 = GlobalMaxPool2D()(x1) 185 | 186 | dense_1 = Dense(vec_dim, activation="linear", name="dense_image_1") 187 | 188 | x1 = dense_1(x1) 189 | 190 | _norm = Lambda(lambda x: K.l2_normalize(x, axis=-1)) 191 | 192 | x1 = _norm(x1) 193 | 194 | model = Model([input_1], x1) 195 | 196 | model.compile(loss="mae", optimizer=Adam(lr)) 197 | 198 | model.summary() 199 | 200 | return model 201 | 202 | 203 | def text_model(vocab_size, lr=0.0001): 204 | input_2 = Input(shape=(None,)) 205 | 206 | embed = Embedding(vocab_size, 50, name="embed") 207 | gru = Bidirectional(GRU(256, return_sequences=True), name="gru_1") 208 | dense_2 = Dense(vec_dim, activation="linear", name="dense_text_1") 209 | 210 | x2 = embed(input_2) 211 | x2 = gru(x2) 212 | x2 = GlobalMaxPool1D()(x2) 213 | x2 = dense_2(x2) 214 | 215 | _norm = Lambda(lambda x: K.l2_normalize(x, axis=-1)) 216 | 217 | x2 = _norm(x2) 218 | 219 | model = Model([input_2], x2) 220 | 221 | model.compile(loss="mae", optimizer=Adam(lr)) 222 | 223 | model.summary() 224 | 225 | return model 226 | 227 | 228 | if __name__ == "__main__": 229 | 230 | mapping = json.load(open('mapping.json', 'r')) 231 | 232 | train = json.load(open("../input/filtred_train_data.json", 'r')) 233 | val = json.load(open("../input/filtred_val_data.json", 'r')) 234 | 235 | list_images_train, captions_train = list(zip(*train)) 236 | captions_train = [tokenize(x) for x in captions_train] 237 | captions_train = map_sentences(captions_train, mapping) 238 | captions_train = cap_sequences(captions_train, 70, 0) 239 | 240 | list_images_val, captions_val = list(zip(*val)) 241 | captions_val = [tokenize(x) for x in captions_val] 242 | 243 | captions_val = map_sentences(captions_val, mapping) 244 | captions_val = cap_sequences(captions_val, 70, 0) 245 | 246 | file_path = "model_triplet.h5" 247 | 248 | model = model(vocab_size=len(mapping) + 1, lr=0.00001) 249 | 250 | try: 251 | model.load_weights(file_path, by_name=True) 252 | except: 253 | pass 254 | 255 | checkpoint = ModelCheckpoint(file_path, monitor='val_loss', verbose=1, save_best_only=True, mode='min') 256 | reduce = ReduceLROnPlateau(monitor="val_loss", mode='min', patience=10, min_lr=1e-7) 257 | 258 | model.fit_generator(gen(list_images_train, captions_train, batch_size=BATCH_SIZE, aug=True), 259 | use_multiprocessing=True, 260 | validation_data=gen(list_images_val, captions_val, batch_size=BATCH_SIZE), epochs=10000, 261 | verbose=1, workers=4, steps_per_epoch=300, validation_steps=100, callbacks=[checkpoint, reduce]) 262 | model.save_weights(file_path) 263 | -------------------------------------------------------------------------------- /code/predict_model_triplet.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from model_triplet import tokenize, map_sentences, cap_sequences, read_img, image_model, text_model 7 | from tqdm import tqdm 8 | 9 | def chunker(seq, size): 10 | return (seq[pos:pos + size] for pos in range(0, len(seq), size)) 11 | 12 | 13 | if __name__ == "__main__": 14 | 15 | out_name = "../output/test_representations.json" 16 | 17 | mapping = json.load(open('mapping.json', 'r')) 18 | 19 | test = json.load(open("../input/filtred_test_data.json", 'r')) 20 | 21 | 22 | file_path = "model_triplet.h5" 23 | 24 | t_model = text_model(vocab_size=len(mapping) + 1) 25 | i_model = image_model() 26 | 27 | t_model.load_weights(file_path, by_name=True) 28 | i_model.load_weights(file_path, by_name=True) 29 | 30 | list_images_test, _captions_test = list(zip(*test)) 31 | captions_test = [tokenize(x) for x in _captions_test] 32 | captions_test = map_sentences(captions_test, mapping) 33 | captions_test = cap_sequences(captions_test, 70, 0) 34 | 35 | target_image_encoding = [] 36 | 37 | for img_paths in tqdm(chunker(list_images_test, 128), total=len(list_images_test)//128): 38 | images = np.array([read_img(file_path) for file_path in img_paths]) 39 | e = i_model.predict(images) 40 | target_image_encoding += e.tolist() 41 | 42 | target_text_encoding = t_model.predict(np.array(captions_test), verbose=1, batch_size=128) 43 | 44 | target_text_encoding = target_text_encoding.tolist() 45 | 46 | df = pd.DataFrame({"images": list_images_test, "text": _captions_test, "image_repr": target_image_encoding, 47 | "text_repr": target_text_encoding}) 48 | 49 | df.to_json(out_name, orient='records') 50 | 51 | data = json.load(open(out_name, 'r')) 52 | json.dump(data, open(out_name, 'w'), indent=4) 53 | 54 | # New queries 55 | 56 | out_name = "../output/queries_representations.json" 57 | 58 | _captions_test = ['blue tshirt', 'blue shirt', 'red dress', 'halloween outfit', 'baggy pants', 'gold ring', 59 | 'Black trousers', 'animal pendant', 'black white tshirt', 'women red hoodie', 'animal outfit', 60 | 'men flipflops'] 61 | 62 | captions_test = [tokenize(x) for x in _captions_test] 63 | captions_test = map_sentences(captions_test, mapping) 64 | captions_test = cap_sequences(captions_test, 70, 0) 65 | 66 | target_text_encoding = t_model.predict(np.array(captions_test), verbose=2, batch_size=128) 67 | 68 | target_text_encoding = target_text_encoding.tolist() 69 | 70 | df = pd.DataFrame({"text": _captions_test, 71 | "text_repr": target_text_encoding}) 72 | 73 | df.to_json(out_name, orient='records') 74 | 75 | data = json.load(open(out_name, 'r')) 76 | json.dump(data, open(out_name, 'w'), indent=4) 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /code/search.py: -------------------------------------------------------------------------------- 1 | import json 2 | from random import shuffle 3 | import pandas as pd 4 | import numpy as np 5 | from matplotlib import pyplot 6 | from sklearn.neighbors import NearestNeighbors 7 | 8 | repr_json = "../output/test_representations.json" 9 | 10 | data = pd.read_json(repr_json) 11 | 12 | data = data.sample(n=1000) 13 | 14 | img_repr = data['image_repr'].tolist() 15 | text_repr = data['text_repr'].tolist() 16 | 17 | nn = NearestNeighbors(n_jobs=-1, n_neighbors=1000) 18 | 19 | nn.fit(text_repr) 20 | 21 | preds = nn.kneighbors(img_repr, return_distance=False).tolist() 22 | ranks = [] 23 | 24 | for i, x in enumerate(preds): 25 | rank = x.index(i)+1 26 | ranks.append(rank) 27 | 28 | print("Average rank :", np.mean(ranks)) -------------------------------------------------------------------------------- /code/search_by_image.py: -------------------------------------------------------------------------------- 1 | import json 2 | from random import shuffle 3 | import pandas as pd 4 | import numpy as np 5 | from matplotlib import pyplot 6 | from sklearn.neighbors import NearestNeighbors 7 | import matplotlib.pyplot as plt 8 | from model_triplet import read_img 9 | import cv2 10 | from uuid import uuid4 11 | 12 | 13 | repr_json = "../output/test_representations.json" 14 | 15 | data = pd.read_json(repr_json) 16 | 17 | #data = data.sample(n=50000) 18 | 19 | img_repr = data['image_repr'].tolist() 20 | img_paths = data['images'].tolist() 21 | text_repr = data['text_repr'].tolist() 22 | 23 | nn = NearestNeighbors(n_jobs=-1, n_neighbors=9) 24 | 25 | nn.fit(img_repr) 26 | 27 | preds = nn.kneighbors(img_repr[:100], return_distance=False).tolist() 28 | 29 | most_similar_images = [] 30 | query_image = [] 31 | 32 | 33 | for i, x in enumerate(preds): 34 | preds_paths = [img_paths[i] for i in x] 35 | query_image.append(preds_paths[0]) 36 | most_similar_images.append(preds_paths[1:]) 37 | 38 | for q, similar in zip(query_image, most_similar_images): 39 | fig, axes = plt.subplots(3, 3) 40 | all_images = [q]+similar 41 | 42 | for idx, img_path in enumerate(all_images): 43 | i = idx % 3 # Get subplot row 44 | j = idx // 3 # Get subplot column 45 | image = read_img(img_path, preprocess=False) 46 | image = image[:, :, ::-1] 47 | axes[i, j].imshow(image/255) 48 | axes[i, j].axis('off') 49 | axes[i, j].axis('off') 50 | if idx == 0: 51 | axes[i, j].set_title('Query Image') 52 | else: 53 | axes[i, j].set_title('Result Image %s'%idx) 54 | 55 | plt.subplots_adjust(wspace=0.2, hspace=0.2) 56 | plt.savefig('../output/images/%s.png'%uuid4().hex) 57 | 58 | -------------------------------------------------------------------------------- /code/search_by_keywords.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from sklearn.neighbors import NearestNeighbors 3 | import matplotlib.pyplot as plt 4 | from model_triplet import read_img 5 | from uuid import uuid4 6 | from uuid import uuid4 7 | 8 | import matplotlib.pyplot as plt 9 | import pandas as pd 10 | from sklearn.neighbors import NearestNeighbors 11 | 12 | from model_triplet import read_img 13 | 14 | repr_json = "../output/test_representations.json" 15 | 16 | data = pd.read_json(repr_json) 17 | 18 | queries_repr_json = "../output/queries_representations.json" 19 | 20 | queries_data = pd.read_json(queries_repr_json) 21 | 22 | #data = data.sample(n=50000) 23 | 24 | img_repr = data['image_repr'].tolist() 25 | img_paths = data['images'].tolist() 26 | text_repr = queries_data['text_repr'].tolist() 27 | 28 | nn = NearestNeighbors(n_jobs=-1, n_neighbors=9) 29 | 30 | nn.fit(img_repr) 31 | 32 | preds = nn.kneighbors(text_repr, return_distance=False).tolist() 33 | 34 | most_similar_images = [] 35 | query_image = [] 36 | 37 | 38 | for i, x in enumerate(preds): 39 | preds_paths = [img_paths[i] for i in x] 40 | most_similar_images.append(preds_paths) 41 | 42 | for q, all_images in zip(queries_data['text'], most_similar_images): 43 | fig, axes = plt.subplots(3, 3) 44 | 45 | for idx, img_path in enumerate(all_images): 46 | i = idx % 3 # Get subplot row 47 | j = idx // 3 # Get subplot column 48 | image = read_img(img_path, preprocess=False) 49 | image = image[:, :, ::-1] 50 | axes[i, j].imshow(image/255) 51 | axes[i, j].axis('off') 52 | axes[i, j].axis('off') 53 | axes[i, j].set_title('Result Image %s'%(idx+1)) 54 | 55 | 56 | plt.subplots_adjust(wspace=0.2, hspace=0.2) 57 | fig.suptitle('Query : %s'%q) 58 | 59 | plt.savefig('../output/queries/%s.png'%uuid4().hex) 60 | 61 | -------------------------------------------------------------------------------- /code/tsne_visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from matplotlib import pyplot as plt 4 | from sklearn import manifold 5 | 6 | out_name = "../output/test_representations.json" 7 | 8 | data = pd.read_json(out_name) 9 | 10 | n = 10000 11 | 12 | data = data.sample(n=n) 13 | 14 | img_repr = data['image_repr'].tolist() 15 | text_repr = data['text_repr'].tolist() 16 | 17 | tsne = manifold.TSNE(n_components=2, init='random', 18 | random_state=0) 19 | 20 | X = tsne.fit_transform(img_repr + text_repr).tolist() 21 | 22 | img_repr_2d = X[:n] 23 | text_repr_2d = X[n:] 24 | 25 | distances = [] 26 | for a, b in zip(img_repr_2d, text_repr_2d): 27 | distances.append(np.fabs(a[0] - b[0]) + np.fabs(a[1] - b[1])) 28 | 29 | quantile = np.quantile(distances, q=0.9) 30 | 31 | fig = plt.figure() 32 | plt.scatter([a[0] for a in img_repr_2d], [a[1] for a in img_repr_2d], c="r", s=4, alpha=0.5, label="Images") 33 | plt.scatter([a[0] for a in text_repr_2d], [a[1] for a in text_repr_2d], c="b", s=4, alpha=0.5, label="Texts") 34 | for a, b in zip(img_repr_2d, text_repr_2d): 35 | if np.fabs(a[0] - b[0]) + np.fabs(a[1] - b[1]) < quantile: 36 | plt.plot([a[0], b[0]], [a[1], b[1]], c="g", lw=0.2) 37 | plt.legend() 38 | plt.title('TSNE Visualization') 39 | plt.show() 40 | -------------------------------------------------------------------------------- /output/Figure_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVxTz/image_search_engine/30314e92a9e53bced79e1c24da5f8d0210521dae/output/Figure_1.png -------------------------------------------------------------------------------- /output/Figure_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVxTz/image_search_engine/30314e92a9e53bced79e1c24da5f8d0210521dae/output/Figure_2.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas==0.25.2 2 | matplotlib==3.1.1 3 | numpy==1.17.3 4 | tensorflow_gpu==2.0.1 5 | nltk==3.4.5 6 | tqdm==4.36.1 7 | opencv_python==4.1.1.26 8 | Pillow==6.2.1 9 | scikit_learn==0.21.3 10 | tensorflow==2.0.1 11 | --------------------------------------------------------------------------------