├── README.md ├── datasets ├── .DS_Store └── example_dataset │ ├── .DS_Store │ ├── image_299803.jpg │ ├── image_349837.JPEG │ ├── image_465128.jpg │ ├── image_521600.jpg │ ├── image_531937.jpg │ ├── image_554874.JPEG │ ├── image_590676.jpg │ └── image_610893.jpg ├── environment.yml ├── imageRAG_OmniGen.py ├── imageRAG_SDXL.py ├── images ├── .DS_Store ├── african_grey_underwater.png ├── golden_cradle.png ├── koshka.jpeg ├── koshka_halloween.png └── origami_birds_NYC.png ├── retrieval.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 |

ImageRAG: Dynamic Image Retrieval for Reference-Guided Image Generation

2 | 3 | 4 |

5 | 6 | Build 7 | 8 | 9 | Build 10 | 11 |

12 | 13 | ## Environment & Dependencies 14 | First, create the ImageRAG environment: 15 | ``` 16 | conda env create -f environment.yml 17 | conda activate ImageRAG 18 | ``` 19 | 20 | Next, save the dataset you wish to retrieve images from (the retrieval dataset) in the `datasets` folder, similarly to `example_dataset`. 21 | 22 | ## Usage 23 | 24 | For Omnigen, clone the [OmniGen repo](https://github.com/VectorSpaceLab/OmniGen), and use `imageRAG_OmniGen.py` as follows: 25 | 26 | ``` 27 | python imageRAG_OmniGen.py \ 28 | --prompt \ 29 | --dataset \ 30 | --omnigen_path \ 31 | --openai_api_key 32 | ``` 33 | 34 | For example: 35 | 36 | ``` 37 | python imageRAG_OmniGen.py \ 38 | --prompt "An african grey underwater." \ 39 | --out_name "african_grey_underwater" \ 40 | --dataset "example_dataset" \ 41 | --omnigen_path \ 42 | --openai_api_key 43 | ``` 44 | 45 | will generate: 46 | 47 | ![OmniGen example](images/african_grey_underwater.png) 48 | 49 | For personalized generation, use ```--input_images ``` with an image of the subject you would like to generate. 50 | 51 | For example: 52 | 53 | ``` 54 | python imageRAG_OmniGen.py \ 55 | --prompt "My cat wearing a Halloween costume. My cat is the cat in this image: <|image_1|>." \ 56 | --input_images "images/koshka.jpeg" \ 57 | --out_name "koshka_halloween" \ 58 | --dataset "example_dataset" \ 59 | --omnigen_path \ 60 | --openai_api_key 61 | ``` 62 | 63 | will generate: 64 | 65 | ![personalization example](images/koshka_halloween.png) 66 | 67 | If the output image isn't satisfactory, You can try the 'generation' mode which retrieves all concepts and not only the missing ones. 68 | For example: 69 | 70 | ``` 71 | python imageRAG_OmniGen.py \ 72 | --prompt "Origami birds flying over New York City." \ 73 | --mode "generation" \ 74 | --out_name "origami_birds_NYC" \ 75 | --dataset "example_dataset" \ 76 | --omnigen_path \ 77 | --openai_api_key 78 | ``` 79 | 80 | will generate: 81 | 82 | ![generation example](images/origami_birds_NYC.png) 83 | 84 | 85 | For SDXL, use ```imageRAG_SDXL.py``` as follows: 86 | 87 | ``` 88 | python imageRAG_SDXL.py \ 89 | --prompt \ 90 | --dataset \ 91 | --openai_api_key 92 | ``` 93 | 94 | For example: 95 | 96 | ``` 97 | python imageRAG_SDXL.py \ 98 | --prompt "A golden retriever and a cradle." \ 99 | --out_name "golden_cradle" \ 100 | --dataset "example_dataset" \ 101 | --openai_api_key 102 | ``` 103 | 104 | will generate: 105 | 106 | ![SDXL example](images/golden_cradle.png) 107 | 108 | ## Citation 109 | If you find this repository useful, please cite our paper. 110 | ``` 111 | @misc{shalevarkushin2025imageragdynamicimageretrieval, 112 | title={ImageRAG: Dynamic Image Retrieval for Reference-Guided Image Generation}, 113 | author={Rotem Shalev-Arkushin and Rinon Gal and Amit H. Bermano and Ohad Fried}, 114 | year={2025}, 115 | eprint={2502.09411}, 116 | archivePrefix={arXiv}, 117 | primaryClass={cs.CV}, 118 | url={https://arxiv.org/abs/2502.09411}, 119 | } 120 | ``` 121 | 122 | 123 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /datasets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rotem-shalev/ImageRAG/16c9502a09b5a049f7c30f39b7f48998fd1e2526/datasets/.DS_Store -------------------------------------------------------------------------------- /datasets/example_dataset/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rotem-shalev/ImageRAG/16c9502a09b5a049f7c30f39b7f48998fd1e2526/datasets/example_dataset/.DS_Store -------------------------------------------------------------------------------- /datasets/example_dataset/image_299803.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rotem-shalev/ImageRAG/16c9502a09b5a049f7c30f39b7f48998fd1e2526/datasets/example_dataset/image_299803.jpg -------------------------------------------------------------------------------- /datasets/example_dataset/image_349837.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rotem-shalev/ImageRAG/16c9502a09b5a049f7c30f39b7f48998fd1e2526/datasets/example_dataset/image_349837.JPEG -------------------------------------------------------------------------------- /datasets/example_dataset/image_465128.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rotem-shalev/ImageRAG/16c9502a09b5a049f7c30f39b7f48998fd1e2526/datasets/example_dataset/image_465128.jpg -------------------------------------------------------------------------------- /datasets/example_dataset/image_521600.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rotem-shalev/ImageRAG/16c9502a09b5a049f7c30f39b7f48998fd1e2526/datasets/example_dataset/image_521600.jpg -------------------------------------------------------------------------------- /datasets/example_dataset/image_531937.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rotem-shalev/ImageRAG/16c9502a09b5a049f7c30f39b7f48998fd1e2526/datasets/example_dataset/image_531937.jpg -------------------------------------------------------------------------------- /datasets/example_dataset/image_554874.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rotem-shalev/ImageRAG/16c9502a09b5a049f7c30f39b7f48998fd1e2526/datasets/example_dataset/image_554874.JPEG -------------------------------------------------------------------------------- /datasets/example_dataset/image_590676.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rotem-shalev/ImageRAG/16c9502a09b5a049f7c30f39b7f48998fd1e2526/datasets/example_dataset/image_590676.jpg -------------------------------------------------------------------------------- /datasets/example_dataset/image_610893.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rotem-shalev/ImageRAG/16c9502a09b5a049f7c30f39b7f48998fd1e2526/datasets/example_dataset/image_610893.jpg -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ImageRAG 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - bzip2=1.0.8=h5eee18b_6 8 | - ca-certificates=2024.9.24=h06a4308_0 9 | - ld_impl_linux-64=2.40=h12ee557_0 10 | - libffi=3.4.4=h6a678d5_1 11 | - libgcc-ng=11.2.0=h1234567_1 12 | - libgomp=11.2.0=h1234567_1 13 | - libstdcxx-ng=11.2.0=h1234567_1 14 | - libuuid=1.41.5=h5eee18b_0 15 | - ncurses=6.4=h6a678d5_0 16 | - openssl=3.0.15=h5eee18b_0 17 | - pip=24.2=py310h06a4308_0 18 | - python=3.10.13=h955ad1f_0 19 | - readline=8.2=h5eee18b_0 20 | - setuptools=75.1.0=py310h06a4308_0 21 | - sqlite=3.45.3=h5eee18b_0 22 | - tk=8.6.14=h39e8969_0 23 | - wheel=0.44.0=py310h06a4308_0 24 | - xz=5.4.6=h5eee18b_1 25 | - zlib=1.2.13=h5eee18b_1 26 | - pip: 27 | - accelerate==1.1.1 28 | - aiohappyeyeballs==2.4.3 29 | - aiohttp==3.10.10 30 | - aiosignal==1.3.1 31 | - annotated-types==0.7.0 32 | - anyio==4.7.0 33 | - async-timeout==4.0.3 34 | - attrs==24.2.0 35 | - certifi==2024.8.30 36 | - charset-normalizer==3.4.0 37 | - git+https://github.com/openai/CLIP.git 38 | - contourpy==1.3.1 39 | - cycler==0.12.1 40 | - datasets==3.1.0 41 | - diffusers==0.31.0 42 | - dill==0.3.8 43 | - distro==1.9.0 44 | - exceptiongroup==1.2.2 45 | - filelock==3.13.1 46 | - fonttools==4.55.2 47 | - frozenlist==1.5.0 48 | - fsspec==2024.2.0 49 | - ftfy==6.3.1 50 | - h11==0.14.0 51 | - httpcore==1.0.7 52 | - httpx==0.28.1 53 | - huggingface-hub==0.26.2 54 | - idna==3.10 55 | - importlib-metadata==8.5.0 56 | - jinja2==3.1.3 57 | - jiter==0.8.0 58 | - kiwisolver==1.4.7 59 | - markupsafe==2.1.5 60 | - matplotlib==3.9.3 61 | - mpmath==1.3.0 62 | - multidict==6.1.0 63 | - multiprocess==0.70.16 64 | - networkx==3.2.1 65 | - numpy==1.26.3 66 | - open-clip-torch==2.29.0 67 | - openai==1.57.0 68 | - packaging==24.2 69 | - pandas==2.2.3 70 | - peft==0.13.2 71 | - pillow==10.2.0 72 | - propcache==0.2.0 73 | - psutil==6.1.0 74 | - pyarrow==18.0.0 75 | - pydantic==2.10.3 76 | - pydantic-core==2.27.1 77 | - pyparsing==3.2.0 78 | - python-dateutil==2.9.0.post0 79 | - pytz==2024.2 80 | - pyyaml==6.0.2 81 | - rank-bm25==0.2.2 82 | - regex==2024.11.6 83 | - requests==2.32.3 84 | - safetensors==0.4.5 85 | - six==1.16.0 86 | - sniffio==1.3.1 87 | - sympy==1.13.1 88 | - timm==1.0.11 89 | - tokenizers==0.20.3 90 | - torch==2.4.1 91 | - torchvision==0.19.1 92 | - tqdm==4.67.0 93 | - transformers==4.46.2 94 | - triton==3.0.0 95 | - typing-extensions==4.12.2 96 | - tzdata==2024.2 97 | - urllib3==2.2.3 98 | - wcwidth==0.2.13 99 | - xxhash==3.5.0 100 | - yarl==1.17.1 101 | - zipp==3.21.0 102 | -------------------------------------------------------------------------------- /imageRAG_OmniGen.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import os 4 | import openai 5 | import numpy as np 6 | 7 | from retrieval import * 8 | from utils import * 9 | 10 | def run_omnigen(prompt, input_images, out_path, args): 11 | print("running OmniGen inference") 12 | device = f"cuda:{args.device}" if int(args.device) >= 0 else "cuda" 13 | pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1", device=device, 14 | model_cpu_offload=args.model_cpu_offload) 15 | images = pipe(prompt=prompt, input_images=input_images, height=args.height, width=args.width, 16 | guidance_scale=args.guidance_scale, img_guidance_scale=args.image_guidance_scale, 17 | seed=args.seed, use_input_image_size_as_output=args.use_input_image_size_as_output) 18 | 19 | images[0].save(out_path) 20 | 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser(description="imageRAG pipeline") 23 | parser.add_argument("--omnigen_path", type=str) 24 | parser.add_argument("--openai_api_key", type=str) 25 | parser.add_argument("--dataset", type=str) 26 | parser.add_argument("--device", type=int, default=-1) 27 | parser.add_argument("--seed", type=int, default=0) 28 | parser.add_argument("--guidance_scale", type=float, default=2.5) 29 | parser.add_argument("--image_guidance_scale", type=float, default=1.6) 30 | parser.add_argument("--height", type=int, default=1024) 31 | parser.add_argument("--width", type=int, default=1024) 32 | parser.add_argument("--data_lim", type=int, default=-1) 33 | parser.add_argument("--prompt", type=str, default="") 34 | parser.add_argument("--out_name", type=str, default="out") 35 | parser.add_argument("--out_path", type=str, default="results") 36 | parser.add_argument("--embeddings_path", type=str, default="") 37 | parser.add_argument("--input_images", type=str, default="") 38 | parser.add_argument("--mode", type=str, default="omnigen_first", choices=['omnigen_first', 'generation', 'personalization']) 39 | parser.add_argument("--model_cpu_offload", action='store_true') 40 | parser.add_argument("--use_input_image_size_as_output", action='store_true') 41 | parser.add_argument("--only_rephrase", action='store_true') 42 | parser.add_argument("--retrieval_method", type=str, default="CLIP", choices=['CLIP', 'SigLIP', 'MoE', 'gpt_rerank']) 43 | 44 | args = parser.parse_args() 45 | 46 | sys.path.append(args.omnigen_path) 47 | from OmniGen import OmniGenPipeline 48 | 49 | openai.api_key = args.openai_api_key 50 | os.environ["OPENAI_API_KEY"] = openai.api_key 51 | client = openai.OpenAI() 52 | 53 | os.makedirs(args.out_path, exist_ok=True) 54 | out_txt_file = os.path.join(args.out_path, args.out_name + ".txt") 55 | f = open(out_txt_file, "w") 56 | device = f"cuda:{args.device}" if int(args.device) >= 0 else "cuda" 57 | data_path = f"datasets/{args.dataset}" 58 | 59 | prompt_w_retreival = args.prompt 60 | 61 | retrieval_image_paths = [os.path.join(data_path, fname) for fname in os.listdir(data_path)] 62 | if args.data_lim != -1: 63 | retrieval_image_paths = retrieval_image_paths[:args.data_lim] 64 | 65 | embeddings_path = args.embeddings_path or f"datasets/embeddings/{args.dataset}" 66 | input_images = args.input_images.split(",") if args.input_images else [] 67 | k_concepts = 3 - len(input_images) if args.mode != "personalization" else 1 68 | k_captions_per_concept = 1 69 | 70 | f.write(f"prompt: {args.prompt}\n") 71 | 72 | if args.mode == "omnigen_first": 73 | out_name = f"{args.out_name}_no_imageRAG.png" 74 | out_path = os.path.join(args.out_path, out_name) 75 | if not os.path.exists(out_path): 76 | f.write(f"running OmniGen, will save results to {out_path}\n") 77 | run_omnigen(args.prompt, input_images, out_path, args) 78 | 79 | if args.only_rephrase: 80 | rephrased_prompt = retrieval_caption_generation(args.prompt, input_images + [out_path], 81 | gpt_client=client, 82 | k_captions_per_concept=k_captions_per_concept, 83 | only_rephrase=args.only_rephrase) 84 | if rephrased_prompt == True: 85 | f.write("result matches prompt, not running imageRAG.") 86 | f.close() 87 | exit() 88 | 89 | f.write(f"running OmniGen, rephrased prompt is: {rephrased_prompt}\n") 90 | out_name = f"{args.out_name}_rephrased.png" 91 | out_path = os.path.join(args.out_path, out_name) 92 | run_omnigen(rephrased_prompt, input_images, out_path, args) 93 | f.close() 94 | exit() 95 | else: 96 | ans = retrieval_caption_generation(args.prompt, 97 | input_images + [out_path], 98 | gpt_client=client, 99 | k_captions_per_concept=k_captions_per_concept) 100 | 101 | if type(ans) != bool: 102 | captions = convert_res_to_captions(ans) 103 | f.write(f"captions: {captions}\n") 104 | else: 105 | f.write("result matches prompt, not running imageRAG.") 106 | f.close() 107 | exit() 108 | 109 | omnigen_out_path = out_path 110 | 111 | elif args.mode == "generation": 112 | captions = retrieval_caption_generation(args.prompt, 113 | input_images, 114 | gpt_client=client, 115 | k_captions_per_concept=k_captions_per_concept, 116 | decision=False) 117 | captions = convert_res_to_captions(captions) 118 | f.write(f"captions: {captions}\n") 119 | 120 | k_imgs_per_caption = 1 121 | paths = retrieve_img_per_caption(captions, retrieval_image_paths, embeddings_path=embeddings_path, 122 | k=k_imgs_per_caption, device=device, method=args.retrieval_method) 123 | final_paths = np.array(paths).flatten().tolist() 124 | j = len(input_images) 125 | k = 3 # can use up to 3 images in prompt with omnigen 126 | paths = final_paths[:k - j] 127 | f.write(f"final retrieved paths: {paths}\n") 128 | image_paths_extended = input_images + paths 129 | 130 | examples = ", ".join([f'{captions[i]}: <|image_{i + j + 1}|>' for i in range(len(paths))]) 131 | prompt_w_retreival = f"According to these images of {examples}, generate {args.prompt}" 132 | f.write(f"prompt_w_retreival: {prompt_w_retreival}\n") 133 | 134 | out_name = f"{args.out_name}_gs_{args.guidance_scale}_im_gs_{args.image_guidance_scale}.png" 135 | out_path = os.path.join(args.out_path, out_name) 136 | f.write(f"running OmniGen, will save result to: {out_path}\n") 137 | 138 | run_omnigen(prompt_w_retreival, image_paths_extended, out_path, args) 139 | f.close() 140 | exit() -------------------------------------------------------------------------------- /imageRAG_SDXL.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | import openai 6 | import torch 7 | from diffusers import AutoPipelineForText2Image, DiffusionPipeline 8 | from transformers import CLIPVisionModelWithProjection 9 | 10 | from utils import * 11 | from retrieval import * 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser(description="imageRAG pipeline") 15 | parser.add_argument("--openai_api_key", type=str) 16 | parser.add_argument("--dataset", type=str) 17 | parser.add_argument("--device", type=int, default=-1) 18 | parser.add_argument("--seed", type=int, default=0) 19 | parser.add_argument("--hf_cache_dir", type=str, default=None) 20 | parser.add_argument("--ip_scale", type=float, default=0.5) 21 | parser.add_argument("--data_lim", type=int, default=-1) 22 | parser.add_argument("--prompt", type=str, default="") 23 | parser.add_argument("--out_name", type=str, default="out") 24 | parser.add_argument("--out_path", type=str, default="results") 25 | parser.add_argument("--embeddings_path", type=str, default="") 26 | parser.add_argument("--mode", type=str, default="sd_first", choices=['sd_first', 'generation']) 27 | parser.add_argument("--only_rephrase", action='store_true') 28 | parser.add_argument("--retrieval_method", type=str, default="CLIP", choices=['CLIP', 'SigLIP', 'MoE', 'gpt_rerank']) 29 | 30 | args = parser.parse_args() 31 | 32 | openai.api_key = args.openai_api_key 33 | os.environ["OPENAI_API_KEY"] = openai.api_key 34 | client = openai.OpenAI() 35 | 36 | os.makedirs(args.out_path, exist_ok=True) 37 | out_txt_file = os.path.join(args.out_path, args.out_name + ".txt") 38 | f = open(out_txt_file, "w") 39 | device = f"cuda:{args.device}" if int(args.device) >= 0 else "cuda" 40 | data_path = f"datasets/{args.dataset}" 41 | 42 | prompt_w_retreival = args.prompt 43 | 44 | retrieval_image_paths = [os.path.join(data_path, fname) for fname in os.listdir(data_path)] 45 | if args.data_lim != -1: 46 | retrieval_image_paths = retrieval_image_paths[:args.data_lim] 47 | 48 | embeddings_path = args.embeddings_path or f"datasets/embeddings/{args.dataset}" 49 | 50 | image_encoder = CLIPVisionModelWithProjection.from_pretrained( 51 | "h94/IP-Adapter", 52 | subfolder="models/image_encoder", 53 | torch_dtype=torch.float16, 54 | cache_dir=args.hf_cache_dir 55 | ) 56 | 57 | pipe_clean = AutoPipelineForText2Image.from_pretrained( 58 | "stabilityai/stable-diffusion-xl-base-1.0", 59 | image_encoder=image_encoder, 60 | torch_dtype=torch.float16, 61 | cache_dir=args.hf_cache_dir 62 | ).to(device) 63 | 64 | generator1 = torch.Generator(device="cuda").manual_seed(args.seed) 65 | pipe_ip = AutoPipelineForText2Image.from_pretrained( 66 | "stabilityai/stable-diffusion-xl-base-1.0", 67 | image_encoder=image_encoder, 68 | torch_dtype=torch.float16, 69 | cache_dir=args.hf_cache_dir 70 | ).to(device) 71 | 72 | pipe_ip.load_ip_adapter("h94/IP-Adapter", 73 | subfolder="sdxl_models", 74 | weight_name="ip-adapter-plus_sdxl_vit-h.safetensors", 75 | cache_dir=args.hf_cache_dir) 76 | 77 | pipe_ip.set_ip_adapter_scale(args.ip_scale) 78 | generator2 = torch.Generator(device=device).manual_seed(args.seed) 79 | 80 | sd_first = args.mode == "sd_first" 81 | 82 | if sd_first: 83 | cur_out_path = os.path.join(args.out_path, f"{args.out_name}_no_imageRAG.png") 84 | if not os.path.exists(cur_out_path): 85 | out_image = pipe_clean( 86 | prompt=args.prompt, 87 | negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 88 | num_inference_steps=50, 89 | generator=generator1, 90 | ).images[0] 91 | out_image.save(cur_out_path) 92 | 93 | ans = retrieval_caption_generation(args.prompt, [cur_out_path], 94 | gpt_client=client, 95 | k_captions_per_concept=1, 96 | k_concepts=1, 97 | only_rephrase=args.only_rephrase) 98 | if type(ans) != bool: 99 | if args.only_rephrase: 100 | print(f"running SDXL, rephrased prompt is: {ans}\n") 101 | cur_out_path = os.path.join(args.out_path, f"{args.out_name}_rephrased.png") 102 | out_image = pipe_clean( 103 | prompt=ans, 104 | negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 105 | num_inference_steps=50, 106 | generator=generator1, 107 | ).images[0] 108 | out_image.save(cur_out_path) 109 | exit() 110 | 111 | caption = ans 112 | caption = convert_res_to_captions(caption)[0] 113 | print(f"caption: {caption}\n") 114 | else: 115 | print(f"prompt: {args.prompt}") 116 | print("result matches prompt, not running imageRAG.") 117 | exit() 118 | else: 119 | caption = retrieval_caption_generation(args.prompt, [], 120 | gpt_client=client, 121 | k_captions_per_concept=1, 122 | decision=False) 123 | caption = convert_res_to_captions(caption)[0] 124 | f.write(f"captions: {caption}\n") 125 | 126 | paths = retrieve_img_per_caption([caption], retrieval_image_paths, embeddings_path=embeddings_path, 127 | k=1, device=device, method=args.retrieval_method) 128 | image_path = np.array(paths).flatten()[0] 129 | print("ref path:", image_path) 130 | 131 | new_prompt = f"According to this image of {caption}, generate {args.prompt}" 132 | image = Image.open(image_path) 133 | 134 | out_image = pipe_ip( 135 | prompt=new_prompt, 136 | ip_adapter_image=image, 137 | negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 138 | num_inference_steps=50, 139 | generator=generator2, 140 | ).images[0] 141 | 142 | cur_out_path = os.path.join(args.out_path, f"{args.out_name}.png") 143 | out_image.save(cur_out_path) -------------------------------------------------------------------------------- /images/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rotem-shalev/ImageRAG/16c9502a09b5a049f7c30f39b7f48998fd1e2526/images/.DS_Store -------------------------------------------------------------------------------- /images/african_grey_underwater.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rotem-shalev/ImageRAG/16c9502a09b5a049f7c30f39b7f48998fd1e2526/images/african_grey_underwater.png -------------------------------------------------------------------------------- /images/golden_cradle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rotem-shalev/ImageRAG/16c9502a09b5a049f7c30f39b7f48998fd1e2526/images/golden_cradle.png -------------------------------------------------------------------------------- /images/koshka.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rotem-shalev/ImageRAG/16c9502a09b5a049f7c30f39b7f48998fd1e2526/images/koshka.jpeg -------------------------------------------------------------------------------- /images/koshka_halloween.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rotem-shalev/ImageRAG/16c9502a09b5a049f7c30f39b7f48998fd1e2526/images/koshka_halloween.png -------------------------------------------------------------------------------- /images/origami_birds_NYC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rotem-shalev/ImageRAG/16c9502a09b5a049f7c30f39b7f48998fd1e2526/images/origami_birds_NYC.png -------------------------------------------------------------------------------- /retrieval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import clip 4 | from open_clip import create_model_from_pretrained, get_tokenizer 5 | import torch.nn.functional as F 6 | from PIL import Image 7 | import numpy as np 8 | 9 | def get_clip_similarities(prompts, image_paths, embeddings_path="", bs=1024, k=50, device='cuda:1'): 10 | model, preprocess = clip.load("ViT-B/32", device=device) 11 | text = clip.tokenize(prompts).to(device) 12 | 13 | top_text_im_paths = [] 14 | top_text_im_scores = [] 15 | top_img_embeddings = torch.empty((0, 512)) 16 | 17 | with torch.no_grad(): 18 | text_features = model.encode_text(text) 19 | normalized_text_vectors = torch.nn.functional.normalize(text_features, p=2, dim=1) 20 | 21 | if bs == len(image_paths): 22 | end = len(image_paths) 23 | else: 24 | end = len(image_paths) - bs 25 | 26 | for bi in range(0, end, bs): 27 | if os.path.exists(os.path.join(embeddings_path, f"clip_embeddings_b{bi}.pt")): 28 | normalized_ims = torch.load(os.path.join(embeddings_path, f"clip_embeddings_b{bi}.pt"), map_location=device) 29 | normalized_im_vectors = normalized_ims['normalized_clip_embeddings'] 30 | final_bi_paths = normalized_ims['paths'] 31 | 32 | else: 33 | to_remove = [] 34 | images = [] 35 | for i in range(bs): 36 | try: 37 | image = preprocess(Image.open(image_paths[bi+i])).unsqueeze(0).to(device) 38 | images.append(image) 39 | except: 40 | print(f"couldn't read {image_paths[bi+i]}") 41 | to_remove.append(image_paths[bi+i]) 42 | continue 43 | 44 | images = torch.stack(images).squeeze(1).to(device) 45 | image_features = model.encode_image(images) 46 | normalized_im_vectors = torch.nn.functional.normalize(image_features, p=2, dim=1) 47 | 48 | final_bi_paths = [path for path in image_paths[bi:bi+bs] if path not in to_remove] 49 | if embeddings_path != "": 50 | os.makedirs(embeddings_path, exist_ok=True) 51 | torch.save({"normalized_clip_embeddings": normalized_im_vectors, "paths": final_bi_paths}, 52 | os.path.join(embeddings_path, f"clip_embeddings_b{bi}.pt")) 53 | 54 | # compute cosine similarities 55 | text_similarity_matrix = torch.matmul(normalized_text_vectors, normalized_im_vectors.T) 56 | 57 | text_sim = text_similarity_matrix.cpu().numpy().squeeze() 58 | text_sim = np.concatenate([top_text_im_scores, text_sim]) 59 | cur_paths = np.concatenate([top_text_im_paths, final_bi_paths]) 60 | top_similarities = text_sim.argsort()[-k:] 61 | cur_paths = np.array(cur_paths) 62 | if cur_paths.shape[0] == 1: 63 | cur_paths = cur_paths[0] 64 | top_text_im_paths = cur_paths[top_similarities] 65 | top_text_im_scores = text_sim[top_similarities] 66 | cur_embeddings = torch.cat([top_img_embeddings, normalized_im_vectors.cpu()]) 67 | top_img_embeddings = cur_embeddings[top_similarities] 68 | 69 | return top_text_im_paths[::-1], top_text_im_scores[::-1] 70 | 71 | def rerank_BM25(candidates, retrieval_captions, k=1): 72 | from rank_bm25 import BM25Okapi 73 | from retrieval_w_gpt import get_image_captions 74 | 75 | candidates = list(set(candidates)) 76 | candidate_captions = get_image_captions(candidates) 77 | 78 | tokenized_captions = [candidate_captions[candidate].lower().split() for candidate in candidates] 79 | bm25 = BM25Okapi(tokenized_captions) 80 | tokenized_query = retrieval_captions[0].lower().split() # TODO currently only works for 1 caption 81 | scores = bm25.get_scores(tokenized_query) 82 | ranked_indices = np.argsort(-scores) 83 | 84 | return np.array(candidates)[ranked_indices[:k]].tolist(), np.array(scores)[ranked_indices[:k]].tolist() 85 | 86 | def get_moe_similarities(prompts, image_paths, embeddings_path="", bs=1024, k=1, device='cuda:2', save=False): 87 | pairs, im_emb = get_clip_similarities(prompts, image_paths, 88 | embeddings_path=embeddings_path, 89 | bs=min(2048, bs), k=3, device=device) 90 | pairs2, im_emb2 = get_siglip_similarities(prompts, image_paths, 91 | embeddings_path=embeddings_path, 92 | bs=min(64, bs), k=3, device=device, save=save) 93 | 94 | candidates = pairs[0].tolist() + pairs2[0].tolist() 95 | scores = pairs[1].tolist() + pairs2[1].tolist() 96 | bm25_best, bm25_scores = rerank_BM25(candidates, prompts, k=3) 97 | path2score = {c: 0 for c in candidates} 98 | for i in range(len(candidates)): 99 | path2score[candidates[i]] += scores[i] 100 | if candidates[i] in bm25_best: 101 | path2score[candidates[i]] += bm25_scores[bm25_best.index(candidates[i])] 102 | 103 | best_score = max(list(path2score.values())) 104 | best_path = [p for p,v in path2score.items() if v == best_score] 105 | return best_path, [best_score] 106 | 107 | def get_siglip_similarities(prompts, image_paths, embeddings_path="", bs=1024, k=50, device='cuda:2', save=False, cache_dir=None): 108 | model, preprocess = create_model_from_pretrained('hf-hub:timm/ViT-SO400M-14-SigLIP-384', cache_dir=cache_dir, device=device) 109 | tokenizer = get_tokenizer('hf-hub:timm/ViT-SO400M-14-SigLIP-384', cache_dir=cache_dir) 110 | text = tokenizer(prompts, context_length=model.context_length).to(device) 111 | 112 | with torch.no_grad(): 113 | text_features = model.encode_text(text) 114 | normalized_text_vectors = F.normalize(text_features, dim=-1) 115 | 116 | top_text_im_paths = [] 117 | top_text_im_scores = [] 118 | top_img_embeddings = torch.empty((0, 1152)) 119 | 120 | if bs == len(image_paths): 121 | end = len(image_paths) 122 | else: 123 | end = len(image_paths) - bs 124 | 125 | for bi in range(0, end, bs): 126 | if os.path.exists(os.path.join(embeddings_path, f"siglip_embeddings_b{bi}.pt")): 127 | normalized_ims = torch.load(os.path.join(embeddings_path, f"siglip_embeddings_b{bi}.pt"), map_location=device) 128 | normalized_im_vectors = normalized_ims['normalized_siglip_embeddings']#.to(device) 129 | final_bi_paths = normalized_ims['paths'] 130 | 131 | elif save: 132 | to_remove = [] 133 | images = [] 134 | for i in range(bs): 135 | try: 136 | image = preprocess(Image.open(image_paths[bi+i])).unsqueeze(0).to(device) 137 | images.append(image) 138 | except: 139 | print(f"couldn't read {image_paths[bi+i]}") 140 | to_remove.append(image_paths[bi+i]) 141 | continue 142 | 143 | if not images: 144 | continue 145 | 146 | images = torch.stack(images).squeeze(1).to(device) 147 | image_features = model.encode_image(images) 148 | normalized_im_vectors = F.normalize(image_features, dim=-1) 149 | 150 | final_bi_paths = [path for path in image_paths[bi:bi+bs] if path not in to_remove] 151 | if embeddings_path != "" and save: 152 | os.makedirs(embeddings_path, exist_ok=True) 153 | torch.save({"normalized_siglip_embeddings": normalized_im_vectors, "paths": final_bi_paths}, 154 | os.path.join(embeddings_path, f"siglip_embeddings_b{bi}.pt")) 155 | else: 156 | continue 157 | 158 | # compute cosine similarities 159 | text_similarity_matrix = torch.matmul(normalized_text_vectors, normalized_im_vectors.T) 160 | 161 | text_sim = text_similarity_matrix.cpu().numpy().squeeze() 162 | text_sim = np.concatenate([top_text_im_scores, text_sim]) 163 | cur_paths = np.concatenate([top_text_im_paths, final_bi_paths]) 164 | top_similarities = text_sim.argsort()[-k:] 165 | cur_paths = np.array(cur_paths) 166 | if cur_paths.shape[0] == 1: 167 | cur_paths = cur_paths[0] 168 | top_text_im_paths = cur_paths[top_similarities] 169 | top_text_im_scores = text_sim[top_similarities] 170 | cur_embeddings = torch.cat([top_img_embeddings, normalized_im_vectors.cpu()]) 171 | top_img_embeddings = cur_embeddings[top_similarities] 172 | 173 | return top_text_im_paths[::-1], top_text_im_scores[::-1] 174 | 175 | def gpt_rerank(caption, image_paths, embeddings_path="", bs=1024, k=1, device='cuda', save=False): 176 | pairs, im_emb = get_clip_similarities(caption, image_paths, 177 | embeddings_path=embeddings_path, 178 | bs=min(2048, bs), k=3, device=device) 179 | pairs2, im_emb2 = get_siglip_similarities(caption, image_paths, 180 | embeddings_path=embeddings_path, 181 | bs=min(64, bs), k=3, device=device, save=save) 182 | print(f"CLIP candidates: {pairs}") 183 | print(f"SigLIP candidates: {pairs2}") 184 | 185 | candidates = pairs[0].tolist() + pairs2[0].tolist() 186 | scores = pairs[1].tolist() + pairs2[1].tolist() 187 | 188 | best_paths = retrieve_from_small_set(candidates, caption, k=k) 189 | 190 | return (best_paths, [scores[candidates.index(p)] for p in best_paths]), im_emb 191 | 192 | def retrieve_from_small_set(image_paths, prompt, k=3): 193 | best = [] 194 | bs = min(6, len(image_paths)) 195 | for i in range(0, len(image_paths), bs): 196 | cur_paths = best + image_paths[i:i+bs] 197 | msg = (f'Which of these images is the most similar to the prompt {prompt}?' 198 | f'in your answer only provide the indices of the {k} most relevant images with a comma between them with no spaces, starting from index 0, e.g. answer: 0,3 if the most similar images are the ones in indices 0 and 3.' 199 | f'If you can\'t determine, return the first {k} indices, e.g. 0,1 if {k}=2.') 200 | best_ind = message_gpt(msg, cur_paths).split(",") 201 | try: 202 | best = [cur_paths[int(j.strip("'").strip('"').strip())] for j in best_ind] 203 | except: 204 | print(f"didn't get ind for i {i}") 205 | print(best_ind) 206 | continue 207 | return best 208 | 209 | def retrieve_img_per_caption(captions, image_paths, embeddings_path="", k=3, device='cuda', method='CLIP'): 210 | paths = [] 211 | for caption in captions: 212 | if method == 'CLIP': 213 | pairs = get_clip_similarities(caption, image_paths, 214 | embeddings_path=embeddings_path, 215 | bs=min(2048, len(image_paths)), k=k, device=device) 216 | elif method == 'SigLIP': 217 | pairs = get_siglip_similarities(caption, image_paths, 218 | embeddings_path=embeddings_path, 219 | bs=min(2048, len(image_paths)), k=k, device=device) 220 | elif method == 'MoE': 221 | pairs = get_moe_similarities(caption, image_paths, 222 | embeddings_path=embeddings_path, 223 | bs=min(2048, len(image_paths)), k=k, device=device) 224 | 225 | elif method == 'gpt_rerank': 226 | pairs = gpt_rerank(caption, image_paths, 227 | embeddings_path=embeddings_path, 228 | bs=min(2048, len(image_paths)), k=k, device=device) 229 | print(f"gpt rerank best path: {pairs[0]}") 230 | 231 | print("pairs:", pairs) 232 | paths.append(pairs[0]) 233 | 234 | return paths -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | 3 | def convert_res_to_captions(res): 4 | captions = [c.strip() for c in res.split("\n") if c != ""] 5 | for i in range(len(captions)): 6 | if captions[i][0].isnumeric() and captions[i][1] == ".": 7 | captions[i] = captions[i][2:] 8 | elif captions[i][0] == "-": 9 | captions[i] = captions[i][1:] 10 | elif f"{i+1}." in captions[i]: 11 | captions[i] = captions[i][captions[i].find(f"{i+1}.")+len(f"{i+1}."):] 12 | 13 | captions[i] = captions[i].strip().replace("'", "").replace('"', '') 14 | return captions 15 | 16 | def encode_image(image_path): 17 | with open(image_path, "rb") as image_file: 18 | return base64.b64encode(image_file.read()).decode('utf-8') 19 | 20 | def message_gpt(msg, client, image_paths=[], context_msgs=[], images_idx=-1, temperature=0): 21 | messages = [{"role": "user", 22 | "content": [{"type": "text", "text": msg}] 23 | }] 24 | if context_msgs: 25 | messages = context_msgs + messages 26 | 27 | if image_paths: 28 | base_64_images = [encode_image(image_path) for image_path in image_paths] 29 | for i, img in enumerate(base_64_images): 30 | messages[images_idx]["content"].append({ 31 | "type": "image_url", 32 | "image_url": {"url": f"data:image/{image_paths[i][image_paths[i].rfind('.') + 1:]};base64,{img}"}}) 33 | 34 | res = client.chat.completions.create( 35 | model="gpt-4o", 36 | messages=messages, 37 | response_format={"type": "text"}, 38 | temperature=temperature # for less randomness 39 | ) 40 | 41 | res_text = res.choices[0].message.content 42 | return res_text 43 | 44 | def message_gpt_w_error_handle(msg, client, image_paths, context_msgs, max_tries=3): 45 | unable = True 46 | temp = 0 47 | while unable and max_tries > 0: 48 | concepts = message_gpt(msg, client, image_paths, context_msgs=context_msgs, images_idx=0) 49 | print("concepts from images", concepts) 50 | 51 | if "unable" not in concepts and "can't" not in concepts: # TODO make more generic 52 | unable = False 53 | 54 | temp += 1 / max_tries 55 | max_tries -= 1 56 | 57 | if unable: 58 | print("was unable to generate concepts, using prompt as caption") 59 | return "" 60 | 61 | return concepts 62 | 63 | def retrieval_caption_generation(prompt, image_paths, gpt_client, k_captions_per_concept=1, k_concepts=-1, decision=True, only_rephrase=False): 64 | if decision: 65 | if len(image_paths) > 1: 66 | msg1 = f'Does the second image match the instruction "{prompt}" applied over the first one? consider both content and style aspects. only answer yes or no.' 67 | else: 68 | msg1 = f'Does this image match the prompt "{prompt}"? consider both content and style aspects. only answer yes or no.' 69 | 70 | ans = message_gpt(msg1, gpt_client, image_paths) 71 | if 'yes' in ans.lower(): 72 | return True 73 | 74 | context_msgs = [{"role": "user", 75 | "content": [{"type": "text", "text": msg1}] 76 | }, 77 | {"role": "assistant", 78 | "content": [{"type": "text", "text": ans}] 79 | }] 80 | 81 | print(f"Answer was {ans}. Running imageRAG") 82 | if only_rephrase: 83 | rephrased_prompt = get_rephrased_prompt(prompt, gpt_client, image_paths, context_msgs=context_msgs, images_idx=0) 84 | print("rephrased_prompt:", rephrased_prompt) 85 | return rephrased_prompt 86 | 87 | msg2 = 'What are the differences between this image and the required prompt? in your answer only provide missing concepts in terms of content and style, each in a separate line. For example, if the prompt is "An oil painting of a sheep and a car" and the image is a painting of a car but not an oil painting, the missing concepts will be:\noil painting style\na sheep' 88 | if k_concepts > 0: 89 | msg2 += f'Return up to {k_concepts} concepts.' 90 | 91 | concepts = message_gpt_w_error_handle(msg2, gpt_client, image_paths, context_msgs, max_tries=3) 92 | if concepts == "": 93 | return prompt 94 | else: # generation mode 95 | context_msgs = [] 96 | msg2 = ( 97 | f'What visual concepts does a generative model need to know to generate an image described by the prompt "{prompt}"?\n' 98 | 'The concepts should be things like objects that should appear in the image, the style of it, etc.' 99 | 'For example, if the prompt is "An elephant standing on a ball", 2 concepts would be: elephant, ball.' 100 | 'In your answer only provide the concepts, each in a separate line.') 101 | 102 | concepts = message_gpt_w_error_handle(msg2, gpt_client, image_paths, context_msgs, max_tries=3) 103 | if concepts == "": 104 | return prompt 105 | 106 | print(f'retrieved concepts: {concepts}') 107 | 108 | msg3 = (f'For each concept you suggested above, please suggest {k_captions_per_concept} image captions describing images that explain this concept only. ' 109 | f'The captions should be stand-alone description of the images, assuming no knowledge of the given images and prompt, that I can use to lookup images with automatically. ' 110 | f'In your answer only provide the image captions, each in a new line with nothing else other than the caption.') 111 | context_msgs += [{"role": "user", 112 | "content": [{"type": "text", "text": msg2}] 113 | }, 114 | {"role": "assistant", 115 | "content": [{"type": "text", "text": concepts}] 116 | }] 117 | captions = message_gpt(msg3, gpt_client, image_paths, context_msgs=context_msgs, images_idx=0) 118 | return captions 119 | 120 | def get_rephrased_prompt(prompt, gpt_client, image_paths=[], context_msgs=[], images_idx=-1): 121 | if not context_msgs: 122 | msg = f'Please rephrase the following prompt to make it clearer for a text-to-image generation model. If it\'s already clear, return it as it is. In your answer only provide the prompt and nothing else, and don\'t change the original meaning of the prompt. If it contains rare words, change the words to a description of their meaning. The prompt to be rephrased: "{prompt}"' 123 | else: 124 | msg = f'Please rephrase the following prompt to make it easier and clearer for the text-to-image generation model that generated the above image for this prompt. The goal is to generate an image that matches the given text prompt. If the prompt is already clear, return it as it is. Simplify and shorten long descriptions of known objects/entities but DO NOT change the original meaning of the text prompt. If the prompt contains rare words, change those words to a description of their meaning. In your answer only provide the prompt and nothing else. The prompt to be rephrased: "{prompt}"' 125 | 126 | ans = message_gpt(msg, gpt_client, image_paths, context_msgs=context_msgs, images_idx=images_idx) 127 | return ans.strip().replace('"', '').replace("'", '') --------------------------------------------------------------------------------