├── .gitignore ├── LICENSE ├── README.md ├── assets ├── cir_candi_1.png ├── cir_candi_2.png ├── cir_query.png ├── corpus │ ├── 000000032077.jpg │ ├── 000000050549.jpg │ ├── 000000098911.jpg │ ├── 000000156031.jpg │ ├── 000000244097.jpg │ ├── 000000272130.jpg │ ├── 000000275230.jpg │ ├── 000000311907.jpg │ ├── 000000357304.jpg │ ├── 000000478916.jpg │ └── 000000545037.jpg ├── query │ └── 000000530944.jpg ├── res-ft-mmeb.png ├── res-scaling.png ├── res-zs-cir.png └── res-zs-mmeb.png ├── eval ├── data │ ├── circo_corpus.jsonl │ ├── circo_query.jsonl │ ├── fashioniq_dress_corpus.jsonl │ ├── fashioniq_dress_query_val.jsonl │ ├── fashioniq_shirt_corpus.jsonl │ ├── fashioniq_shirt_query_val.jsonl │ ├── fashioniq_toptee_corpus.jsonl │ └── fashioniq_toptee_query_val.jsonl ├── eval_Circo.py ├── eval_fashioniq.py ├── flag_dataset.py ├── flag_mmret.py └── results │ ├── mmret_base_circo.json │ └── mmret_large_circo.json ├── modeling_MMRet_CLIP.py └── retrieval_demo.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | my_util/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 JUNJIE99 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

MegaPairs: Massive Data Synthesis For Universal Multimodal Retrieval

2 | 3 |

4 | 5 | Build 6 | 7 | 8 | Build 9 | 10 | 11 | Build 12 |

13 | 14 |

15 | 16 | 17 | Build 18 | 19 | 20 | Build 21 | 22 | 23 | Build 24 | 25 | 26 | Build 27 | 28 |

29 |

30 | 31 |

32 | 33 | 34 | Build 35 | 36 | 37 | Build 38 | 39 |

40 |

41 | 42 | 43 | ## News 44 | ```2025-5-20``` 🚀🚀 We are excited to announce the release of **BGE-VL-v1.5**! BGE-VL-v1.5 is developed based on the BGE-VL-MLLM-S1 model and further trained on additional multi-task multimodal data synthesized and collected by our team. Our [zero-shot model](https://huggingface.co/BAAI/BGE-VL-v1.5-zs) achieves state-of-the-art zero-shot performance on the [MMEB leaderboard](https://huggingface.co/spaces/TIGER-Lab/MMEB-Leaderboard). Furthermore, [the fine-tuned version](https://huggingface.co/BAAI/BGE-VL-v1.5-mmeb) achieves the best results among all methods using the same base model (Llava-1.6-7B), reaching a new high in retrieval tasks with a Recall@1 of 72.2%. 45 | 46 | ```2025-5-16``` 🎉🎉 We are pleased to share that our works, **MegaPairs** ([repo](https://github.com/VectorSpaceLab/MegaPairs), [paper](https://arxiv.org/abs/2412.14475)) and **Vis-IR** ([repo](https://github.com/VectorSpaceLab/Vis-IR), [paper](https://arxiv.org/pdf/2502.11431)), have been accepted to the **ACL 2025 Main Conference**! 47 | 48 | ```2025-4-13``` 🎉🎉 We have uploaded our MegaPairs dataset to [🤗Hugging Face](https://huggingface.co/datasets/JUNJIE99/MegaPairs), which contains over 26 million multimodal retrieval instruction-tuning triplets. To reduce upload time and enhance data accessibility, we resized all images to a resolution of 512 × 512 instead of using their original size. This adjustment has minimal impact on performance, considering that most vision-language models (e.g., CLIP) use even smaller input image sizes. [Dataset Card](https://github.com/VectorSpaceLab/MegaPairs?tab=readme-ov-file#megapairs-dataset-card) 49 | 50 | ```2025-4-2``` 🌟🌟 BGE-VL models are also available on [WiseModel](https://www.wisemodel.cn/models/JUNJIE99/BGE-VL-large). 51 | 52 | ```2025-3-6``` 📰📰 Thank you to [SyncedTech (机器之心)](https://mp.weixin.qq.com/s/iw9BmSDwv6NYtD7pkC5kxQ), [QbitAI (量子位)](https://mp.weixin.qq.com/s/r_zWAZ0ir5732OfIrEsDtg), and [AI Era (新智元)](https://mp.weixin.qq.com/s/FZwKYJnx_78YDAEreu1edg) for reporting on our work! 53 | 54 | ```2025-3-4``` 🚀🚀 We have released the BGE-VL-MLLM models on Huggingface: [BGE-VL-MLLM-S1](https://huggingface.co/BAAI/BGE-VL-MLLM-S1) and [BGE-VL-MLLM-S2](https://huggingface.co/BAAI/BGE-VL-MLLM-S2). **BGE-VL-MLLM-S1** is trained exclusively on our MegaPairs dataset, achieving outstanding performance in composed image retrieval, with an 8.1% improvement on the CIRCO benchmark (mAP@5) over the previous state-of-the-art. **BGE-VL-MLLM-S2** builds on BGE-VL-MLLM-S1 with an additional epoch of fine-tuning on the MMEB benchmark training set, delivering enhanced performance across a broader range of multimodal embedding tasks. 55 | 56 | ```2024-12-27``` 🚀🚀 BGE-VL-CLIP models are released on Huggingface: [BGE-VL-base](https://huggingface.co/BAAI/BGE-VL-base) and [BGE-VL-large](https://huggingface.co/BAAI/BGE-VL-large). 57 | 58 | ```2024-12-19``` 🎉🎉 Release our paper: [MegaPairs: Massive Data Synthesis For Universal Multimodal Retrieval](https://arxiv.org/pdf/2412.14475). 59 | 60 | ## Release Plan 61 | - [x] Paper 62 | - [x] BGE-VL-base and BGE-VL-large models 63 | - [x] BGE-VL-MLLM model 64 | - [x] MegaPairs Dataset 65 | - [x] Evaluation code examples 66 | - [ ] Fine-tuning code 67 | 68 | 69 | ## Introduction 70 | In this work, we introduce **MegaPairs**, a novel data synthesis method that leverages open-domain images to create *heterogeneous KNN triplets* for universal multimodal retrieval. Our MegaPairs dataset contains over 26 million triplets, and we have trained a series of multimodal retrieval models, **BGE-VL**, including BGE-VL-CLIP (base and large) and BGE-VL-MLLM. 71 | 72 | BGE-VL achieve state-of-the-art performance on four popular zero-shot composed image retrieval benchmarks and the massive multimodal embedding benchmark (MMEB). Extensive experiments demonstrate the ***efficiency, scalability, and generalization*** features of MegaPairs. Please refer to our [paper](https://arxiv.org/abs/2412.14475) for more details. 73 | 74 | 75 | 76 | ## Model Usage 77 | 78 | ### 1. BGE-VL-CLIP Models 79 | You can easily use BGE-VL-CLIP models based on ```transformers``` 80 | > Our code works well on transformers==4.45.2, and we recommend using this version. 81 | ```python 82 | import torch 83 | from transformers import AutoModel 84 | 85 | MODEL_NAME = "BAAI/BGE-VL-base" # or "BAAI/BGE-VL-large" 86 | 87 | model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True) # You must set trust_remote_code=True 88 | model.set_processor(MODEL_NAME) 89 | model.eval() 90 | 91 | with torch.no_grad(): 92 | query = model.encode( 93 | images = "./assets/cir_query.png", 94 | text = "Make the background dark, as if the camera has taken the photo at night" 95 | ) 96 | 97 | candidates = model.encode( 98 | images = ["./assets/cir_candi_1.png", "./assets/cir_candi_2.png"] 99 | ) 100 | 101 | scores = query @ candidates.T 102 | print(scores) 103 | ``` 104 | 105 | See the [demo](./retrieval_demo.ipynb) for a complete example of using BGE-VL for multimodel retrieval. 106 | 107 | 108 | ### 2. BGE-VL-MLLM Models 109 | 110 | > Our code works well on transformers==4.45.2, and we recommend using this version. 111 | 112 | ```python 113 | import torch 114 | from transformers import AutoModel 115 | from PIL import Image 116 | 117 | MODEL_NAME= "BAAI/BGE-VL-MLLM-S1" 118 | 119 | model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True) 120 | model.eval() 121 | model.cuda() 122 | 123 | with torch.no_grad(): 124 | model.set_processor(MODEL_NAME) 125 | 126 | query_inputs = model.data_process( 127 | text="Make the background dark, as if the camera has taken the photo at night", 128 | images="./assets/cir_query.png", 129 | q_or_c="q", 130 | task_instruction="Retrieve the target image that best meets the combined criteria by using both the provided image and the image retrieval instructions: " 131 | ) 132 | 133 | candidate_inputs = model.data_process( 134 | images=["./assets/cir_candi_1.png", "./assets/cir_candi_2.png"], 135 | q_or_c="c", 136 | ) 137 | 138 | query_embs = model(**query_inputs, output_hidden_states=True)[:, -1, :] 139 | candi_embs = model(**candidate_inputs, output_hidden_states=True)[:, -1, :] 140 | 141 | query_embs = torch.nn.functional.normalize(query_embs, dim=-1) 142 | candi_embs = torch.nn.functional.normalize(candi_embs, dim=-1) 143 | 144 | scores = torch.matmul(query_embs, candi_embs.T) 145 | print(scores) 146 | ``` 147 | 148 | ## MegaPairs Dataset Card 149 | 150 | We are excited to release the **MegaPairs** dataset on [Hugging Face](https://huggingface.co/datasets/JUNJIE99/MegaPairs), which contains over **26 million training samples** tailored for composed image retrieval and universal multimodal retrieval tasks. 151 | 152 | ### Dataset Structure 153 | 154 | Each entry in the dataset consists of the following fields: 155 | 156 | - **q_img**: `str` 157 | The file path to the query image. 158 | 159 | - **q_text**: `list` 160 | A list of textual query statements related to the query image. During training, you can randomly select one statement from this list. 161 | 162 | - **t_img**: `str` 163 | The file path to the target image, which serves as the **positive example** for the combination of `q_img` and `q_text`. 164 | 165 | - **hns**: `list` 166 | A list of file paths for **hard negative sample** images. These are challenging distractors that are visually or semantically similar to the query. It is recommended to include at least one hard negative sample during training, with **`hns[0]` (the query image itself)** being a mandatory choice. In our experiments, we used **four hard negative samples** per query. 167 | 168 | 169 | ### Usage 170 | 171 | The dataset is available for download and exploration on [Hugging Face](https://huggingface.co/datasets/JUNJIE99/MegaPairs). We encourage researchers and practitioners to leverage this dataset to advance multimodal retrieval research and systems. 172 | 173 | ## Model Performance 174 | ### Zero-Shot Composed Image Retrieval 175 | 176 | BGE-VL sets a new performance benchmark in zero-shot composed image retrieval tasks. On the CIRCO benchmark, our BGE-VL-base model, with only 149 million parameters, surpasses all previous models, including those with 50 times more parameters. Additionally, BGE-VL-MLLM achieves an 8.1% improvement over the previous state-of-the-art model. 177 | 178 | 179 | 180 | ### Zero-Shot Performance on MMEB 181 | 182 | BGE-VL-MLLM achieves state-of-the-art zero-shot performance on the Massive Multimodal Embedding Benchmark (MMEB), despite being trained only on the ImageText-to-Image paradigm. This demonstrates the excellent generalization capability of MegaPairs for multimodal embedding. 183 | 184 | 185 | 186 | ### Fine-Tuning Performance on MMEB 187 | 188 | After fine-tuning on downstream tasks, BGE-VL-MLLM maintains its leading performance. Notably, it surpasses the previous state-of-the-art by 7.1% on the MMEB out-of-distribution (OOD) set. These results demonstrate the robust generalization capability of BGE-VL-MLLM and highlight the potential of MegaPairs as foundational training data for universal multimodal embedding. 189 | 190 | 191 | 192 | ### Performance Scaling 193 | MegaPairs showcases **scalability**: BGE-VL-base improves as training data increases. It also demonstrates **efficiency**: with just 0.5M training samples, BGE-VL-base significantly outperforms MagicLens, which uses the same CLIP-base backbone and was trained on 36.7M samples. 194 | 195 | 196 | 197 | 198 | ## License 199 | The annotations for MegaPairs and the BGE-VL models are released under the [MIT License](LICENSE). The images in MegaPairs originate from the [Recap-Datacomp](https://huggingface.co/datasets/UCSC-VLAA/Recap-DataComp-1B), which is released under the CC BY 4.0 license. 200 | 201 | 202 | 203 | ## Citation 204 | If you find this repository useful, please consider giving a star ⭐ and citation 205 | 206 | ``` 207 | @article{zhou2024megapairs, 208 | title={MegaPairs: Massive Data Synthesis For Universal Multimodal Retrieval}, 209 | author={Zhou, Junjie and Liu, Zheng and Liu, Ze and Xiao, Shitao and Wang, Yueze and Zhao, Bo and Zhang, Chen Jason and Lian, Defu and Xiong, Yongping}, 210 | journal={arXiv preprint arXiv:2412.14475}, 211 | year={2024} 212 | } 213 | ``` 214 | -------------------------------------------------------------------------------- /assets/cir_candi_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/cir_candi_1.png -------------------------------------------------------------------------------- /assets/cir_candi_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/cir_candi_2.png -------------------------------------------------------------------------------- /assets/cir_query.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/cir_query.png -------------------------------------------------------------------------------- /assets/corpus/000000032077.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/corpus/000000032077.jpg -------------------------------------------------------------------------------- /assets/corpus/000000050549.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/corpus/000000050549.jpg -------------------------------------------------------------------------------- /assets/corpus/000000098911.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/corpus/000000098911.jpg -------------------------------------------------------------------------------- /assets/corpus/000000156031.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/corpus/000000156031.jpg -------------------------------------------------------------------------------- /assets/corpus/000000244097.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/corpus/000000244097.jpg -------------------------------------------------------------------------------- /assets/corpus/000000272130.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/corpus/000000272130.jpg -------------------------------------------------------------------------------- /assets/corpus/000000275230.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/corpus/000000275230.jpg -------------------------------------------------------------------------------- /assets/corpus/000000311907.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/corpus/000000311907.jpg -------------------------------------------------------------------------------- /assets/corpus/000000357304.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/corpus/000000357304.jpg -------------------------------------------------------------------------------- /assets/corpus/000000478916.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/corpus/000000478916.jpg -------------------------------------------------------------------------------- /assets/corpus/000000545037.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/corpus/000000545037.jpg -------------------------------------------------------------------------------- /assets/query/000000530944.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/query/000000530944.jpg -------------------------------------------------------------------------------- /assets/res-ft-mmeb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/res-ft-mmeb.png -------------------------------------------------------------------------------- /assets/res-scaling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/res-scaling.png -------------------------------------------------------------------------------- /assets/res-zs-cir.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/res-zs-cir.png -------------------------------------------------------------------------------- /assets/res-zs-mmeb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/res-zs-mmeb.png -------------------------------------------------------------------------------- /eval/eval_Circo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import faiss 3 | import torch 4 | import logging 5 | import datasets 6 | import numpy as np 7 | from tqdm import tqdm 8 | from typing import Optional 9 | from dataclasses import dataclass, field 10 | from transformers import HfArgumentParser 11 | from flag_mmret import Flag_mmret 12 | import json 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | @dataclass 18 | class Args: 19 | model_name: str = field( 20 | default="BAAI/BGE-VL-large", 21 | metadata={'help': 'Model Name'} 22 | ) 23 | result_save_path: str = field( 24 | default="./eval/mmret_large_circo.json", 25 | metadata={'help': 'Where to save the results.'} 26 | ) 27 | image_dir: str = field( 28 | default="YOUR_COCO_IMAGE_DIRECTORY", 29 | metadata={'help': 'Where the images located on.'} 30 | ) 31 | fp16: bool = field( 32 | default=False, 33 | metadata={'help': 'Use fp16 in inference?'} 34 | ) 35 | max_query_length: int = field( 36 | default=64, 37 | metadata={'help': 'Max query length.'} 38 | ) 39 | max_passage_length: int = field( 40 | default=77, 41 | metadata={'help': 'Max passage length.'} 42 | ) 43 | batch_size: int = field( 44 | default=256, 45 | metadata={'help': 'Inference batch size.'} 46 | ) 47 | index_factory: str = field( 48 | default="Flat", 49 | metadata={'help': 'Faiss index factory.'} 50 | ) 51 | k: int = field( 52 | default=100, 53 | metadata={'help': 'How many neighbors to retrieve?'} 54 | ) 55 | save_embedding: bool = field( 56 | default=False, 57 | metadata={'help': 'Save embeddings in memmap at save_dir?'} 58 | ) 59 | load_embedding: bool = field( 60 | default=False, 61 | metadata={'help': 'Load embeddings from save_dir?'} 62 | ) 63 | save_path: str = field( 64 | default="embeddings.memmap", 65 | metadata={'help': 'Path to save embeddings.'} 66 | ) 67 | 68 | 69 | 70 | def index(model: Flag_mmret, corpus: datasets.Dataset, batch_size: int = 256, max_length: int=512, index_factory: str = "Flat", save_path: str = None, save_embedding: bool = False, load_embedding: bool = False): 71 | """ 72 | 1. Encode the entire corpus into dense embeddings; 73 | 2. Create faiss index; 74 | 3. Optionally save embeddings. 75 | """ 76 | if load_embedding: 77 | test = model.encode("test") 78 | dtype = test.dtype 79 | dim = len(test) 80 | 81 | corpus_embeddings = np.memmap( 82 | save_path, 83 | mode="r", 84 | dtype=dtype 85 | ).reshape(-1, dim) 86 | 87 | else: 88 | 89 | corpus_embeddings = model.encode_corpus(corpus, batch_size=batch_size, max_length=max_length, corpus_type='image') 90 | 91 | dim = corpus_embeddings.shape[-1] 92 | 93 | if save_embedding: 94 | logger.info(f"saving embeddings at {save_path}...") 95 | memmap = np.memmap( 96 | save_path, 97 | shape=corpus_embeddings.shape, 98 | mode="w+", 99 | dtype=corpus_embeddings.dtype 100 | ) 101 | 102 | length = corpus_embeddings.shape[0] 103 | # add in batch 104 | save_batch_size = 10000 105 | if length > save_batch_size: 106 | for i in tqdm(range(0, length, save_batch_size), leave=False, desc="Saving Embeddings"): 107 | j = min(i + save_batch_size, length) 108 | memmap[i: j] = corpus_embeddings[i: j] 109 | else: 110 | memmap[:] = corpus_embeddings 111 | 112 | # create faiss index 113 | faiss_index = faiss.index_factory(dim, index_factory, faiss.METRIC_INNER_PRODUCT) 114 | 115 | 116 | if model.device == torch.device("cuda"): 117 | # co = faiss.GpuClonerOptions() 118 | co = faiss.GpuMultipleClonerOptions() 119 | co.useFloat16 = True 120 | # faiss_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, faiss_index, co) 121 | faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co) 122 | 123 | # NOTE: faiss only accepts float32 124 | logger.info("Adding embeddings...") 125 | corpus_embeddings = corpus_embeddings.astype(np.float32) 126 | faiss_index.train(corpus_embeddings) 127 | faiss_index.add(corpus_embeddings) 128 | 129 | 130 | 131 | return faiss_index 132 | 133 | 134 | def search(model: Flag_mmret, queries: datasets, faiss_index: faiss.Index, k:int = 100, batch_size: int = 256, max_length: int=512): 135 | """ 136 | 1. Encode queries into dense embeddings; 137 | 2. Search through faiss index 138 | """ 139 | query_embeddings = model.encode_queries([queries["q_text"], queries["q_img"]], 140 | batch_size=batch_size, 141 | max_length=max_length, 142 | query_type='mm_it') 143 | 144 | 145 | query_size = len(query_embeddings) 146 | 147 | all_scores = [] 148 | all_indices = [] 149 | 150 | for i in tqdm(range(0, query_size, batch_size), desc="Searching"): 151 | j = min(i + batch_size, query_size) 152 | query_embedding = query_embeddings[i: j] 153 | score, indice = faiss_index.search(query_embedding.astype(np.float32), k=k) 154 | all_scores.append(score) 155 | all_indices.append(indice) 156 | 157 | all_scores = np.concatenate(all_scores, axis=0) 158 | all_indices = np.concatenate(all_indices, axis=0) 159 | return all_scores, all_indices 160 | 161 | 162 | def main(): 163 | parser = HfArgumentParser([Args]) 164 | args: Args = parser.parse_args_into_dataclasses()[0] 165 | 166 | print(f"Results will be saved in {args.result_save_path}") 167 | eval_data = datasets.load_dataset('json', data_files="./eval/data/circo_query.jsonl", split='train') 168 | image_corpus_test = datasets.load_dataset('json', data_files="./eval/data/circo_corpus.jsonl", split='train') 169 | 170 | model = Flag_mmret(model_name=args.model_name, 171 | normlized = True, 172 | image_dir=args.image_dir, 173 | use_fp16=False, 174 | ) 175 | 176 | 177 | faiss_index = index( 178 | model=model, 179 | corpus=image_corpus_test, 180 | batch_size=args.batch_size, 181 | max_length=args.max_passage_length, 182 | index_factory=args.index_factory, 183 | save_path=args.save_path, 184 | save_embedding=args.save_embedding, 185 | load_embedding=args.load_embedding 186 | ) 187 | 188 | scores, indices = search( 189 | model=model, 190 | queries=eval_data, 191 | faiss_index=faiss_index, 192 | k=args.k, 193 | batch_size=args.batch_size, 194 | max_length=args.max_query_length 195 | ) 196 | 197 | 198 | retrieval_results = [] 199 | for indice in indices: 200 | # filter invalid indices 201 | indice = indice[indice != -1].tolist() 202 | retrieval_results.append(image_corpus_test[indice]["content"]) 203 | 204 | ########## results in test corpus ######### 205 | q_images = eval_data["q_img"] 206 | 207 | q_ids = [] 208 | for _img in q_images: 209 | _id = os.path.basename(_img) 210 | _id = os.path.splitext(_id)[0] 211 | q_ids.append(_id) 212 | 213 | pairids = eval_data["id"] 214 | results = {} 215 | for pairid, re_results, q_img in zip(pairids, retrieval_results, q_images): 216 | id = str(pairid) 217 | top_50_results = re_results[0:50] 218 | results[id] = top_50_results 219 | 220 | with open(args.result_save_path, "w") as f: 221 | json.dump(results, f) 222 | 223 | 224 | if __name__ == "__main__": 225 | main() -------------------------------------------------------------------------------- /eval/eval_fashioniq.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['CUDA_VISIBLE_DEVICES'] = "0" 3 | 4 | import sys 5 | print(os.getcwd()) 6 | 7 | import faiss 8 | import torch 9 | import logging 10 | import datasets 11 | import numpy as np 12 | from tqdm import tqdm 13 | from typing import Optional 14 | from dataclasses import dataclass, field 15 | from transformers import HfArgumentParser 16 | from flag_mmret import Flag_mmret 17 | import json 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | @dataclass 23 | class Args: 24 | model_name: str = field( 25 | default="BAAI/BGE-VL-large", 26 | metadata={'help': 'Model Name'} 27 | ) 28 | image_dir: str = field( 29 | default="YOUR_FASHIONIQ_IMAGE_DIRECTORY", 30 | metadata={'help': 'Where are the images located on.'} 31 | ) 32 | fp16: bool = field( 33 | default=False, 34 | metadata={'help': 'Use fp16 in inference?'} 35 | ) 36 | max_query_length: int = field( 37 | default=64, 38 | metadata={'help': 'Max query length.'} 39 | ) 40 | max_passage_length: int = field( 41 | default=77, 42 | metadata={'help': 'Max passage length.'} 43 | ) 44 | batch_size: int = field( 45 | default=256, 46 | metadata={'help': 'Inference batch size.'} 47 | ) 48 | index_factory: str = field( 49 | default="Flat", 50 | metadata={'help': 'Faiss index factory.'} 51 | ) 52 | k: int = field( 53 | default=100, 54 | metadata={'help': 'How many neighbors to retrieve?'} 55 | ) 56 | save_embedding: bool = field( 57 | default=False, 58 | metadata={'help': 'Save embeddings in memmap at save_dir?'} 59 | ) 60 | load_embedding: bool = field( 61 | default=False, 62 | metadata={'help': 'Load embeddings from save_dir?'} 63 | ) 64 | save_path: str = field( 65 | default="embeddings.memmap", 66 | metadata={'help': 'Path to save embeddings.'} 67 | ) 68 | 69 | 70 | 71 | def index(model: Flag_mmret, corpus: datasets.Dataset, batch_size: int = 256, max_length: int=512, index_factory: str = "Flat", save_path: str = None, save_embedding: bool = False, load_embedding: bool = False): 72 | """ 73 | 1. Encode the entire corpus into dense embeddings; 74 | 2. Create faiss index; 75 | 3. Optionally save embeddings. 76 | """ 77 | if load_embedding: 78 | test = model.encode("test") 79 | dtype = test.dtype 80 | dim = len(test) 81 | 82 | corpus_embeddings = np.memmap( 83 | save_path, 84 | mode="r", 85 | dtype=dtype 86 | ).reshape(-1, dim) 87 | 88 | else: 89 | 90 | corpus_embeddings = model.encode_corpus(corpus, batch_size=batch_size, max_length=max_length, corpus_type='image') 91 | 92 | dim = corpus_embeddings.shape[-1] 93 | 94 | if save_embedding: 95 | logger.info(f"saving embeddings at {save_path}...") 96 | memmap = np.memmap( 97 | save_path, 98 | shape=corpus_embeddings.shape, 99 | mode="w+", 100 | dtype=corpus_embeddings.dtype 101 | ) 102 | 103 | length = corpus_embeddings.shape[0] 104 | # add in batch 105 | save_batch_size = 10000 106 | if length > save_batch_size: 107 | for i in tqdm(range(0, length, save_batch_size), leave=False, desc="Saving Embeddings"): 108 | j = min(i + save_batch_size, length) 109 | memmap[i: j] = corpus_embeddings[i: j] 110 | else: 111 | memmap[:] = corpus_embeddings 112 | 113 | # create faiss index 114 | faiss_index = faiss.index_factory(dim, index_factory, faiss.METRIC_INNER_PRODUCT) 115 | 116 | 117 | if model.device == torch.device("cuda"): 118 | # co = faiss.GpuClonerOptions() 119 | co = faiss.GpuMultipleClonerOptions() 120 | co.useFloat16 = True 121 | # faiss_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, faiss_index, co) 122 | faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co) 123 | 124 | # NOTE: faiss only accepts float32 125 | logger.info("Adding embeddings...") 126 | corpus_embeddings = corpus_embeddings.astype(np.float32) 127 | faiss_index.train(corpus_embeddings) 128 | faiss_index.add(corpus_embeddings) 129 | 130 | 131 | 132 | return faiss_index 133 | 134 | 135 | def search(model: Flag_mmret, queries: datasets, faiss_index: faiss.Index, k:int = 100, batch_size: int = 256, max_length: int=512): 136 | """ 137 | 1. Encode queries into dense embeddings; 138 | 2. Search through faiss index 139 | """ 140 | query_embeddings = model.encode_queries([queries["q_text"], queries["q_img"]], 141 | batch_size=batch_size, 142 | max_length=max_length, 143 | query_type='mm_it') 144 | 145 | query_size = len(query_embeddings) 146 | 147 | all_scores = [] 148 | all_indices = [] 149 | 150 | for i in tqdm(range(0, query_size, batch_size), desc="Searching"): 151 | j = min(i + batch_size, query_size) 152 | query_embedding = query_embeddings[i: j] 153 | score, indice = faiss_index.search(query_embedding.astype(np.float32), k=k) 154 | all_scores.append(score) 155 | all_indices.append(indice) 156 | 157 | all_scores = np.concatenate(all_scores, axis=0) 158 | all_indices = np.concatenate(all_indices, axis=0) 159 | return all_scores, all_indices 160 | 161 | 162 | def evaluate(preds, labels, cutoffs=[1,5,10,20,50,100]): 163 | """ 164 | Evaluate MRR and Recall at cutoffs. 165 | """ 166 | metrics = {} 167 | 168 | # MRR 169 | mrrs = np.zeros(len(cutoffs)) 170 | for pred, label in zip(preds, labels): 171 | jump = False 172 | for i, x in enumerate(pred, 1): 173 | if x in label: 174 | for k, cutoff in enumerate(cutoffs): 175 | if i <= cutoff: 176 | mrrs[k] += 1 / i 177 | jump = True 178 | if jump: 179 | break 180 | mrrs /= len(preds) 181 | for i, cutoff in enumerate(cutoffs): 182 | mrr = mrrs[i] 183 | metrics[f"MRR@{cutoff}"] = mrr 184 | 185 | # Recall 186 | recalls = np.zeros(len(cutoffs)) 187 | for pred, label in zip(preds, labels): 188 | if not isinstance(label, list): 189 | label = [label] 190 | for k, cutoff in enumerate(cutoffs): 191 | recall = np.intersect1d(label, pred[:cutoff]) 192 | recalls[k] += len(recall) / len(label) 193 | recalls /= len(preds) 194 | for i, cutoff in enumerate(cutoffs): 195 | recall = recalls[i] 196 | metrics[f"Recall@{cutoff}"] = recall 197 | 198 | return metrics 199 | 200 | def main(): 201 | parser = HfArgumentParser([Args]) 202 | args: Args = parser.parse_args_into_dataclasses()[0] 203 | 204 | model = Flag_mmret(model_name=args.model_name, 205 | normlized = True, 206 | image_dir=args.image_dir, 207 | use_fp16=False, 208 | ) 209 | 210 | eval_data = datasets.load_dataset('json', data_files="./eval/data/fashioniq_shirt_query_val.jsonl", split='train') 211 | image_corpus = datasets.load_dataset('json', data_files="./eval/data/fashioniq_shirt_corpus.jsonl", split='train') 212 | 213 | faiss_index = index( 214 | model=model, 215 | corpus=image_corpus, 216 | batch_size=args.batch_size, 217 | max_length=args.max_passage_length, 218 | index_factory=args.index_factory, 219 | save_path=args.save_path, 220 | save_embedding=args.save_embedding, 221 | load_embedding=args.load_embedding 222 | ) 223 | 224 | scores, indices = search( 225 | model=model, 226 | queries=eval_data, 227 | faiss_index=faiss_index, 228 | k=args.k, 229 | batch_size=args.batch_size, 230 | max_length=args.max_query_length 231 | ) 232 | 233 | 234 | 235 | retrieval_results = [] 236 | for indice in indices: 237 | # filter invalid indices 238 | indice = indice[indice != -1].tolist() 239 | retrieval_results.append(image_corpus[indice]["content"]) 240 | 241 | ground_truths = [] 242 | for sample in eval_data: 243 | ground_truths.append(sample["positive_key"]) 244 | 245 | metrics_shirt = evaluate(retrieval_results, ground_truths) 246 | print("FashionIQ tasks (shirt):") 247 | print(metrics_shirt) 248 | 249 | 250 | 251 | 252 | 253 | eval_data = datasets.load_dataset('json', data_files="./eval/data/fashioniq_dress_query_val.jsonl", split='train') 254 | image_corpus = datasets.load_dataset('json', data_files="./eval/data/fashioniq_dress_corpus.jsonl", split='train') 255 | 256 | faiss_index = index( 257 | model=model, 258 | corpus=image_corpus, 259 | batch_size=args.batch_size, 260 | max_length=args.max_passage_length, 261 | index_factory=args.index_factory, 262 | save_path=args.save_path, 263 | save_embedding=args.save_embedding, 264 | load_embedding=args.load_embedding 265 | ) 266 | 267 | scores, indices = search( 268 | model=model, 269 | queries=eval_data, 270 | faiss_index=faiss_index, 271 | k=args.k, 272 | batch_size=args.batch_size, 273 | max_length=args.max_query_length 274 | ) 275 | 276 | 277 | 278 | retrieval_results = [] 279 | for indice in indices: 280 | # filter invalid indices 281 | indice = indice[indice != -1].tolist() 282 | retrieval_results.append(image_corpus[indice]["content"]) 283 | 284 | ground_truths = [] 285 | for sample in eval_data: 286 | ground_truths.append(sample["positive_key"]) 287 | 288 | metrics_dress = evaluate(retrieval_results, ground_truths) 289 | print("FashionIQ tasks (dress):") 290 | print(metrics_dress) 291 | 292 | 293 | eval_data = datasets.load_dataset('json', data_files="./eval/data/fashioniq_toptee_query_val.jsonl", split='train') 294 | image_corpus = datasets.load_dataset('json', data_files="./eval/data/fashioniq_toptee_corpus.jsonl", split='train') 295 | 296 | 297 | faiss_index = index( 298 | model=model, 299 | corpus=image_corpus, 300 | batch_size=args.batch_size, 301 | max_length=args.max_passage_length, 302 | index_factory=args.index_factory, 303 | save_path=args.save_path, 304 | save_embedding=args.save_embedding, 305 | load_embedding=args.load_embedding 306 | ) 307 | 308 | scores, indices = search( 309 | model=model, 310 | queries=eval_data, 311 | faiss_index=faiss_index, 312 | k=args.k, 313 | batch_size=args.batch_size, 314 | max_length=args.max_query_length 315 | ) 316 | 317 | 318 | 319 | retrieval_results = [] 320 | for indice in indices: 321 | # filter invalid indices 322 | indice = indice[indice != -1].tolist() 323 | retrieval_results.append(image_corpus[indice]["content"]) 324 | 325 | ground_truths = [] 326 | for sample in eval_data: 327 | ground_truths.append(sample["positive_key"]) 328 | 329 | metrics_toptee = evaluate(retrieval_results, ground_truths) 330 | print("FashionIQ tasks (toptee):") 331 | print(metrics_toptee) 332 | 333 | 334 | 335 | print(f"shirt: {metrics_shirt['Recall@10'] * 100:.2f} / {metrics_shirt['Recall@50'] * 100:.2f}") 336 | print(f"dress: {metrics_dress['Recall@10'] * 100:.2f} / {metrics_dress['Recall@50'] * 100:.2f}") 337 | print(f"toptee: {metrics_toptee['Recall@10'] * 100:.2f} / {metrics_toptee['Recall@50'] * 100:.2f}") 338 | print(f"overall: {(metrics_shirt['Recall@10'] + metrics_dress['Recall@10'] + metrics_toptee['Recall@10']) * 100 / 3:.2f} / {(metrics_shirt['Recall@50'] + metrics_dress['Recall@50'] + metrics_toptee['Recall@50']) * 100 / 3:.2f}") 339 | 340 | 341 | if __name__ == "__main__": 342 | main() -------------------------------------------------------------------------------- /eval/flag_dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os.path 3 | import random 4 | from dataclasses import dataclass 5 | from typing import Iterator 6 | 7 | import datasets 8 | from torch.utils.data import Dataset, IterableDataset 9 | from transformers import DataCollatorWithPadding 10 | from transformers import PreTrainedTokenizer, BatchEncoding 11 | from transformers import CLIPImageProcessor 12 | 13 | 14 | from PIL import Image 15 | import json 16 | import torch 17 | import torch.distributed 18 | 19 | from io import BytesIO 20 | import warnings 21 | 22 | class MMIT_Dataset(Dataset): 23 | def __init__(self, captions, image_ids, image_dir, image_processor) -> None: 24 | img_id_example = image_ids[0] 25 | img_id_example = str(img_id_example) 26 | if img_id_example[-4:] in [".jpg", ".png", "JPEG"]: 27 | self.image_path =[os.path.join(image_dir, str(id)) for id in image_ids] 28 | else: 29 | warnings.warn("Not found file extention in image_ids, will forcefully add '.jpg'.", UserWarning) 30 | self.image_path =[os.path.join(image_dir, str(id) + '.jpg') for id in image_ids] 31 | self.captions = captions 32 | self.image_processor = image_processor 33 | 34 | def __getitem__(self, item): 35 | pil_data = Image.open(self.image_path[item]) 36 | pil_data = pil_data.convert('RGB') 37 | image = self.image_processor(pil_data) 38 | 39 | 40 | 41 | 42 | caption = self.captions[item] 43 | 44 | return caption, image 45 | 46 | def __len__(self): 47 | return len(self.image_path) 48 | 49 | 50 | class MMIT_Collator: 51 | def __init__(self, tokenizer, caption_max_len): 52 | self.tokenizer = tokenizer 53 | self.caption_max_len = caption_max_len 54 | 55 | 56 | 57 | def __call__(self, features): 58 | caption = [f[0] for f in features] 59 | images = [f[1] for f in features] 60 | 61 | c_collated = self.tokenizer( 62 | caption, 63 | truncation=True, 64 | padding = True, 65 | max_length=self.caption_max_len, 66 | return_tensors="pt", 67 | ) 68 | 69 | # i_collated = torch.stack(images) 70 | 71 | # for clip model 72 | images = [f["pixel_values"][0] for f in images] 73 | images = [torch.tensor(arr) for arr in images] 74 | i_collated = torch.stack(images) 75 | ##clip_end 76 | 77 | return c_collated, i_collated 78 | 79 | class Image_Dataset(Dataset): 80 | def __init__(self, image_ids, image_dir, image_processor) -> None: 81 | 82 | self.image_path =[os.path.join(image_dir, str(id)) for id in image_ids] 83 | self.image_processor = image_processor 84 | 85 | def __getitem__(self, item): 86 | pil_data = Image.open(self.image_path[item]) 87 | image = self.image_processor(pil_data) 88 | 89 | return image 90 | 91 | def __len__(self): 92 | return len(self.image_path) 93 | 94 | class Image_Collator: 95 | def __init__(self, tokenizer, caption_max_len): 96 | self.tokenizer = tokenizer 97 | self.caption_max_len = caption_max_len 98 | 99 | 100 | def __call__(self, features): 101 | # images = features 102 | # i_collated = torch.stack(images) 103 | 104 | # for clip model 105 | images = [f["pixel_values"][0] for f in features] 106 | images = [torch.tensor(arr) for arr in images] 107 | i_collated = torch.stack(images) 108 | ## clip-end 109 | return i_collated -------------------------------------------------------------------------------- /eval/flag_mmret.py: -------------------------------------------------------------------------------- 1 | from typing import cast, List, Union, Tuple 2 | import numpy as np 3 | import torch 4 | from tqdm import tqdm 5 | from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification, CLIPModel, CLIPImageProcessor, CLIPTokenizer 6 | import os 7 | from PIL import Image 8 | from torch.utils.data import DataLoader 9 | from torch import nn 10 | from flag_dataset import MMIT_Dataset, MMIT_Collator, Image_Dataset, Image_Collator 11 | class Flag_mmret(nn.Module): 12 | def __init__( 13 | self, 14 | model_name: str = None, 15 | normlized: bool = True, 16 | pooling_method: str = 'cls', 17 | use_fp16: bool=True, 18 | image_dir: str = None, 19 | ) -> None: 20 | super().__init__() 21 | 22 | self.model = AutoModel.from_pretrained(model_name) 23 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 24 | self.image_processor = CLIPImageProcessor.from_pretrained(model_name) 25 | 26 | 27 | self.normalize_embeddings = normlized 28 | self.pooling_method = pooling_method 29 | 30 | self.image_dir = image_dir 31 | 32 | if use_fp16: 33 | self.use_fp16 = True 34 | self.model.half() 35 | else: 36 | self.use_fp16 = False 37 | 38 | self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 39 | self.model = self.model.to(self.device) 40 | 41 | self.num_gpus = torch.cuda.device_count() 42 | if self.num_gpus > 1: 43 | print(f"----------using {self.num_gpus}*GPUs----------") 44 | self.model = torch.nn.DataParallel(self.model) 45 | 46 | 47 | def encode_queries(self, queries: Union[List[str], str], 48 | batch_size: int=256, 49 | max_length: int=77, 50 | query_type: str = None, 51 | ) -> np.ndarray: 52 | 53 | 54 | if query_type == 'text': 55 | input_texts = queries 56 | 57 | return self.encode_text(input_texts, batch_size=batch_size, max_length=max_length) 58 | elif query_type == 'mm_it': 59 | q_text, q_img = queries 60 | 61 | input_texts = q_text 62 | 63 | 64 | return self.encode_mm_it(input_texts, q_img, batch_size=batch_size) 65 | elif query_type == 'image': 66 | q_img = queries 67 | return self.encode_image(q_img, batch_size=batch_size) 68 | else: 69 | raise NotImplementedError 70 | 71 | 72 | def encode_corpus(self, 73 | corpus: dict, 74 | batch_size: int=256, 75 | max_length: int=77, 76 | corpus_type: str = None, 77 | ) -> np.ndarray: 78 | if corpus_type == 'text': 79 | return self.encode_text(corpus["text"], batch_size=batch_size, max_length=max_length) 80 | elif corpus_type == 'mm_it': 81 | return self.encode_mm_it(corpus["text"], corpus["image"], batch_size=batch_size, max_length=max_length) 82 | elif corpus_type == 'image': 83 | return self.encode_image(corpus["image"], batch_size=batch_size, max_length=max_length) 84 | else: 85 | raise RuntimeError(f"You must choose a corpus type from: [mm_it, text, image]") 86 | 87 | 88 | 89 | @torch.no_grad() 90 | def encode_text(self, sentences: Union[List[str], str], batch_size: int=256, max_length: int=77) -> np.ndarray: 91 | if self.num_gpus > 0: 92 | batch_size = batch_size * self.num_gpus 93 | self.model.eval() 94 | 95 | input_was_string = False 96 | if isinstance(sentences, str): 97 | sentences = [sentences] 98 | input_was_string = True 99 | 100 | all_embeddings = [] 101 | for start_index in tqdm(range(0, len(sentences), batch_size), desc="Inference Embeddings", disable=len(sentences)<256): 102 | sentences_batch = sentences[start_index:start_index + batch_size] 103 | inputs = self.tokenizer( 104 | sentences_batch, 105 | padding=True, 106 | truncation=True, 107 | return_tensors='pt', 108 | max_length=max_length, 109 | ).to(self.device) 110 | 111 | embeddings = self.model.get_text_features(**inputs) 112 | embeddings = torch.nn.functional.normalize(embeddings, dim=-1) 113 | embeddings = cast(torch.Tensor, embeddings) 114 | all_embeddings.append(embeddings.cpu().numpy()) 115 | 116 | all_embeddings = np.concatenate(all_embeddings, axis=0) 117 | if input_was_string: 118 | return all_embeddings[0] 119 | return all_embeddings 120 | 121 | 122 | @torch.no_grad() 123 | def encode_mm_it(self, captions: Union[List[str], str], image_ids: Union[List[str], str], batch_size: int=256, max_length: int=77) -> np.ndarray: 124 | if self.num_gpus > 0: 125 | batch_size = batch_size * self.num_gpus 126 | self.model.eval() 127 | 128 | input_was_string = False 129 | if isinstance(captions, str): 130 | captions = [captions] 131 | image_ids = [image_ids] 132 | input_was_string = True 133 | 134 | all_embeddings = [] 135 | mm_it_dataset = MMIT_Dataset(captions=captions, 136 | image_ids=image_ids, 137 | image_dir=self.image_dir, 138 | image_processor=self.image_processor 139 | ) 140 | mm_it_collator = MMIT_Collator(self.tokenizer, caption_max_len=75) 141 | 142 | mm_it_dataloader = DataLoader(dataset=mm_it_dataset, 143 | collate_fn=mm_it_collator, 144 | num_workers=8, 145 | batch_size=batch_size, 146 | shuffle=False, 147 | drop_last=False,) 148 | 149 | for data in tqdm(mm_it_dataloader, desc="Inference Embeddings", disable=len(captions)<256): 150 | captions_inputs = data[0].to(self.device) 151 | 152 | images = data[1].to(self.device) 153 | if self.use_fp16 and images.dtype != torch.float16: 154 | images = images.half() 155 | 156 | text_embeddings = self.model.get_text_features(**captions_inputs) 157 | image_embeddings = self.model.get_image_features(images) 158 | 159 | embeddings = text_embeddings + image_embeddings 160 | 161 | embeddings = torch.nn.functional.normalize(embeddings, dim=-1) 162 | 163 | embeddings = cast(torch.Tensor, embeddings) 164 | all_embeddings.append(embeddings.cpu().numpy()) 165 | 166 | all_embeddings = np.concatenate(all_embeddings, axis=0) 167 | if input_was_string: 168 | return all_embeddings[0] 169 | return all_embeddings 170 | 171 | @torch.no_grad() 172 | def encode_image(self, image_ids: Union[List[str], str], batch_size: int=256, max_length: int=77) -> np.ndarray: 173 | if self.num_gpus > 0: 174 | batch_size = batch_size * self.num_gpus 175 | self.model.eval() 176 | 177 | all_embeddings = [] 178 | image_dataset = Image_Dataset(image_ids=image_ids, 179 | image_dir=self.image_dir, 180 | image_processor=self.image_processor 181 | ) 182 | image_collator = Image_Collator(self.tokenizer, caption_max_len=312) 183 | 184 | image_dataloader = DataLoader(dataset=image_dataset, 185 | collate_fn=image_collator, 186 | num_workers=8, 187 | batch_size=batch_size, 188 | shuffle=False, 189 | drop_last=False,) 190 | 191 | for data in tqdm(image_dataloader, desc="Inference Image Embeddings"): 192 | 193 | images = data.to(self.device) 194 | if self.use_fp16 and images.dtype != torch.float16: 195 | images = images.half() 196 | 197 | 198 | 199 | embeddings = self.model.get_image_features(images) 200 | embeddings = torch.nn.functional.normalize(embeddings, dim=-1) 201 | embeddings = cast(torch.Tensor, embeddings) 202 | all_embeddings.append(embeddings.cpu().numpy()) 203 | 204 | all_embeddings = np.concatenate(all_embeddings, axis=0) 205 | 206 | return all_embeddings -------------------------------------------------------------------------------- /modeling_MMRet_CLIP.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch CLIP model.""" 16 | 17 | from dataclasses import dataclass 18 | from typing import Any, Optional, Tuple, Union 19 | 20 | import torch 21 | import torch.utils.checkpoint 22 | from torch import nn 23 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 24 | from PIL import Image 25 | from transformers.activations import ACT2FN 26 | from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask 27 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput 28 | from transformers.modeling_utils import PreTrainedModel 29 | from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2 30 | from transformers.utils import ( 31 | ModelOutput, 32 | add_code_sample_docstrings, 33 | add_start_docstrings, 34 | add_start_docstrings_to_model_forward, 35 | is_flash_attn_2_available, 36 | is_flash_attn_greater_or_equal_2_10, 37 | logging, 38 | replace_return_docstrings, 39 | ) 40 | from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig 41 | from transformers import CLIPProcessor 42 | 43 | if is_flash_attn_2_available(): 44 | from transformers.modeling_flash_attention_utils import _flash_attention_forward 45 | 46 | 47 | logger = logging.get_logger(__name__) 48 | 49 | # General docstring 50 | _CONFIG_FOR_DOC = "MMRet_CLIP" 51 | 52 | # Image classification docstring 53 | _IMAGE_CLASS_CHECKPOINT = "JUNJIE99/MMRet-base" 54 | _IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_0" 55 | 56 | 57 | # contrastive loss function, adapted from 58 | # https://sachinruk.github.io/blog/2021-03-07-clip.html 59 | def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: 60 | return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) 61 | 62 | 63 | def clip_loss(similarity: torch.Tensor) -> torch.Tensor: 64 | caption_loss = contrastive_loss(similarity) 65 | image_loss = contrastive_loss(similarity.t()) 66 | return (caption_loss + image_loss) / 2.0 67 | 68 | 69 | def _get_vector_norm(tensor: torch.Tensor) -> torch.Tensor: 70 | """ 71 | This method is equivalent to tensor.norm(p=2, dim=-1, keepdim=True) and used to make 72 | model `executorch` exportable. See issue https://github.com/pytorch/executorch/issues/3566 73 | """ 74 | square_tensor = torch.pow(tensor, 2) 75 | sum_tensor = torch.sum(square_tensor, dim=-1, keepdim=True) 76 | normed_tensor = torch.pow(sum_tensor, 0.5) 77 | return normed_tensor 78 | 79 | 80 | @dataclass 81 | class CLIPVisionModelOutput(ModelOutput): 82 | """ 83 | Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. 84 | 85 | Args: 86 | image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): 87 | The image embeddings obtained by applying the projection layer to the pooler_output. 88 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): 89 | Sequence of hidden-states at the output of the last layer of the model. 90 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 91 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + 92 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. 93 | 94 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. 95 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 96 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 97 | sequence_length)`. 98 | 99 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 100 | heads. 101 | """ 102 | 103 | image_embeds: Optional[torch.FloatTensor] = None 104 | last_hidden_state: torch.FloatTensor = None 105 | hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None 106 | attentions: Optional[Tuple[torch.FloatTensor, ...]] = None 107 | 108 | 109 | @dataclass 110 | class CLIPTextModelOutput(ModelOutput): 111 | """ 112 | Base class for text model's outputs that also contains a pooling of the last hidden states. 113 | 114 | Args: 115 | text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): 116 | The text embeddings obtained by applying the projection layer to the pooler_output. 117 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): 118 | Sequence of hidden-states at the output of the last layer of the model. 119 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 120 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + 121 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. 122 | 123 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. 124 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 125 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 126 | sequence_length)`. 127 | 128 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 129 | heads. 130 | """ 131 | 132 | text_embeds: Optional[torch.FloatTensor] = None 133 | last_hidden_state: torch.FloatTensor = None 134 | hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None 135 | attentions: Optional[Tuple[torch.FloatTensor, ...]] = None 136 | 137 | 138 | @dataclass 139 | class CLIPOutput(ModelOutput): 140 | """ 141 | Args: 142 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): 143 | Contrastive loss for image-text similarity. 144 | logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): 145 | The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text 146 | similarity scores. 147 | logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): 148 | The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image 149 | similarity scores. 150 | text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): 151 | The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`]. 152 | image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): 153 | The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`]. 154 | text_model_output (`BaseModelOutputWithPooling`): 155 | The output of the [`CLIPTextModel`]. 156 | vision_model_output (`BaseModelOutputWithPooling`): 157 | The output of the [`CLIPVisionModel`]. 158 | """ 159 | 160 | loss: Optional[torch.FloatTensor] = None 161 | logits_per_image: torch.FloatTensor = None 162 | logits_per_text: torch.FloatTensor = None 163 | text_embeds: torch.FloatTensor = None 164 | image_embeds: torch.FloatTensor = None 165 | text_model_output: BaseModelOutputWithPooling = None 166 | vision_model_output: BaseModelOutputWithPooling = None 167 | 168 | def to_tuple(self) -> Tuple[Any]: 169 | return tuple( 170 | self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() 171 | for k in self.keys() 172 | ) 173 | 174 | 175 | class CLIPVisionEmbeddings(nn.Module): 176 | def __init__(self, config: CLIPVisionConfig): 177 | super().__init__() 178 | self.config = config 179 | self.embed_dim = config.hidden_size 180 | self.image_size = config.image_size 181 | self.patch_size = config.patch_size 182 | 183 | self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) 184 | 185 | self.patch_embedding = nn.Conv2d( 186 | in_channels=config.num_channels, 187 | out_channels=self.embed_dim, 188 | kernel_size=self.patch_size, 189 | stride=self.patch_size, 190 | bias=False, 191 | ) 192 | 193 | self.num_patches = (self.image_size // self.patch_size) ** 2 194 | self.num_positions = self.num_patches + 1 195 | self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) 196 | self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) 197 | 198 | def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: 199 | batch_size = pixel_values.shape[0] 200 | target_dtype = self.patch_embedding.weight.dtype 201 | patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] 202 | patch_embeds = patch_embeds.flatten(2).transpose(1, 2) 203 | 204 | class_embeds = self.class_embedding.expand(batch_size, 1, -1) 205 | embeddings = torch.cat([class_embeds, patch_embeds], dim=1) 206 | embeddings = embeddings + self.position_embedding(self.position_ids) 207 | return embeddings 208 | 209 | 210 | class CLIPTextEmbeddings(nn.Module): 211 | def __init__(self, config: CLIPTextConfig): 212 | super().__init__() 213 | embed_dim = config.hidden_size 214 | 215 | self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) 216 | self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) 217 | 218 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized 219 | self.register_buffer( 220 | "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False 221 | ) 222 | 223 | def forward( 224 | self, 225 | input_ids: Optional[torch.LongTensor] = None, 226 | position_ids: Optional[torch.LongTensor] = None, 227 | inputs_embeds: Optional[torch.FloatTensor] = None, 228 | ) -> torch.Tensor: 229 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] 230 | 231 | if position_ids is None: 232 | position_ids = self.position_ids[:, :seq_length] 233 | 234 | if inputs_embeds is None: 235 | inputs_embeds = self.token_embedding(input_ids) 236 | 237 | position_embeddings = self.position_embedding(position_ids) 238 | embeddings = inputs_embeds + position_embeddings 239 | 240 | return embeddings 241 | 242 | 243 | class CLIPAttention(nn.Module): 244 | """Multi-headed attention from 'Attention Is All You Need' paper""" 245 | 246 | def __init__(self, config): 247 | super().__init__() 248 | self.config = config 249 | self.embed_dim = config.hidden_size 250 | self.num_heads = config.num_attention_heads 251 | self.head_dim = self.embed_dim // self.num_heads 252 | if self.head_dim * self.num_heads != self.embed_dim: 253 | raise ValueError( 254 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" 255 | f" {self.num_heads})." 256 | ) 257 | self.scale = self.head_dim**-0.5 258 | self.dropout = config.attention_dropout 259 | 260 | self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) 261 | self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) 262 | self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) 263 | self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) 264 | 265 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 266 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 267 | 268 | def forward( 269 | self, 270 | hidden_states: torch.Tensor, 271 | attention_mask: Optional[torch.Tensor] = None, 272 | causal_attention_mask: Optional[torch.Tensor] = None, 273 | output_attentions: Optional[bool] = False, 274 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 275 | """Input shape: Batch x Time x Channel""" 276 | 277 | bsz, tgt_len, embed_dim = hidden_states.size() 278 | 279 | # get query proj 280 | query_states = self.q_proj(hidden_states) * self.scale 281 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 282 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 283 | 284 | proj_shape = (bsz * self.num_heads, -1, self.head_dim) 285 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) 286 | key_states = key_states.view(*proj_shape) 287 | value_states = value_states.view(*proj_shape) 288 | 289 | src_len = key_states.size(1) 290 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) 291 | 292 | if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): 293 | raise ValueError( 294 | f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" 295 | f" {attn_weights.size()}" 296 | ) 297 | 298 | # apply the causal_attention_mask first 299 | if causal_attention_mask is not None: 300 | if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): 301 | raise ValueError( 302 | f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" 303 | f" {causal_attention_mask.size()}" 304 | ) 305 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask 306 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 307 | 308 | if attention_mask is not None: 309 | if attention_mask.size() != (bsz, 1, tgt_len, src_len): 310 | raise ValueError( 311 | f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" 312 | ) 313 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask 314 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 315 | 316 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 317 | 318 | if output_attentions: 319 | # this operation is a bit akward, but it's required to 320 | # make sure that attn_weights keeps its gradient. 321 | # In order to do so, attn_weights have to reshaped 322 | # twice and have to be reused in the following 323 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 324 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) 325 | else: 326 | attn_weights_reshaped = None 327 | 328 | attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) 329 | 330 | attn_output = torch.bmm(attn_probs, value_states) 331 | 332 | if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): 333 | raise ValueError( 334 | f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" 335 | f" {attn_output.size()}" 336 | ) 337 | 338 | attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) 339 | attn_output = attn_output.transpose(1, 2) 340 | attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) 341 | 342 | attn_output = self.out_proj(attn_output) 343 | 344 | return attn_output, attn_weights_reshaped 345 | 346 | 347 | class CLIPFlashAttention2(CLIPAttention): 348 | """ 349 | CLIPAttention flash attention module. This module inherits from `CLIPAttention` as the weights of the module stays 350 | untouched. The only required change would be on the forward pass where it needs to correctly call the public API of 351 | flash attention and deal with padding tokens in case the input contains any of them. 352 | """ 353 | 354 | # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ 355 | def __init__(self, *args, **kwargs): 356 | super().__init__(*args, **kwargs) 357 | 358 | # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. 359 | # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. 360 | # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). 361 | self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() 362 | 363 | # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward 364 | def forward( 365 | self, 366 | hidden_states: torch.Tensor, 367 | attention_mask: Optional[torch.Tensor] = None, 368 | causal_attention_mask: Optional[torch.Tensor] = None, 369 | output_attentions: Optional[bool] = False, 370 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 371 | output_attentions = False 372 | 373 | batch_size, q_len, _ = hidden_states.size() 374 | 375 | query_states = self.q_proj(hidden_states) 376 | key_states = self.k_proj(hidden_states) 377 | value_states = self.v_proj(hidden_states) 378 | 379 | # Flash attention requires the input to have the shape 380 | # batch_size x seq_length x head_dim x hidden_dim 381 | # therefore we just need to keep the original shape 382 | query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim) 383 | key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim) 384 | value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim) 385 | 386 | dropout_rate = self.dropout if self.training else 0.0 387 | 388 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 389 | # therefore the input hidden states gets silently casted in float32. Hence, we need 390 | # cast them back in the correct dtype just to be sure everything works as expected. 391 | # This might slowdown training & inference so it is recommended to not cast the LayerNorms 392 | # in fp32. 393 | 394 | input_dtype = query_states.dtype 395 | if input_dtype == torch.float32: 396 | if torch.is_autocast_enabled(): 397 | target_dtype = torch.get_autocast_gpu_dtype() 398 | # Handle the case where the model is quantized 399 | elif hasattr(self.config, "_pre_quantization_dtype"): 400 | target_dtype = self.config._pre_quantization_dtype 401 | else: 402 | target_dtype = self.q_proj.weight.dtype 403 | 404 | logger.warning_once( 405 | f"The input hidden states seems to be silently casted in float32, this might be related to" 406 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 407 | f" {target_dtype}." 408 | ) 409 | 410 | query_states = query_states.to(target_dtype) 411 | key_states = key_states.to(target_dtype) 412 | value_states = value_states.to(target_dtype) 413 | 414 | attn_output = _flash_attention_forward( 415 | query_states, 416 | key_states, 417 | value_states, 418 | attention_mask, 419 | q_len, 420 | dropout=dropout_rate, 421 | is_causal=causal_attention_mask is not None, 422 | use_top_left_mask=self._flash_attn_uses_top_left_mask, 423 | ) 424 | 425 | attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous() 426 | attn_output = self.out_proj(attn_output) 427 | 428 | if not output_attentions: 429 | attn_weights = None 430 | 431 | return attn_output, attn_weights 432 | 433 | 434 | class CLIPSdpaAttention(CLIPAttention): 435 | """ 436 | SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from 437 | `CLIPAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to 438 | SDPA API. 439 | """ 440 | 441 | # Adapted from CLIPAttention.forward 442 | def forward( 443 | self, 444 | hidden_states: torch.Tensor, 445 | attention_mask: Optional[torch.Tensor] = None, 446 | causal_attention_mask: Optional[torch.Tensor] = None, 447 | output_attentions: Optional[bool] = False, 448 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 449 | if output_attentions: 450 | # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. 451 | logger.warning_once( 452 | "CLIPModel is using CLIPSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not " 453 | "support `output_attentions=True`. Falling back to the manual attention implementation, but specifying " 454 | "the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can " 455 | 'be removed using the argument `attn_implementation="eager"` when loading the model.' 456 | ) 457 | return super().forward( 458 | hidden_states=hidden_states, 459 | attention_mask=attention_mask, 460 | causal_attention_mask=causal_attention_mask, 461 | output_attentions=output_attentions, 462 | ) 463 | 464 | # CLIP text model uses both `causal_attention_mask` and `attention_mask` 465 | if attention_mask is not None and causal_attention_mask is not None: 466 | attn_mask = attention_mask + causal_attention_mask 467 | elif causal_attention_mask is not None: 468 | attn_mask = causal_attention_mask 469 | else: 470 | attn_mask = attention_mask 471 | 472 | bsz, tgt_len, embed_dim = hidden_states.size() 473 | 474 | query_states = self.q_proj(hidden_states) 475 | key_states = self.k_proj(hidden_states) 476 | value_states = self.v_proj(hidden_states) 477 | 478 | query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) 479 | key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) 480 | value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) 481 | 482 | # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, 483 | # Reference: https://github.com/pytorch/pytorch/issues/112577. 484 | if not is_torch_greater_or_equal_than_2_2 and query_states.device.type == "cuda" and attn_mask is not None: 485 | query_states = query_states.contiguous() 486 | key_states = key_states.contiguous() 487 | value_states = value_states.contiguous() 488 | 489 | # CLIP text model uses both `causal_attention_mask` and `attention_mask` sequentially. 490 | attn_output = torch.nn.functional.scaled_dot_product_attention( 491 | query_states, 492 | key_states, 493 | value_states, 494 | attn_mask=attn_mask, 495 | dropout_p=self.dropout if self.training else 0.0, 496 | scale=self.scale, 497 | ) 498 | 499 | attn_output = attn_output.transpose(1, 2) 500 | attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) 501 | 502 | attn_output = self.out_proj(attn_output) 503 | 504 | return attn_output, None 505 | 506 | 507 | CLIP_ATTENTION_CLASSES = { 508 | "eager": CLIPAttention, 509 | "sdpa": CLIPSdpaAttention, 510 | "flash_attention_2": CLIPFlashAttention2, 511 | } 512 | 513 | 514 | class CLIPMLP(nn.Module): 515 | def __init__(self, config): 516 | super().__init__() 517 | self.config = config 518 | self.activation_fn = ACT2FN[config.hidden_act] 519 | self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) 520 | self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) 521 | 522 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 523 | hidden_states = self.fc1(hidden_states) 524 | hidden_states = self.activation_fn(hidden_states) 525 | hidden_states = self.fc2(hidden_states) 526 | return hidden_states 527 | 528 | 529 | class CLIPEncoderLayer(nn.Module): 530 | def __init__(self, config: CLIPConfig): 531 | super().__init__() 532 | self.embed_dim = config.hidden_size 533 | self.self_attn = CLIP_ATTENTION_CLASSES[config._attn_implementation](config) 534 | self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) 535 | self.mlp = CLIPMLP(config) 536 | self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) 537 | 538 | def forward( 539 | self, 540 | hidden_states: torch.Tensor, 541 | attention_mask: torch.Tensor, 542 | causal_attention_mask: torch.Tensor, 543 | output_attentions: Optional[bool] = False, 544 | ) -> Tuple[torch.FloatTensor]: 545 | """ 546 | Args: 547 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 548 | attention_mask (`torch.FloatTensor`): attention mask of size 549 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 550 | `(config.encoder_attention_heads,)`. 551 | output_attentions (`bool`, *optional*): 552 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 553 | returned tensors for more detail. 554 | """ 555 | residual = hidden_states 556 | 557 | hidden_states = self.layer_norm1(hidden_states) 558 | hidden_states, attn_weights = self.self_attn( 559 | hidden_states=hidden_states, 560 | attention_mask=attention_mask, 561 | causal_attention_mask=causal_attention_mask, 562 | output_attentions=output_attentions, 563 | ) 564 | hidden_states = residual + hidden_states 565 | 566 | residual = hidden_states 567 | hidden_states = self.layer_norm2(hidden_states) 568 | hidden_states = self.mlp(hidden_states) 569 | hidden_states = residual + hidden_states 570 | 571 | outputs = (hidden_states,) 572 | 573 | if output_attentions: 574 | outputs += (attn_weights,) 575 | 576 | return outputs 577 | 578 | 579 | class CLIPPreTrainedModel(PreTrainedModel): 580 | """ 581 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 582 | models. 583 | """ 584 | 585 | config_class = CLIPConfig 586 | base_model_prefix = "clip" 587 | supports_gradient_checkpointing = True 588 | _supports_sdpa = True 589 | _supports_flash_attn_2 = True 590 | 591 | def _init_weights(self, module): 592 | """Initialize the weights""" 593 | factor = self.config.initializer_factor 594 | if isinstance(module, CLIPTextEmbeddings): 595 | module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) 596 | module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) 597 | elif isinstance(module, CLIPVisionEmbeddings): 598 | factor = self.config.initializer_factor 599 | nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) 600 | nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) 601 | nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) 602 | elif isinstance(module, CLIPAttention): 603 | factor = self.config.initializer_factor 604 | in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor 605 | out_proj_std = (module.embed_dim**-0.5) * factor 606 | nn.init.normal_(module.q_proj.weight, std=in_proj_std) 607 | nn.init.normal_(module.k_proj.weight, std=in_proj_std) 608 | nn.init.normal_(module.v_proj.weight, std=in_proj_std) 609 | nn.init.normal_(module.out_proj.weight, std=out_proj_std) 610 | elif isinstance(module, CLIPMLP): 611 | factor = self.config.initializer_factor 612 | in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor 613 | fc_std = (2 * module.config.hidden_size) ** -0.5 * factor 614 | nn.init.normal_(module.fc1.weight, std=fc_std) 615 | nn.init.normal_(module.fc2.weight, std=in_proj_std) 616 | elif isinstance(module, CLIPModel): 617 | nn.init.normal_( 618 | module.text_projection.weight, 619 | std=module.text_embed_dim**-0.5 * self.config.initializer_factor, 620 | ) 621 | nn.init.normal_( 622 | module.visual_projection.weight, 623 | std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, 624 | ) 625 | elif isinstance(module, CLIPVisionModelWithProjection): 626 | nn.init.normal_( 627 | module.visual_projection.weight, 628 | std=self.config.hidden_size**-0.5 * self.config.initializer_factor, 629 | ) 630 | elif isinstance(module, CLIPTextModelWithProjection): 631 | nn.init.normal_( 632 | module.text_projection.weight, 633 | std=self.config.hidden_size**-0.5 * self.config.initializer_factor, 634 | ) 635 | elif isinstance(module, CLIPForImageClassification): 636 | nn.init.normal_( 637 | module.classifier.weight, 638 | std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor, 639 | ) 640 | 641 | if isinstance(module, nn.LayerNorm): 642 | module.bias.data.zero_() 643 | module.weight.data.fill_(1.0) 644 | if isinstance(module, nn.Linear) and module.bias is not None: 645 | module.bias.data.zero_() 646 | 647 | 648 | CLIP_START_DOCSTRING = r""" 649 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 650 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 651 | etc.) 652 | 653 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 654 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 655 | and behavior. 656 | 657 | Parameters: 658 | config ([`CLIPConfig`]): Model configuration class with all the parameters of the model. 659 | Initializing with a config file does not load the weights associated with the model, only the 660 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. 661 | """ 662 | 663 | CLIP_TEXT_INPUTS_DOCSTRING = r""" 664 | Args: 665 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 666 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 667 | it. 668 | 669 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 670 | [`PreTrainedTokenizer.__call__`] for details. 671 | 672 | [What are input IDs?](../glossary#input-ids) 673 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 674 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 675 | 676 | - 1 for tokens that are **not masked**, 677 | - 0 for tokens that are **masked**. 678 | 679 | [What are attention masks?](../glossary#attention-mask) 680 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 681 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 682 | config.max_position_embeddings - 1]`. 683 | 684 | [What are position IDs?](../glossary#position-ids) 685 | output_attentions (`bool`, *optional*): 686 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 687 | tensors for more detail. 688 | output_hidden_states (`bool`, *optional*): 689 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 690 | more detail. 691 | return_dict (`bool`, *optional*): 692 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 693 | """ 694 | 695 | CLIP_VISION_INPUTS_DOCSTRING = r""" 696 | Args: 697 | pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 698 | Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using 699 | [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. 700 | output_attentions (`bool`, *optional*): 701 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 702 | tensors for more detail. 703 | output_hidden_states (`bool`, *optional*): 704 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 705 | more detail. 706 | return_dict (`bool`, *optional*): 707 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 708 | """ 709 | 710 | CLIP_INPUTS_DOCSTRING = r""" 711 | Args: 712 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 713 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 714 | it. 715 | 716 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 717 | [`PreTrainedTokenizer.__call__`] for details. 718 | 719 | [What are input IDs?](../glossary#input-ids) 720 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 721 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 722 | 723 | - 1 for tokens that are **not masked**, 724 | - 0 for tokens that are **masked**. 725 | 726 | [What are attention masks?](../glossary#attention-mask) 727 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 728 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 729 | config.max_position_embeddings - 1]`. 730 | 731 | [What are position IDs?](../glossary#position-ids) 732 | pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 733 | Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using 734 | [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. 735 | return_loss (`bool`, *optional*): 736 | Whether or not to return the contrastive loss. 737 | output_attentions (`bool`, *optional*): 738 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 739 | tensors for more detail. 740 | output_hidden_states (`bool`, *optional*): 741 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 742 | more detail. 743 | return_dict (`bool`, *optional*): 744 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 745 | """ 746 | 747 | 748 | class CLIPEncoder(nn.Module): 749 | """ 750 | Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a 751 | [`CLIPEncoderLayer`]. 752 | 753 | Args: 754 | config: CLIPConfig 755 | """ 756 | 757 | def __init__(self, config: CLIPConfig): 758 | super().__init__() 759 | self.config = config 760 | self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) 761 | self.gradient_checkpointing = False 762 | 763 | def forward( 764 | self, 765 | inputs_embeds, 766 | attention_mask: Optional[torch.Tensor] = None, 767 | causal_attention_mask: Optional[torch.Tensor] = None, 768 | output_attentions: Optional[bool] = None, 769 | output_hidden_states: Optional[bool] = None, 770 | return_dict: Optional[bool] = None, 771 | ) -> Union[Tuple, BaseModelOutput]: 772 | r""" 773 | Args: 774 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): 775 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. 776 | This is useful if you want more control over how to convert `input_ids` indices into associated vectors 777 | than the model's internal embedding lookup matrix. 778 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 779 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 780 | 781 | - 1 for tokens that are **not masked**, 782 | - 0 for tokens that are **masked**. 783 | 784 | [What are attention masks?](../glossary#attention-mask) 785 | causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 786 | Causal mask for the text model. Mask values selected in `[0, 1]`: 787 | 788 | - 1 for tokens that are **not masked**, 789 | - 0 for tokens that are **masked**. 790 | 791 | [What are attention masks?](../glossary#attention-mask) 792 | output_attentions (`bool`, *optional*): 793 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 794 | returned tensors for more detail. 795 | output_hidden_states (`bool`, *optional*): 796 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 797 | for more detail. 798 | return_dict (`bool`, *optional*): 799 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 800 | """ 801 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 802 | output_hidden_states = ( 803 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 804 | ) 805 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 806 | 807 | encoder_states = () if output_hidden_states else None 808 | all_attentions = () if output_attentions else None 809 | 810 | hidden_states = inputs_embeds 811 | for idx, encoder_layer in enumerate(self.layers): 812 | if output_hidden_states: 813 | encoder_states = encoder_states + (hidden_states,) 814 | if self.gradient_checkpointing and self.training: 815 | layer_outputs = self._gradient_checkpointing_func( 816 | encoder_layer.__call__, 817 | hidden_states, 818 | attention_mask, 819 | causal_attention_mask, 820 | output_attentions, 821 | ) 822 | else: 823 | layer_outputs = encoder_layer( 824 | hidden_states, 825 | attention_mask, 826 | causal_attention_mask, 827 | output_attentions=output_attentions, 828 | ) 829 | 830 | hidden_states = layer_outputs[0] 831 | 832 | if output_attentions: 833 | all_attentions = all_attentions + (layer_outputs[1],) 834 | 835 | if output_hidden_states: 836 | encoder_states = encoder_states + (hidden_states,) 837 | 838 | if not return_dict: 839 | return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) 840 | return BaseModelOutput( 841 | last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions 842 | ) 843 | 844 | 845 | class CLIPTextTransformer(nn.Module): 846 | def __init__(self, config: CLIPTextConfig): 847 | super().__init__() 848 | self.config = config 849 | embed_dim = config.hidden_size 850 | self.embeddings = CLIPTextEmbeddings(config) 851 | self.encoder = CLIPEncoder(config) 852 | self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) 853 | 854 | # For `pooled_output` computation 855 | self.eos_token_id = config.eos_token_id 856 | 857 | # For attention mask, it differs between `flash_attention_2` and other attention implementations 858 | self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" 859 | 860 | @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) 861 | @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) 862 | def forward( 863 | self, 864 | input_ids: Optional[torch.Tensor] = None, 865 | attention_mask: Optional[torch.Tensor] = None, 866 | position_ids: Optional[torch.Tensor] = None, 867 | output_attentions: Optional[bool] = None, 868 | output_hidden_states: Optional[bool] = None, 869 | return_dict: Optional[bool] = None, 870 | ) -> Union[Tuple, BaseModelOutputWithPooling]: 871 | r""" 872 | Returns: 873 | 874 | """ 875 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 876 | output_hidden_states = ( 877 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 878 | ) 879 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 880 | 881 | if input_ids is None: 882 | raise ValueError("You have to specify input_ids") 883 | 884 | input_shape = input_ids.size() 885 | input_ids = input_ids.view(-1, input_shape[-1]) 886 | 887 | hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) 888 | 889 | # CLIP's text model uses causal mask, prepare it here. 890 | # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 891 | causal_attention_mask = _create_4d_causal_attention_mask( 892 | input_shape, hidden_states.dtype, device=hidden_states.device 893 | ) 894 | 895 | # expand attention_mask 896 | if attention_mask is not None and not self._use_flash_attention_2: 897 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 898 | attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) 899 | 900 | encoder_outputs = self.encoder( 901 | inputs_embeds=hidden_states, 902 | attention_mask=attention_mask, 903 | causal_attention_mask=causal_attention_mask, 904 | output_attentions=output_attentions, 905 | output_hidden_states=output_hidden_states, 906 | return_dict=return_dict, 907 | ) 908 | 909 | last_hidden_state = encoder_outputs[0] 910 | last_hidden_state = self.final_layer_norm(last_hidden_state) 911 | 912 | if self.eos_token_id == 2: 913 | # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. 914 | # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added 915 | # ------------------------------------------------------------ 916 | # text_embeds.shape = [batch_size, sequence_length, transformer.width] 917 | # take features from the eot embedding (eot_token is the highest number in each sequence) 918 | # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 919 | pooled_output = last_hidden_state[ 920 | torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), 921 | input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), 922 | ] 923 | else: 924 | # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) 925 | pooled_output = last_hidden_state[ 926 | torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), 927 | # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) 928 | # Note: we assume each sequence (along batch dim.) contains an `eos_token_id` (e.g. prepared by the tokenizer) 929 | (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id) 930 | .int() 931 | .argmax(dim=-1), 932 | ] 933 | 934 | if not return_dict: 935 | return (last_hidden_state, pooled_output) + encoder_outputs[1:] 936 | 937 | return BaseModelOutputWithPooling( 938 | last_hidden_state=last_hidden_state, 939 | pooler_output=pooled_output, 940 | hidden_states=encoder_outputs.hidden_states, 941 | attentions=encoder_outputs.attentions, 942 | ) 943 | 944 | 945 | @add_start_docstrings( 946 | """The text model from CLIP without any head or projection on top.""", 947 | CLIP_START_DOCSTRING, 948 | ) 949 | class CLIPTextModel(CLIPPreTrainedModel): 950 | config_class = CLIPTextConfig 951 | 952 | _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"] 953 | 954 | def __init__(self, config: CLIPTextConfig): 955 | super().__init__(config) 956 | self.text_model = CLIPTextTransformer(config) 957 | # Initialize weights and apply final processing 958 | self.post_init() 959 | 960 | def get_input_embeddings(self) -> nn.Module: 961 | return self.text_model.embeddings.token_embedding 962 | 963 | def set_input_embeddings(self, value): 964 | self.text_model.embeddings.token_embedding = value 965 | 966 | @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) 967 | @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) 968 | def forward( 969 | self, 970 | input_ids: Optional[torch.Tensor] = None, 971 | attention_mask: Optional[torch.Tensor] = None, 972 | position_ids: Optional[torch.Tensor] = None, 973 | output_attentions: Optional[bool] = None, 974 | output_hidden_states: Optional[bool] = None, 975 | return_dict: Optional[bool] = None, 976 | ) -> Union[Tuple, BaseModelOutputWithPooling]: 977 | r""" 978 | Returns: 979 | 980 | Examples: 981 | 982 | ```python 983 | >>> from transformers import AutoTokenizer, CLIPTextModel 984 | 985 | >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") 986 | >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") 987 | 988 | >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") 989 | 990 | >>> outputs = model(**inputs) 991 | >>> last_hidden_state = outputs.last_hidden_state 992 | >>> pooled_output = outputs.pooler_output # pooled (EOS token) states 993 | ```""" 994 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 995 | 996 | return self.text_model( 997 | input_ids=input_ids, 998 | attention_mask=attention_mask, 999 | position_ids=position_ids, 1000 | output_attentions=output_attentions, 1001 | output_hidden_states=output_hidden_states, 1002 | return_dict=return_dict, 1003 | ) 1004 | 1005 | 1006 | class CLIPVisionTransformer(nn.Module): 1007 | def __init__(self, config: CLIPVisionConfig): 1008 | super().__init__() 1009 | self.config = config 1010 | embed_dim = config.hidden_size 1011 | 1012 | self.embeddings = CLIPVisionEmbeddings(config) 1013 | self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) 1014 | self.encoder = CLIPEncoder(config) 1015 | self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) 1016 | 1017 | @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) 1018 | @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig) 1019 | def forward( 1020 | self, 1021 | pixel_values: Optional[torch.FloatTensor] = None, 1022 | output_attentions: Optional[bool] = None, 1023 | output_hidden_states: Optional[bool] = None, 1024 | return_dict: Optional[bool] = None, 1025 | ) -> Union[Tuple, BaseModelOutputWithPooling]: 1026 | r""" 1027 | Returns: 1028 | 1029 | """ 1030 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1031 | output_hidden_states = ( 1032 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1033 | ) 1034 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1035 | 1036 | if pixel_values is None: 1037 | raise ValueError("You have to specify pixel_values") 1038 | 1039 | hidden_states = self.embeddings(pixel_values) 1040 | hidden_states = self.pre_layrnorm(hidden_states) 1041 | 1042 | encoder_outputs = self.encoder( 1043 | inputs_embeds=hidden_states, 1044 | output_attentions=output_attentions, 1045 | output_hidden_states=output_hidden_states, 1046 | return_dict=return_dict, 1047 | ) 1048 | 1049 | last_hidden_state = encoder_outputs[0] 1050 | pooled_output = last_hidden_state[:, 0, :] 1051 | pooled_output = self.post_layernorm(pooled_output) 1052 | 1053 | if not return_dict: 1054 | return (last_hidden_state, pooled_output) + encoder_outputs[1:] 1055 | 1056 | return BaseModelOutputWithPooling( 1057 | last_hidden_state=last_hidden_state, 1058 | pooler_output=pooled_output, 1059 | hidden_states=encoder_outputs.hidden_states, 1060 | attentions=encoder_outputs.attentions, 1061 | ) 1062 | 1063 | 1064 | @add_start_docstrings( 1065 | """The vision model from CLIP without any head or projection on top.""", 1066 | CLIP_START_DOCSTRING, 1067 | ) 1068 | class CLIPVisionModel(CLIPPreTrainedModel): 1069 | config_class = CLIPVisionConfig 1070 | main_input_name = "pixel_values" 1071 | _no_split_modules = ["CLIPEncoderLayer"] 1072 | 1073 | def __init__(self, config: CLIPVisionConfig): 1074 | super().__init__(config) 1075 | self.vision_model = CLIPVisionTransformer(config) 1076 | # Initialize weights and apply final processing 1077 | self.post_init() 1078 | 1079 | def get_input_embeddings(self) -> nn.Module: 1080 | return self.vision_model.embeddings.patch_embedding 1081 | 1082 | @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) 1083 | @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig) 1084 | def forward( 1085 | self, 1086 | pixel_values: Optional[torch.FloatTensor] = None, 1087 | output_attentions: Optional[bool] = None, 1088 | output_hidden_states: Optional[bool] = None, 1089 | return_dict: Optional[bool] = None, 1090 | ) -> Union[Tuple, BaseModelOutputWithPooling]: 1091 | r""" 1092 | Returns: 1093 | 1094 | Examples: 1095 | 1096 | ```python 1097 | >>> from PIL import Image 1098 | >>> import requests 1099 | >>> from transformers import AutoProcessor, CLIPVisionModel 1100 | 1101 | >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") 1102 | >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") 1103 | 1104 | >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" 1105 | >>> image = Image.open(requests.get(url, stream=True).raw) 1106 | 1107 | >>> inputs = processor(images=image, return_tensors="pt") 1108 | 1109 | >>> outputs = model(**inputs) 1110 | >>> last_hidden_state = outputs.last_hidden_state 1111 | >>> pooled_output = outputs.pooler_output # pooled CLS states 1112 | ```""" 1113 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1114 | 1115 | return self.vision_model( 1116 | pixel_values=pixel_values, 1117 | output_attentions=output_attentions, 1118 | output_hidden_states=output_hidden_states, 1119 | return_dict=return_dict, 1120 | ) 1121 | 1122 | 1123 | @add_start_docstrings(CLIP_START_DOCSTRING) 1124 | class CLIPModel(CLIPPreTrainedModel): 1125 | config_class = CLIPConfig 1126 | _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer", "CLIPVisionEmbeddings"] 1127 | 1128 | def __init__(self, config: CLIPConfig): 1129 | super().__init__(config) 1130 | 1131 | if not isinstance(config.text_config, CLIPTextConfig): 1132 | raise TypeError( 1133 | "config.text_config is expected to be of type CLIPTextConfig but is of type" 1134 | f" {type(config.text_config)}." 1135 | ) 1136 | 1137 | if not isinstance(config.vision_config, CLIPVisionConfig): 1138 | raise TypeError( 1139 | "config.vision_config is expected to be of type CLIPVisionConfig but is of type" 1140 | f" {type(config.vision_config)}." 1141 | ) 1142 | 1143 | text_config = config.text_config 1144 | vision_config = config.vision_config 1145 | 1146 | self.projection_dim = config.projection_dim 1147 | self.text_embed_dim = text_config.hidden_size 1148 | self.vision_embed_dim = vision_config.hidden_size 1149 | 1150 | text_model = CLIPTextModel._from_config(text_config, attn_implementation=config._attn_implementation) 1151 | self.text_model = text_model.text_model 1152 | 1153 | vision_model = CLIPVisionModel._from_config(vision_config, attn_implementation=config._attn_implementation) 1154 | self.vision_model = vision_model.vision_model 1155 | 1156 | self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) 1157 | self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) 1158 | self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) 1159 | 1160 | # Initialize weights and apply final processing 1161 | self.post_init() 1162 | 1163 | def set_processor(self, model_name): 1164 | self.processor = CLIPProcessor.from_pretrained(model_name) 1165 | 1166 | @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) 1167 | def get_text_features( 1168 | self, 1169 | input_ids: Optional[torch.Tensor] = None, 1170 | attention_mask: Optional[torch.Tensor] = None, 1171 | position_ids: Optional[torch.Tensor] = None, 1172 | output_attentions: Optional[bool] = None, 1173 | output_hidden_states: Optional[bool] = None, 1174 | return_dict: Optional[bool] = None, 1175 | ) -> torch.FloatTensor: 1176 | r""" 1177 | Returns: 1178 | text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by 1179 | applying the projection layer to the pooled output of [`CLIPTextModel`]. 1180 | 1181 | Examples: 1182 | 1183 | ```python 1184 | >>> from transformers import AutoTokenizer, CLIPModel 1185 | 1186 | >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") 1187 | >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") 1188 | 1189 | >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") 1190 | >>> text_features = model.get_text_features(**inputs) 1191 | ```""" 1192 | # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. 1193 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1194 | output_hidden_states = ( 1195 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1196 | ) 1197 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1198 | 1199 | text_outputs = self.text_model( 1200 | input_ids=input_ids, 1201 | attention_mask=attention_mask, 1202 | position_ids=position_ids, 1203 | output_attentions=output_attentions, 1204 | output_hidden_states=output_hidden_states, 1205 | return_dict=return_dict, 1206 | ) 1207 | 1208 | pooled_output = text_outputs[1] 1209 | text_features = self.text_projection(pooled_output) 1210 | 1211 | return text_features 1212 | 1213 | @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) 1214 | def get_image_features( 1215 | self, 1216 | pixel_values: Optional[torch.FloatTensor] = None, 1217 | output_attentions: Optional[bool] = None, 1218 | output_hidden_states: Optional[bool] = None, 1219 | return_dict: Optional[bool] = None, 1220 | ) -> torch.FloatTensor: 1221 | r""" 1222 | Returns: 1223 | image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by 1224 | applying the projection layer to the pooled output of [`CLIPVisionModel`]. 1225 | 1226 | Examples: 1227 | 1228 | ```python 1229 | >>> from PIL import Image 1230 | >>> import requests 1231 | >>> from transformers import AutoProcessor, CLIPModel 1232 | 1233 | >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") 1234 | >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") 1235 | 1236 | >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" 1237 | >>> image = Image.open(requests.get(url, stream=True).raw) 1238 | 1239 | >>> inputs = processor(images=image, return_tensors="pt") 1240 | 1241 | >>> image_features = model.get_image_features(**inputs) 1242 | ```""" 1243 | # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. 1244 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1245 | output_hidden_states = ( 1246 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1247 | ) 1248 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1249 | 1250 | vision_outputs = self.vision_model( 1251 | pixel_values=pixel_values, 1252 | output_attentions=output_attentions, 1253 | output_hidden_states=output_hidden_states, 1254 | return_dict=return_dict, 1255 | ) 1256 | 1257 | pooled_output = vision_outputs[1] # pooled_output 1258 | image_features = self.visual_projection(pooled_output) 1259 | 1260 | return image_features 1261 | 1262 | 1263 | def encode_image(self, images): 1264 | embeddings = self.get_image_features(images) 1265 | embeddings = torch.nn.functional.normalize(embeddings, dim=-1) 1266 | return embeddings 1267 | 1268 | def encode_text(self, text): 1269 | embeddings = self.get_text_features(**text) 1270 | embeddings = torch.nn.functional.normalize(embeddings, dim=-1) 1271 | return embeddings 1272 | 1273 | def encode_multimodal(self, images, text): 1274 | text_embeddings = self.get_text_features(**text) 1275 | image_embeddings = self.get_image_features(images) 1276 | 1277 | embeddings = text_embeddings + image_embeddings 1278 | embeddings = torch.nn.functional.normalize(embeddings, dim=-1) 1279 | 1280 | return embeddings.contiguous() 1281 | 1282 | def data_process(self, images=None, text=None): 1283 | if images is None and text is not None: 1284 | text = self.processor(text=text, return_tensors="pt", padding=True).to(self.device) 1285 | 1286 | return images, text, "text" 1287 | elif images is not None and text is None: 1288 | if isinstance(images, str): 1289 | images = Image.open(images).convert("RGB") 1290 | elif isinstance(images, list): 1291 | images = [Image.open(image).convert("RGB") for image in images] 1292 | images = self.processor(images=images, return_tensors="pt").to(self.device) 1293 | images = images["pixel_values"] 1294 | return images, text, "images" 1295 | elif images is not None and text is not None: 1296 | assert type(images) == type(text), "images and text must be the same type: list or str" 1297 | if isinstance(images, str): 1298 | images = Image.open(images).convert("RGB") 1299 | elif isinstance(images, list): 1300 | assert len(images) == len(text), "images and text must be lists of the same length when use list" 1301 | images = [Image.open(image).convert("RGB") for image in images] 1302 | images = self.processor(images=images, return_tensors="pt").to(self.device) 1303 | images = images["pixel_values"] 1304 | text = self.processor(text=text, return_tensors="pt", padding=True).to(self.device) 1305 | return images, text, "multimodal" 1306 | else: 1307 | raise ValueError("images and text cannot both be None") 1308 | 1309 | def encode(self, images=None, text=None): 1310 | images, text, data_type = self.data_process(images, text) 1311 | if data_type == "images": 1312 | return self.encode_image(images) 1313 | elif data_type == "text": 1314 | return self.encode_text(text) 1315 | elif data_type == "multimodal": 1316 | return self.encode_multimodal(images, text) 1317 | 1318 | 1319 | @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING) 1320 | @replace_return_docstrings(output_type=CLIPOutput, config_class=CLIPConfig) 1321 | def forward( 1322 | self, 1323 | input_ids: Optional[torch.LongTensor] = None, 1324 | pixel_values: Optional[torch.FloatTensor] = None, 1325 | attention_mask: Optional[torch.Tensor] = None, 1326 | position_ids: Optional[torch.LongTensor] = None, 1327 | return_loss: Optional[bool] = None, 1328 | output_attentions: Optional[bool] = None, 1329 | output_hidden_states: Optional[bool] = None, 1330 | return_dict: Optional[bool] = None, 1331 | ) -> Union[Tuple, CLIPOutput]: 1332 | r""" 1333 | Returns: 1334 | 1335 | Examples: 1336 | 1337 | ```python 1338 | >>> from PIL import Image 1339 | >>> import requests 1340 | >>> from transformers import AutoProcessor, CLIPModel 1341 | 1342 | >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") 1343 | >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") 1344 | 1345 | >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" 1346 | >>> image = Image.open(requests.get(url, stream=True).raw) 1347 | 1348 | >>> inputs = processor( 1349 | ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True 1350 | ... ) 1351 | 1352 | >>> outputs = model(**inputs) 1353 | >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score 1354 | >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities 1355 | ```""" 1356 | # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. 1357 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1358 | output_hidden_states = ( 1359 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1360 | ) 1361 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1362 | 1363 | vision_outputs = self.vision_model( 1364 | pixel_values=pixel_values, 1365 | output_attentions=output_attentions, 1366 | output_hidden_states=output_hidden_states, 1367 | return_dict=return_dict, 1368 | ) 1369 | 1370 | text_outputs = self.text_model( 1371 | input_ids=input_ids, 1372 | attention_mask=attention_mask, 1373 | position_ids=position_ids, 1374 | output_attentions=output_attentions, 1375 | output_hidden_states=output_hidden_states, 1376 | return_dict=return_dict, 1377 | ) 1378 | 1379 | image_embeds = vision_outputs[1] 1380 | image_embeds = self.visual_projection(image_embeds) 1381 | 1382 | text_embeds = text_outputs[1] 1383 | text_embeds = self.text_projection(text_embeds) 1384 | 1385 | # normalized features 1386 | image_embeds = image_embeds / _get_vector_norm(image_embeds) 1387 | text_embeds = text_embeds / _get_vector_norm(text_embeds) 1388 | 1389 | # cosine similarity as logits 1390 | logit_scale = self.logit_scale.exp() 1391 | logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) * logit_scale.to( 1392 | text_embeds.device 1393 | ) 1394 | logits_per_image = logits_per_text.t() 1395 | 1396 | loss = None 1397 | if return_loss: 1398 | loss = clip_loss(logits_per_text) 1399 | 1400 | if not return_dict: 1401 | output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) 1402 | return ((loss,) + output) if loss is not None else output 1403 | 1404 | return CLIPOutput( 1405 | loss=loss, 1406 | logits_per_image=logits_per_image, 1407 | logits_per_text=logits_per_text, 1408 | text_embeds=text_embeds, 1409 | image_embeds=image_embeds, 1410 | text_model_output=text_outputs, 1411 | vision_model_output=vision_outputs, 1412 | ) 1413 | 1414 | 1415 | @add_start_docstrings( 1416 | """ 1417 | CLIP Text Model with a projection layer on top (a linear layer on top of the pooled output). 1418 | """, 1419 | CLIP_START_DOCSTRING, 1420 | ) 1421 | class CLIPTextModelWithProjection(CLIPPreTrainedModel): 1422 | config_class = CLIPTextConfig 1423 | 1424 | _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"] 1425 | 1426 | def __init__(self, config: CLIPTextConfig): 1427 | super().__init__(config) 1428 | 1429 | text_model = CLIPTextModel._from_config(config, attn_implementation=config._attn_implementation) 1430 | self.text_model = text_model.text_model 1431 | 1432 | self.text_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False) 1433 | 1434 | # Initialize weights and apply final processing 1435 | self.post_init() 1436 | 1437 | def get_input_embeddings(self) -> nn.Module: 1438 | return self.text_model.embeddings.token_embedding 1439 | 1440 | def set_input_embeddings(self, value): 1441 | self.text_model.embeddings.token_embedding = value 1442 | 1443 | @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) 1444 | @replace_return_docstrings(output_type=CLIPTextModelOutput, config_class=CLIPTextConfig) 1445 | def forward( 1446 | self, 1447 | input_ids: Optional[torch.Tensor] = None, 1448 | attention_mask: Optional[torch.Tensor] = None, 1449 | position_ids: Optional[torch.Tensor] = None, 1450 | output_attentions: Optional[bool] = None, 1451 | output_hidden_states: Optional[bool] = None, 1452 | return_dict: Optional[bool] = None, 1453 | ) -> Union[Tuple, CLIPTextModelOutput]: 1454 | r""" 1455 | Returns: 1456 | 1457 | Examples: 1458 | 1459 | ```python 1460 | >>> from transformers import AutoTokenizer, CLIPTextModelWithProjection 1461 | 1462 | >>> model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32") 1463 | >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") 1464 | 1465 | >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") 1466 | 1467 | >>> outputs = model(**inputs) 1468 | >>> text_embeds = outputs.text_embeds 1469 | ```""" 1470 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1471 | 1472 | text_outputs = self.text_model( 1473 | input_ids=input_ids, 1474 | attention_mask=attention_mask, 1475 | position_ids=position_ids, 1476 | output_attentions=output_attentions, 1477 | output_hidden_states=output_hidden_states, 1478 | return_dict=return_dict, 1479 | ) 1480 | 1481 | pooled_output = text_outputs[1] 1482 | 1483 | text_embeds = self.text_projection(pooled_output) 1484 | 1485 | if not return_dict: 1486 | outputs = (text_embeds, text_outputs[0]) + text_outputs[2:] 1487 | return tuple(output for output in outputs if output is not None) 1488 | 1489 | return CLIPTextModelOutput( 1490 | text_embeds=text_embeds, 1491 | last_hidden_state=text_outputs.last_hidden_state, 1492 | hidden_states=text_outputs.hidden_states, 1493 | attentions=text_outputs.attentions, 1494 | ) 1495 | 1496 | 1497 | @add_start_docstrings( 1498 | """ 1499 | CLIP Vision Model with a projection layer on top (a linear layer on top of the pooled output). 1500 | """, 1501 | CLIP_START_DOCSTRING, 1502 | ) 1503 | class CLIPVisionModelWithProjection(CLIPPreTrainedModel): 1504 | config_class = CLIPVisionConfig 1505 | main_input_name = "pixel_values" 1506 | 1507 | def __init__(self, config: CLIPVisionConfig): 1508 | super().__init__(config) 1509 | 1510 | vision_model = CLIPVisionModel._from_config(config, attn_implementation=config._attn_implementation) 1511 | self.vision_model = vision_model.vision_model 1512 | 1513 | self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False) 1514 | 1515 | # Initialize weights and apply final processing 1516 | self.post_init() 1517 | 1518 | def get_input_embeddings(self) -> nn.Module: 1519 | return self.vision_model.embeddings.patch_embedding 1520 | 1521 | @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) 1522 | @replace_return_docstrings(output_type=CLIPVisionModelOutput, config_class=CLIPVisionConfig) 1523 | def forward( 1524 | self, 1525 | pixel_values: Optional[torch.FloatTensor] = None, 1526 | output_attentions: Optional[bool] = None, 1527 | output_hidden_states: Optional[bool] = None, 1528 | return_dict: Optional[bool] = None, 1529 | ) -> Union[Tuple, CLIPVisionModelOutput]: 1530 | r""" 1531 | Returns: 1532 | 1533 | Examples: 1534 | 1535 | ```python 1536 | >>> from PIL import Image 1537 | >>> import requests 1538 | >>> from transformers import AutoProcessor, CLIPVisionModelWithProjection 1539 | 1540 | >>> model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32") 1541 | >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") 1542 | 1543 | >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" 1544 | >>> image = Image.open(requests.get(url, stream=True).raw) 1545 | 1546 | >>> inputs = processor(images=image, return_tensors="pt") 1547 | 1548 | >>> outputs = model(**inputs) 1549 | >>> image_embeds = outputs.image_embeds 1550 | ```""" 1551 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1552 | 1553 | vision_outputs = self.vision_model( 1554 | pixel_values=pixel_values, 1555 | output_attentions=output_attentions, 1556 | output_hidden_states=output_hidden_states, 1557 | return_dict=return_dict, 1558 | ) 1559 | 1560 | pooled_output = vision_outputs[1] # pooled_output 1561 | 1562 | image_embeds = self.visual_projection(pooled_output) 1563 | 1564 | if not return_dict: 1565 | outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:] 1566 | return tuple(output for output in outputs if output is not None) 1567 | 1568 | return CLIPVisionModelOutput( 1569 | image_embeds=image_embeds, 1570 | last_hidden_state=vision_outputs.last_hidden_state, 1571 | hidden_states=vision_outputs.hidden_states, 1572 | attentions=vision_outputs.attentions, 1573 | ) 1574 | 1575 | 1576 | @add_start_docstrings( 1577 | """ 1578 | CLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of 1579 | the patch tokens) e.g. for ImageNet. 1580 | """, 1581 | CLIP_START_DOCSTRING, 1582 | ) 1583 | class CLIPForImageClassification(CLIPPreTrainedModel): 1584 | main_input_name = "pixel_values" 1585 | 1586 | def __init__(self, config: CLIPConfig) -> None: 1587 | super().__init__(config) 1588 | 1589 | self.num_labels = config.num_labels 1590 | vision_model = CLIPVisionModel._from_config( 1591 | config.vision_config, attn_implementation=config._attn_implementation 1592 | ) 1593 | self.vision_model = vision_model.vision_model 1594 | 1595 | # Classifier head 1596 | self.classifier = ( 1597 | nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() 1598 | ) 1599 | 1600 | # Initialize weights and apply final processing 1601 | self.post_init() 1602 | 1603 | @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING) 1604 | @add_code_sample_docstrings( 1605 | checkpoint=_IMAGE_CLASS_CHECKPOINT, 1606 | output_type=ImageClassifierOutput, 1607 | config_class=_CONFIG_FOR_DOC, 1608 | expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, 1609 | ) 1610 | def forward( 1611 | self, 1612 | pixel_values: Optional[torch.Tensor] = None, 1613 | labels: Optional[torch.Tensor] = None, 1614 | output_attentions: Optional[bool] = None, 1615 | output_hidden_states: Optional[bool] = None, 1616 | return_dict: Optional[bool] = None, 1617 | ) -> Union[tuple, ImageClassifierOutput]: 1618 | r""" 1619 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1620 | Labels for computing the image classification/regression loss. Indices should be in `[0, ..., 1621 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1622 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1623 | """ 1624 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1625 | output_hidden_states = ( 1626 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1627 | ) 1628 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1629 | 1630 | outputs = self.vision_model( 1631 | pixel_values, 1632 | output_attentions=output_attentions, 1633 | output_hidden_states=output_hidden_states, 1634 | return_dict=return_dict, 1635 | ) 1636 | 1637 | sequence_output = outputs[0] 1638 | 1639 | # average pool the patch tokens 1640 | sequence_output = torch.mean(sequence_output[:, 1:, :], dim=1) 1641 | # apply classifier 1642 | logits = self.classifier(sequence_output) 1643 | 1644 | loss = None 1645 | if labels is not None: 1646 | # move labels to correct device to enable model parallelism 1647 | labels = labels.to(logits.device) 1648 | if self.config.problem_type is None: 1649 | if self.num_labels == 1: 1650 | self.config.problem_type = "regression" 1651 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 1652 | self.config.problem_type = "single_label_classification" 1653 | else: 1654 | self.config.problem_type = "multi_label_classification" 1655 | 1656 | if self.config.problem_type == "regression": 1657 | loss_fct = MSELoss() 1658 | if self.num_labels == 1: 1659 | loss = loss_fct(logits.squeeze(), labels.squeeze()) 1660 | else: 1661 | loss = loss_fct(logits, labels) 1662 | elif self.config.problem_type == "single_label_classification": 1663 | loss_fct = CrossEntropyLoss() 1664 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1665 | elif self.config.problem_type == "multi_label_classification": 1666 | loss_fct = BCEWithLogitsLoss() 1667 | loss = loss_fct(logits, labels) 1668 | 1669 | if not return_dict: 1670 | output = (logits,) + outputs[2:] 1671 | return ((loss,) + output) if loss is not None else output 1672 | 1673 | return ImageClassifierOutput( 1674 | loss=loss, 1675 | logits=logits, 1676 | hidden_states=outputs.hidden_states, 1677 | attentions=outputs.attentions, 1678 | ) 1679 | --------------------------------------------------------------------------------