├── .gitignore ├── CODEOWNERS ├── LICENSE ├── README.md ├── figures ├── eval_heatmap.png └── lossplot_5.png ├── fine-tuning ├── clip_dataset.py ├── evaluate.py ├── plot_losses.py ├── train.py ├── train_configs │ ├── run1.cfg │ ├── run2.cfg │ ├── run3.cfg │ ├── run4.cfg │ └── run5.cfg └── vectorize_images.py ├── image-search ├── app.py ├── image_to_image.py ├── load-vespa-index.py ├── search-vespa-index.py ├── start-streamlit.sh ├── streamlit.nginx ├── streamlit.service ├── text_to_image.py ├── utils.py ├── vespa.service └── vespa │ └── src │ └── main │ └── application │ ├── hosts.xml │ ├── schemas │ └── image.sd │ ├── search │ └── query-profiles │ │ ├── default.xml │ │ └── types │ │ └── root.xml │ └── services.xml └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # These owners will be the default owners for everything in 2 | # the repo. Unless a later match takes precedence. 3 | * @sujitpal 4 | 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Elsevier Labs Open Source Repository 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # clip-image-search 2 | 3 | Fine-tuning OpenAI CLIP Model for Image Search on medical images. 4 | 5 | _**PLEASE NOTE**: The code in this repository is being provided as-is without any warranty of any kind. While efforts have been made to ensure that the instructions are accurate, some of the underlying software libraries move on a fast update cycle and may have changed, and as a result the code may neither be installable or executable. It is being provided for study purposes only._ 6 | 7 | * [Motivation](#motivation) 8 | * [Applications](#applications) 9 | * [References / Previous Work](#references--previous-work) 10 | * [Fine Tuning](#fine-tuning) 11 | * [Environment](#environment) 12 | * [Data Preparation](#data-preparation) 13 | * [Training Hyperparameters](#training-hyperparameters) 14 | * [Evaluation](#evaluation) 15 | * [Image Search](#image-search) 16 | * [Environment](#environment-1) 17 | * [Vespa](#vespa) 18 | * [Streamlit](#streamlit) 19 | 20 | 21 | ## Motivation 22 | 23 | * Model based image search (i.e. using machine learning based models trained on image similarity rather than traditional Lucene based search on captions) 24 | * Unsupervised or self supervised, because 25 | * labeling is expensive 26 | * we have lots of captioned images in our collection, so many (image, caption) pairs available "in the wild" 27 | * This project is **not** about caption prediction, rather it exploring the feasibility of text-to-image image search, whereby the user enters a text string to bring up the most appropriate images for the text string. 28 | 29 | ## Applications 30 | 31 | * Image search for internal users to get editorial efficiencies 32 | * Image search for customers 33 | * Captioning of new images, concepts in new images 34 | * Image decomposition and caption assignment - an enabling technology for machine learning 35 | 36 | ## References / Previous Work 37 | 38 | * [Contrastive Learning of Medical Visual Representations from Paired Images and Text](https://arxiv.org/abs/2010.00747) (Zhang et al, 2020) 39 | * learns visual representation of medical images and associated text using contrastive learning 40 | * uses different encoders for different specialties, eg, chest image encoder, bone image encoder 41 | * [Learning Transferable Visual Models from Natural Language Supervision](https://arxiv.org/abs/2103.00020) (Radford et al, 2021) 42 | * led to the OpenAI CLIP model on which this project is based 43 | * [CLIP: Connecting Text and Images](https://openai.com/blog/clip/) -- blog post provides high level intro 44 | * [CLIP Hugging Face Model](https://huggingface.co/transformers/model_doc/clip.html) 45 | * [CLIP-rsicd project](https://github.com/arampacha/CLIP-rsicd) 46 | * same idea applied to satellite images 47 | * done with external team as part of Hugging Face Flax/JAX community week 48 | 49 | ## Fine Tuning 50 | 51 | ### Environment 52 | 53 | * AWS EC2 p2.xlarge 54 | * 4 CPUs 55 | * 1 GPU (Tesla K80) 56 | * 64 GB RAM 57 | * 300 GB disk 58 | * Ubuntu 18.04 (bionic) 59 | * [AWS Deep Learning AMI (Ubuntu 18.04)](https://aws.amazon.com/marketplace/pp/prodview-x5nivojpquy6y) 60 | * Conda Environment: pytorch_latest_p37 61 | * Python 3.7.10 62 | * Pytorch 1.8.1+cu111 63 | * Additional packages 64 | * Transformers 4.10.0 65 | * see [requirements.txt](requirements.txt) 66 | 67 | ### Data Preparation 68 | 69 | * Data for task originally sourced from the [ImageCLEF 2017 Caption Prediction task](https://www.imageclef.org/2017/caption). 70 | * Once downloaded, the dataset is exploded and results in the folder structure as shown below. Essentially, three subfolders, one each for `training`, `test` and `validation` data. Since the dataset is for caption prediction, the `test` folder does not contain any captions. The `train` and `validation` folders contains a CSV file of image file names and captions, and all the three folders contain a nested sub-folder containing the images corresponding to each split. 71 | 72 | ``` 73 | $ mkdir -p ImageCLEF2017-CaptionPrediction; cd ImageCLEF2017-CaptionPrediction 74 | $ tree . 75 | | 76 | *-- test/ 77 | | | 78 | | +-- CaptionPredictionTesting2017-List.txt 79 | | +-- CaptionPredictionTesting2017/ 80 | | | 81 | | +-- test-image-1.jpg 82 | | +-- ... 9,999 more ... 83 | | 84 | +-- training/ 85 | | | 86 | | +-- CaptionPredictionTraining2017-Captions.csv 87 | | +-- CaptionPredictionTraining2017/ 88 | | | 89 | | +-- training-image-1.jpg 90 | | +-- ... 164,613 more ... 91 | | 92 | +-- validation/ 93 | | 94 | +-- CaptionPredictionValidation2017-Captions.csv 95 | +-- ConceptDetectionValidation2017/ 96 | | 97 | +-- validation-image-1.jpg 98 | +-- ... 9,999 more ... 99 | ``` 100 | 101 | * Dataset contains following splits 102 | * training: 164,614 images + captions 103 | * validtion: 10,000 images + captions 104 | * test: 10,000 images (no captions) 105 | * We need 3 splits -- training, validation, and test. 106 | * We cannot use test for training or evaluation, so we discard it 107 | * validation data becomes test data 108 | * training data is split 90:10 into new training and validation 109 | * End count is as follows: 110 | * training: 148,153 images + captions 111 | * validation: 16,461 images + captions 112 | * test: 10,000 images + captions 113 | 114 | ### Training Hyperparameters 115 | 116 | * Best model hyperparameters 117 | * batch size: 64 118 | * optimizer: ADAM 119 | * learning rate: 5e-5 120 | * number of epochs: 10 121 | * number of training samples: 50,000 122 | * We note that loss continues to drop so it is likely that further training or with larger amounts of data will increase performance. However, the flattening of the validation curve shows that we are in an area of diminishing returns. 123 | * 124 | * We considered doing image and text augmentation but dropped the idea since training set size is quite large (148k+ images+captions) and we achieve regularization through random sampling a subset of this dataset. 125 | * Table showing evaluation results tests 5 different hyperparameter configurations. 126 | 127 | 128 | ### Evaluation 129 | 130 | * We feed in batches of (caption-image) pairs 131 | * Evaluation metrics based on the intuition illustrated by heatmap, i.e. labels for each batch are along the diagonal 132 | * 133 | * We compute MRR@k (Mean Reciprocal Rank) for k=1, 3, 5, 10, 20 for image-caption similarity 134 | * Formula for [Mean Reciprocal Rank](https://en.wikipedia.org/wiki/Mean_reciprocal_rank) 135 | * Bounding it by k just means that we will only score a caption if it appears in the most likely captions for the image. 136 | 137 | | Experiment | k=1 | k=3 | k=5 | k=10 | k=20 | 138 | |-------------------------------------|---------|---------|---------|---------|---------| 139 | | baseline | 0.42580 | 0.53402 | 0.55837 | 0.57349 | 0.57829 | 140 | | [run-1](fine-tuning/train_configs/run1.cfg) | 0.69130 | 0.78962 | 0.80113 | 0.80517 | 0.80589 | 141 | | [run-2](fine-tuning/train_configs/run2.cfg) | 0.71200 | 0.80445 | 0.81519 | 0.81912 | 0.81968 | 142 | | [run-3](fine-tuning/train_configs/run3.cfg) | 0.34540 | 0.46338 | 0.49253 | 0.51154 | 0.51753 | 143 | | [run-4](fine-tuning/train_configs/run4.cfg) | 0.78760 | 0.86227 | 0.86870 | 0.87080 | 0.87120 | 144 | | [run-5](fine-tuning/train_configs/run5.cfg) | **0.80200** | **0.87170** | **0.87743** | **0.87966** | **0.88002** | 145 | 146 | 147 | --- 148 | 149 | ## Image Search 150 | 151 | The Image Search demo is located on a standalone CPU-only box since we are only doing inference. The corpus of images + captions used is the combination of the training, validation, and unseen test sets provided by ImageCLEF 2017 Caption Prediction challenge. The captions and image vectors are hosted on the Vespa search engine, which provides both BM25 based text search services and HNSW and Cosine similarity based Approximate Nearest Neighbor services. 152 | 153 | ### Environment 154 | 155 | * AWS EC2 r5.2xlarge 156 | * 8 CPUs 157 | * 64 GB RAM 158 | * 200 GB disk 159 | * Ubuntu 18.04 (bionic) 160 | * Anaconda 2021.05-Linux 161 | * Python 3.7.10 162 | * PyTorch 1.9.0 163 | * SpaCy (with `en_core_web_sm` Language Model) 164 | * Transformers 165 | * Streamlit 166 | 167 | ### Vespa 168 | 169 | * Docker -- [instructions](https://docs.docker.com/engine/install/ubuntu/) 170 | * Vespa -- [instructions](https://docs.vespa.ai/en/getting-started.html) 171 | * **NOTE -- instructions below are very likely obsolete now that [vespa-cli](https://docs.vespa.ai/en/vespa-cli.html) is available.** 172 | * download [vespa-engine/sample-apps](https://github.com/vespa-engine/sample-apps) 173 | * create new app `clip-demo` as new app in `sample-apps` 174 | * copy `image-search/vespa/src` folder to `sample-apps/clip-demo` 175 | * download [sujitpal/vespa-poc](https://github.com/sujitpal/vespa-poc) 176 | * update scripts in `bash-scripts` to point to `clip-demo` sample app 177 | * `launch.sh` to start docker instance 178 | * `deploy.sh` to deploy `clip-demo` to Vespa 179 | * `status.sh` to verify Vespa status 180 | * Prepare data 181 | * Images -- Download ImageCLEF dataset and explode as described in the Fine Tuning section. 182 | ``` 183 | $ mkdir CaptionPrediction; cd CaptionPrediction 184 | $ unzip folder structure 185 | ``` 186 | 187 | * Models -- copy over the folder corresponding to the fine-tuned CLIP model from fine-tuning step that contains the `pytorch_model.bin` file. 188 | 189 | * Vectors -- use fine-tuned model to generate image vectors. 190 | 191 | ``` 192 | $ cd fine-tuning 193 | $ python vectorize-images --model_path /path/to/fine-tuned/pytorch_model.bin \ 194 | --output-dir /path/to/folder/containing/vector/files/ 195 | ``` 196 | 197 | * Load data -- run `load-vespa-index.py` 198 | ``` 199 | $ cd image-search 200 | $ python load-vespa-index.py --image_dir /path/to/CaptionPrediction/folder \ 201 | --vector_dir /path/to/folder/containing/vector/files/ 202 | ``` 203 | 204 | ### Streamlit 205 | 206 | * A Streamlit based demo illustrates the usage of the trained model for the following use cases. 207 | 208 | * text to image search -- user enters a text query to match against image corpus 209 | * using caption text -- this is standard text query, searches images by caption text 210 | * using query text vector -- since CLIP learns to embed images close to their corresponding captions, we do a vector search against the image corpus using the vector representation of the query text. Distance measure uses Cosine Similarity and Vespa provides HNSW based Approximate Nearest Neighbor search. 211 | * using hybrid text + vector search -- relevance is a linear interpolation of the BM25 relevance from caption search and cosine similarity from vector search. 212 | * image to image search -- this is more like a search for similar images in the corpus, user provides an image ID and search will return similar images to the query image in the corpus. We could also query the corpus using an external image with our infrastructure (compute image embedding from trained model and find nearest neighbors in the corpus), but the demo does not support that functionality. 213 | * image only -- use consine similarity between the query image vector and vectors for images in the corpus 214 | * image + text -- use hybrid text + vector search, computing relevance as a linear interpolation of cosine similarity between image vectors and BM25 similarity between source image caption and target image captions. 215 | 216 | * To run streamlit server, run following command, application will start listening on port 8501. 217 | 218 | ``` 219 | $ cd image-search 220 | $ streamlit run app.py 221 | ``` 222 | 223 | 224 | 225 | -------------------------------------------------------------------------------- /figures/eval_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elsevierlabs-os/clip-image-search/965e1bbcaf9b74ddb703428aff9d37508e7fcca9/figures/eval_heatmap.png -------------------------------------------------------------------------------- /figures/lossplot_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elsevierlabs-os/clip-image-search/965e1bbcaf9b74ddb703428aff9d37508e7fcca9/figures/lossplot_5.png -------------------------------------------------------------------------------- /fine-tuning/clip_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class ImageCaptionDataset(Dataset): 9 | def __init__(self, image_folder, caption_file): 10 | 11 | super().__init__() 12 | self.image_folder = image_folder 13 | self.caption_file = caption_file 14 | 15 | self.image_to_caption = {} 16 | self.images = [] 17 | with open(self.caption_file, "r") as fcap: 18 | for line in fcap: 19 | image_id, caption = line.strip().split('\t') 20 | if os.path.exists(os.path.join(self.image_folder, image_id + ".jpg")): 21 | self.image_to_caption[image_id] = caption 22 | self.images.append(image_id) 23 | 24 | def __len__(self): 25 | return len(self.image_to_caption) 26 | 27 | def __getitem__(self, idx): 28 | image = self._get_image(idx) 29 | caption = self._get_caption(idx) 30 | return { 31 | "image_id": self.images[idx], 32 | "image": image, 33 | "caption": caption 34 | } 35 | 36 | def _get_image(self, idx): 37 | image_id = self.images[idx] 38 | image = Image.open(os.path.join(self.image_folder, image_id + ".jpg")) 39 | image = image.convert("RGB") 40 | return image 41 | 42 | def _get_caption(self, idx): 43 | image_id = self.images[idx] 44 | caption = self.image_to_caption[image_id] 45 | return caption 46 | 47 | 48 | class ImageCaptionCollator(object): 49 | def __init__(self, processor, 50 | image_size=224, 51 | max_caption_length=64): 52 | self.processor = processor 53 | self.image_size = image_size 54 | self.max_caption_length = max_caption_length 55 | 56 | def __call__(self, batch): 57 | image_ids = [row["image_id"] for row in batch] 58 | images = [row["image"] for row in batch] 59 | captions = [row["caption"] for row in batch] 60 | # image preprocessing: feature extractor defaults 61 | # caption preprocessing: pad/truncate + tensor 62 | inputs = self.processor(text=captions, 63 | images=images, 64 | return_tensors="pt", 65 | padding="max_length", 66 | max_length=64, 67 | truncation=True) 68 | return inputs, image_ids 69 | 70 | 71 | 72 | # from transformers import CLIPProcessor, CLIPModel 73 | # from torch.utils.data import DataLoader 74 | 75 | # DATA_DIR = "../ImageCLEF2017-CaptionPrediction" 76 | 77 | # model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") 78 | # processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 79 | 80 | # device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 81 | # model.to(device) 82 | 83 | # collator = ImageCaptionCollator(processor) 84 | # train_ds = ImageCaptionDataset( 85 | # os.path.join(DATA_DIR, "training", "CaptionPredictionTraining2017"), 86 | # os.path.join(DATA_DIR, "training", "CaptionPredictionTraining2017-Captions.csv")) 87 | # train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=0, 88 | # collate_fn=collator) 89 | 90 | 91 | # for bid, (inputs, _) in enumerate(train_dl): 92 | # inputs.to(device) 93 | # outputs = model(**inputs) 94 | 95 | # logits_per_image = outputs.logits_per_image 96 | # probs = logits_per_image.softmax(dim=1) 97 | # print(probs) 98 | # break 99 | -------------------------------------------------------------------------------- /fine-tuning/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing as mp 3 | import numpy as np 4 | import os 5 | import torch 6 | import time 7 | 8 | from transformers import CLIPModel, CLIPProcessor 9 | from torch.utils.data import DataLoader 10 | 11 | from clip_dataset import ImageCaptionDataset, ImageCaptionCollator 12 | 13 | DATA_DIR = "../data" 14 | IMAGE_DATA_DIR = "../../ImageCLEF2017-CaptionPrediction" 15 | 16 | OPENAI_CLIP_HF_HUBID = "openai/clip-vit-base-patch32" 17 | K_VALUES = [1, 3, 5, 10, 20] 18 | BATCH_SIZE = 64 19 | 20 | EVAL_REPORT = os.path.join(DATA_DIR, "eval-report.tsv") 21 | 22 | def compute_batch_mrrs(probs, k_values): 23 | ranks = np.argsort(-probs, axis=1) 24 | batch_mrrs = np.zeros((ranks.shape[0], len(k_values))) 25 | for i in range(ranks.shape[0]): 26 | mrr_at_k = [] 27 | for j, k in enumerate(k_values): 28 | rank = np.where(ranks[i, 0:k] == i)[0] 29 | if rank.shape[0] == 0: 30 | # item not found in top k, don't add to MRR 31 | batch_mrrs[i, j] = 0 32 | else: 33 | # item found in top k, add dimension 34 | batch_mrrs[i, j] = 1.0 / (rank[0] + 1) 35 | 36 | return batch_mrrs 37 | 38 | 39 | def write_report(fout, batch_mrr): 40 | for i in range(batch_mrr.shape[0]): 41 | k_scores = ["{:.5f}".format(x) for x in batch_mrr[i, :].tolist()] 42 | fout.write(",".join(k_scores) + "\n") 43 | 44 | 45 | ############################ main ############################ 46 | 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument("model_path", help="'baseline' for HF model or path to local model") 49 | args = parser.parse_args() 50 | 51 | if args.model_path == "baseline": 52 | model_path = OPENAI_CLIP_HF_HUBID 53 | else: 54 | model_path = args.model_path 55 | 56 | model = CLIPModel.from_pretrained(model_path) 57 | processor = CLIPProcessor.from_pretrained(OPENAI_CLIP_HF_HUBID) 58 | 59 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 60 | model.to(device) 61 | 62 | collator = ImageCaptionCollator(processor) 63 | test_ds = ImageCaptionDataset( 64 | os.path.join(IMAGE_DATA_DIR, "test", "ConceptDetectionValidation2017"), 65 | os.path.join(IMAGE_DATA_DIR, "test", "CaptionPredictionValidation2017-Captions.csv")) 66 | test_dl = DataLoader(test_ds, 67 | batch_size=BATCH_SIZE, 68 | num_workers=mp.cpu_count() - 1, 69 | collate_fn=collator) 70 | 71 | 72 | model.eval() 73 | 74 | start = time.time() 75 | fout = open(EVAL_REPORT, "w") 76 | 77 | for bid, (inputs, _) in enumerate(test_dl): 78 | if bid % 10 == 0: 79 | print("{:d} batches processed".format(bid)) 80 | 81 | inputs.to(device) 82 | outputs = model(**inputs) 83 | 84 | logits_per_image = outputs.logits_per_image 85 | probs = logits_per_image.softmax(dim=1) 86 | batch_mrrs = compute_batch_mrrs(probs.detach().cpu().numpy(), K_VALUES) 87 | write_report(fout, batch_mrrs) 88 | # break 89 | 90 | elapsed = time.time() - start 91 | print("{:d} batches processed, COMPLETE".format(bid)) 92 | print("elapsed time: {:.3f} s".format(elapsed)) 93 | fout.close() 94 | 95 | mrr_scores = [] 96 | with open(EVAL_REPORT, "r") as frep: 97 | for line in frep: 98 | scores = np.array([float(x) for x in line.strip().split(',')]) 99 | mrr_scores.append(scores) 100 | mrr_scores = np.array(mrr_scores) 101 | eval_scores = np.mean(mrr_scores, axis=0) 102 | frep.close() 103 | 104 | print(" | ".join(["{:.5f}".format(x) for x in eval_scores.tolist()])) 105 | 106 | -------------------------------------------------------------------------------- /fine-tuning/plot_losses.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import matplotlib.pyplot as plt 3 | import os 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("history_file", help="path to history file") 7 | parser.add_argument("--save_to", "-s", help="path to save figure") 8 | args = parser.parse_args() 9 | 10 | epochs, train_losses, val_losses, val_accs = [], [], [], [] 11 | with open(args.history_file, "r") as fhist: 12 | for line in fhist: 13 | epoch, train_loss, val_loss, val_acc = line.strip().split('\t') 14 | epochs.append(int(epoch)) 15 | train_losses.append(float(train_loss)) 16 | val_losses.append(float(val_loss)) 17 | val_accs.append(float(val_acc)) 18 | 19 | plt.subplot(2, 1, 1) 20 | plt.plot(epochs, train_losses, label="train") 21 | plt.plot(epochs, val_losses, label="va;") 22 | plt.legend(loc="best") 23 | plt.xlabel("epochs") 24 | plt.ylabel("loss") 25 | 26 | plt.subplot(2, 1, 2) 27 | plt.plot(epochs, val_accs, label="val") 28 | plt.legend(loc="best") 29 | plt.xlabel("epochs") 30 | plt.ylabel("accuracy") 31 | 32 | if args.save_to is None: 33 | _ = plt.show() 34 | else: 35 | plt.savefig(args.save_to) -------------------------------------------------------------------------------- /fine-tuning/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing as mp 3 | import numpy as np 4 | import os 5 | import shutil 6 | import torch 7 | import yaml 8 | 9 | from PIL import Image 10 | from transformers import ( 11 | CLIPModel, CLIPProcessor, 12 | AdamW, get_scheduler 13 | ) 14 | from torch.utils.data import DataLoader, RandomSampler 15 | 16 | from clip_dataset import ImageCaptionDataset, ImageCaptionCollator 17 | 18 | 19 | def do_train(model, train_dl): 20 | train_loss = 0 21 | model.train() 22 | for bid, (batch, _) in enumerate(train_dl): 23 | if bid % 100 == 0: 24 | print("...{:d} training steps complete".format(bid)) 25 | batch = {k: v.to(device) for k, v in batch.items()} 26 | outputs = model(**batch, return_loss=True) 27 | loss = outputs.loss 28 | train_loss += loss.detach().cpu().numpy() 29 | loss.backward() 30 | 31 | optimizer.step() 32 | lr_scheduler.step() 33 | optimizer.zero_grad() 34 | 35 | print("...{:d} training steps COMPLETE".format(bid)) 36 | return train_loss 37 | 38 | 39 | def do_eval(model, eval_dl): 40 | model.eval() 41 | val_loss, val_acc, num_examples = 0, 0, 0 42 | for bid, (batch, _) in enumerate(eval_dl): 43 | if bid % 100 == 0: 44 | print("... {:d} validation steps complete".format(bid)) 45 | batch = {k: v.to(device) for k, v in batch.items()} 46 | with torch.no_grad(): 47 | outputs = model(**batch, return_loss=True) 48 | 49 | loss = outputs.loss 50 | val_loss += loss.detach().cpu().numpy() 51 | 52 | logits_per_image = outputs.logits_per_image 53 | probs = logits_per_image.softmax(dim=1) 54 | predictions = torch.argmax(probs, dim=-1) 55 | labels = torch.arange(len(predictions)).to(device) 56 | 57 | accuracy = torch.sum(predictions == labels) 58 | num_examples += len(predictions) 59 | val_acc += accuracy 60 | 61 | print("... {:d} validation steps COMPLETE".format(bid)) 62 | val_acc = val_acc.detach().cpu().numpy() / num_examples 63 | return val_loss, val_acc 64 | 65 | 66 | def save_checkpoint(model, model_dir, epoch): 67 | model.save_pretrained(os.path.join(model_dir, "ckpt-{:d}".format(epoch + 1))) 68 | 69 | 70 | def save_training_history(history, model_dir): 71 | fhist = open(os.path.join(model_dir, "history.tsv"), "w") 72 | for epoch, train_loss, val_loss, val_acc in history: 73 | fhist.write("{:d}\t{:.5f}\t{:.5f}\t{:.5f}\n".format( 74 | epoch, train_loss, val_loss, val_acc)) 75 | fhist.close() 76 | 77 | 78 | ###################### main ########################### 79 | 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument("config_file", help="provides training params") 82 | args = parser.parse_args() 83 | 84 | config_file = args.config_file 85 | with open(config_file, "r") as fcfg: 86 | config = yaml.load(fcfg) 87 | print(config) 88 | 89 | model_dir = os.path.join(config["models_dir"], 90 | os.path.basename(config_file).split(".")[0]) 91 | shutil.rmtree(model_dir, ignore_errors=True) 92 | os.makedirs(model_dir) 93 | 94 | train_ds = ImageCaptionDataset( 95 | os.path.join(config["image_data_dir"], "training", config["train_images_dir"]), 96 | os.path.join(config["image_data_dir"], "training", config["train_captions_file"])) 97 | validation_ds = ImageCaptionDataset( 98 | os.path.join(config["image_data_dir"], "validation", config["validation_images_dir"]), 99 | os.path.join(config["image_data_dir"], "validation", config["validation_captions_file"])) 100 | 101 | model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") 102 | processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 103 | 104 | collator = ImageCaptionCollator(processor) 105 | train_sampler = RandomSampler(train_ds, 106 | replacement=True, 107 | num_samples=config["train_sample_size"]) 108 | train_dl = DataLoader(train_ds, 109 | batch_size=config["train_batch_size"], 110 | # shuffle=True, 111 | sampler=train_sampler, 112 | num_workers=mp.cpu_count() - 1, 113 | collate_fn=collator) 114 | validation_dl = DataLoader(validation_ds, 115 | batch_size=config["validation_batch_size"], 116 | num_workers=mp.cpu_count() - 1, 117 | collate_fn=collator) 118 | 119 | optimizer = AdamW(model.parameters(), 120 | lr=config["learning_rate"]) 121 | 122 | num_training_steps = config["num_epochs"] * len(train_dl) 123 | lr_scheduler = get_scheduler( 124 | "linear", 125 | optimizer=optimizer, 126 | num_warmup_steps=0, 127 | num_training_steps=num_training_steps 128 | ) 129 | 130 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 131 | model.to(device) 132 | 133 | history = [] 134 | for epoch in range(config["num_epochs"]): 135 | train_loss = do_train(model, train_dl) 136 | val_loss, val_acc = do_eval(model, validation_dl) 137 | save_checkpoint(model, model_dir, epoch) 138 | history.append((epoch + 1, train_loss, val_loss, val_acc)) 139 | print("EPOCH {:d}, training loss: {:.3f}, validation loss: {:.3f}, accuracy: {:.3f}" 140 | .format(epoch + 1, 141 | train_loss, 142 | val_loss, 143 | val_acc)) 144 | save_training_history(history, model_dir) 145 | -------------------------------------------------------------------------------- /fine-tuning/train_configs/run1.cfg: -------------------------------------------------------------------------------- 1 | image_data_dir: ../../ImageCLEF2017-CaptionPrediction 2 | train_images_dir: CaptionPredictionTraining2017 3 | train_captions_file: CaptionPredictionTraining2017-Captions.csv 4 | validation_images_dir: ConceptDetectionValidation2017 5 | validation_captions_file: CaptionPredictionValidation2017-Captions.csv 6 | models_dir: ../data/models 7 | 8 | train_sample_size: 20000 9 | train_batch_size: 32 10 | validation_batch_size: 32 11 | learning_rate: 5.e-5 12 | num_epochs: 5 13 | -------------------------------------------------------------------------------- /fine-tuning/train_configs/run2.cfg: -------------------------------------------------------------------------------- 1 | image_data_dir: ../../ImageCLEF2017-CaptionPrediction 2 | train_images_dir: CaptionPredictionTraining2017 3 | train_captions_file: CaptionPredictionTraining2017-Captions.csv 4 | validation_images_dir: ConceptDetectionValidation2017 5 | validation_captions_file: CaptionPredictionValidation2017-Captions.csv 6 | models_dir: ../data/models 7 | 8 | train_sample_size: 20000 9 | train_batch_size: 64 10 | validation_batch_size: 64 11 | learning_rate: 5.e-5 12 | num_epochs: 5 13 | -------------------------------------------------------------------------------- /fine-tuning/train_configs/run3.cfg: -------------------------------------------------------------------------------- 1 | image_data_dir: ../../ImageCLEF2017-CaptionPrediction 2 | train_images_dir: CaptionPredictionTraining2017 3 | train_captions_file: CaptionPredictionTraining2017-Captions.csv 4 | validation_images_dir: ConceptDetectionValidation2017 5 | validation_captions_file: CaptionPredictionValidation2017-Captions.csv 6 | models_dir: ../data/models 7 | 8 | train_sample_size: 20000 9 | train_batch_size: 64 10 | validation_batch_size: 64 11 | learning_rate: 1.e-4 12 | num_epochs: 5 13 | -------------------------------------------------------------------------------- /fine-tuning/train_configs/run4.cfg: -------------------------------------------------------------------------------- 1 | image_data_dir: ../../ImageCLEF2017-CaptionPrediction 2 | train_images_dir: CaptionPredictionTraining2017 3 | train_captions_file: CaptionPredictionTraining2017-Captions.csv 4 | validation_images_dir: ConceptDetectionValidation2017 5 | validation_captions_file: CaptionPredictionValidation2017-Captions.csv 6 | models_dir: ../data/models 7 | 8 | train_sample_size: 20000 9 | train_batch_size: 64 10 | validation_batch_size: 64 11 | learning_rate: 5.e-5 12 | num_epochs: 10 13 | -------------------------------------------------------------------------------- /fine-tuning/train_configs/run5.cfg: -------------------------------------------------------------------------------- 1 | image_data_dir: ../../ImageCLEF2017-CaptionPrediction 2 | train_images_dir: CaptionPredictionTraining2017 3 | train_captions_file: CaptionPredictionTraining2017-Captions.csv 4 | validation_images_dir: ConceptDetectionValidation2017 5 | validation_captions_file: CaptionPredictionValidation2017-Captions.csv 6 | models_dir: ../data/models 7 | 8 | train_sample_size: 50000 9 | train_batch_size: 64 10 | validation_batch_size: 64 11 | learning_rate: 5.e-5 12 | num_epochs: 10 13 | -------------------------------------------------------------------------------- /fine-tuning/vectorize_images.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing as mp 3 | import numpy as np 4 | import os 5 | import torch 6 | import time 7 | 8 | from transformers import CLIPModel, CLIPProcessor 9 | from torch.utils.data import DataLoader 10 | 11 | from clip_dataset import ImageCaptionDataset, ImageCaptionCollator 12 | 13 | OPENAI_CLIP_HF_HUBID = "openai/clip-vit-base-patch32" 14 | IMAGE_DATA_DIR = "../../ImageCLEF2017-CaptionPrediction" 15 | BATCH_SIZE = 64 16 | 17 | ############################ main ############################ 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("model_path", help="path to local model (or 'baseline' for OpenAI CLIP)") 21 | parser.add_argument("output_dir", help="path to folder containing TSV files") 22 | args = parser.parse_args() 23 | 24 | if args.model_path == "baseline": 25 | model_path = OPENAI_CLIP_HF_HUBID 26 | else: 27 | model_path = args.model_path 28 | 29 | if os.path.exists(args.output_dir): 30 | os.makedirs(args.output_dir) 31 | 32 | model = CLIPModel.from_pretrained(model_path) 33 | processor = CLIPProcessor.from_pretrained(OPENAI_CLIP_HF_HUBID) 34 | 35 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 36 | model.to(device) 37 | 38 | collator = ImageCaptionCollator(processor) 39 | datasets = { 40 | "training": ImageCaptionDataset( 41 | os.path.join(IMAGE_DATA_DIR, "training", "CaptionPredictionTraining2017"), 42 | os.path.join(IMAGE_DATA_DIR, "training", "CaptionPredictionTraining2017-Captions.csv")), 43 | "validation": ImageCaptionDataset( 44 | os.path.join(IMAGE_DATA_DIR, "validation", "ConceptDetectionValidation2017"), 45 | os.path.join(IMAGE_DATA_DIR, "validation", "CaptionPredictionValidation2017-Captions.csv")), 46 | "test": ImageCaptionDataset( 47 | os.path.join(IMAGE_DATA_DIR, "unseen", "CaptionPredictionTesting2017"), 48 | os.path.join(IMAGE_DATA_DIR, "unseen", "CaptionPredictionTesting2017-Captions-dummy.txt")) 49 | } 50 | for split, dataset in datasets.items(): 51 | test_dl = DataLoader(dataset, 52 | batch_size=BATCH_SIZE, 53 | num_workers=mp.cpu_count() - 1, 54 | collate_fn=collator) 55 | 56 | fvec = open( 57 | os.path.join(args.output_dir, "vectors-{:s}.tsv".format(split)), 58 | "w") 59 | for bid, (batch, image_ids) in enumerate(test_dl): 60 | if bid % 100 == 0: 61 | print("... {:d} batches (of {:d}) vectors generated for {:s}".format( 62 | bid, BATCH_SIZE, split)) 63 | batch = {k: v.to(device) for k, v in batch.items()} 64 | with torch.no_grad(): 65 | outputs = model.get_image_features(pixel_values=batch["pixel_values"]) 66 | outputs = outputs.cpu().numpy() 67 | for i in range(outputs.shape[0]): 68 | image_id = image_ids[i] 69 | vector = outputs[i].reshape(-1).tolist() 70 | fvec.write("{:s}\t{:s}\n".format( 71 | image_id, 72 | ",".join(["{:.5f}".format(v) for v in vector]))) 73 | # break 74 | 75 | print("... {:d} batches (of {:d}) vectors generated for {:s}, COMPLETE".format( 76 | bid, BATCH_SIZE, split)) 77 | fvec.close() 78 | -------------------------------------------------------------------------------- /image-search/app.py: -------------------------------------------------------------------------------- 1 | import text_to_image 2 | import image_to_image 3 | 4 | import streamlit as st 5 | import logging 6 | 7 | 8 | logging.basicConfig(filename="streamlit_logs.txt") 9 | 10 | PAGES = { 11 | "Test to Image Search": text_to_image, 12 | "Image to Image Search": image_to_image 13 | } 14 | 15 | st.sidebar.title("CLIP Image Search Demo") 16 | st.sidebar.markdown(""" 17 | Demo to showcase image search capabilities of a CLIP transformer model 18 | fine-tuned on the ImageCLEF 2017 Caption Prediction dataset. 19 | 20 | CLIP is a transformer model from OpenAI that was trained on a large number 21 | of image + text pairs from the Internet. CLIP learns a joint embedding for 22 | images and their associated captions, such that images and their captions 23 | are pushed closer together in the embedding space. The resulting CLIP model 24 | can be used for text-to-image or image-to-image search, and performs well 25 | on general images. 26 | 27 | OOB performance on medical images is not as good, however. To remedy that, 28 | the CLIP model with medical images and captions from ImageCLEF, which resulted 29 | in significant performance improvements (as measured by MRR@k, for k=1, 3, 30 | 5, 10, 20). 31 | 32 | The CLIP model trained on ImageCLEF data was used to generate vectors for the 33 | entire ImageCLEF dataset (training, validation, and unseen test images) and 34 | loaded into a Vespa search index, which provides the Approximate Nearest 35 | Neighbors vector search functionality for this demo. 36 | """) 37 | 38 | selection = st.sidebar.radio("Go to", list(PAGES.keys())) 39 | page = PAGES[selection] 40 | page.app() 41 | -------------------------------------------------------------------------------- /image-search/image_to_image.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import streamlit as st 4 | 5 | from PIL import Image 6 | import utils 7 | 8 | 9 | def app(): 10 | model, processor = utils.load_model(utils.MODEL_PATH, utils.BASELINE_MODEL) 11 | 12 | st.title("Retrieve Images given an Image") 13 | image_id = st.text_input("Enter Image-ID:") 14 | query_type = st.radio("Select search type:", 15 | options=["image-only", "image+text"], 16 | index=0) 17 | 18 | if st.button("Search"): 19 | logging.info("image_to_image: {:s} ({:s})".format(image_id, query_type)) 20 | st.text("returning results") 21 | image_vec, caption_text = utils.get_image_vector_and_caption_text_by_id( 22 | image_id, utils.SEARCH_ENDPOINT_URL, utils.GET_ENDPOINT_URL) 23 | if query_type == "image-only": 24 | results = utils.do_vector_search(image_vec, utils.SEARCH_ENDPOINT_URL) 25 | else: 26 | results = utils.do_combined_search(caption_text, image_vec, utils.SEARCH_ENDPOINT_URL) 27 | 28 | metadata, data = utils.parse_results(results) 29 | st.markdown("## Results 1-{:d} from {:d} matches".format( 30 | len(data), metadata["num_results"])) 31 | for rid, row in enumerate(data): 32 | image_id = row["image_id"] 33 | image_path = row["image_path"] 34 | image = Image.open(image_path).convert("RGB") 35 | caption = row["caption_text"] 36 | relevance = row["relevance"] 37 | col1, col2, col3 = st.columns([2, 10, 10]) 38 | col1.markdown("{:d}.".format(rid + 1)) 39 | col2.image(image) 40 | col3.markdown(""" 41 | * **Image-ID**: {:s} 42 | * **Caption**: {:s} 43 | * **Relevance**: {:.3f} 44 | """.format(image_id, caption, relevance)) 45 | st.markdown("---") 46 | 47 | -------------------------------------------------------------------------------- /image-search/load-vespa-index.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import requests 5 | 6 | # DATA_DIR = "/home/ubuntu/CaptionPrediction" 7 | DATA_SUBDIRS = ["training", "validation", "test"] 8 | 9 | APP_NAME = "clip-demo" 10 | SCHEMA_NAME = "image" 11 | ENDPOINT = "http://localhost:8080/document/v1/{:s}/{:s}/docid/{:d}" 12 | 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("image_dir", help="path to folder containing training, validation and test images and captions") 16 | parser.add_argument("vector_dir", help="path to folder containing vector TSV files") 17 | args = parser.parse_args() 18 | 19 | 20 | # scan directories and compose paths 21 | image_paths = {} 22 | for data_subdir in DATA_SUBDIRS: 23 | for image_folder_cand in os.listdir(os.path.join(args.image_dir, data_subdir)): 24 | # print(data_subdir, subdir_content) 25 | if os.path.isdir(os.path.join(args.image_dir, data_subdir, image_folder_cand)): 26 | for image_file in os.listdir(os.path.join(args.image_dir, data_subdir, image_folder_cand)): 27 | image_path = os.path.join(args.image_dir, data_subdir, image_folder_cand, image_file) 28 | image_id = image_file.replace(".jpg", "") 29 | # if image_id in image_paths: 30 | # print("duplicate image:", image_file) 31 | image_paths[image_id] = image_path 32 | 33 | print("# of image paths:", len(image_paths)) 34 | 35 | image_captions = {} 36 | for data_subdir in DATA_SUBDIRS: 37 | for image_folder_cand in os.listdir(os.path.join(args.image_dir, data_subdir)): 38 | if image_folder_cand.find("-Captions") > -1: 39 | with open(os.path.join(args.image_dir, data_subdir, image_folder_cand), "r") as f: 40 | for line in f: 41 | image_id, caption = line.strip().split('\t') 42 | image_captions[image_id] = caption 43 | 44 | print("# of image captions:", len(image_captions)) 45 | 46 | 47 | doc_id = 1 48 | failures, successes = 0, 0 49 | headers = { "Content-Type": "application/json" } 50 | for vec_file_cand in os.listdir(args.vector_dir): 51 | if vec_file_cand.startswith("vectors-"): 52 | with open(os.path.join(args.vector_dir, vec_file_cand), "r") as f: 53 | for line in f: 54 | image_id, vec_str = line.strip().split('\t') 55 | vec = np.array([float(x) for x in vec_str.split(',')]) 56 | vec_norm = np.linalg.norm(vec, 2) 57 | vec /= vec_norm 58 | vec = vec.tolist() 59 | image_path = image_paths[image_id] 60 | try: 61 | caption_text = image_captions[image_id] 62 | except KeyError: 63 | caption_text = "(no caption provided)" 64 | input_rec = { 65 | "fields": { 66 | "image_id": image_id, 67 | "image_path": image_path, 68 | "caption_text": caption_text, 69 | "clip_vector": { 70 | "values": vec 71 | } 72 | } 73 | } 74 | url = ENDPOINT.format(APP_NAME, SCHEMA_NAME, doc_id) 75 | resp = requests.post(url, headers=headers, json=input_rec) 76 | if resp.status_code != 200: 77 | print("ERROR loading [{:d}] {:s}: {:s}".format( 78 | doc_id, image_id, resp.reason)) 79 | failures += 1 80 | else: 81 | successes += 1 82 | print("Inserted document {:20s}, {:6d} ok, {:6d} failed, {:6d} total\r" 83 | .format(image_id, successes, failures, doc_id), end="") 84 | doc_id += 1 85 | 86 | print("\n{:d} documents read, {:d} succeeded, {:d} failed, COMPLETE" 87 | .format(doc_id + 1, successes, failures)) 88 | -------------------------------------------------------------------------------- /image-search/search-vespa-index.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import numpy as np 4 | import requests 5 | 6 | # from requests.sessions import dispatch_hook 7 | 8 | APP_NAME = "clip-demo" 9 | SCHEMA_NAME = "image" 10 | GET_ENDPOINT_URL = "http://localhost:8080/document/v1/{:s}/{:s}/docid/{:d}" 11 | SEARCH_ENDPOINT_URL = "http://localhost:8080/search/" 12 | 13 | 14 | def do_text_search(query_text, endpoint_url): 15 | headers = { "Content-type" : "application/json" } 16 | yql = """ select * from sources image where caption_text contains '{:s}'; """.format( 17 | query_text) 18 | params = { 19 | "yql": yql, 20 | "hits": 10, 21 | "ranking.profile": "caption-search" 22 | } 23 | resp = requests.post(endpoint_url, headers=headers, json=params) 24 | return resp.json() 25 | 26 | 27 | def do_vector_search(query_vec, endpoint_url): 28 | headers = { "Content-Type" : "application/json" } 29 | params = { 30 | "yql": """select * from sources image where ([{"targetHits": 10}]nearestNeighbor(clip_vector, query_vector)); """, 31 | "hits": 10, 32 | "ranking.features.query(query_vector)": query_vec.tolist(), 33 | "ranking.profile": "image-search" 34 | } 35 | resp = requests.post(endpoint_url, headers=headers, json=params) 36 | return resp.json() 37 | 38 | 39 | def do_combined_search(query_text, query_vec, endpoint_url): 40 | headers = { "Content-Type" : "application/json" } 41 | yql = """select * from sources image where ([{"targetHits": 10}]nearestNeighbor(clip_vector, query_vector)) OR caption_text contains '%s'; """ % (query_text) 42 | params = { 43 | "yql": yql, 44 | "hits": 10, 45 | "ranking.features.query(query_vector)": query_vec.tolist(), 46 | "ranking.profile": "combined-search" 47 | } 48 | data = json.dumps(params) 49 | resp = requests.post(endpoint_url, headers=headers, data=data) 50 | return resp.json() 51 | 52 | 53 | def get_document_by_id(doc_id, endpoint_url): 54 | headers = { "Content-Type" : "application/json" } 55 | resp = requests.get(endpoint_url.format(APP_NAME, SCHEMA_NAME, doc_id), 56 | headers=headers) 57 | return resp.json() 58 | 59 | 60 | def parse_image_vector_from_response(resp_json): 61 | emb = np.zeros((512), dtype=np.float32) 62 | for cell in resp_json["fields"]["clip_vector"]["cells"]: 63 | pos = int(cell["address"]["x"]) 64 | val = cell["value"] 65 | emb[pos] = val 66 | return emb 67 | 68 | 69 | def get_best_vector_from_text_query(query_text, endpoint_url, getdoc_url): 70 | text_results = do_text_search(query_text, endpoint_url) 71 | top_result_id = int(text_results["root"]["children"][0]["id"].split(":")[-1]) 72 | docid_json = get_document_by_id(top_result_id, getdoc_url) 73 | emb = parse_image_vector_from_response(docid_json) 74 | return emb 75 | 76 | 77 | def get_image_vector_by_id(image_id, endpoint_url, getdoc_url): 78 | headers = { "Content-Type" : "application/json" } 79 | params = { 80 | "yql": """select documentid from sources image where image_id matches '%s'; """ % (image_id), 81 | "hits": 1, 82 | } 83 | resp = requests.post(endpoint_url, headers=headers, json=params) 84 | doc_id = int(resp.json()["root"]["children"][0]["id"].split(":")[-1]) 85 | docid_json = get_document_by_id(doc_id, getdoc_url) 86 | image_vector = parse_image_vector_from_response(docid_json) 87 | return image_vector 88 | 89 | 90 | def parse_results(result_json): 91 | metadata = { 92 | "num_results": result_json["root"]["fields"]["totalCount"] 93 | } 94 | data = [] 95 | try: 96 | for child in result_json["root"]["children"]: 97 | data.append({ 98 | "id": child["id"], 99 | "image_id": child["fields"]["image_id"], 100 | "image_path": child["fields"]["image_path"], 101 | "caption_text": child["fields"]["caption_text"], 102 | "relevance": child["relevance"] 103 | }) 104 | except KeyError: 105 | pass 106 | return metadata, data 107 | 108 | 109 | def pad_truncate(text, length): 110 | pad = " " * length 111 | text = (text + pad)[0:length] 112 | return text 113 | 114 | 115 | def print_results(metadata, data): 116 | print("## Top 10 of {:d} results".format(metadata["num_results"])) 117 | for row in data: 118 | print("{:20s} | {:50s} | {:.3f}".format( 119 | pad_truncate(row["image_id"], 20), 120 | pad_truncate(row["caption_text"], 50), 121 | row["relevance"] 122 | )) 123 | 124 | 125 | ################################## main ################################## 126 | 127 | parser = argparse.ArgumentParser() 128 | parser.add_argument("--query_type", "-t", 129 | choices=["text", "vector", "combined"], 130 | default="text", 131 | help="specify type of query to use") 132 | parser.add_argument("--query", "-q", 133 | default="X-rays", 134 | help="text of query") 135 | args = parser.parse_args() 136 | 137 | query_type = args.query_type 138 | query_text = args.query 139 | 140 | search_results = None 141 | if query_type == "text": 142 | search_results = do_text_search(query_text, SEARCH_ENDPOINT_URL) 143 | elif query_type == "vector": 144 | query_vector = get_best_vector_from_text_query( 145 | query_text, SEARCH_ENDPOINT_URL, GET_ENDPOINT_URL) 146 | search_results = do_vector_search(query_vector, SEARCH_ENDPOINT_URL) 147 | else: 148 | query_vector = get_best_vector_from_text_query( 149 | query_text, SEARCH_ENDPOINT_URL, GET_ENDPOINT_URL) 150 | search_results = do_combined_search( 151 | query_text, query_vector, SEARCH_ENDPOINT_URL) 152 | 153 | print("--- {:s} search results ---".format(query_type)) 154 | metadata, data = parse_results(search_results) 155 | print_results(metadata, data) 156 | 157 | # image_vector = get_image_vector_by_id('ORT-1745-3674-80-548-g002', 158 | # SEARCH_ENDPOINT_URL, GET_ENDPOINT_URL) 159 | # print(image_vector) 160 | -------------------------------------------------------------------------------- /image-search/start-streamlit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | rm nohup.out 3 | nohup streamlit run app.py --logger.level=info 2>streamlit_logs.txt & 4 | 5 | -------------------------------------------------------------------------------- /image-search/streamlit.nginx: -------------------------------------------------------------------------------- 1 | #### begin streamlit app #### 2 | 3 | location /clip-demo { 4 | location ^~ /clip-demo/healthz { 5 | proxy_pass http://127.0.0.1:8501/healthz; 6 | } 7 | location /clip-demo/stream { 8 | proxy_pass http://127.0.0.1:8501/stream; 9 | proxy_http_version 1.1; 10 | proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; 11 | proxy_set_header Host $host; 12 | proxy_set_header Upgrade $http_upgrade; 13 | proxy_set_header Connection "upgrade"; 14 | proxy_read_timeout 86400; 15 | } 16 | location /clip-demo/media { 17 | proxy_pass http://127.0.0.1:8501/media; 18 | } 19 | proxy_pass http://127.0.0.1:8501/; 20 | } 21 | location ^~ /static { 22 | proxy_pass http://127.0.0.1:8501/static/; 23 | } 24 | location ^~ /vendor { 25 | proxy_pass http://127.0.0.1:8501/vendor; 26 | } 27 | 28 | #### end streamlit app #### 29 | -------------------------------------------------------------------------------- /image-search/streamlit.service: -------------------------------------------------------------------------------- 1 | [Unit] 2 | Description=StreamlitService 3 | After=vespa.service 4 | 5 | [Service] 6 | User=ubuntu 7 | Group=ubuntu 8 | WorkingDirectory=/home/ubuntu/labs-clip-imageclef/demo 9 | ExecStart=/home/ubuntu/labs-clip-imageclef/demo/start-streamlit.sh 10 | 11 | [Install] 12 | WantedBy=multi-user.target 13 | 14 | -------------------------------------------------------------------------------- /image-search/text_to_image.py: -------------------------------------------------------------------------------- 1 | import json 2 | import streamlit as st 3 | import logging 4 | 5 | from PIL import Image 6 | import utils 7 | 8 | 9 | def app(): 10 | model, processor = utils.load_model(utils.MODEL_PATH, utils.BASELINE_MODEL) 11 | 12 | st.title("Retrieve Images given Text") 13 | query_text = st.text_input("Enter a text query:") 14 | query_type = st.radio("Select search type:", 15 | options=["text search", "vector search", "combined search"], 16 | index=0) 17 | 18 | if st.button("Search"): 19 | logging.info("text_to_image: {:s} ({:s})".format(query_text, query_type)) 20 | st.text("returning results") 21 | 22 | if query_type == "text search": 23 | results = utils.do_text_search(query_text, utils.SEARCH_ENDPOINT_URL) 24 | elif query_type == "vector search": 25 | query_vec = utils.get_query_vec(query_text, model, processor) 26 | results = utils.do_vector_search(query_vec, utils.SEARCH_ENDPOINT_URL) 27 | else: 28 | query_vec = utils.get_query_vec(query_text, model, processor) 29 | results = utils.do_combined_search(query_text, query_vec, utils.SEARCH_ENDPOINT_URL) 30 | 31 | # st.text(json.dumps(results, indent=2)) 32 | 33 | metadata, data = utils.parse_results(results) 34 | st.markdown("## Results 1-{:d} from {:d} matches".format( 35 | len(data), metadata["num_results"])) 36 | for rid, row in enumerate(data): 37 | image_id = row["image_id"] 38 | image_path = row["image_path"] 39 | image = Image.open(image_path).convert("RGB") 40 | caption = row["caption_text"] 41 | relevance = row["relevance"] 42 | col1, col2, col3 = st.columns([2, 10, 10]) 43 | col1.markdown("{:d}.".format(rid + 1)) 44 | col2.image(image) 45 | col3.markdown(""" 46 | * **Image-ID**: {:s} 47 | * **Caption**: {:s} 48 | * **Relevance**: {:.3f} 49 | """.format(image_id, caption, relevance)) 50 | st.markdown("---") 51 | 52 | -------------------------------------------------------------------------------- /image-search/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import os 4 | import requests 5 | import streamlit as st 6 | import torch 7 | 8 | from transformers import CLIPModel, CLIPProcessor 9 | 10 | BASELINE_MODEL = "openai/clip-vit-base-patch32" 11 | MODEL_PATH = "/home/ubuntu/clip-model/clip-imageclef-run5-ckpt10" 12 | APP_NAME = "clip-demo" 13 | SCHEMA_NAME = "image" 14 | GET_ENDPOINT_URL = "http://localhost:8080/document/v1/{:s}/{:s}/docid/{:d}" 15 | SEARCH_ENDPOINT_URL = "http://localhost:8080/search/" 16 | 17 | 18 | @st.cache(allow_output_mutation=True) 19 | def load_model(model_path, baseline_model): 20 | model = CLIPModel.from_pretrained(model_path) 21 | # model = CLIPModel.from_pretrained(baseline_model) 22 | processor = CLIPProcessor.from_pretrained(baseline_model) 23 | return model, processor 24 | 25 | 26 | def do_text_search(query_text, endpoint_url): 27 | headers = { "Content-type" : "application/json" } 28 | yql = """ select * from sources image where caption_text contains '{:s}'; """.format( 29 | query_text) 30 | params = { 31 | "yql": yql, 32 | "hits": 10, 33 | "ranking.profile": "caption-search" 34 | } 35 | resp = requests.post(endpoint_url, headers=headers, json=params) 36 | return resp.json() 37 | 38 | 39 | def do_vector_search(query_vec, endpoint_url): 40 | headers = { "Content-Type" : "application/json" } 41 | params = { 42 | "yql": """select * from sources image where ([{"targetHits": 10}]nearestNeighbor(clip_vector, query_vector)); """, 43 | "hits": 10, 44 | "ranking.features.query(query_vector)": query_vec.tolist(), 45 | "ranking.profile": "image-search" 46 | } 47 | resp = requests.post(endpoint_url, headers=headers, json=params) 48 | return resp.json() 49 | 50 | 51 | def do_combined_search(query_text, query_vec, endpoint_url): 52 | headers = { "Content-Type" : "application/json" } 53 | yql = """select * from sources image where ([{"targetHits": 10}]nearestNeighbor(clip_vector, query_vector)) OR caption_text contains '%s'; """ % (query_text) 54 | params = { 55 | "yql": yql, 56 | "hits": 10, 57 | "ranking.features.query(query_vector)": query_vec.tolist(), 58 | "ranking.profile": "combined-search" 59 | } 60 | data = json.dumps(params) 61 | resp = requests.post(endpoint_url, headers=headers, data=data) 62 | return resp.json() 63 | 64 | 65 | def get_document_by_id(doc_id, endpoint_url): 66 | headers = { "Content-Type" : "application/json" } 67 | resp = requests.get(endpoint_url.format(APP_NAME, SCHEMA_NAME, doc_id), 68 | headers=headers) 69 | return resp.json() 70 | 71 | 72 | def parse_image_vector_from_response(resp_json): 73 | emb = np.zeros((512), dtype=np.float32) 74 | for cell in resp_json["fields"]["clip_vector"]["cells"]: 75 | pos = int(cell["address"]["x"]) 76 | val = cell["value"] 77 | emb[pos] = val 78 | return emb 79 | 80 | 81 | def get_best_vector_from_text_query(query_text, endpoint_url, getdoc_url): 82 | text_results = do_text_search(query_text, endpoint_url) 83 | top_result_id = int(text_results["root"]["children"][0]["id"].split(":")[-1]) 84 | docid_json = get_document_by_id(top_result_id, getdoc_url) 85 | emb = parse_image_vector_from_response(docid_json) 86 | return emb 87 | 88 | 89 | def get_image_vector_and_caption_text_by_id(image_id, endpoint_url, getdoc_url): 90 | headers = { "Content-Type" : "application/json" } 91 | params = { 92 | "yql": """select documentid from sources image where image_id matches '%s'; """ % (image_id), 93 | "hits": 1, 94 | } 95 | resp = requests.post(endpoint_url, headers=headers, json=params) 96 | doc_id = int(resp.json()["root"]["children"][0]["id"].split(":")[-1]) 97 | docid_json = get_document_by_id(doc_id, getdoc_url) 98 | image_vector = parse_image_vector_from_response(docid_json) 99 | caption_text = docid_json["fields"]["caption_text"] 100 | return image_vector, caption_text 101 | 102 | 103 | def parse_results(result_json): 104 | metadata = { 105 | "num_results": result_json["root"]["fields"]["totalCount"] 106 | } 107 | data = [] 108 | try: 109 | for child in result_json["root"]["children"]: 110 | data.append({ 111 | "id": child["id"], 112 | "image_id": child["fields"]["image_id"], 113 | "image_path": child["fields"]["image_path"], 114 | "caption_text": child["fields"]["caption_text"], 115 | "relevance": child["relevance"] 116 | }) 117 | except KeyError: 118 | pass 119 | return metadata, data 120 | 121 | 122 | def get_query_vec(query_text, model, processor): 123 | inputs = processor([query_text], padding=True, return_tensors="pt") 124 | with torch.no_grad(): 125 | query_vec = model.get_text_features(**inputs) 126 | query_vec = query_vec.reshape(-1).numpy() 127 | query_vec /= np.linalg.norm(query_vec, 2) 128 | return query_vec 129 | -------------------------------------------------------------------------------- /image-search/vespa.service: -------------------------------------------------------------------------------- 1 | [Unit] 2 | Description=VespaService 3 | After=docker.service 4 | 5 | [Service] 6 | User=ubuntu 7 | Group=ubuntu 8 | WorkingDirectory=/home/ubuntu/vespa-poc/bash-scripts 9 | ExecStart=/home/ubuntu/vespa-poc/bash-scripts/start.sh 10 | ExecStop=/home/ubuntu/vespa-poc/bash-scripts/stop.sh 11 | 12 | [Install] 13 | WantedBy=multi-user.target 14 | 15 | -------------------------------------------------------------------------------- /image-search/vespa/src/main/application/hosts.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | node1 5 | 6 | 7 | -------------------------------------------------------------------------------- /image-search/vespa/src/main/application/schemas/image.sd: -------------------------------------------------------------------------------- 1 | schema image { 2 | document image { 3 | field image_id type string { 4 | indexing: summary | attribute 5 | } 6 | field image_path type string { 7 | indexing: summary | attribute 8 | } 9 | field caption_text type string { 10 | indexing: summary | index 11 | index: enable-bm25 12 | } 13 | field clip_vector type tensor(x[512]) { 14 | indexing: attribute | index 15 | attribute { 16 | distance-metric: innerproduct 17 | } 18 | index { 19 | hnsw { 20 | max-links-per-node: 32 21 | neighbors-to-explore-at-insert: 500 22 | } 23 | } 24 | } 25 | } 26 | 27 | fieldset default { 28 | fields: image_id, image_path, caption_text, clip_vector 29 | } 30 | 31 | rank-profile caption-search inherits default { 32 | first-phase { 33 | expression: bm25(caption_text) 34 | } 35 | } 36 | 37 | rank-profile image-search inherits default { 38 | first-phase { 39 | expression: closeness(field, clip_vector) 40 | } 41 | } 42 | 43 | rank-profile combined-search inherits default { 44 | first-phase { 45 | expression: bm25(caption_text) + closeness(clip_vector) 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /image-search/vespa/src/main/application/search/query-profiles/default.xml: -------------------------------------------------------------------------------- 1 | 2 | 1000 3 | 1000 4 | 2s 5 | 6 | -------------------------------------------------------------------------------- /image-search/vespa/src/main/application/search/query-profiles/types/root.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /image-search/vespa/src/main/application/services.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 1 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 1.0 22 | 23 | 24 | 25 | 26 | 27 | 28 | 2 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | pyperclip 3 | spacy 4 | streamlit 5 | torch 6 | torchvision 7 | transformers 8 | --------------------------------------------------------------------------------