├── .gitignore ├── 1k.csv ├── README.md ├── create_textual_knn_indices.py ├── create_visual_knn_indices.py ├── environment.yml ├── extract_img_links_from_csv.py ├── extract_textual_clip_embeddings.py ├── extract_visual_clip_embeddings.py ├── prompt-search.png ├── test_knn_index.py └── test_visual_knn_index.py /.gitignore: -------------------------------------------------------------------------------- 1 | imgs 2 | knn_indices 3 | visual_embeddings 4 | 5 | *.onnx 6 | 7 | img_links.txt 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 |
5 |
6 | 7 | Logo 8 | 9 | 10 |

prompt search

11 | 12 |

13 | simple implementation of CLIP search with a database of prompts. 14 |
15 | explore prompts 16 |
17 |
18 | newsletter 19 | · 20 | community 21 |

22 |
23 | 24 | 25 | 26 | # About 27 | 28 | This is the code that we used build our CLIP semantic search engine for [krea.ai](https://krea.ai). This work heavily inspired by [clip-retrieval](https://github.com/rom1504/clip-retrieval/), [autofaiss](https://github.com/criteo/autofaiss), and [CLIP-ONNX](https://github.com/Lednik7/CLIP-ONNX). We keept our implementation simple, focused on working with data from [open-prompts](https://github.com/krea-ai/open-prompts), and prepared to run efficiently on a CPU. 29 | 30 | # CLIP Search 31 | 32 | CLIP will serve us to find generated images given an input that can be a prompt or another image. It could also be used to find other prompts given the same input. 33 | 34 | ## CLIP 35 | If you are not familiar with CLIP, we would recommend starting with the [blog](https://openai.com/blog/clip/) that OpenAI wrote about it. 36 | 37 | CLIP is a multi-modal neural network that can encode both, images and text in a common feature space. This means that we can create vectors that contain semantic information extracted from a text or an image. We can use these semantic vectors to compute operations such as cosine similarity, which would give us a similarity score. 38 | 39 | As a high level example, when CLIP extracts features from an image with a red car, it produces a similar vector to the one that it creates when it sees the text "a red car", or an image from another red car—since the semantics in all these elements are related. 40 | 41 | So far, CLIP has been helpful for creating datasets like [LAION-5B](https://laion.ai/blog/laion-5b/), guiding generative models like VQ-GAN, for image classification tasks where there is not a lot of labeled data, or as a backbone for AI models like Stable Diffusion. 42 | 43 | ## Semantic Search 44 | Semantic search consists of finding similar items within a dataset by comparing feature vectors. These feature vectors are also known as embeddings, and they can be computed in different ways. CLIP is one of the most interesting models for extracting features for semantic search. 45 | 46 | The search process consists of encoding items as embeddings, indexing them, and using these indices for fast search. Romain Beaumont wrote a great [medium post](https://rom1504.medium.com/semantic-search-with-embeddings-index-anything-8fb18556443c) about semantic search, we highly recommend reading it. 47 | 48 | With this code, you will compute embeddings using CLIP, index them using K-Nearest Neighbors, and search for similarities efficiently given an input CLIP embedding. 49 | 50 | # Environment 51 | 52 | We used conda to set up our environment. You will basically need the following packages: 53 | 54 | ``` 55 | - autofaiss==2.15.3 56 | - clip-by-openai==1.1 57 | - onnxruntime==1.12.1 58 | - onnx==1.11.0 59 | ``` 60 | 61 | We used `python 3.7`. 62 | 63 | Create a new conda environment with the following command: 64 | 65 | `conda env create -f environment.yml` 66 | 67 | 68 | # Data Preparation 69 | 70 | We will use a dataset of prompts generated with stable diffusion from [open-prompts](https://github.com/krea-ai/open-prompts). `1k.csv` is a subset from a [larger dataset](https://drive.google.com/file/d/1c4WHxtlzvHYd0UY5WCMJNn2EO-Aiv2A0/view) that you can find there—perfect for testing! 71 | 72 | # CLIP Search 73 | 74 | You will need a GPU for this one. We recommend using [Lambda](https://lambdalabs.com/service/gpu-cloud)—great pricing and easy to set up. 75 | 76 | ## Set Up Lambda 77 | [Sign up](https://lambdalabs.com/cloud/entrance) to Lambda. Fill your information, complete the email verification and add a credit card. 78 | 79 | Press the "Launch instance" button and introduce your public ssh key. Your public key should be on the folder in `~/.ssh/` in a file named `id_rsa.pub`. You can see its content with `cat ~/.ssh/id_rsa.pub`. Copy and paste the result to lambda and you'll be set. If you do not find this folder, check out [here](https://docs.oracle.com/en/cloud/cloud-at-customer/occ-get-started/generate-ssh-key-pair.html) how you can generate an ssh key, it's really straightforward. 80 | 81 | For this project a *1x RTX 6000 (24GB)* should be enough. Launch the instance with the *Launch instance* button in the *Instances* page from the Lambda dasboard. 82 | 83 | Once the instance is launched, wait for a minute or two until the Status of the machine says "Running". Then, copy the line under "SSH LOGIN", the one that looks like: `ssh ubuntu@`, where the `` will be a series of numbers in the form `123.456.789.012`. Paste it on your terminal, type "yes" to the prompt that will appear and you'll have accessed to your new machine with an GPU! 84 | First, *sign in* or *sign up* to [Lambda](https://lambdalabs.com/cloud/entrance). 85 | 86 | ## Search Images 87 | ### Download images 88 | The first step will consist of downloading the images from the `csv` file that contains all the prompt data. To do so, we will leverage the [img2dataset](https://github.com/rom1504/img2dataset) package. 89 | 90 | Execute the following command to create a new file with links from images: 91 | 92 | ``` 93 | python extract_img_links_from_csv.py 94 | ``` 95 | 96 | Note that by default, the process will create the links from `1k.csv`. Change the `CSV_FILE` variable in `extract_img_links_from_csv.py` if you want to use another data file as input. 97 | 98 | ```python 99 | CSV_FILE = "./1k.csv" 100 | OUTPUT_FILE = "./img_links.txt" 101 | ``` 102 | 103 | The results will be stored in `img_links.txt`. 104 | 105 | Run the following command to download images: 106 | 107 | ```bash 108 | img2dataset --url_list img_links.txt --output_folder imgs --thread_count=64 --image_size=256 109 | ``` 110 | 111 | The output will be stored in a sub-folder named `00000` within `imgs`. 112 | 113 | ### Compute Visual CLIP Embeddings 114 | 115 | Once the folder `imgs` is created and filled with generated images, you can run the following command to compute visual CLIP embeddings for each of them: 116 | 117 | `python extract_visual_clip_embeddings.py` 118 | 119 | The following are the main parameters that you might need to change from `extract_visual_clip_embeddings.py`: 120 | ``` 121 | IMG_DIR = "./imgs/00000" #directory where all your images were downloaded 122 | BATCH_SIZE = 128 #number of CLIP embeddings computed at each iterations 123 | NUM_WORKERS = 14 #number of workers that will run in parallel (recommended is number_of_cores - 2) 124 | ``` 125 | 126 | Once the process is finished, you will see a new folder named `visual_embeddings`. This folder will contain two other folders named `ids` and `embeddings`. `ids` will contain `.npy` files with information of the `ids` of each generation computed at each batch. `embeddings` will contain `.npy` files with the resulting embeddings computed at each batch. This data will be useful for computing the KNN indices. 127 | 128 | ### Compute Visual KNN indices 129 | If you did not make any modifications in the default output structure from the previous step, this process should be as easy as running the following command: 130 | 131 | `python create_visual_knn_indices.py` 132 | 133 | Otherwise, you might want to modify the following variables from `create_visual_knn_indices.py`: 134 | 135 | ```python 136 | INDICES_FOLDER = "knn_indices" 137 | EMBEDDINGS_DIR = "visual_embeddings" 138 | ``` 139 | 140 | The result will be stored within a new folder named `knn_indices` in a file named `visual_prompts.index`. 141 | 142 | 143 | ### Search Images 144 | 145 | In order to search generated images more efficiently, we will use an ONNX version of CLIP. We will use the implementation from [`CLIP-ONNX`](https://github.com/Lednik7/CLIP-ONNX) for this. 146 | 147 | Install the following package: 148 | ```bash 149 | pip install git+https://github.com/Lednik7/CLIP-ONNX.git --no-deps 150 | ``` 151 | 152 | Once installed, download the ONNX CLIP models with the following commands: 153 | ``` 154 | wget https://clip-as-service.s3.us-east-2.amazonaws.com/models/onnx/ViT-B-32/visual.onnx 155 | wget https://clip-as-service.s3.us-east-2.amazonaws.com/models/onnx/ViT-B-32/textual.onnx 156 | ``` 157 | 158 | Finally, execute the following line to perform the search with regular CLIP and ONNX CLIP: 159 | 160 | ``` 161 | python test_visual_knn_index.py 162 | ``` 163 | 164 | The result should be a list of image filenames that are the most similar to the prompt `"image of a blue robot with red background"` and the image `prompt-search.png`. 165 | 166 | Change the following parameters in `test_visual_knn_index.py` to try out different input prompts and images: 167 | 168 | ```python 169 | INPUT_IMG_PATH = "./prompt-search.png" 170 | INPUT_PROMPT = "image of a blue robot with red background" 171 | NUM_RESULTS = 5 172 | ``` 173 | 174 | Have fun! 175 | 176 | # Get in touch 177 | 178 | - Follow and DM us on Twitter: [@krea_ai](https://twitter.com/krea_ai) 179 | - Join [our Discord community](https://discord.gg/3mkFbvPYut) 180 | - Email either `v` or `d` (`v` at `krea` dot `ai`; `d` at `krea` dot `ai` respectively) -------------------------------------------------------------------------------- /create_textual_knn_indices.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from autofaiss import build_index 5 | 6 | indices_folder = "knn_indices" 7 | 8 | embeddings = np.load("embeddings/text_embeddings.npy") 9 | 10 | prompt_index_filename = os.path.join(indices_folder, "prompts.index") 11 | index, index_infos = build_index( 12 | embeddings, 13 | index_path=prompt_index_filename, 14 | ) -------------------------------------------------------------------------------- /create_visual_knn_indices.py: -------------------------------------------------------------------------------- 1 | import os 2 | import itertools 3 | import glob 4 | 5 | import numpy as np 6 | from autofaiss import build_index 7 | 8 | INDICES_FOLDER = "knn_indices" 9 | EMBEDDINGS_DIR = "visual_embeddings" 10 | 11 | embeddings_path = os.path.join(EMBEDDINGS_DIR, "embeddings") 12 | out_ids_path = os.path.join(EMBEDDINGS_DIR, "visual_ids.npy") 13 | 14 | if not os.path.exists(out_ids_path): 15 | ids_path = os.path.join(EMBEDDINGS_DIR, "ids") 16 | 17 | ids_paths = glob.glob(f"{ids_path}/*") 18 | ids_paths.sort() 19 | ids = [[str(embedding_id) for embedding_id in np.load(path)] 20 | for path in ids_paths] 21 | ids = np.asarray(list(itertools.chain.from_iterable(ids))) 22 | 23 | np.save(out_ids_path, ids) 24 | 25 | prompt_index_filename = os.path.join(INDICES_FOLDER, "visual_prompts.index") 26 | infos_index_filename = os.path.join(INDICES_FOLDER, 27 | "visual-prompts-infos.index") 28 | index, index_infos = build_index( 29 | embeddings_path, 30 | index_path=prompt_index_filename, 31 | index_infos_path=infos_index_filename, 32 | ) -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: prompt-search 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.7.13 6 | - pip=22.1.2 7 | - pip: 8 | - autofaiss==2.15.3 9 | - clip-by-openai==1.1 10 | - img2dataset==1.33.0 11 | - onnxruntime==1.12.1 -------------------------------------------------------------------------------- /extract_img_links_from_csv.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | 4 | CSV_FILE = "./1k.csv" 5 | OUTPUT_FILE = "./img_links.txt" 6 | 7 | with open(CSV_FILE, 'r') as f: 8 | reader = csv.reader(f) 9 | _headers = next(reader) 10 | 11 | for idx, csv_data in enumerate(reader): 12 | try: 13 | img_link = json.loads( 14 | csv_data[-1])["raw_discord_data"]["image_uri"] 15 | with open(OUTPUT_FILE, 'a') as f: 16 | f.write(img_link + "\n") 17 | 18 | except Exception as e: 19 | print(f'error in line {idx + 1} :/') 20 | print(e) 21 | -------------------------------------------------------------------------------- /extract_textual_clip_embeddings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | 4 | import numpy as np 5 | import torch 6 | import clip 7 | 8 | USE_CACHE = False 9 | BATCH_SIZE = 2048 10 | OUTDIR = "embeddings" 11 | 12 | os.makedirs(OUTDIR, exist_ok=True) 13 | 14 | device = "cuda" if torch.cuda.is_available() else "cpu" 15 | model, preprocess = clip.load("ViT-B/32", device=device) 16 | 17 | with open('data.csv', newline='') as csvfile: 18 | reader = csv.reader(csvfile) 19 | _headers = next(reader) 20 | 21 | prompt_data = set([(row[0], row[1]) for row in reader if row[1] != '']) 22 | 23 | prompt_ids = [data[0] for data in prompt_data] 24 | prompts = (data[1] for data in prompt_data) 25 | 26 | prompt_ids_filename = os.path.join(OUTDIR, f"prompt_ids.npy") 27 | np.save(prompt_ids_filename, prompt_ids) 28 | 29 | text_embeddings = None 30 | batched_prompts = [] 31 | for idx, prompt in enumerate(prompts): 32 | batched_prompts.append(prompt) 33 | 34 | if len(batched_prompts) % BATCH_SIZE == 0 or idx == len(prompt_ids) - 1: 35 | print(f"processing -- {idx + 1}") 36 | 37 | batch_text_embeddings_filename = os.path.join( 38 | OUTDIR, f"text_embeddings_{idx + 1}.npy") 39 | 40 | if os.path.exists(batch_text_embeddings_filename) and USE_CACHE: 41 | batch_text_embeddings = np.load(batch_text_embeddings_filename) 42 | 43 | else: 44 | with torch.no_grad(): 45 | batched_text = clip.tokenize( 46 | batched_prompts, 47 | truncate=True, 48 | ).to(device) 49 | 50 | batch_text_embeddings = model.encode_text(batched_text, ) 51 | batch_text_embeddings /= batch_text_embeddings.norm( 52 | dim=-1, keepdim=True) 53 | 54 | batch_text_embeddings = batch_text_embeddings.cpu().numpy().astype( 55 | 'float32') 56 | 57 | if USE_CACHE: 58 | np.save(batch_text_embeddings_filename, batch_text_embeddings) 59 | 60 | if text_embeddings is None: 61 | text_embeddings = batch_text_embeddings 62 | 63 | else: 64 | text_embeddings = np.concatenate( 65 | (text_embeddings, batch_text_embeddings)) 66 | 67 | print(f"text embeddings shape -- {text_embeddings.shape}") 68 | print("\n") 69 | 70 | batched_prompts = [] 71 | 72 | print(f"{len(text_embeddings)} CLIP embeddings extracted!") 73 | text_embeddings_filename = os.path.join(OUTDIR, f"text_embeddings.npy") 74 | np.save(text_embeddings_filename, text_embeddings) -------------------------------------------------------------------------------- /extract_visual_clip_embeddings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import glob 4 | 5 | import torch 6 | import clip 7 | import numpy as np 8 | from PIL import Image 9 | from torch.utils.data import Dataset, DataLoader 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | 12 | try: 13 | from torchvision.transforms import InterpolationMode 14 | BICUBIC = InterpolationMode.BICUBIC 15 | 16 | except ImportError: 17 | BICUBIC = Image.BICUBIC 18 | 19 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 20 | 21 | USE_CACHE = False 22 | IMG_DIR = "./imgs/00000" 23 | BATCH_SIZE = 128 24 | NUM_WORKERS = 14 25 | PERFETCH_FACTOR = 14 26 | OUTDIR = "./visual_embeddings" 27 | 28 | os.makedirs(OUTDIR, exist_ok=True) 29 | os.makedirs(OUTDIR + "/ids", exist_ok=True) 30 | os.makedirs(OUTDIR + "/embeddings", exist_ok=True) 31 | 32 | 33 | def _convert_image_to_rgb(image): 34 | return image.convert("RGB") 35 | 36 | 37 | class CLIPImgDataset(Dataset): 38 | 39 | def __init__( 40 | self, 41 | img_dir: str, 42 | ): 43 | self.img_paths = glob.glob(f"{img_dir}/*.jpg", ) 44 | 45 | self.transform = Compose([ 46 | Resize(224, interpolation=BICUBIC), 47 | CenterCrop(224), 48 | _convert_image_to_rgb, 49 | ToTensor(), 50 | Normalize((0.48145466, 0.4578275, 0.40821073), 51 | (0.26862954, 0.26130258, 0.27577711)), 52 | ]) 53 | 54 | def __len__(self): 55 | return len(self.img_paths) 56 | 57 | def __getitem__( 58 | self, 59 | idx, 60 | ): 61 | img_path = self.img_paths[idx] 62 | 63 | generation_id = img_path.split("/")[-1].split(".")[0] 64 | 65 | img = Image.open(img_path) 66 | img = self.transform(img) 67 | 68 | return img, generation_id 69 | 70 | 71 | def main(): 72 | print("setting up dataloader...") 73 | model, _preprocess = clip.load( 74 | "ViT-B/32", 75 | device=DEVICE, 76 | ) 77 | clip_img_dataset = CLIPImgDataset(img_dir=IMG_DIR, ) 78 | clip_img_dataloader = DataLoader( 79 | clip_img_dataset, 80 | batch_size=BATCH_SIZE, 81 | num_workers=NUM_WORKERS, 82 | prefetch_factor=PERFETCH_FACTOR, 83 | persistent_workers=True, 84 | # multiprocessing_context="spawn", 85 | ) 86 | 87 | print("starting to process!") 88 | for idx, (batched_imgs, 89 | generation_ids) in enumerate(clip_img_dataloader): 90 | print(f"processing! -- {(idx + 1) * BATCH_SIZE} / {len(clip_img_dataset)}") 91 | 92 | with torch.no_grad(): 93 | batched_img_embeddings = model.visual( 94 | batched_imgs.to(DEVICE, torch.float16), ) 95 | 96 | batched_img_embeddings /= batched_img_embeddings.norm(dim=-1, 97 | keepdim=True) 98 | batched_img_embeddings = batched_img_embeddings.cpu().numpy().astype( 99 | 'float32') 100 | 101 | prompt_ids_filename = os.path.join(OUTDIR, 102 | f"ids/{str(idx).zfill(9)}.npy") 103 | np.save(prompt_ids_filename, np.asarray(generation_ids)) 104 | 105 | img_embeddings_filename = os.path.join( 106 | OUTDIR, f"embeddings/{str(idx).zfill(9)}.npy") 107 | np.save(img_embeddings_filename, batched_img_embeddings) 108 | 109 | torch.cuda.empty_cache() 110 | 111 | 112 | if __name__ == "__main__": 113 | main() -------------------------------------------------------------------------------- /prompt-search.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/krea-ai/prompt-search/4216c03412f900a8b26f90e8dc138aa649fd0c2f/prompt-search.png -------------------------------------------------------------------------------- /test_knn_index.py: -------------------------------------------------------------------------------- 1 | import os 2 | import faiss 3 | import numpy as np 4 | 5 | import clip 6 | import torch 7 | from clip_onnx import clip_onnx 8 | 9 | indices_folder = "knn_indices" 10 | 11 | prompt_index_filename = os.path.join(indices_folder, "prompts.index") 12 | embeddings_dir = "embeddings" 13 | 14 | prompt_ids = np.load(os.path.join(embeddings_dir, "prompt_ids.npy")) 15 | 16 | loaded_index = faiss.read_index( 17 | prompt_index_filename, 18 | faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY, 19 | ) 20 | 21 | text = "cute cat" 22 | 23 | device = "cpu" 24 | model, preprocess = clip.load("ViT-B/32", device=device) 25 | onnx_model = clip_onnx(None) 26 | onnx_model.load_onnx( 27 | visual_path="visual.onnx", 28 | textual_path="textual.onnx", 29 | logit_scale=100.0000, 30 | ) 31 | onnx_model.start_sessions(providers=["CPUExecutionProvider"], ) 32 | 33 | tokenized_text = clip.tokenize( 34 | [text], 35 | truncate=True, 36 | ).to(device) 37 | 38 | with torch.no_grad(): 39 | text_embedding = model.encode_text(tokenized_text, ) 40 | text_embedding /= text_embedding.norm(dim=-1, keepdim=True) 41 | 42 | text_embedding = text_embedding.cpu().numpy().astype('float32') 43 | 44 | tokenized_text = tokenized_text.detach().cpu().numpy().astype(np.int64) 45 | onnx_text_embedding = onnx_model.encode_text(tokenized_text, ) 46 | # onnx_text_embedding /= onnx_text_embedding.norm(dim=-1, keepdim=True) 47 | onnx_text_embedding = np.around(onnx_text_embedding, decimals=4) 48 | 49 | _, I = loaded_index.search(text_embedding, 5) 50 | print("CLIP RESULTS") 51 | print([f"{str(prompt_ids[idx])}" for idx in I[0]]) 52 | 53 | _, I = loaded_index.search(onnx_text_embedding, 5) 54 | print("ONNX RESULTS") 55 | print([f"{str(prompt_ids[idx])}" for idx in I[0]]) -------------------------------------------------------------------------------- /test_visual_knn_index.py: -------------------------------------------------------------------------------- 1 | import os 2 | import faiss 3 | import numpy as np 4 | 5 | import clip 6 | import torch 7 | from PIL import Image 8 | from clip_onnx import clip_onnx 9 | 10 | 11 | USE_ONNX = True 12 | INDICES_FOLDER = "./knn_indices" 13 | INDEX_FILE_PATH = os.path.join(INDICES_FOLDER, "visual_prompts.index") 14 | VISUAL_EMBEDDINGS_DIR = "./visual_embeddings" 15 | DEVICE = "cpu" 16 | INPUT_IMG_PATH = "./prompt-search.png" 17 | INPUT_PROMPT = "image of a blue robot with red background" 18 | NUM_RESULTS = 5 19 | 20 | prompt_index_filename = os.path.join(INDICES_FOLDER, "visual_prompts.index") 21 | prompt_ids = np.load(os.path.join(VISUAL_EMBEDDINGS_DIR, "visual_ids.npy")) 22 | 23 | loaded_index = faiss.read_index( 24 | prompt_index_filename, 25 | faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY, 26 | ) 27 | 28 | 29 | model, preprocess = clip.load("ViT-B/32", device=DEVICE) 30 | onnx_model = clip_onnx(None) 31 | onnx_model.load_onnx( 32 | visual_path="visual.onnx", 33 | textual_path="textual.onnx", 34 | logit_scale=100.0000, 35 | ) 36 | onnx_model.start_sessions(providers=["CPUExecutionProvider"], ) 37 | 38 | img = Image.open(INPUT_IMG_PATH) 39 | processed_img = preprocess(img).unsqueeze(0).to(DEVICE) 40 | 41 | with torch.no_grad(): 42 | visual_embedding = model.encode_image(processed_img, ) 43 | visual_embedding /= visual_embedding.norm(dim=-1, keepdim=True) 44 | 45 | visual_embedding = visual_embedding.cpu().numpy().astype('float32') 46 | 47 | _, I = loaded_index.search(visual_embedding, NUM_RESULTS) 48 | 49 | print("SIMILAR IMGS FROM INPUT IMG") 50 | print([f"{str(prompt_ids[idx])}.jpg" for idx in I[0]]) 51 | 52 | tokenized_text = clip.tokenize( 53 | [INPUT_PROMPT], 54 | # truncate=True, 55 | ).to(DEVICE) 56 | 57 | with torch.no_grad(): 58 | text_embedding = model.encode_text(tokenized_text, ) 59 | text_embedding /= text_embedding.norm(dim=-1, keepdim=True) 60 | 61 | text_embedding = text_embedding.cpu().numpy().astype('float32') 62 | 63 | _, I = loaded_index.search(text_embedding, NUM_RESULTS) 64 | 65 | print("SIMILAR IMGS FROM INPUT PROMPT") 66 | print([f"{str(prompt_ids[idx])}.jpg" for idx in I[0]]) 67 | 68 | if USE_ONNX: 69 | onnx_visual_embedding = onnx_model.encode_image(processed_img.numpy(), ) 70 | onnx_visual_embedding /= np.linalg.norm(onnx_visual_embedding, axis=-1, keepdims=True) 71 | onnx_visual_embedding = np.around(onnx_visual_embedding, decimals=4) 72 | 73 | _, I = loaded_index.search(onnx_visual_embedding, 5) 74 | print("ONNX SIMILAR IMGS FROM INPUT IMG") 75 | print([f"{str(prompt_ids[idx])}.jpg" for idx in I[0]]) 76 | 77 | onnx_text_embedding = onnx_model.encode_text(tokenized_text.numpy(), ) 78 | onnx_text_embedding /= np.linalg.norm(onnx_text_embedding, axis=-1, keepdims=True) 79 | 80 | _, I = loaded_index.search(onnx_text_embedding, 5) 81 | print("ONNX SIMILAR IMGS FROM INPUT PROMPT") 82 | print([f"{str(prompt_ids[idx])}.jpg" for idx in I[0]]) --------------------------------------------------------------------------------