├── .gitignore
├── CODEOWNERS
├── LICENSE
├── README.md
├── figures
├── eval_heatmap.png
└── lossplot_5.png
├── fine-tuning
├── clip_dataset.py
├── evaluate.py
├── plot_losses.py
├── train.py
├── train_configs
│ ├── run1.cfg
│ ├── run2.cfg
│ ├── run3.cfg
│ ├── run4.cfg
│ └── run5.cfg
└── vectorize_images.py
├── image-search
├── app.py
├── image_to_image.py
├── load-vespa-index.py
├── search-vespa-index.py
├── start-streamlit.sh
├── streamlit.nginx
├── streamlit.service
├── text_to_image.py
├── utils.py
├── vespa.service
└── vespa
│ └── src
│ └── main
│ └── application
│ ├── hosts.xml
│ ├── schemas
│ └── image.sd
│ ├── search
│ └── query-profiles
│ │ ├── default.xml
│ │ └── types
│ │ └── root.xml
│ └── services.xml
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # These owners will be the default owners for everything in
2 | # the repo. Unless a later match takes precedence.
3 | * @sujitpal
4 |
5 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Elsevier Labs Open Source Repository
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # clip-image-search
2 |
3 | Fine-tuning OpenAI CLIP Model for Image Search on medical images.
4 |
5 | _**PLEASE NOTE**: The code in this repository is being provided as-is without any warranty of any kind. While efforts have been made to ensure that the instructions are accurate, some of the underlying software libraries move on a fast update cycle and may have changed, and as a result the code may neither be installable or executable. It is being provided for study purposes only._
6 |
7 | * [Motivation](#motivation)
8 | * [Applications](#applications)
9 | * [References / Previous Work](#references--previous-work)
10 | * [Fine Tuning](#fine-tuning)
11 | * [Environment](#environment)
12 | * [Data Preparation](#data-preparation)
13 | * [Training Hyperparameters](#training-hyperparameters)
14 | * [Evaluation](#evaluation)
15 | * [Image Search](#image-search)
16 | * [Environment](#environment-1)
17 | * [Vespa](#vespa)
18 | * [Streamlit](#streamlit)
19 |
20 |
21 | ## Motivation
22 |
23 | * Model based image search (i.e. using machine learning based models trained on image similarity rather than traditional Lucene based search on captions)
24 | * Unsupervised or self supervised, because
25 | * labeling is expensive
26 | * we have lots of captioned images in our collection, so many (image, caption) pairs available "in the wild"
27 | * This project is **not** about caption prediction, rather it exploring the feasibility of text-to-image image search, whereby the user enters a text string to bring up the most appropriate images for the text string.
28 |
29 | ## Applications
30 |
31 | * Image search for internal users to get editorial efficiencies
32 | * Image search for customers
33 | * Captioning of new images, concepts in new images
34 | * Image decomposition and caption assignment - an enabling technology for machine learning
35 |
36 | ## References / Previous Work
37 |
38 | * [Contrastive Learning of Medical Visual Representations from Paired Images and Text](https://arxiv.org/abs/2010.00747) (Zhang et al, 2020)
39 | * learns visual representation of medical images and associated text using contrastive learning
40 | * uses different encoders for different specialties, eg, chest image encoder, bone image encoder
41 | * [Learning Transferable Visual Models from Natural Language Supervision](https://arxiv.org/abs/2103.00020) (Radford et al, 2021)
42 | * led to the OpenAI CLIP model on which this project is based
43 | * [CLIP: Connecting Text and Images](https://openai.com/blog/clip/) -- blog post provides high level intro
44 | * [CLIP Hugging Face Model](https://huggingface.co/transformers/model_doc/clip.html)
45 | * [CLIP-rsicd project](https://github.com/arampacha/CLIP-rsicd)
46 | * same idea applied to satellite images
47 | * done with external team as part of Hugging Face Flax/JAX community week
48 |
49 | ## Fine Tuning
50 |
51 | ### Environment
52 |
53 | * AWS EC2 p2.xlarge
54 | * 4 CPUs
55 | * 1 GPU (Tesla K80)
56 | * 64 GB RAM
57 | * 300 GB disk
58 | * Ubuntu 18.04 (bionic)
59 | * [AWS Deep Learning AMI (Ubuntu 18.04)](https://aws.amazon.com/marketplace/pp/prodview-x5nivojpquy6y)
60 | * Conda Environment: pytorch_latest_p37
61 | * Python 3.7.10
62 | * Pytorch 1.8.1+cu111
63 | * Additional packages
64 | * Transformers 4.10.0
65 | * see [requirements.txt](requirements.txt)
66 |
67 | ### Data Preparation
68 |
69 | * Data for task originally sourced from the [ImageCLEF 2017 Caption Prediction task](https://www.imageclef.org/2017/caption).
70 | * Once downloaded, the dataset is exploded and results in the folder structure as shown below. Essentially, three subfolders, one each for `training`, `test` and `validation` data. Since the dataset is for caption prediction, the `test` folder does not contain any captions. The `train` and `validation` folders contains a CSV file of image file names and captions, and all the three folders contain a nested sub-folder containing the images corresponding to each split.
71 |
72 | ```
73 | $ mkdir -p ImageCLEF2017-CaptionPrediction; cd ImageCLEF2017-CaptionPrediction
74 | $ tree .
75 | |
76 | *-- test/
77 | | |
78 | | +-- CaptionPredictionTesting2017-List.txt
79 | | +-- CaptionPredictionTesting2017/
80 | | |
81 | | +-- test-image-1.jpg
82 | | +-- ... 9,999 more ...
83 | |
84 | +-- training/
85 | | |
86 | | +-- CaptionPredictionTraining2017-Captions.csv
87 | | +-- CaptionPredictionTraining2017/
88 | | |
89 | | +-- training-image-1.jpg
90 | | +-- ... 164,613 more ...
91 | |
92 | +-- validation/
93 | |
94 | +-- CaptionPredictionValidation2017-Captions.csv
95 | +-- ConceptDetectionValidation2017/
96 | |
97 | +-- validation-image-1.jpg
98 | +-- ... 9,999 more ...
99 | ```
100 |
101 | * Dataset contains following splits
102 | * training: 164,614 images + captions
103 | * validtion: 10,000 images + captions
104 | * test: 10,000 images (no captions)
105 | * We need 3 splits -- training, validation, and test.
106 | * We cannot use test for training or evaluation, so we discard it
107 | * validation data becomes test data
108 | * training data is split 90:10 into new training and validation
109 | * End count is as follows:
110 | * training: 148,153 images + captions
111 | * validation: 16,461 images + captions
112 | * test: 10,000 images + captions
113 |
114 | ### Training Hyperparameters
115 |
116 | * Best model hyperparameters
117 | * batch size: 64
118 | * optimizer: ADAM
119 | * learning rate: 5e-5
120 | * number of epochs: 10
121 | * number of training samples: 50,000
122 | * We note that loss continues to drop so it is likely that further training or with larger amounts of data will increase performance. However, the flattening of the validation curve shows that we are in an area of diminishing returns.
123 | *
124 | * We considered doing image and text augmentation but dropped the idea since training set size is quite large (148k+ images+captions) and we achieve regularization through random sampling a subset of this dataset.
125 | * Table showing evaluation results tests 5 different hyperparameter configurations.
126 |
127 |
128 | ### Evaluation
129 |
130 | * We feed in batches of (caption-image) pairs
131 | * Evaluation metrics based on the intuition illustrated by heatmap, i.e. labels for each batch are along the diagonal
132 | *
133 | * We compute MRR@k (Mean Reciprocal Rank) for k=1, 3, 5, 10, 20 for image-caption similarity
134 | * Formula for [Mean Reciprocal Rank](https://en.wikipedia.org/wiki/Mean_reciprocal_rank)
135 | * Bounding it by k just means that we will only score a caption if it appears in the most likely captions for the image.
136 |
137 | | Experiment | k=1 | k=3 | k=5 | k=10 | k=20 |
138 | |-------------------------------------|---------|---------|---------|---------|---------|
139 | | baseline | 0.42580 | 0.53402 | 0.55837 | 0.57349 | 0.57829 |
140 | | [run-1](fine-tuning/train_configs/run1.cfg) | 0.69130 | 0.78962 | 0.80113 | 0.80517 | 0.80589 |
141 | | [run-2](fine-tuning/train_configs/run2.cfg) | 0.71200 | 0.80445 | 0.81519 | 0.81912 | 0.81968 |
142 | | [run-3](fine-tuning/train_configs/run3.cfg) | 0.34540 | 0.46338 | 0.49253 | 0.51154 | 0.51753 |
143 | | [run-4](fine-tuning/train_configs/run4.cfg) | 0.78760 | 0.86227 | 0.86870 | 0.87080 | 0.87120 |
144 | | [run-5](fine-tuning/train_configs/run5.cfg) | **0.80200** | **0.87170** | **0.87743** | **0.87966** | **0.88002** |
145 |
146 |
147 | ---
148 |
149 | ## Image Search
150 |
151 | The Image Search demo is located on a standalone CPU-only box since we are only doing inference. The corpus of images + captions used is the combination of the training, validation, and unseen test sets provided by ImageCLEF 2017 Caption Prediction challenge. The captions and image vectors are hosted on the Vespa search engine, which provides both BM25 based text search services and HNSW and Cosine similarity based Approximate Nearest Neighbor services.
152 |
153 | ### Environment
154 |
155 | * AWS EC2 r5.2xlarge
156 | * 8 CPUs
157 | * 64 GB RAM
158 | * 200 GB disk
159 | * Ubuntu 18.04 (bionic)
160 | * Anaconda 2021.05-Linux
161 | * Python 3.7.10
162 | * PyTorch 1.9.0
163 | * SpaCy (with `en_core_web_sm` Language Model)
164 | * Transformers
165 | * Streamlit
166 |
167 | ### Vespa
168 |
169 | * Docker -- [instructions](https://docs.docker.com/engine/install/ubuntu/)
170 | * Vespa -- [instructions](https://docs.vespa.ai/en/getting-started.html)
171 | * **NOTE -- instructions below are very likely obsolete now that [vespa-cli](https://docs.vespa.ai/en/vespa-cli.html) is available.**
172 | * download [vespa-engine/sample-apps](https://github.com/vespa-engine/sample-apps)
173 | * create new app `clip-demo` as new app in `sample-apps`
174 | * copy `image-search/vespa/src` folder to `sample-apps/clip-demo`
175 | * download [sujitpal/vespa-poc](https://github.com/sujitpal/vespa-poc)
176 | * update scripts in `bash-scripts` to point to `clip-demo` sample app
177 | * `launch.sh` to start docker instance
178 | * `deploy.sh` to deploy `clip-demo` to Vespa
179 | * `status.sh` to verify Vespa status
180 | * Prepare data
181 | * Images -- Download ImageCLEF dataset and explode as described in the Fine Tuning section.
182 | ```
183 | $ mkdir CaptionPrediction; cd CaptionPrediction
184 | $ unzip folder structure
185 | ```
186 |
187 | * Models -- copy over the folder corresponding to the fine-tuned CLIP model from fine-tuning step that contains the `pytorch_model.bin` file.
188 |
189 | * Vectors -- use fine-tuned model to generate image vectors.
190 |
191 | ```
192 | $ cd fine-tuning
193 | $ python vectorize-images --model_path /path/to/fine-tuned/pytorch_model.bin \
194 | --output-dir /path/to/folder/containing/vector/files/
195 | ```
196 |
197 | * Load data -- run `load-vespa-index.py`
198 | ```
199 | $ cd image-search
200 | $ python load-vespa-index.py --image_dir /path/to/CaptionPrediction/folder \
201 | --vector_dir /path/to/folder/containing/vector/files/
202 | ```
203 |
204 | ### Streamlit
205 |
206 | * A Streamlit based demo illustrates the usage of the trained model for the following use cases.
207 |
208 | * text to image search -- user enters a text query to match against image corpus
209 | * using caption text -- this is standard text query, searches images by caption text
210 | * using query text vector -- since CLIP learns to embed images close to their corresponding captions, we do a vector search against the image corpus using the vector representation of the query text. Distance measure uses Cosine Similarity and Vespa provides HNSW based Approximate Nearest Neighbor search.
211 | * using hybrid text + vector search -- relevance is a linear interpolation of the BM25 relevance from caption search and cosine similarity from vector search.
212 | * image to image search -- this is more like a search for similar images in the corpus, user provides an image ID and search will return similar images to the query image in the corpus. We could also query the corpus using an external image with our infrastructure (compute image embedding from trained model and find nearest neighbors in the corpus), but the demo does not support that functionality.
213 | * image only -- use consine similarity between the query image vector and vectors for images in the corpus
214 | * image + text -- use hybrid text + vector search, computing relevance as a linear interpolation of cosine similarity between image vectors and BM25 similarity between source image caption and target image captions.
215 |
216 | * To run streamlit server, run following command, application will start listening on port 8501.
217 |
218 | ```
219 | $ cd image-search
220 | $ streamlit run app.py
221 | ```
222 |
223 |
224 |
225 |
--------------------------------------------------------------------------------
/figures/eval_heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elsevierlabs-os/clip-image-search/965e1bbcaf9b74ddb703428aff9d37508e7fcca9/figures/eval_heatmap.png
--------------------------------------------------------------------------------
/figures/lossplot_5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elsevierlabs-os/clip-image-search/965e1bbcaf9b74ddb703428aff9d37508e7fcca9/figures/lossplot_5.png
--------------------------------------------------------------------------------
/fine-tuning/clip_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 |
7 |
8 | class ImageCaptionDataset(Dataset):
9 | def __init__(self, image_folder, caption_file):
10 |
11 | super().__init__()
12 | self.image_folder = image_folder
13 | self.caption_file = caption_file
14 |
15 | self.image_to_caption = {}
16 | self.images = []
17 | with open(self.caption_file, "r") as fcap:
18 | for line in fcap:
19 | image_id, caption = line.strip().split('\t')
20 | if os.path.exists(os.path.join(self.image_folder, image_id + ".jpg")):
21 | self.image_to_caption[image_id] = caption
22 | self.images.append(image_id)
23 |
24 | def __len__(self):
25 | return len(self.image_to_caption)
26 |
27 | def __getitem__(self, idx):
28 | image = self._get_image(idx)
29 | caption = self._get_caption(idx)
30 | return {
31 | "image_id": self.images[idx],
32 | "image": image,
33 | "caption": caption
34 | }
35 |
36 | def _get_image(self, idx):
37 | image_id = self.images[idx]
38 | image = Image.open(os.path.join(self.image_folder, image_id + ".jpg"))
39 | image = image.convert("RGB")
40 | return image
41 |
42 | def _get_caption(self, idx):
43 | image_id = self.images[idx]
44 | caption = self.image_to_caption[image_id]
45 | return caption
46 |
47 |
48 | class ImageCaptionCollator(object):
49 | def __init__(self, processor,
50 | image_size=224,
51 | max_caption_length=64):
52 | self.processor = processor
53 | self.image_size = image_size
54 | self.max_caption_length = max_caption_length
55 |
56 | def __call__(self, batch):
57 | image_ids = [row["image_id"] for row in batch]
58 | images = [row["image"] for row in batch]
59 | captions = [row["caption"] for row in batch]
60 | # image preprocessing: feature extractor defaults
61 | # caption preprocessing: pad/truncate + tensor
62 | inputs = self.processor(text=captions,
63 | images=images,
64 | return_tensors="pt",
65 | padding="max_length",
66 | max_length=64,
67 | truncation=True)
68 | return inputs, image_ids
69 |
70 |
71 |
72 | # from transformers import CLIPProcessor, CLIPModel
73 | # from torch.utils.data import DataLoader
74 |
75 | # DATA_DIR = "../ImageCLEF2017-CaptionPrediction"
76 |
77 | # model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
78 | # processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
79 |
80 | # device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
81 | # model.to(device)
82 |
83 | # collator = ImageCaptionCollator(processor)
84 | # train_ds = ImageCaptionDataset(
85 | # os.path.join(DATA_DIR, "training", "CaptionPredictionTraining2017"),
86 | # os.path.join(DATA_DIR, "training", "CaptionPredictionTraining2017-Captions.csv"))
87 | # train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=0,
88 | # collate_fn=collator)
89 |
90 |
91 | # for bid, (inputs, _) in enumerate(train_dl):
92 | # inputs.to(device)
93 | # outputs = model(**inputs)
94 |
95 | # logits_per_image = outputs.logits_per_image
96 | # probs = logits_per_image.softmax(dim=1)
97 | # print(probs)
98 | # break
99 |
--------------------------------------------------------------------------------
/fine-tuning/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import multiprocessing as mp
3 | import numpy as np
4 | import os
5 | import torch
6 | import time
7 |
8 | from transformers import CLIPModel, CLIPProcessor
9 | from torch.utils.data import DataLoader
10 |
11 | from clip_dataset import ImageCaptionDataset, ImageCaptionCollator
12 |
13 | DATA_DIR = "../data"
14 | IMAGE_DATA_DIR = "../../ImageCLEF2017-CaptionPrediction"
15 |
16 | OPENAI_CLIP_HF_HUBID = "openai/clip-vit-base-patch32"
17 | K_VALUES = [1, 3, 5, 10, 20]
18 | BATCH_SIZE = 64
19 |
20 | EVAL_REPORT = os.path.join(DATA_DIR, "eval-report.tsv")
21 |
22 | def compute_batch_mrrs(probs, k_values):
23 | ranks = np.argsort(-probs, axis=1)
24 | batch_mrrs = np.zeros((ranks.shape[0], len(k_values)))
25 | for i in range(ranks.shape[0]):
26 | mrr_at_k = []
27 | for j, k in enumerate(k_values):
28 | rank = np.where(ranks[i, 0:k] == i)[0]
29 | if rank.shape[0] == 0:
30 | # item not found in top k, don't add to MRR
31 | batch_mrrs[i, j] = 0
32 | else:
33 | # item found in top k, add dimension
34 | batch_mrrs[i, j] = 1.0 / (rank[0] + 1)
35 |
36 | return batch_mrrs
37 |
38 |
39 | def write_report(fout, batch_mrr):
40 | for i in range(batch_mrr.shape[0]):
41 | k_scores = ["{:.5f}".format(x) for x in batch_mrr[i, :].tolist()]
42 | fout.write(",".join(k_scores) + "\n")
43 |
44 |
45 | ############################ main ############################
46 |
47 | parser = argparse.ArgumentParser()
48 | parser.add_argument("model_path", help="'baseline' for HF model or path to local model")
49 | args = parser.parse_args()
50 |
51 | if args.model_path == "baseline":
52 | model_path = OPENAI_CLIP_HF_HUBID
53 | else:
54 | model_path = args.model_path
55 |
56 | model = CLIPModel.from_pretrained(model_path)
57 | processor = CLIPProcessor.from_pretrained(OPENAI_CLIP_HF_HUBID)
58 |
59 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
60 | model.to(device)
61 |
62 | collator = ImageCaptionCollator(processor)
63 | test_ds = ImageCaptionDataset(
64 | os.path.join(IMAGE_DATA_DIR, "test", "ConceptDetectionValidation2017"),
65 | os.path.join(IMAGE_DATA_DIR, "test", "CaptionPredictionValidation2017-Captions.csv"))
66 | test_dl = DataLoader(test_ds,
67 | batch_size=BATCH_SIZE,
68 | num_workers=mp.cpu_count() - 1,
69 | collate_fn=collator)
70 |
71 |
72 | model.eval()
73 |
74 | start = time.time()
75 | fout = open(EVAL_REPORT, "w")
76 |
77 | for bid, (inputs, _) in enumerate(test_dl):
78 | if bid % 10 == 0:
79 | print("{:d} batches processed".format(bid))
80 |
81 | inputs.to(device)
82 | outputs = model(**inputs)
83 |
84 | logits_per_image = outputs.logits_per_image
85 | probs = logits_per_image.softmax(dim=1)
86 | batch_mrrs = compute_batch_mrrs(probs.detach().cpu().numpy(), K_VALUES)
87 | write_report(fout, batch_mrrs)
88 | # break
89 |
90 | elapsed = time.time() - start
91 | print("{:d} batches processed, COMPLETE".format(bid))
92 | print("elapsed time: {:.3f} s".format(elapsed))
93 | fout.close()
94 |
95 | mrr_scores = []
96 | with open(EVAL_REPORT, "r") as frep:
97 | for line in frep:
98 | scores = np.array([float(x) for x in line.strip().split(',')])
99 | mrr_scores.append(scores)
100 | mrr_scores = np.array(mrr_scores)
101 | eval_scores = np.mean(mrr_scores, axis=0)
102 | frep.close()
103 |
104 | print(" | ".join(["{:.5f}".format(x) for x in eval_scores.tolist()]))
105 |
106 |
--------------------------------------------------------------------------------
/fine-tuning/plot_losses.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import matplotlib.pyplot as plt
3 | import os
4 |
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument("history_file", help="path to history file")
7 | parser.add_argument("--save_to", "-s", help="path to save figure")
8 | args = parser.parse_args()
9 |
10 | epochs, train_losses, val_losses, val_accs = [], [], [], []
11 | with open(args.history_file, "r") as fhist:
12 | for line in fhist:
13 | epoch, train_loss, val_loss, val_acc = line.strip().split('\t')
14 | epochs.append(int(epoch))
15 | train_losses.append(float(train_loss))
16 | val_losses.append(float(val_loss))
17 | val_accs.append(float(val_acc))
18 |
19 | plt.subplot(2, 1, 1)
20 | plt.plot(epochs, train_losses, label="train")
21 | plt.plot(epochs, val_losses, label="va;")
22 | plt.legend(loc="best")
23 | plt.xlabel("epochs")
24 | plt.ylabel("loss")
25 |
26 | plt.subplot(2, 1, 2)
27 | plt.plot(epochs, val_accs, label="val")
28 | plt.legend(loc="best")
29 | plt.xlabel("epochs")
30 | plt.ylabel("accuracy")
31 |
32 | if args.save_to is None:
33 | _ = plt.show()
34 | else:
35 | plt.savefig(args.save_to)
--------------------------------------------------------------------------------
/fine-tuning/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import multiprocessing as mp
3 | import numpy as np
4 | import os
5 | import shutil
6 | import torch
7 | import yaml
8 |
9 | from PIL import Image
10 | from transformers import (
11 | CLIPModel, CLIPProcessor,
12 | AdamW, get_scheduler
13 | )
14 | from torch.utils.data import DataLoader, RandomSampler
15 |
16 | from clip_dataset import ImageCaptionDataset, ImageCaptionCollator
17 |
18 |
19 | def do_train(model, train_dl):
20 | train_loss = 0
21 | model.train()
22 | for bid, (batch, _) in enumerate(train_dl):
23 | if bid % 100 == 0:
24 | print("...{:d} training steps complete".format(bid))
25 | batch = {k: v.to(device) for k, v in batch.items()}
26 | outputs = model(**batch, return_loss=True)
27 | loss = outputs.loss
28 | train_loss += loss.detach().cpu().numpy()
29 | loss.backward()
30 |
31 | optimizer.step()
32 | lr_scheduler.step()
33 | optimizer.zero_grad()
34 |
35 | print("...{:d} training steps COMPLETE".format(bid))
36 | return train_loss
37 |
38 |
39 | def do_eval(model, eval_dl):
40 | model.eval()
41 | val_loss, val_acc, num_examples = 0, 0, 0
42 | for bid, (batch, _) in enumerate(eval_dl):
43 | if bid % 100 == 0:
44 | print("... {:d} validation steps complete".format(bid))
45 | batch = {k: v.to(device) for k, v in batch.items()}
46 | with torch.no_grad():
47 | outputs = model(**batch, return_loss=True)
48 |
49 | loss = outputs.loss
50 | val_loss += loss.detach().cpu().numpy()
51 |
52 | logits_per_image = outputs.logits_per_image
53 | probs = logits_per_image.softmax(dim=1)
54 | predictions = torch.argmax(probs, dim=-1)
55 | labels = torch.arange(len(predictions)).to(device)
56 |
57 | accuracy = torch.sum(predictions == labels)
58 | num_examples += len(predictions)
59 | val_acc += accuracy
60 |
61 | print("... {:d} validation steps COMPLETE".format(bid))
62 | val_acc = val_acc.detach().cpu().numpy() / num_examples
63 | return val_loss, val_acc
64 |
65 |
66 | def save_checkpoint(model, model_dir, epoch):
67 | model.save_pretrained(os.path.join(model_dir, "ckpt-{:d}".format(epoch + 1)))
68 |
69 |
70 | def save_training_history(history, model_dir):
71 | fhist = open(os.path.join(model_dir, "history.tsv"), "w")
72 | for epoch, train_loss, val_loss, val_acc in history:
73 | fhist.write("{:d}\t{:.5f}\t{:.5f}\t{:.5f}\n".format(
74 | epoch, train_loss, val_loss, val_acc))
75 | fhist.close()
76 |
77 |
78 | ###################### main ###########################
79 |
80 | parser = argparse.ArgumentParser()
81 | parser.add_argument("config_file", help="provides training params")
82 | args = parser.parse_args()
83 |
84 | config_file = args.config_file
85 | with open(config_file, "r") as fcfg:
86 | config = yaml.load(fcfg)
87 | print(config)
88 |
89 | model_dir = os.path.join(config["models_dir"],
90 | os.path.basename(config_file).split(".")[0])
91 | shutil.rmtree(model_dir, ignore_errors=True)
92 | os.makedirs(model_dir)
93 |
94 | train_ds = ImageCaptionDataset(
95 | os.path.join(config["image_data_dir"], "training", config["train_images_dir"]),
96 | os.path.join(config["image_data_dir"], "training", config["train_captions_file"]))
97 | validation_ds = ImageCaptionDataset(
98 | os.path.join(config["image_data_dir"], "validation", config["validation_images_dir"]),
99 | os.path.join(config["image_data_dir"], "validation", config["validation_captions_file"]))
100 |
101 | model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
102 | processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
103 |
104 | collator = ImageCaptionCollator(processor)
105 | train_sampler = RandomSampler(train_ds,
106 | replacement=True,
107 | num_samples=config["train_sample_size"])
108 | train_dl = DataLoader(train_ds,
109 | batch_size=config["train_batch_size"],
110 | # shuffle=True,
111 | sampler=train_sampler,
112 | num_workers=mp.cpu_count() - 1,
113 | collate_fn=collator)
114 | validation_dl = DataLoader(validation_ds,
115 | batch_size=config["validation_batch_size"],
116 | num_workers=mp.cpu_count() - 1,
117 | collate_fn=collator)
118 |
119 | optimizer = AdamW(model.parameters(),
120 | lr=config["learning_rate"])
121 |
122 | num_training_steps = config["num_epochs"] * len(train_dl)
123 | lr_scheduler = get_scheduler(
124 | "linear",
125 | optimizer=optimizer,
126 | num_warmup_steps=0,
127 | num_training_steps=num_training_steps
128 | )
129 |
130 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
131 | model.to(device)
132 |
133 | history = []
134 | for epoch in range(config["num_epochs"]):
135 | train_loss = do_train(model, train_dl)
136 | val_loss, val_acc = do_eval(model, validation_dl)
137 | save_checkpoint(model, model_dir, epoch)
138 | history.append((epoch + 1, train_loss, val_loss, val_acc))
139 | print("EPOCH {:d}, training loss: {:.3f}, validation loss: {:.3f}, accuracy: {:.3f}"
140 | .format(epoch + 1,
141 | train_loss,
142 | val_loss,
143 | val_acc))
144 | save_training_history(history, model_dir)
145 |
--------------------------------------------------------------------------------
/fine-tuning/train_configs/run1.cfg:
--------------------------------------------------------------------------------
1 | image_data_dir: ../../ImageCLEF2017-CaptionPrediction
2 | train_images_dir: CaptionPredictionTraining2017
3 | train_captions_file: CaptionPredictionTraining2017-Captions.csv
4 | validation_images_dir: ConceptDetectionValidation2017
5 | validation_captions_file: CaptionPredictionValidation2017-Captions.csv
6 | models_dir: ../data/models
7 |
8 | train_sample_size: 20000
9 | train_batch_size: 32
10 | validation_batch_size: 32
11 | learning_rate: 5.e-5
12 | num_epochs: 5
13 |
--------------------------------------------------------------------------------
/fine-tuning/train_configs/run2.cfg:
--------------------------------------------------------------------------------
1 | image_data_dir: ../../ImageCLEF2017-CaptionPrediction
2 | train_images_dir: CaptionPredictionTraining2017
3 | train_captions_file: CaptionPredictionTraining2017-Captions.csv
4 | validation_images_dir: ConceptDetectionValidation2017
5 | validation_captions_file: CaptionPredictionValidation2017-Captions.csv
6 | models_dir: ../data/models
7 |
8 | train_sample_size: 20000
9 | train_batch_size: 64
10 | validation_batch_size: 64
11 | learning_rate: 5.e-5
12 | num_epochs: 5
13 |
--------------------------------------------------------------------------------
/fine-tuning/train_configs/run3.cfg:
--------------------------------------------------------------------------------
1 | image_data_dir: ../../ImageCLEF2017-CaptionPrediction
2 | train_images_dir: CaptionPredictionTraining2017
3 | train_captions_file: CaptionPredictionTraining2017-Captions.csv
4 | validation_images_dir: ConceptDetectionValidation2017
5 | validation_captions_file: CaptionPredictionValidation2017-Captions.csv
6 | models_dir: ../data/models
7 |
8 | train_sample_size: 20000
9 | train_batch_size: 64
10 | validation_batch_size: 64
11 | learning_rate: 1.e-4
12 | num_epochs: 5
13 |
--------------------------------------------------------------------------------
/fine-tuning/train_configs/run4.cfg:
--------------------------------------------------------------------------------
1 | image_data_dir: ../../ImageCLEF2017-CaptionPrediction
2 | train_images_dir: CaptionPredictionTraining2017
3 | train_captions_file: CaptionPredictionTraining2017-Captions.csv
4 | validation_images_dir: ConceptDetectionValidation2017
5 | validation_captions_file: CaptionPredictionValidation2017-Captions.csv
6 | models_dir: ../data/models
7 |
8 | train_sample_size: 20000
9 | train_batch_size: 64
10 | validation_batch_size: 64
11 | learning_rate: 5.e-5
12 | num_epochs: 10
13 |
--------------------------------------------------------------------------------
/fine-tuning/train_configs/run5.cfg:
--------------------------------------------------------------------------------
1 | image_data_dir: ../../ImageCLEF2017-CaptionPrediction
2 | train_images_dir: CaptionPredictionTraining2017
3 | train_captions_file: CaptionPredictionTraining2017-Captions.csv
4 | validation_images_dir: ConceptDetectionValidation2017
5 | validation_captions_file: CaptionPredictionValidation2017-Captions.csv
6 | models_dir: ../data/models
7 |
8 | train_sample_size: 50000
9 | train_batch_size: 64
10 | validation_batch_size: 64
11 | learning_rate: 5.e-5
12 | num_epochs: 10
13 |
--------------------------------------------------------------------------------
/fine-tuning/vectorize_images.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import multiprocessing as mp
3 | import numpy as np
4 | import os
5 | import torch
6 | import time
7 |
8 | from transformers import CLIPModel, CLIPProcessor
9 | from torch.utils.data import DataLoader
10 |
11 | from clip_dataset import ImageCaptionDataset, ImageCaptionCollator
12 |
13 | OPENAI_CLIP_HF_HUBID = "openai/clip-vit-base-patch32"
14 | IMAGE_DATA_DIR = "../../ImageCLEF2017-CaptionPrediction"
15 | BATCH_SIZE = 64
16 |
17 | ############################ main ############################
18 |
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument("model_path", help="path to local model (or 'baseline' for OpenAI CLIP)")
21 | parser.add_argument("output_dir", help="path to folder containing TSV files")
22 | args = parser.parse_args()
23 |
24 | if args.model_path == "baseline":
25 | model_path = OPENAI_CLIP_HF_HUBID
26 | else:
27 | model_path = args.model_path
28 |
29 | if os.path.exists(args.output_dir):
30 | os.makedirs(args.output_dir)
31 |
32 | model = CLIPModel.from_pretrained(model_path)
33 | processor = CLIPProcessor.from_pretrained(OPENAI_CLIP_HF_HUBID)
34 |
35 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
36 | model.to(device)
37 |
38 | collator = ImageCaptionCollator(processor)
39 | datasets = {
40 | "training": ImageCaptionDataset(
41 | os.path.join(IMAGE_DATA_DIR, "training", "CaptionPredictionTraining2017"),
42 | os.path.join(IMAGE_DATA_DIR, "training", "CaptionPredictionTraining2017-Captions.csv")),
43 | "validation": ImageCaptionDataset(
44 | os.path.join(IMAGE_DATA_DIR, "validation", "ConceptDetectionValidation2017"),
45 | os.path.join(IMAGE_DATA_DIR, "validation", "CaptionPredictionValidation2017-Captions.csv")),
46 | "test": ImageCaptionDataset(
47 | os.path.join(IMAGE_DATA_DIR, "unseen", "CaptionPredictionTesting2017"),
48 | os.path.join(IMAGE_DATA_DIR, "unseen", "CaptionPredictionTesting2017-Captions-dummy.txt"))
49 | }
50 | for split, dataset in datasets.items():
51 | test_dl = DataLoader(dataset,
52 | batch_size=BATCH_SIZE,
53 | num_workers=mp.cpu_count() - 1,
54 | collate_fn=collator)
55 |
56 | fvec = open(
57 | os.path.join(args.output_dir, "vectors-{:s}.tsv".format(split)),
58 | "w")
59 | for bid, (batch, image_ids) in enumerate(test_dl):
60 | if bid % 100 == 0:
61 | print("... {:d} batches (of {:d}) vectors generated for {:s}".format(
62 | bid, BATCH_SIZE, split))
63 | batch = {k: v.to(device) for k, v in batch.items()}
64 | with torch.no_grad():
65 | outputs = model.get_image_features(pixel_values=batch["pixel_values"])
66 | outputs = outputs.cpu().numpy()
67 | for i in range(outputs.shape[0]):
68 | image_id = image_ids[i]
69 | vector = outputs[i].reshape(-1).tolist()
70 | fvec.write("{:s}\t{:s}\n".format(
71 | image_id,
72 | ",".join(["{:.5f}".format(v) for v in vector])))
73 | # break
74 |
75 | print("... {:d} batches (of {:d}) vectors generated for {:s}, COMPLETE".format(
76 | bid, BATCH_SIZE, split))
77 | fvec.close()
78 |
--------------------------------------------------------------------------------
/image-search/app.py:
--------------------------------------------------------------------------------
1 | import text_to_image
2 | import image_to_image
3 |
4 | import streamlit as st
5 | import logging
6 |
7 |
8 | logging.basicConfig(filename="streamlit_logs.txt")
9 |
10 | PAGES = {
11 | "Test to Image Search": text_to_image,
12 | "Image to Image Search": image_to_image
13 | }
14 |
15 | st.sidebar.title("CLIP Image Search Demo")
16 | st.sidebar.markdown("""
17 | Demo to showcase image search capabilities of a CLIP transformer model
18 | fine-tuned on the ImageCLEF 2017 Caption Prediction dataset.
19 |
20 | CLIP is a transformer model from OpenAI that was trained on a large number
21 | of image + text pairs from the Internet. CLIP learns a joint embedding for
22 | images and their associated captions, such that images and their captions
23 | are pushed closer together in the embedding space. The resulting CLIP model
24 | can be used for text-to-image or image-to-image search, and performs well
25 | on general images.
26 |
27 | OOB performance on medical images is not as good, however. To remedy that,
28 | the CLIP model with medical images and captions from ImageCLEF, which resulted
29 | in significant performance improvements (as measured by MRR@k, for k=1, 3,
30 | 5, 10, 20).
31 |
32 | The CLIP model trained on ImageCLEF data was used to generate vectors for the
33 | entire ImageCLEF dataset (training, validation, and unseen test images) and
34 | loaded into a Vespa search index, which provides the Approximate Nearest
35 | Neighbors vector search functionality for this demo.
36 | """)
37 |
38 | selection = st.sidebar.radio("Go to", list(PAGES.keys()))
39 | page = PAGES[selection]
40 | page.app()
41 |
--------------------------------------------------------------------------------
/image-search/image_to_image.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import streamlit as st
4 |
5 | from PIL import Image
6 | import utils
7 |
8 |
9 | def app():
10 | model, processor = utils.load_model(utils.MODEL_PATH, utils.BASELINE_MODEL)
11 |
12 | st.title("Retrieve Images given an Image")
13 | image_id = st.text_input("Enter Image-ID:")
14 | query_type = st.radio("Select search type:",
15 | options=["image-only", "image+text"],
16 | index=0)
17 |
18 | if st.button("Search"):
19 | logging.info("image_to_image: {:s} ({:s})".format(image_id, query_type))
20 | st.text("returning results")
21 | image_vec, caption_text = utils.get_image_vector_and_caption_text_by_id(
22 | image_id, utils.SEARCH_ENDPOINT_URL, utils.GET_ENDPOINT_URL)
23 | if query_type == "image-only":
24 | results = utils.do_vector_search(image_vec, utils.SEARCH_ENDPOINT_URL)
25 | else:
26 | results = utils.do_combined_search(caption_text, image_vec, utils.SEARCH_ENDPOINT_URL)
27 |
28 | metadata, data = utils.parse_results(results)
29 | st.markdown("## Results 1-{:d} from {:d} matches".format(
30 | len(data), metadata["num_results"]))
31 | for rid, row in enumerate(data):
32 | image_id = row["image_id"]
33 | image_path = row["image_path"]
34 | image = Image.open(image_path).convert("RGB")
35 | caption = row["caption_text"]
36 | relevance = row["relevance"]
37 | col1, col2, col3 = st.columns([2, 10, 10])
38 | col1.markdown("{:d}.".format(rid + 1))
39 | col2.image(image)
40 | col3.markdown("""
41 | * **Image-ID**: {:s}
42 | * **Caption**: {:s}
43 | * **Relevance**: {:.3f}
44 | """.format(image_id, caption, relevance))
45 | st.markdown("---")
46 |
47 |
--------------------------------------------------------------------------------
/image-search/load-vespa-index.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | import requests
5 |
6 | # DATA_DIR = "/home/ubuntu/CaptionPrediction"
7 | DATA_SUBDIRS = ["training", "validation", "test"]
8 |
9 | APP_NAME = "clip-demo"
10 | SCHEMA_NAME = "image"
11 | ENDPOINT = "http://localhost:8080/document/v1/{:s}/{:s}/docid/{:d}"
12 |
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument("image_dir", help="path to folder containing training, validation and test images and captions")
16 | parser.add_argument("vector_dir", help="path to folder containing vector TSV files")
17 | args = parser.parse_args()
18 |
19 |
20 | # scan directories and compose paths
21 | image_paths = {}
22 | for data_subdir in DATA_SUBDIRS:
23 | for image_folder_cand in os.listdir(os.path.join(args.image_dir, data_subdir)):
24 | # print(data_subdir, subdir_content)
25 | if os.path.isdir(os.path.join(args.image_dir, data_subdir, image_folder_cand)):
26 | for image_file in os.listdir(os.path.join(args.image_dir, data_subdir, image_folder_cand)):
27 | image_path = os.path.join(args.image_dir, data_subdir, image_folder_cand, image_file)
28 | image_id = image_file.replace(".jpg", "")
29 | # if image_id in image_paths:
30 | # print("duplicate image:", image_file)
31 | image_paths[image_id] = image_path
32 |
33 | print("# of image paths:", len(image_paths))
34 |
35 | image_captions = {}
36 | for data_subdir in DATA_SUBDIRS:
37 | for image_folder_cand in os.listdir(os.path.join(args.image_dir, data_subdir)):
38 | if image_folder_cand.find("-Captions") > -1:
39 | with open(os.path.join(args.image_dir, data_subdir, image_folder_cand), "r") as f:
40 | for line in f:
41 | image_id, caption = line.strip().split('\t')
42 | image_captions[image_id] = caption
43 |
44 | print("# of image captions:", len(image_captions))
45 |
46 |
47 | doc_id = 1
48 | failures, successes = 0, 0
49 | headers = { "Content-Type": "application/json" }
50 | for vec_file_cand in os.listdir(args.vector_dir):
51 | if vec_file_cand.startswith("vectors-"):
52 | with open(os.path.join(args.vector_dir, vec_file_cand), "r") as f:
53 | for line in f:
54 | image_id, vec_str = line.strip().split('\t')
55 | vec = np.array([float(x) for x in vec_str.split(',')])
56 | vec_norm = np.linalg.norm(vec, 2)
57 | vec /= vec_norm
58 | vec = vec.tolist()
59 | image_path = image_paths[image_id]
60 | try:
61 | caption_text = image_captions[image_id]
62 | except KeyError:
63 | caption_text = "(no caption provided)"
64 | input_rec = {
65 | "fields": {
66 | "image_id": image_id,
67 | "image_path": image_path,
68 | "caption_text": caption_text,
69 | "clip_vector": {
70 | "values": vec
71 | }
72 | }
73 | }
74 | url = ENDPOINT.format(APP_NAME, SCHEMA_NAME, doc_id)
75 | resp = requests.post(url, headers=headers, json=input_rec)
76 | if resp.status_code != 200:
77 | print("ERROR loading [{:d}] {:s}: {:s}".format(
78 | doc_id, image_id, resp.reason))
79 | failures += 1
80 | else:
81 | successes += 1
82 | print("Inserted document {:20s}, {:6d} ok, {:6d} failed, {:6d} total\r"
83 | .format(image_id, successes, failures, doc_id), end="")
84 | doc_id += 1
85 |
86 | print("\n{:d} documents read, {:d} succeeded, {:d} failed, COMPLETE"
87 | .format(doc_id + 1, successes, failures))
88 |
--------------------------------------------------------------------------------
/image-search/search-vespa-index.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import numpy as np
4 | import requests
5 |
6 | # from requests.sessions import dispatch_hook
7 |
8 | APP_NAME = "clip-demo"
9 | SCHEMA_NAME = "image"
10 | GET_ENDPOINT_URL = "http://localhost:8080/document/v1/{:s}/{:s}/docid/{:d}"
11 | SEARCH_ENDPOINT_URL = "http://localhost:8080/search/"
12 |
13 |
14 | def do_text_search(query_text, endpoint_url):
15 | headers = { "Content-type" : "application/json" }
16 | yql = """ select * from sources image where caption_text contains '{:s}'; """.format(
17 | query_text)
18 | params = {
19 | "yql": yql,
20 | "hits": 10,
21 | "ranking.profile": "caption-search"
22 | }
23 | resp = requests.post(endpoint_url, headers=headers, json=params)
24 | return resp.json()
25 |
26 |
27 | def do_vector_search(query_vec, endpoint_url):
28 | headers = { "Content-Type" : "application/json" }
29 | params = {
30 | "yql": """select * from sources image where ([{"targetHits": 10}]nearestNeighbor(clip_vector, query_vector)); """,
31 | "hits": 10,
32 | "ranking.features.query(query_vector)": query_vec.tolist(),
33 | "ranking.profile": "image-search"
34 | }
35 | resp = requests.post(endpoint_url, headers=headers, json=params)
36 | return resp.json()
37 |
38 |
39 | def do_combined_search(query_text, query_vec, endpoint_url):
40 | headers = { "Content-Type" : "application/json" }
41 | yql = """select * from sources image where ([{"targetHits": 10}]nearestNeighbor(clip_vector, query_vector)) OR caption_text contains '%s'; """ % (query_text)
42 | params = {
43 | "yql": yql,
44 | "hits": 10,
45 | "ranking.features.query(query_vector)": query_vec.tolist(),
46 | "ranking.profile": "combined-search"
47 | }
48 | data = json.dumps(params)
49 | resp = requests.post(endpoint_url, headers=headers, data=data)
50 | return resp.json()
51 |
52 |
53 | def get_document_by_id(doc_id, endpoint_url):
54 | headers = { "Content-Type" : "application/json" }
55 | resp = requests.get(endpoint_url.format(APP_NAME, SCHEMA_NAME, doc_id),
56 | headers=headers)
57 | return resp.json()
58 |
59 |
60 | def parse_image_vector_from_response(resp_json):
61 | emb = np.zeros((512), dtype=np.float32)
62 | for cell in resp_json["fields"]["clip_vector"]["cells"]:
63 | pos = int(cell["address"]["x"])
64 | val = cell["value"]
65 | emb[pos] = val
66 | return emb
67 |
68 |
69 | def get_best_vector_from_text_query(query_text, endpoint_url, getdoc_url):
70 | text_results = do_text_search(query_text, endpoint_url)
71 | top_result_id = int(text_results["root"]["children"][0]["id"].split(":")[-1])
72 | docid_json = get_document_by_id(top_result_id, getdoc_url)
73 | emb = parse_image_vector_from_response(docid_json)
74 | return emb
75 |
76 |
77 | def get_image_vector_by_id(image_id, endpoint_url, getdoc_url):
78 | headers = { "Content-Type" : "application/json" }
79 | params = {
80 | "yql": """select documentid from sources image where image_id matches '%s'; """ % (image_id),
81 | "hits": 1,
82 | }
83 | resp = requests.post(endpoint_url, headers=headers, json=params)
84 | doc_id = int(resp.json()["root"]["children"][0]["id"].split(":")[-1])
85 | docid_json = get_document_by_id(doc_id, getdoc_url)
86 | image_vector = parse_image_vector_from_response(docid_json)
87 | return image_vector
88 |
89 |
90 | def parse_results(result_json):
91 | metadata = {
92 | "num_results": result_json["root"]["fields"]["totalCount"]
93 | }
94 | data = []
95 | try:
96 | for child in result_json["root"]["children"]:
97 | data.append({
98 | "id": child["id"],
99 | "image_id": child["fields"]["image_id"],
100 | "image_path": child["fields"]["image_path"],
101 | "caption_text": child["fields"]["caption_text"],
102 | "relevance": child["relevance"]
103 | })
104 | except KeyError:
105 | pass
106 | return metadata, data
107 |
108 |
109 | def pad_truncate(text, length):
110 | pad = " " * length
111 | text = (text + pad)[0:length]
112 | return text
113 |
114 |
115 | def print_results(metadata, data):
116 | print("## Top 10 of {:d} results".format(metadata["num_results"]))
117 | for row in data:
118 | print("{:20s} | {:50s} | {:.3f}".format(
119 | pad_truncate(row["image_id"], 20),
120 | pad_truncate(row["caption_text"], 50),
121 | row["relevance"]
122 | ))
123 |
124 |
125 | ################################## main ##################################
126 |
127 | parser = argparse.ArgumentParser()
128 | parser.add_argument("--query_type", "-t",
129 | choices=["text", "vector", "combined"],
130 | default="text",
131 | help="specify type of query to use")
132 | parser.add_argument("--query", "-q",
133 | default="X-rays",
134 | help="text of query")
135 | args = parser.parse_args()
136 |
137 | query_type = args.query_type
138 | query_text = args.query
139 |
140 | search_results = None
141 | if query_type == "text":
142 | search_results = do_text_search(query_text, SEARCH_ENDPOINT_URL)
143 | elif query_type == "vector":
144 | query_vector = get_best_vector_from_text_query(
145 | query_text, SEARCH_ENDPOINT_URL, GET_ENDPOINT_URL)
146 | search_results = do_vector_search(query_vector, SEARCH_ENDPOINT_URL)
147 | else:
148 | query_vector = get_best_vector_from_text_query(
149 | query_text, SEARCH_ENDPOINT_URL, GET_ENDPOINT_URL)
150 | search_results = do_combined_search(
151 | query_text, query_vector, SEARCH_ENDPOINT_URL)
152 |
153 | print("--- {:s} search results ---".format(query_type))
154 | metadata, data = parse_results(search_results)
155 | print_results(metadata, data)
156 |
157 | # image_vector = get_image_vector_by_id('ORT-1745-3674-80-548-g002',
158 | # SEARCH_ENDPOINT_URL, GET_ENDPOINT_URL)
159 | # print(image_vector)
160 |
--------------------------------------------------------------------------------
/image-search/start-streamlit.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | rm nohup.out
3 | nohup streamlit run app.py --logger.level=info 2>streamlit_logs.txt &
4 |
5 |
--------------------------------------------------------------------------------
/image-search/streamlit.nginx:
--------------------------------------------------------------------------------
1 | #### begin streamlit app ####
2 |
3 | location /clip-demo {
4 | location ^~ /clip-demo/healthz {
5 | proxy_pass http://127.0.0.1:8501/healthz;
6 | }
7 | location /clip-demo/stream {
8 | proxy_pass http://127.0.0.1:8501/stream;
9 | proxy_http_version 1.1;
10 | proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
11 | proxy_set_header Host $host;
12 | proxy_set_header Upgrade $http_upgrade;
13 | proxy_set_header Connection "upgrade";
14 | proxy_read_timeout 86400;
15 | }
16 | location /clip-demo/media {
17 | proxy_pass http://127.0.0.1:8501/media;
18 | }
19 | proxy_pass http://127.0.0.1:8501/;
20 | }
21 | location ^~ /static {
22 | proxy_pass http://127.0.0.1:8501/static/;
23 | }
24 | location ^~ /vendor {
25 | proxy_pass http://127.0.0.1:8501/vendor;
26 | }
27 |
28 | #### end streamlit app ####
29 |
--------------------------------------------------------------------------------
/image-search/streamlit.service:
--------------------------------------------------------------------------------
1 | [Unit]
2 | Description=StreamlitService
3 | After=vespa.service
4 |
5 | [Service]
6 | User=ubuntu
7 | Group=ubuntu
8 | WorkingDirectory=/home/ubuntu/labs-clip-imageclef/demo
9 | ExecStart=/home/ubuntu/labs-clip-imageclef/demo/start-streamlit.sh
10 |
11 | [Install]
12 | WantedBy=multi-user.target
13 |
14 |
--------------------------------------------------------------------------------
/image-search/text_to_image.py:
--------------------------------------------------------------------------------
1 | import json
2 | import streamlit as st
3 | import logging
4 |
5 | from PIL import Image
6 | import utils
7 |
8 |
9 | def app():
10 | model, processor = utils.load_model(utils.MODEL_PATH, utils.BASELINE_MODEL)
11 |
12 | st.title("Retrieve Images given Text")
13 | query_text = st.text_input("Enter a text query:")
14 | query_type = st.radio("Select search type:",
15 | options=["text search", "vector search", "combined search"],
16 | index=0)
17 |
18 | if st.button("Search"):
19 | logging.info("text_to_image: {:s} ({:s})".format(query_text, query_type))
20 | st.text("returning results")
21 |
22 | if query_type == "text search":
23 | results = utils.do_text_search(query_text, utils.SEARCH_ENDPOINT_URL)
24 | elif query_type == "vector search":
25 | query_vec = utils.get_query_vec(query_text, model, processor)
26 | results = utils.do_vector_search(query_vec, utils.SEARCH_ENDPOINT_URL)
27 | else:
28 | query_vec = utils.get_query_vec(query_text, model, processor)
29 | results = utils.do_combined_search(query_text, query_vec, utils.SEARCH_ENDPOINT_URL)
30 |
31 | # st.text(json.dumps(results, indent=2))
32 |
33 | metadata, data = utils.parse_results(results)
34 | st.markdown("## Results 1-{:d} from {:d} matches".format(
35 | len(data), metadata["num_results"]))
36 | for rid, row in enumerate(data):
37 | image_id = row["image_id"]
38 | image_path = row["image_path"]
39 | image = Image.open(image_path).convert("RGB")
40 | caption = row["caption_text"]
41 | relevance = row["relevance"]
42 | col1, col2, col3 = st.columns([2, 10, 10])
43 | col1.markdown("{:d}.".format(rid + 1))
44 | col2.image(image)
45 | col3.markdown("""
46 | * **Image-ID**: {:s}
47 | * **Caption**: {:s}
48 | * **Relevance**: {:.3f}
49 | """.format(image_id, caption, relevance))
50 | st.markdown("---")
51 |
52 |
--------------------------------------------------------------------------------
/image-search/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import os
4 | import requests
5 | import streamlit as st
6 | import torch
7 |
8 | from transformers import CLIPModel, CLIPProcessor
9 |
10 | BASELINE_MODEL = "openai/clip-vit-base-patch32"
11 | MODEL_PATH = "/home/ubuntu/clip-model/clip-imageclef-run5-ckpt10"
12 | APP_NAME = "clip-demo"
13 | SCHEMA_NAME = "image"
14 | GET_ENDPOINT_URL = "http://localhost:8080/document/v1/{:s}/{:s}/docid/{:d}"
15 | SEARCH_ENDPOINT_URL = "http://localhost:8080/search/"
16 |
17 |
18 | @st.cache(allow_output_mutation=True)
19 | def load_model(model_path, baseline_model):
20 | model = CLIPModel.from_pretrained(model_path)
21 | # model = CLIPModel.from_pretrained(baseline_model)
22 | processor = CLIPProcessor.from_pretrained(baseline_model)
23 | return model, processor
24 |
25 |
26 | def do_text_search(query_text, endpoint_url):
27 | headers = { "Content-type" : "application/json" }
28 | yql = """ select * from sources image where caption_text contains '{:s}'; """.format(
29 | query_text)
30 | params = {
31 | "yql": yql,
32 | "hits": 10,
33 | "ranking.profile": "caption-search"
34 | }
35 | resp = requests.post(endpoint_url, headers=headers, json=params)
36 | return resp.json()
37 |
38 |
39 | def do_vector_search(query_vec, endpoint_url):
40 | headers = { "Content-Type" : "application/json" }
41 | params = {
42 | "yql": """select * from sources image where ([{"targetHits": 10}]nearestNeighbor(clip_vector, query_vector)); """,
43 | "hits": 10,
44 | "ranking.features.query(query_vector)": query_vec.tolist(),
45 | "ranking.profile": "image-search"
46 | }
47 | resp = requests.post(endpoint_url, headers=headers, json=params)
48 | return resp.json()
49 |
50 |
51 | def do_combined_search(query_text, query_vec, endpoint_url):
52 | headers = { "Content-Type" : "application/json" }
53 | yql = """select * from sources image where ([{"targetHits": 10}]nearestNeighbor(clip_vector, query_vector)) OR caption_text contains '%s'; """ % (query_text)
54 | params = {
55 | "yql": yql,
56 | "hits": 10,
57 | "ranking.features.query(query_vector)": query_vec.tolist(),
58 | "ranking.profile": "combined-search"
59 | }
60 | data = json.dumps(params)
61 | resp = requests.post(endpoint_url, headers=headers, data=data)
62 | return resp.json()
63 |
64 |
65 | def get_document_by_id(doc_id, endpoint_url):
66 | headers = { "Content-Type" : "application/json" }
67 | resp = requests.get(endpoint_url.format(APP_NAME, SCHEMA_NAME, doc_id),
68 | headers=headers)
69 | return resp.json()
70 |
71 |
72 | def parse_image_vector_from_response(resp_json):
73 | emb = np.zeros((512), dtype=np.float32)
74 | for cell in resp_json["fields"]["clip_vector"]["cells"]:
75 | pos = int(cell["address"]["x"])
76 | val = cell["value"]
77 | emb[pos] = val
78 | return emb
79 |
80 |
81 | def get_best_vector_from_text_query(query_text, endpoint_url, getdoc_url):
82 | text_results = do_text_search(query_text, endpoint_url)
83 | top_result_id = int(text_results["root"]["children"][0]["id"].split(":")[-1])
84 | docid_json = get_document_by_id(top_result_id, getdoc_url)
85 | emb = parse_image_vector_from_response(docid_json)
86 | return emb
87 |
88 |
89 | def get_image_vector_and_caption_text_by_id(image_id, endpoint_url, getdoc_url):
90 | headers = { "Content-Type" : "application/json" }
91 | params = {
92 | "yql": """select documentid from sources image where image_id matches '%s'; """ % (image_id),
93 | "hits": 1,
94 | }
95 | resp = requests.post(endpoint_url, headers=headers, json=params)
96 | doc_id = int(resp.json()["root"]["children"][0]["id"].split(":")[-1])
97 | docid_json = get_document_by_id(doc_id, getdoc_url)
98 | image_vector = parse_image_vector_from_response(docid_json)
99 | caption_text = docid_json["fields"]["caption_text"]
100 | return image_vector, caption_text
101 |
102 |
103 | def parse_results(result_json):
104 | metadata = {
105 | "num_results": result_json["root"]["fields"]["totalCount"]
106 | }
107 | data = []
108 | try:
109 | for child in result_json["root"]["children"]:
110 | data.append({
111 | "id": child["id"],
112 | "image_id": child["fields"]["image_id"],
113 | "image_path": child["fields"]["image_path"],
114 | "caption_text": child["fields"]["caption_text"],
115 | "relevance": child["relevance"]
116 | })
117 | except KeyError:
118 | pass
119 | return metadata, data
120 |
121 |
122 | def get_query_vec(query_text, model, processor):
123 | inputs = processor([query_text], padding=True, return_tensors="pt")
124 | with torch.no_grad():
125 | query_vec = model.get_text_features(**inputs)
126 | query_vec = query_vec.reshape(-1).numpy()
127 | query_vec /= np.linalg.norm(query_vec, 2)
128 | return query_vec
129 |
--------------------------------------------------------------------------------
/image-search/vespa.service:
--------------------------------------------------------------------------------
1 | [Unit]
2 | Description=VespaService
3 | After=docker.service
4 |
5 | [Service]
6 | User=ubuntu
7 | Group=ubuntu
8 | WorkingDirectory=/home/ubuntu/vespa-poc/bash-scripts
9 | ExecStart=/home/ubuntu/vespa-poc/bash-scripts/start.sh
10 | ExecStop=/home/ubuntu/vespa-poc/bash-scripts/stop.sh
11 |
12 | [Install]
13 | WantedBy=multi-user.target
14 |
15 |
--------------------------------------------------------------------------------
/image-search/vespa/src/main/application/hosts.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | node1
5 |
6 |
7 |
--------------------------------------------------------------------------------
/image-search/vespa/src/main/application/schemas/image.sd:
--------------------------------------------------------------------------------
1 | schema image {
2 | document image {
3 | field image_id type string {
4 | indexing: summary | attribute
5 | }
6 | field image_path type string {
7 | indexing: summary | attribute
8 | }
9 | field caption_text type string {
10 | indexing: summary | index
11 | index: enable-bm25
12 | }
13 | field clip_vector type tensor(x[512]) {
14 | indexing: attribute | index
15 | attribute {
16 | distance-metric: innerproduct
17 | }
18 | index {
19 | hnsw {
20 | max-links-per-node: 32
21 | neighbors-to-explore-at-insert: 500
22 | }
23 | }
24 | }
25 | }
26 |
27 | fieldset default {
28 | fields: image_id, image_path, caption_text, clip_vector
29 | }
30 |
31 | rank-profile caption-search inherits default {
32 | first-phase {
33 | expression: bm25(caption_text)
34 | }
35 | }
36 |
37 | rank-profile image-search inherits default {
38 | first-phase {
39 | expression: closeness(field, clip_vector)
40 | }
41 | }
42 |
43 | rank-profile combined-search inherits default {
44 | first-phase {
45 | expression: bm25(caption_text) + closeness(clip_vector)
46 | }
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/image-search/vespa/src/main/application/search/query-profiles/default.xml:
--------------------------------------------------------------------------------
1 |
2 | 1000
3 | 1000
4 | 2s
5 |
6 |
--------------------------------------------------------------------------------
/image-search/vespa/src/main/application/search/query-profiles/types/root.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/image-search/vespa/src/main/application/services.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | 1
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 | 1.0
22 |
23 |
24 |
25 |
26 |
27 |
28 | 2
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ftfy
2 | pyperclip
3 | spacy
4 | streamlit
5 | torch
6 | torchvision
7 | transformers
8 |
--------------------------------------------------------------------------------