├── .gitignore ├── LICENSE ├── README.md ├── assets ├── cir_candi_1.png ├── cir_candi_2.png ├── cir_query.png ├── corpus │ ├── 000000032077.jpg │ ├── 000000050549.jpg │ ├── 000000098911.jpg │ ├── 000000156031.jpg │ ├── 000000244097.jpg │ ├── 000000272130.jpg │ ├── 000000275230.jpg │ ├── 000000311907.jpg │ ├── 000000357304.jpg │ ├── 000000478916.jpg │ └── 000000545037.jpg ├── query │ └── 000000530944.jpg ├── res-ft-mmeb.png ├── res-scaling.png ├── res-zs-cir.png └── res-zs-mmeb.png ├── eval ├── data │ ├── circo_corpus.jsonl │ ├── circo_query.jsonl │ ├── fashioniq_dress_corpus.jsonl │ ├── fashioniq_dress_query_val.jsonl │ ├── fashioniq_shirt_corpus.jsonl │ ├── fashioniq_shirt_query_val.jsonl │ ├── fashioniq_toptee_corpus.jsonl │ └── fashioniq_toptee_query_val.jsonl ├── eval_Circo.py ├── eval_fashioniq.py ├── flag_dataset.py ├── flag_mmret.py └── results │ ├── mmret_base_circo.json │ └── mmret_large_circo.json ├── modeling_MMRet_CLIP.py └── retrieval_demo.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | my_util/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 JUNJIE99 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
30 | 31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
41 |
42 |
43 | ## News
44 | ```2025-5-20``` 🚀🚀 We are excited to announce the release of **BGE-VL-v1.5**! BGE-VL-v1.5 is developed based on the BGE-VL-MLLM-S1 model and further trained on additional multi-task multimodal data synthesized and collected by our team. Our [zero-shot model](https://huggingface.co/BAAI/BGE-VL-v1.5-zs) achieves state-of-the-art zero-shot performance on the [MMEB leaderboard](https://huggingface.co/spaces/TIGER-Lab/MMEB-Leaderboard). Furthermore, [the fine-tuned version](https://huggingface.co/BAAI/BGE-VL-v1.5-mmeb) achieves the best results among all methods using the same base model (Llava-1.6-7B), reaching a new high in retrieval tasks with a Recall@1 of 72.2%.
45 |
46 | ```2025-5-16``` 🎉🎉 We are pleased to share that our works, **MegaPairs** ([repo](https://github.com/VectorSpaceLab/MegaPairs), [paper](https://arxiv.org/abs/2412.14475)) and **Vis-IR** ([repo](https://github.com/VectorSpaceLab/Vis-IR), [paper](https://arxiv.org/pdf/2502.11431)), have been accepted to the **ACL 2025 Main Conference**!
47 |
48 | ```2025-4-13``` 🎉🎉 We have uploaded our MegaPairs dataset to [🤗Hugging Face](https://huggingface.co/datasets/JUNJIE99/MegaPairs), which contains over 26 million multimodal retrieval instruction-tuning triplets. To reduce upload time and enhance data accessibility, we resized all images to a resolution of 512 × 512 instead of using their original size. This adjustment has minimal impact on performance, considering that most vision-language models (e.g., CLIP) use even smaller input image sizes. [Dataset Card](https://github.com/VectorSpaceLab/MegaPairs?tab=readme-ov-file#megapairs-dataset-card)
49 |
50 | ```2025-4-2``` 🌟🌟 BGE-VL models are also available on [WiseModel](https://www.wisemodel.cn/models/JUNJIE99/BGE-VL-large).
51 |
52 | ```2025-3-6``` 📰📰 Thank you to [SyncedTech (机器之心)](https://mp.weixin.qq.com/s/iw9BmSDwv6NYtD7pkC5kxQ), [QbitAI (量子位)](https://mp.weixin.qq.com/s/r_zWAZ0ir5732OfIrEsDtg), and [AI Era (新智元)](https://mp.weixin.qq.com/s/FZwKYJnx_78YDAEreu1edg) for reporting on our work!
53 |
54 | ```2025-3-4``` 🚀🚀 We have released the BGE-VL-MLLM models on Huggingface: [BGE-VL-MLLM-S1](https://huggingface.co/BAAI/BGE-VL-MLLM-S1) and [BGE-VL-MLLM-S2](https://huggingface.co/BAAI/BGE-VL-MLLM-S2). **BGE-VL-MLLM-S1** is trained exclusively on our MegaPairs dataset, achieving outstanding performance in composed image retrieval, with an 8.1% improvement on the CIRCO benchmark (mAP@5) over the previous state-of-the-art. **BGE-VL-MLLM-S2** builds on BGE-VL-MLLM-S1 with an additional epoch of fine-tuning on the MMEB benchmark training set, delivering enhanced performance across a broader range of multimodal embedding tasks.
55 |
56 | ```2024-12-27``` 🚀🚀 BGE-VL-CLIP models are released on Huggingface: [BGE-VL-base](https://huggingface.co/BAAI/BGE-VL-base) and [BGE-VL-large](https://huggingface.co/BAAI/BGE-VL-large).
57 |
58 | ```2024-12-19``` 🎉🎉 Release our paper: [MegaPairs: Massive Data Synthesis For Universal Multimodal Retrieval](https://arxiv.org/pdf/2412.14475).
59 |
60 | ## Release Plan
61 | - [x] Paper
62 | - [x] BGE-VL-base and BGE-VL-large models
63 | - [x] BGE-VL-MLLM model
64 | - [x] MegaPairs Dataset
65 | - [x] Evaluation code examples
66 | - [ ] Fine-tuning code
67 |
68 |
69 | ## Introduction
70 | In this work, we introduce **MegaPairs**, a novel data synthesis method that leverages open-domain images to create *heterogeneous KNN triplets* for universal multimodal retrieval. Our MegaPairs dataset contains over 26 million triplets, and we have trained a series of multimodal retrieval models, **BGE-VL**, including BGE-VL-CLIP (base and large) and BGE-VL-MLLM.
71 |
72 | BGE-VL achieve state-of-the-art performance on four popular zero-shot composed image retrieval benchmarks and the massive multimodal embedding benchmark (MMEB). Extensive experiments demonstrate the ***efficiency, scalability, and generalization*** features of MegaPairs. Please refer to our [paper](https://arxiv.org/abs/2412.14475) for more details.
73 |
74 |
75 |
76 | ## Model Usage
77 |
78 | ### 1. BGE-VL-CLIP Models
79 | You can easily use BGE-VL-CLIP models based on ```transformers```
80 | > Our code works well on transformers==4.45.2, and we recommend using this version.
81 | ```python
82 | import torch
83 | from transformers import AutoModel
84 |
85 | MODEL_NAME = "BAAI/BGE-VL-base" # or "BAAI/BGE-VL-large"
86 |
87 | model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True) # You must set trust_remote_code=True
88 | model.set_processor(MODEL_NAME)
89 | model.eval()
90 |
91 | with torch.no_grad():
92 | query = model.encode(
93 | images = "./assets/cir_query.png",
94 | text = "Make the background dark, as if the camera has taken the photo at night"
95 | )
96 |
97 | candidates = model.encode(
98 | images = ["./assets/cir_candi_1.png", "./assets/cir_candi_2.png"]
99 | )
100 |
101 | scores = query @ candidates.T
102 | print(scores)
103 | ```
104 |
105 | See the [demo](./retrieval_demo.ipynb) for a complete example of using BGE-VL for multimodel retrieval.
106 |
107 |
108 | ### 2. BGE-VL-MLLM Models
109 |
110 | > Our code works well on transformers==4.45.2, and we recommend using this version.
111 |
112 | ```python
113 | import torch
114 | from transformers import AutoModel
115 | from PIL import Image
116 |
117 | MODEL_NAME= "BAAI/BGE-VL-MLLM-S1"
118 |
119 | model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)
120 | model.eval()
121 | model.cuda()
122 |
123 | with torch.no_grad():
124 | model.set_processor(MODEL_NAME)
125 |
126 | query_inputs = model.data_process(
127 | text="Make the background dark, as if the camera has taken the photo at night",
128 | images="./assets/cir_query.png",
129 | q_or_c="q",
130 | task_instruction="Retrieve the target image that best meets the combined criteria by using both the provided image and the image retrieval instructions: "
131 | )
132 |
133 | candidate_inputs = model.data_process(
134 | images=["./assets/cir_candi_1.png", "./assets/cir_candi_2.png"],
135 | q_or_c="c",
136 | )
137 |
138 | query_embs = model(**query_inputs, output_hidden_states=True)[:, -1, :]
139 | candi_embs = model(**candidate_inputs, output_hidden_states=True)[:, -1, :]
140 |
141 | query_embs = torch.nn.functional.normalize(query_embs, dim=-1)
142 | candi_embs = torch.nn.functional.normalize(candi_embs, dim=-1)
143 |
144 | scores = torch.matmul(query_embs, candi_embs.T)
145 | print(scores)
146 | ```
147 |
148 | ## MegaPairs Dataset Card
149 |
150 | We are excited to release the **MegaPairs** dataset on [Hugging Face](https://huggingface.co/datasets/JUNJIE99/MegaPairs), which contains over **26 million training samples** tailored for composed image retrieval and universal multimodal retrieval tasks.
151 |
152 | ### Dataset Structure
153 |
154 | Each entry in the dataset consists of the following fields:
155 |
156 | - **q_img**: `str`
157 | The file path to the query image.
158 |
159 | - **q_text**: `list`
160 | A list of textual query statements related to the query image. During training, you can randomly select one statement from this list.
161 |
162 | - **t_img**: `str`
163 | The file path to the target image, which serves as the **positive example** for the combination of `q_img` and `q_text`.
164 |
165 | - **hns**: `list`
166 | A list of file paths for **hard negative sample** images. These are challenging distractors that are visually or semantically similar to the query. It is recommended to include at least one hard negative sample during training, with **`hns[0]` (the query image itself)** being a mandatory choice. In our experiments, we used **four hard negative samples** per query.
167 |
168 |
169 | ### Usage
170 |
171 | The dataset is available for download and exploration on [Hugging Face](https://huggingface.co/datasets/JUNJIE99/MegaPairs). We encourage researchers and practitioners to leverage this dataset to advance multimodal retrieval research and systems.
172 |
173 | ## Model Performance
174 | ### Zero-Shot Composed Image Retrieval
175 |
176 | BGE-VL sets a new performance benchmark in zero-shot composed image retrieval tasks. On the CIRCO benchmark, our BGE-VL-base model, with only 149 million parameters, surpasses all previous models, including those with 50 times more parameters. Additionally, BGE-VL-MLLM achieves an 8.1% improvement over the previous state-of-the-art model.
177 |
178 |
179 |
180 | ### Zero-Shot Performance on MMEB
181 |
182 | BGE-VL-MLLM achieves state-of-the-art zero-shot performance on the Massive Multimodal Embedding Benchmark (MMEB), despite being trained only on the ImageText-to-Image paradigm. This demonstrates the excellent generalization capability of MegaPairs for multimodal embedding.
183 |
184 |
185 |
186 | ### Fine-Tuning Performance on MMEB
187 |
188 | After fine-tuning on downstream tasks, BGE-VL-MLLM maintains its leading performance. Notably, it surpasses the previous state-of-the-art by 7.1% on the MMEB out-of-distribution (OOD) set. These results demonstrate the robust generalization capability of BGE-VL-MLLM and highlight the potential of MegaPairs as foundational training data for universal multimodal embedding.
189 |
190 |
191 |
192 | ### Performance Scaling
193 | MegaPairs showcases **scalability**: BGE-VL-base improves as training data increases. It also demonstrates **efficiency**: with just 0.5M training samples, BGE-VL-base significantly outperforms MagicLens, which uses the same CLIP-base backbone and was trained on 36.7M samples.
194 |
195 |
196 |
197 |
198 | ## License
199 | The annotations for MegaPairs and the BGE-VL models are released under the [MIT License](LICENSE). The images in MegaPairs originate from the [Recap-Datacomp](https://huggingface.co/datasets/UCSC-VLAA/Recap-DataComp-1B), which is released under the CC BY 4.0 license.
200 |
201 |
202 |
203 | ## Citation
204 | If you find this repository useful, please consider giving a star ⭐ and citation
205 |
206 | ```
207 | @article{zhou2024megapairs,
208 | title={MegaPairs: Massive Data Synthesis For Universal Multimodal Retrieval},
209 | author={Zhou, Junjie and Liu, Zheng and Liu, Ze and Xiao, Shitao and Wang, Yueze and Zhao, Bo and Zhang, Chen Jason and Lian, Defu and Xiong, Yongping},
210 | journal={arXiv preprint arXiv:2412.14475},
211 | year={2024}
212 | }
213 | ```
214 |
--------------------------------------------------------------------------------
/assets/cir_candi_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/cir_candi_1.png
--------------------------------------------------------------------------------
/assets/cir_candi_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/cir_candi_2.png
--------------------------------------------------------------------------------
/assets/cir_query.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/cir_query.png
--------------------------------------------------------------------------------
/assets/corpus/000000032077.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/corpus/000000032077.jpg
--------------------------------------------------------------------------------
/assets/corpus/000000050549.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/corpus/000000050549.jpg
--------------------------------------------------------------------------------
/assets/corpus/000000098911.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/corpus/000000098911.jpg
--------------------------------------------------------------------------------
/assets/corpus/000000156031.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/corpus/000000156031.jpg
--------------------------------------------------------------------------------
/assets/corpus/000000244097.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/corpus/000000244097.jpg
--------------------------------------------------------------------------------
/assets/corpus/000000272130.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/corpus/000000272130.jpg
--------------------------------------------------------------------------------
/assets/corpus/000000275230.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/corpus/000000275230.jpg
--------------------------------------------------------------------------------
/assets/corpus/000000311907.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/corpus/000000311907.jpg
--------------------------------------------------------------------------------
/assets/corpus/000000357304.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/corpus/000000357304.jpg
--------------------------------------------------------------------------------
/assets/corpus/000000478916.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/corpus/000000478916.jpg
--------------------------------------------------------------------------------
/assets/corpus/000000545037.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/corpus/000000545037.jpg
--------------------------------------------------------------------------------
/assets/query/000000530944.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/query/000000530944.jpg
--------------------------------------------------------------------------------
/assets/res-ft-mmeb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/res-ft-mmeb.png
--------------------------------------------------------------------------------
/assets/res-scaling.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/res-scaling.png
--------------------------------------------------------------------------------
/assets/res-zs-cir.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/res-zs-cir.png
--------------------------------------------------------------------------------
/assets/res-zs-mmeb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VectorSpaceLab/MegaPairs/ad99aca2f2e07d58e61e0c1825ea4ed6208cfc7c/assets/res-zs-mmeb.png
--------------------------------------------------------------------------------
/eval/eval_Circo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import faiss
3 | import torch
4 | import logging
5 | import datasets
6 | import numpy as np
7 | from tqdm import tqdm
8 | from typing import Optional
9 | from dataclasses import dataclass, field
10 | from transformers import HfArgumentParser
11 | from flag_mmret import Flag_mmret
12 | import json
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | @dataclass
18 | class Args:
19 | model_name: str = field(
20 | default="BAAI/BGE-VL-large",
21 | metadata={'help': 'Model Name'}
22 | )
23 | result_save_path: str = field(
24 | default="./eval/mmret_large_circo.json",
25 | metadata={'help': 'Where to save the results.'}
26 | )
27 | image_dir: str = field(
28 | default="YOUR_COCO_IMAGE_DIRECTORY",
29 | metadata={'help': 'Where the images located on.'}
30 | )
31 | fp16: bool = field(
32 | default=False,
33 | metadata={'help': 'Use fp16 in inference?'}
34 | )
35 | max_query_length: int = field(
36 | default=64,
37 | metadata={'help': 'Max query length.'}
38 | )
39 | max_passage_length: int = field(
40 | default=77,
41 | metadata={'help': 'Max passage length.'}
42 | )
43 | batch_size: int = field(
44 | default=256,
45 | metadata={'help': 'Inference batch size.'}
46 | )
47 | index_factory: str = field(
48 | default="Flat",
49 | metadata={'help': 'Faiss index factory.'}
50 | )
51 | k: int = field(
52 | default=100,
53 | metadata={'help': 'How many neighbors to retrieve?'}
54 | )
55 | save_embedding: bool = field(
56 | default=False,
57 | metadata={'help': 'Save embeddings in memmap at save_dir?'}
58 | )
59 | load_embedding: bool = field(
60 | default=False,
61 | metadata={'help': 'Load embeddings from save_dir?'}
62 | )
63 | save_path: str = field(
64 | default="embeddings.memmap",
65 | metadata={'help': 'Path to save embeddings.'}
66 | )
67 |
68 |
69 |
70 | def index(model: Flag_mmret, corpus: datasets.Dataset, batch_size: int = 256, max_length: int=512, index_factory: str = "Flat", save_path: str = None, save_embedding: bool = False, load_embedding: bool = False):
71 | """
72 | 1. Encode the entire corpus into dense embeddings;
73 | 2. Create faiss index;
74 | 3. Optionally save embeddings.
75 | """
76 | if load_embedding:
77 | test = model.encode("test")
78 | dtype = test.dtype
79 | dim = len(test)
80 |
81 | corpus_embeddings = np.memmap(
82 | save_path,
83 | mode="r",
84 | dtype=dtype
85 | ).reshape(-1, dim)
86 |
87 | else:
88 |
89 | corpus_embeddings = model.encode_corpus(corpus, batch_size=batch_size, max_length=max_length, corpus_type='image')
90 |
91 | dim = corpus_embeddings.shape[-1]
92 |
93 | if save_embedding:
94 | logger.info(f"saving embeddings at {save_path}...")
95 | memmap = np.memmap(
96 | save_path,
97 | shape=corpus_embeddings.shape,
98 | mode="w+",
99 | dtype=corpus_embeddings.dtype
100 | )
101 |
102 | length = corpus_embeddings.shape[0]
103 | # add in batch
104 | save_batch_size = 10000
105 | if length > save_batch_size:
106 | for i in tqdm(range(0, length, save_batch_size), leave=False, desc="Saving Embeddings"):
107 | j = min(i + save_batch_size, length)
108 | memmap[i: j] = corpus_embeddings[i: j]
109 | else:
110 | memmap[:] = corpus_embeddings
111 |
112 | # create faiss index
113 | faiss_index = faiss.index_factory(dim, index_factory, faiss.METRIC_INNER_PRODUCT)
114 |
115 |
116 | if model.device == torch.device("cuda"):
117 | # co = faiss.GpuClonerOptions()
118 | co = faiss.GpuMultipleClonerOptions()
119 | co.useFloat16 = True
120 | # faiss_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, faiss_index, co)
121 | faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co)
122 |
123 | # NOTE: faiss only accepts float32
124 | logger.info("Adding embeddings...")
125 | corpus_embeddings = corpus_embeddings.astype(np.float32)
126 | faiss_index.train(corpus_embeddings)
127 | faiss_index.add(corpus_embeddings)
128 |
129 |
130 |
131 | return faiss_index
132 |
133 |
134 | def search(model: Flag_mmret, queries: datasets, faiss_index: faiss.Index, k:int = 100, batch_size: int = 256, max_length: int=512):
135 | """
136 | 1. Encode queries into dense embeddings;
137 | 2. Search through faiss index
138 | """
139 | query_embeddings = model.encode_queries([queries["q_text"], queries["q_img"]],
140 | batch_size=batch_size,
141 | max_length=max_length,
142 | query_type='mm_it')
143 |
144 |
145 | query_size = len(query_embeddings)
146 |
147 | all_scores = []
148 | all_indices = []
149 |
150 | for i in tqdm(range(0, query_size, batch_size), desc="Searching"):
151 | j = min(i + batch_size, query_size)
152 | query_embedding = query_embeddings[i: j]
153 | score, indice = faiss_index.search(query_embedding.astype(np.float32), k=k)
154 | all_scores.append(score)
155 | all_indices.append(indice)
156 |
157 | all_scores = np.concatenate(all_scores, axis=0)
158 | all_indices = np.concatenate(all_indices, axis=0)
159 | return all_scores, all_indices
160 |
161 |
162 | def main():
163 | parser = HfArgumentParser([Args])
164 | args: Args = parser.parse_args_into_dataclasses()[0]
165 |
166 | print(f"Results will be saved in {args.result_save_path}")
167 | eval_data = datasets.load_dataset('json', data_files="./eval/data/circo_query.jsonl", split='train')
168 | image_corpus_test = datasets.load_dataset('json', data_files="./eval/data/circo_corpus.jsonl", split='train')
169 |
170 | model = Flag_mmret(model_name=args.model_name,
171 | normlized = True,
172 | image_dir=args.image_dir,
173 | use_fp16=False,
174 | )
175 |
176 |
177 | faiss_index = index(
178 | model=model,
179 | corpus=image_corpus_test,
180 | batch_size=args.batch_size,
181 | max_length=args.max_passage_length,
182 | index_factory=args.index_factory,
183 | save_path=args.save_path,
184 | save_embedding=args.save_embedding,
185 | load_embedding=args.load_embedding
186 | )
187 |
188 | scores, indices = search(
189 | model=model,
190 | queries=eval_data,
191 | faiss_index=faiss_index,
192 | k=args.k,
193 | batch_size=args.batch_size,
194 | max_length=args.max_query_length
195 | )
196 |
197 |
198 | retrieval_results = []
199 | for indice in indices:
200 | # filter invalid indices
201 | indice = indice[indice != -1].tolist()
202 | retrieval_results.append(image_corpus_test[indice]["content"])
203 |
204 | ########## results in test corpus #########
205 | q_images = eval_data["q_img"]
206 |
207 | q_ids = []
208 | for _img in q_images:
209 | _id = os.path.basename(_img)
210 | _id = os.path.splitext(_id)[0]
211 | q_ids.append(_id)
212 |
213 | pairids = eval_data["id"]
214 | results = {}
215 | for pairid, re_results, q_img in zip(pairids, retrieval_results, q_images):
216 | id = str(pairid)
217 | top_50_results = re_results[0:50]
218 | results[id] = top_50_results
219 |
220 | with open(args.result_save_path, "w") as f:
221 | json.dump(results, f)
222 |
223 |
224 | if __name__ == "__main__":
225 | main()
--------------------------------------------------------------------------------
/eval/eval_fashioniq.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['CUDA_VISIBLE_DEVICES'] = "0"
3 |
4 | import sys
5 | print(os.getcwd())
6 |
7 | import faiss
8 | import torch
9 | import logging
10 | import datasets
11 | import numpy as np
12 | from tqdm import tqdm
13 | from typing import Optional
14 | from dataclasses import dataclass, field
15 | from transformers import HfArgumentParser
16 | from flag_mmret import Flag_mmret
17 | import json
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | @dataclass
23 | class Args:
24 | model_name: str = field(
25 | default="BAAI/BGE-VL-large",
26 | metadata={'help': 'Model Name'}
27 | )
28 | image_dir: str = field(
29 | default="YOUR_FASHIONIQ_IMAGE_DIRECTORY",
30 | metadata={'help': 'Where are the images located on.'}
31 | )
32 | fp16: bool = field(
33 | default=False,
34 | metadata={'help': 'Use fp16 in inference?'}
35 | )
36 | max_query_length: int = field(
37 | default=64,
38 | metadata={'help': 'Max query length.'}
39 | )
40 | max_passage_length: int = field(
41 | default=77,
42 | metadata={'help': 'Max passage length.'}
43 | )
44 | batch_size: int = field(
45 | default=256,
46 | metadata={'help': 'Inference batch size.'}
47 | )
48 | index_factory: str = field(
49 | default="Flat",
50 | metadata={'help': 'Faiss index factory.'}
51 | )
52 | k: int = field(
53 | default=100,
54 | metadata={'help': 'How many neighbors to retrieve?'}
55 | )
56 | save_embedding: bool = field(
57 | default=False,
58 | metadata={'help': 'Save embeddings in memmap at save_dir?'}
59 | )
60 | load_embedding: bool = field(
61 | default=False,
62 | metadata={'help': 'Load embeddings from save_dir?'}
63 | )
64 | save_path: str = field(
65 | default="embeddings.memmap",
66 | metadata={'help': 'Path to save embeddings.'}
67 | )
68 |
69 |
70 |
71 | def index(model: Flag_mmret, corpus: datasets.Dataset, batch_size: int = 256, max_length: int=512, index_factory: str = "Flat", save_path: str = None, save_embedding: bool = False, load_embedding: bool = False):
72 | """
73 | 1. Encode the entire corpus into dense embeddings;
74 | 2. Create faiss index;
75 | 3. Optionally save embeddings.
76 | """
77 | if load_embedding:
78 | test = model.encode("test")
79 | dtype = test.dtype
80 | dim = len(test)
81 |
82 | corpus_embeddings = np.memmap(
83 | save_path,
84 | mode="r",
85 | dtype=dtype
86 | ).reshape(-1, dim)
87 |
88 | else:
89 |
90 | corpus_embeddings = model.encode_corpus(corpus, batch_size=batch_size, max_length=max_length, corpus_type='image')
91 |
92 | dim = corpus_embeddings.shape[-1]
93 |
94 | if save_embedding:
95 | logger.info(f"saving embeddings at {save_path}...")
96 | memmap = np.memmap(
97 | save_path,
98 | shape=corpus_embeddings.shape,
99 | mode="w+",
100 | dtype=corpus_embeddings.dtype
101 | )
102 |
103 | length = corpus_embeddings.shape[0]
104 | # add in batch
105 | save_batch_size = 10000
106 | if length > save_batch_size:
107 | for i in tqdm(range(0, length, save_batch_size), leave=False, desc="Saving Embeddings"):
108 | j = min(i + save_batch_size, length)
109 | memmap[i: j] = corpus_embeddings[i: j]
110 | else:
111 | memmap[:] = corpus_embeddings
112 |
113 | # create faiss index
114 | faiss_index = faiss.index_factory(dim, index_factory, faiss.METRIC_INNER_PRODUCT)
115 |
116 |
117 | if model.device == torch.device("cuda"):
118 | # co = faiss.GpuClonerOptions()
119 | co = faiss.GpuMultipleClonerOptions()
120 | co.useFloat16 = True
121 | # faiss_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, faiss_index, co)
122 | faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co)
123 |
124 | # NOTE: faiss only accepts float32
125 | logger.info("Adding embeddings...")
126 | corpus_embeddings = corpus_embeddings.astype(np.float32)
127 | faiss_index.train(corpus_embeddings)
128 | faiss_index.add(corpus_embeddings)
129 |
130 |
131 |
132 | return faiss_index
133 |
134 |
135 | def search(model: Flag_mmret, queries: datasets, faiss_index: faiss.Index, k:int = 100, batch_size: int = 256, max_length: int=512):
136 | """
137 | 1. Encode queries into dense embeddings;
138 | 2. Search through faiss index
139 | """
140 | query_embeddings = model.encode_queries([queries["q_text"], queries["q_img"]],
141 | batch_size=batch_size,
142 | max_length=max_length,
143 | query_type='mm_it')
144 |
145 | query_size = len(query_embeddings)
146 |
147 | all_scores = []
148 | all_indices = []
149 |
150 | for i in tqdm(range(0, query_size, batch_size), desc="Searching"):
151 | j = min(i + batch_size, query_size)
152 | query_embedding = query_embeddings[i: j]
153 | score, indice = faiss_index.search(query_embedding.astype(np.float32), k=k)
154 | all_scores.append(score)
155 | all_indices.append(indice)
156 |
157 | all_scores = np.concatenate(all_scores, axis=0)
158 | all_indices = np.concatenate(all_indices, axis=0)
159 | return all_scores, all_indices
160 |
161 |
162 | def evaluate(preds, labels, cutoffs=[1,5,10,20,50,100]):
163 | """
164 | Evaluate MRR and Recall at cutoffs.
165 | """
166 | metrics = {}
167 |
168 | # MRR
169 | mrrs = np.zeros(len(cutoffs))
170 | for pred, label in zip(preds, labels):
171 | jump = False
172 | for i, x in enumerate(pred, 1):
173 | if x in label:
174 | for k, cutoff in enumerate(cutoffs):
175 | if i <= cutoff:
176 | mrrs[k] += 1 / i
177 | jump = True
178 | if jump:
179 | break
180 | mrrs /= len(preds)
181 | for i, cutoff in enumerate(cutoffs):
182 | mrr = mrrs[i]
183 | metrics[f"MRR@{cutoff}"] = mrr
184 |
185 | # Recall
186 | recalls = np.zeros(len(cutoffs))
187 | for pred, label in zip(preds, labels):
188 | if not isinstance(label, list):
189 | label = [label]
190 | for k, cutoff in enumerate(cutoffs):
191 | recall = np.intersect1d(label, pred[:cutoff])
192 | recalls[k] += len(recall) / len(label)
193 | recalls /= len(preds)
194 | for i, cutoff in enumerate(cutoffs):
195 | recall = recalls[i]
196 | metrics[f"Recall@{cutoff}"] = recall
197 |
198 | return metrics
199 |
200 | def main():
201 | parser = HfArgumentParser([Args])
202 | args: Args = parser.parse_args_into_dataclasses()[0]
203 |
204 | model = Flag_mmret(model_name=args.model_name,
205 | normlized = True,
206 | image_dir=args.image_dir,
207 | use_fp16=False,
208 | )
209 |
210 | eval_data = datasets.load_dataset('json', data_files="./eval/data/fashioniq_shirt_query_val.jsonl", split='train')
211 | image_corpus = datasets.load_dataset('json', data_files="./eval/data/fashioniq_shirt_corpus.jsonl", split='train')
212 |
213 | faiss_index = index(
214 | model=model,
215 | corpus=image_corpus,
216 | batch_size=args.batch_size,
217 | max_length=args.max_passage_length,
218 | index_factory=args.index_factory,
219 | save_path=args.save_path,
220 | save_embedding=args.save_embedding,
221 | load_embedding=args.load_embedding
222 | )
223 |
224 | scores, indices = search(
225 | model=model,
226 | queries=eval_data,
227 | faiss_index=faiss_index,
228 | k=args.k,
229 | batch_size=args.batch_size,
230 | max_length=args.max_query_length
231 | )
232 |
233 |
234 |
235 | retrieval_results = []
236 | for indice in indices:
237 | # filter invalid indices
238 | indice = indice[indice != -1].tolist()
239 | retrieval_results.append(image_corpus[indice]["content"])
240 |
241 | ground_truths = []
242 | for sample in eval_data:
243 | ground_truths.append(sample["positive_key"])
244 |
245 | metrics_shirt = evaluate(retrieval_results, ground_truths)
246 | print("FashionIQ tasks (shirt):")
247 | print(metrics_shirt)
248 |
249 |
250 |
251 |
252 |
253 | eval_data = datasets.load_dataset('json', data_files="./eval/data/fashioniq_dress_query_val.jsonl", split='train')
254 | image_corpus = datasets.load_dataset('json', data_files="./eval/data/fashioniq_dress_corpus.jsonl", split='train')
255 |
256 | faiss_index = index(
257 | model=model,
258 | corpus=image_corpus,
259 | batch_size=args.batch_size,
260 | max_length=args.max_passage_length,
261 | index_factory=args.index_factory,
262 | save_path=args.save_path,
263 | save_embedding=args.save_embedding,
264 | load_embedding=args.load_embedding
265 | )
266 |
267 | scores, indices = search(
268 | model=model,
269 | queries=eval_data,
270 | faiss_index=faiss_index,
271 | k=args.k,
272 | batch_size=args.batch_size,
273 | max_length=args.max_query_length
274 | )
275 |
276 |
277 |
278 | retrieval_results = []
279 | for indice in indices:
280 | # filter invalid indices
281 | indice = indice[indice != -1].tolist()
282 | retrieval_results.append(image_corpus[indice]["content"])
283 |
284 | ground_truths = []
285 | for sample in eval_data:
286 | ground_truths.append(sample["positive_key"])
287 |
288 | metrics_dress = evaluate(retrieval_results, ground_truths)
289 | print("FashionIQ tasks (dress):")
290 | print(metrics_dress)
291 |
292 |
293 | eval_data = datasets.load_dataset('json', data_files="./eval/data/fashioniq_toptee_query_val.jsonl", split='train')
294 | image_corpus = datasets.load_dataset('json', data_files="./eval/data/fashioniq_toptee_corpus.jsonl", split='train')
295 |
296 |
297 | faiss_index = index(
298 | model=model,
299 | corpus=image_corpus,
300 | batch_size=args.batch_size,
301 | max_length=args.max_passage_length,
302 | index_factory=args.index_factory,
303 | save_path=args.save_path,
304 | save_embedding=args.save_embedding,
305 | load_embedding=args.load_embedding
306 | )
307 |
308 | scores, indices = search(
309 | model=model,
310 | queries=eval_data,
311 | faiss_index=faiss_index,
312 | k=args.k,
313 | batch_size=args.batch_size,
314 | max_length=args.max_query_length
315 | )
316 |
317 |
318 |
319 | retrieval_results = []
320 | for indice in indices:
321 | # filter invalid indices
322 | indice = indice[indice != -1].tolist()
323 | retrieval_results.append(image_corpus[indice]["content"])
324 |
325 | ground_truths = []
326 | for sample in eval_data:
327 | ground_truths.append(sample["positive_key"])
328 |
329 | metrics_toptee = evaluate(retrieval_results, ground_truths)
330 | print("FashionIQ tasks (toptee):")
331 | print(metrics_toptee)
332 |
333 |
334 |
335 | print(f"shirt: {metrics_shirt['Recall@10'] * 100:.2f} / {metrics_shirt['Recall@50'] * 100:.2f}")
336 | print(f"dress: {metrics_dress['Recall@10'] * 100:.2f} / {metrics_dress['Recall@50'] * 100:.2f}")
337 | print(f"toptee: {metrics_toptee['Recall@10'] * 100:.2f} / {metrics_toptee['Recall@50'] * 100:.2f}")
338 | print(f"overall: {(metrics_shirt['Recall@10'] + metrics_dress['Recall@10'] + metrics_toptee['Recall@10']) * 100 / 3:.2f} / {(metrics_shirt['Recall@50'] + metrics_dress['Recall@50'] + metrics_toptee['Recall@50']) * 100 / 3:.2f}")
339 |
340 |
341 | if __name__ == "__main__":
342 | main()
--------------------------------------------------------------------------------
/eval/flag_dataset.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os.path
3 | import random
4 | from dataclasses import dataclass
5 | from typing import Iterator
6 |
7 | import datasets
8 | from torch.utils.data import Dataset, IterableDataset
9 | from transformers import DataCollatorWithPadding
10 | from transformers import PreTrainedTokenizer, BatchEncoding
11 | from transformers import CLIPImageProcessor
12 |
13 |
14 | from PIL import Image
15 | import json
16 | import torch
17 | import torch.distributed
18 |
19 | from io import BytesIO
20 | import warnings
21 |
22 | class MMIT_Dataset(Dataset):
23 | def __init__(self, captions, image_ids, image_dir, image_processor) -> None:
24 | img_id_example = image_ids[0]
25 | img_id_example = str(img_id_example)
26 | if img_id_example[-4:] in [".jpg", ".png", "JPEG"]:
27 | self.image_path =[os.path.join(image_dir, str(id)) for id in image_ids]
28 | else:
29 | warnings.warn("Not found file extention in image_ids, will forcefully add '.jpg'.", UserWarning)
30 | self.image_path =[os.path.join(image_dir, str(id) + '.jpg') for id in image_ids]
31 | self.captions = captions
32 | self.image_processor = image_processor
33 |
34 | def __getitem__(self, item):
35 | pil_data = Image.open(self.image_path[item])
36 | pil_data = pil_data.convert('RGB')
37 | image = self.image_processor(pil_data)
38 |
39 |
40 |
41 |
42 | caption = self.captions[item]
43 |
44 | return caption, image
45 |
46 | def __len__(self):
47 | return len(self.image_path)
48 |
49 |
50 | class MMIT_Collator:
51 | def __init__(self, tokenizer, caption_max_len):
52 | self.tokenizer = tokenizer
53 | self.caption_max_len = caption_max_len
54 |
55 |
56 |
57 | def __call__(self, features):
58 | caption = [f[0] for f in features]
59 | images = [f[1] for f in features]
60 |
61 | c_collated = self.tokenizer(
62 | caption,
63 | truncation=True,
64 | padding = True,
65 | max_length=self.caption_max_len,
66 | return_tensors="pt",
67 | )
68 |
69 | # i_collated = torch.stack(images)
70 |
71 | # for clip model
72 | images = [f["pixel_values"][0] for f in images]
73 | images = [torch.tensor(arr) for arr in images]
74 | i_collated = torch.stack(images)
75 | ##clip_end
76 |
77 | return c_collated, i_collated
78 |
79 | class Image_Dataset(Dataset):
80 | def __init__(self, image_ids, image_dir, image_processor) -> None:
81 |
82 | self.image_path =[os.path.join(image_dir, str(id)) for id in image_ids]
83 | self.image_processor = image_processor
84 |
85 | def __getitem__(self, item):
86 | pil_data = Image.open(self.image_path[item])
87 | image = self.image_processor(pil_data)
88 |
89 | return image
90 |
91 | def __len__(self):
92 | return len(self.image_path)
93 |
94 | class Image_Collator:
95 | def __init__(self, tokenizer, caption_max_len):
96 | self.tokenizer = tokenizer
97 | self.caption_max_len = caption_max_len
98 |
99 |
100 | def __call__(self, features):
101 | # images = features
102 | # i_collated = torch.stack(images)
103 |
104 | # for clip model
105 | images = [f["pixel_values"][0] for f in features]
106 | images = [torch.tensor(arr) for arr in images]
107 | i_collated = torch.stack(images)
108 | ## clip-end
109 | return i_collated
--------------------------------------------------------------------------------
/eval/flag_mmret.py:
--------------------------------------------------------------------------------
1 | from typing import cast, List, Union, Tuple
2 | import numpy as np
3 | import torch
4 | from tqdm import tqdm
5 | from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification, CLIPModel, CLIPImageProcessor, CLIPTokenizer
6 | import os
7 | from PIL import Image
8 | from torch.utils.data import DataLoader
9 | from torch import nn
10 | from flag_dataset import MMIT_Dataset, MMIT_Collator, Image_Dataset, Image_Collator
11 | class Flag_mmret(nn.Module):
12 | def __init__(
13 | self,
14 | model_name: str = None,
15 | normlized: bool = True,
16 | pooling_method: str = 'cls',
17 | use_fp16: bool=True,
18 | image_dir: str = None,
19 | ) -> None:
20 | super().__init__()
21 |
22 | self.model = AutoModel.from_pretrained(model_name)
23 | self.tokenizer = AutoTokenizer.from_pretrained(model_name)
24 | self.image_processor = CLIPImageProcessor.from_pretrained(model_name)
25 |
26 |
27 | self.normalize_embeddings = normlized
28 | self.pooling_method = pooling_method
29 |
30 | self.image_dir = image_dir
31 |
32 | if use_fp16:
33 | self.use_fp16 = True
34 | self.model.half()
35 | else:
36 | self.use_fp16 = False
37 |
38 | self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
39 | self.model = self.model.to(self.device)
40 |
41 | self.num_gpus = torch.cuda.device_count()
42 | if self.num_gpus > 1:
43 | print(f"----------using {self.num_gpus}*GPUs----------")
44 | self.model = torch.nn.DataParallel(self.model)
45 |
46 |
47 | def encode_queries(self, queries: Union[List[str], str],
48 | batch_size: int=256,
49 | max_length: int=77,
50 | query_type: str = None,
51 | ) -> np.ndarray:
52 |
53 |
54 | if query_type == 'text':
55 | input_texts = queries
56 |
57 | return self.encode_text(input_texts, batch_size=batch_size, max_length=max_length)
58 | elif query_type == 'mm_it':
59 | q_text, q_img = queries
60 |
61 | input_texts = q_text
62 |
63 |
64 | return self.encode_mm_it(input_texts, q_img, batch_size=batch_size)
65 | elif query_type == 'image':
66 | q_img = queries
67 | return self.encode_image(q_img, batch_size=batch_size)
68 | else:
69 | raise NotImplementedError
70 |
71 |
72 | def encode_corpus(self,
73 | corpus: dict,
74 | batch_size: int=256,
75 | max_length: int=77,
76 | corpus_type: str = None,
77 | ) -> np.ndarray:
78 | if corpus_type == 'text':
79 | return self.encode_text(corpus["text"], batch_size=batch_size, max_length=max_length)
80 | elif corpus_type == 'mm_it':
81 | return self.encode_mm_it(corpus["text"], corpus["image"], batch_size=batch_size, max_length=max_length)
82 | elif corpus_type == 'image':
83 | return self.encode_image(corpus["image"], batch_size=batch_size, max_length=max_length)
84 | else:
85 | raise RuntimeError(f"You must choose a corpus type from: [mm_it, text, image]")
86 |
87 |
88 |
89 | @torch.no_grad()
90 | def encode_text(self, sentences: Union[List[str], str], batch_size: int=256, max_length: int=77) -> np.ndarray:
91 | if self.num_gpus > 0:
92 | batch_size = batch_size * self.num_gpus
93 | self.model.eval()
94 |
95 | input_was_string = False
96 | if isinstance(sentences, str):
97 | sentences = [sentences]
98 | input_was_string = True
99 |
100 | all_embeddings = []
101 | for start_index in tqdm(range(0, len(sentences), batch_size), desc="Inference Embeddings", disable=len(sentences)<256):
102 | sentences_batch = sentences[start_index:start_index + batch_size]
103 | inputs = self.tokenizer(
104 | sentences_batch,
105 | padding=True,
106 | truncation=True,
107 | return_tensors='pt',
108 | max_length=max_length,
109 | ).to(self.device)
110 |
111 | embeddings = self.model.get_text_features(**inputs)
112 | embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
113 | embeddings = cast(torch.Tensor, embeddings)
114 | all_embeddings.append(embeddings.cpu().numpy())
115 |
116 | all_embeddings = np.concatenate(all_embeddings, axis=0)
117 | if input_was_string:
118 | return all_embeddings[0]
119 | return all_embeddings
120 |
121 |
122 | @torch.no_grad()
123 | def encode_mm_it(self, captions: Union[List[str], str], image_ids: Union[List[str], str], batch_size: int=256, max_length: int=77) -> np.ndarray:
124 | if self.num_gpus > 0:
125 | batch_size = batch_size * self.num_gpus
126 | self.model.eval()
127 |
128 | input_was_string = False
129 | if isinstance(captions, str):
130 | captions = [captions]
131 | image_ids = [image_ids]
132 | input_was_string = True
133 |
134 | all_embeddings = []
135 | mm_it_dataset = MMIT_Dataset(captions=captions,
136 | image_ids=image_ids,
137 | image_dir=self.image_dir,
138 | image_processor=self.image_processor
139 | )
140 | mm_it_collator = MMIT_Collator(self.tokenizer, caption_max_len=75)
141 |
142 | mm_it_dataloader = DataLoader(dataset=mm_it_dataset,
143 | collate_fn=mm_it_collator,
144 | num_workers=8,
145 | batch_size=batch_size,
146 | shuffle=False,
147 | drop_last=False,)
148 |
149 | for data in tqdm(mm_it_dataloader, desc="Inference Embeddings", disable=len(captions)<256):
150 | captions_inputs = data[0].to(self.device)
151 |
152 | images = data[1].to(self.device)
153 | if self.use_fp16 and images.dtype != torch.float16:
154 | images = images.half()
155 |
156 | text_embeddings = self.model.get_text_features(**captions_inputs)
157 | image_embeddings = self.model.get_image_features(images)
158 |
159 | embeddings = text_embeddings + image_embeddings
160 |
161 | embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
162 |
163 | embeddings = cast(torch.Tensor, embeddings)
164 | all_embeddings.append(embeddings.cpu().numpy())
165 |
166 | all_embeddings = np.concatenate(all_embeddings, axis=0)
167 | if input_was_string:
168 | return all_embeddings[0]
169 | return all_embeddings
170 |
171 | @torch.no_grad()
172 | def encode_image(self, image_ids: Union[List[str], str], batch_size: int=256, max_length: int=77) -> np.ndarray:
173 | if self.num_gpus > 0:
174 | batch_size = batch_size * self.num_gpus
175 | self.model.eval()
176 |
177 | all_embeddings = []
178 | image_dataset = Image_Dataset(image_ids=image_ids,
179 | image_dir=self.image_dir,
180 | image_processor=self.image_processor
181 | )
182 | image_collator = Image_Collator(self.tokenizer, caption_max_len=312)
183 |
184 | image_dataloader = DataLoader(dataset=image_dataset,
185 | collate_fn=image_collator,
186 | num_workers=8,
187 | batch_size=batch_size,
188 | shuffle=False,
189 | drop_last=False,)
190 |
191 | for data in tqdm(image_dataloader, desc="Inference Image Embeddings"):
192 |
193 | images = data.to(self.device)
194 | if self.use_fp16 and images.dtype != torch.float16:
195 | images = images.half()
196 |
197 |
198 |
199 | embeddings = self.model.get_image_features(images)
200 | embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
201 | embeddings = cast(torch.Tensor, embeddings)
202 | all_embeddings.append(embeddings.cpu().numpy())
203 |
204 | all_embeddings = np.concatenate(all_embeddings, axis=0)
205 |
206 | return all_embeddings
--------------------------------------------------------------------------------
/modeling_MMRet_CLIP.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch CLIP model."""
16 |
17 | from dataclasses import dataclass
18 | from typing import Any, Optional, Tuple, Union
19 |
20 | import torch
21 | import torch.utils.checkpoint
22 | from torch import nn
23 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24 | from PIL import Image
25 | from transformers.activations import ACT2FN
26 | from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
27 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
28 | from transformers.modeling_utils import PreTrainedModel
29 | from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2
30 | from transformers.utils import (
31 | ModelOutput,
32 | add_code_sample_docstrings,
33 | add_start_docstrings,
34 | add_start_docstrings_to_model_forward,
35 | is_flash_attn_2_available,
36 | is_flash_attn_greater_or_equal_2_10,
37 | logging,
38 | replace_return_docstrings,
39 | )
40 | from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
41 | from transformers import CLIPProcessor
42 |
43 | if is_flash_attn_2_available():
44 | from transformers.modeling_flash_attention_utils import _flash_attention_forward
45 |
46 |
47 | logger = logging.get_logger(__name__)
48 |
49 | # General docstring
50 | _CONFIG_FOR_DOC = "MMRet_CLIP"
51 |
52 | # Image classification docstring
53 | _IMAGE_CLASS_CHECKPOINT = "JUNJIE99/MMRet-base"
54 | _IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_0"
55 |
56 |
57 | # contrastive loss function, adapted from
58 | # https://sachinruk.github.io/blog/2021-03-07-clip.html
59 | def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
60 | return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
61 |
62 |
63 | def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
64 | caption_loss = contrastive_loss(similarity)
65 | image_loss = contrastive_loss(similarity.t())
66 | return (caption_loss + image_loss) / 2.0
67 |
68 |
69 | def _get_vector_norm(tensor: torch.Tensor) -> torch.Tensor:
70 | """
71 | This method is equivalent to tensor.norm(p=2, dim=-1, keepdim=True) and used to make
72 | model `executorch` exportable. See issue https://github.com/pytorch/executorch/issues/3566
73 | """
74 | square_tensor = torch.pow(tensor, 2)
75 | sum_tensor = torch.sum(square_tensor, dim=-1, keepdim=True)
76 | normed_tensor = torch.pow(sum_tensor, 0.5)
77 | return normed_tensor
78 |
79 |
80 | @dataclass
81 | class CLIPVisionModelOutput(ModelOutput):
82 | """
83 | Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
84 |
85 | Args:
86 | image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
87 | The image embeddings obtained by applying the projection layer to the pooler_output.
88 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
89 | Sequence of hidden-states at the output of the last layer of the model.
90 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
91 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
92 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
93 |
94 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
95 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
96 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
97 | sequence_length)`.
98 |
99 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
100 | heads.
101 | """
102 |
103 | image_embeds: Optional[torch.FloatTensor] = None
104 | last_hidden_state: torch.FloatTensor = None
105 | hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
106 | attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
107 |
108 |
109 | @dataclass
110 | class CLIPTextModelOutput(ModelOutput):
111 | """
112 | Base class for text model's outputs that also contains a pooling of the last hidden states.
113 |
114 | Args:
115 | text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
116 | The text embeddings obtained by applying the projection layer to the pooler_output.
117 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
118 | Sequence of hidden-states at the output of the last layer of the model.
119 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
120 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
121 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
122 |
123 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
124 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
125 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
126 | sequence_length)`.
127 |
128 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
129 | heads.
130 | """
131 |
132 | text_embeds: Optional[torch.FloatTensor] = None
133 | last_hidden_state: torch.FloatTensor = None
134 | hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
135 | attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
136 |
137 |
138 | @dataclass
139 | class CLIPOutput(ModelOutput):
140 | """
141 | Args:
142 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
143 | Contrastive loss for image-text similarity.
144 | logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
145 | The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
146 | similarity scores.
147 | logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
148 | The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
149 | similarity scores.
150 | text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
151 | The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
152 | image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
153 | The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
154 | text_model_output (`BaseModelOutputWithPooling`):
155 | The output of the [`CLIPTextModel`].
156 | vision_model_output (`BaseModelOutputWithPooling`):
157 | The output of the [`CLIPVisionModel`].
158 | """
159 |
160 | loss: Optional[torch.FloatTensor] = None
161 | logits_per_image: torch.FloatTensor = None
162 | logits_per_text: torch.FloatTensor = None
163 | text_embeds: torch.FloatTensor = None
164 | image_embeds: torch.FloatTensor = None
165 | text_model_output: BaseModelOutputWithPooling = None
166 | vision_model_output: BaseModelOutputWithPooling = None
167 |
168 | def to_tuple(self) -> Tuple[Any]:
169 | return tuple(
170 | self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
171 | for k in self.keys()
172 | )
173 |
174 |
175 | class CLIPVisionEmbeddings(nn.Module):
176 | def __init__(self, config: CLIPVisionConfig):
177 | super().__init__()
178 | self.config = config
179 | self.embed_dim = config.hidden_size
180 | self.image_size = config.image_size
181 | self.patch_size = config.patch_size
182 |
183 | self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
184 |
185 | self.patch_embedding = nn.Conv2d(
186 | in_channels=config.num_channels,
187 | out_channels=self.embed_dim,
188 | kernel_size=self.patch_size,
189 | stride=self.patch_size,
190 | bias=False,
191 | )
192 |
193 | self.num_patches = (self.image_size // self.patch_size) ** 2
194 | self.num_positions = self.num_patches + 1
195 | self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
196 | self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
197 |
198 | def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
199 | batch_size = pixel_values.shape[0]
200 | target_dtype = self.patch_embedding.weight.dtype
201 | patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
202 | patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
203 |
204 | class_embeds = self.class_embedding.expand(batch_size, 1, -1)
205 | embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
206 | embeddings = embeddings + self.position_embedding(self.position_ids)
207 | return embeddings
208 |
209 |
210 | class CLIPTextEmbeddings(nn.Module):
211 | def __init__(self, config: CLIPTextConfig):
212 | super().__init__()
213 | embed_dim = config.hidden_size
214 |
215 | self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
216 | self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
217 |
218 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized
219 | self.register_buffer(
220 | "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
221 | )
222 |
223 | def forward(
224 | self,
225 | input_ids: Optional[torch.LongTensor] = None,
226 | position_ids: Optional[torch.LongTensor] = None,
227 | inputs_embeds: Optional[torch.FloatTensor] = None,
228 | ) -> torch.Tensor:
229 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
230 |
231 | if position_ids is None:
232 | position_ids = self.position_ids[:, :seq_length]
233 |
234 | if inputs_embeds is None:
235 | inputs_embeds = self.token_embedding(input_ids)
236 |
237 | position_embeddings = self.position_embedding(position_ids)
238 | embeddings = inputs_embeds + position_embeddings
239 |
240 | return embeddings
241 |
242 |
243 | class CLIPAttention(nn.Module):
244 | """Multi-headed attention from 'Attention Is All You Need' paper"""
245 |
246 | def __init__(self, config):
247 | super().__init__()
248 | self.config = config
249 | self.embed_dim = config.hidden_size
250 | self.num_heads = config.num_attention_heads
251 | self.head_dim = self.embed_dim // self.num_heads
252 | if self.head_dim * self.num_heads != self.embed_dim:
253 | raise ValueError(
254 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
255 | f" {self.num_heads})."
256 | )
257 | self.scale = self.head_dim**-0.5
258 | self.dropout = config.attention_dropout
259 |
260 | self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
261 | self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
262 | self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
263 | self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
264 |
265 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
266 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
267 |
268 | def forward(
269 | self,
270 | hidden_states: torch.Tensor,
271 | attention_mask: Optional[torch.Tensor] = None,
272 | causal_attention_mask: Optional[torch.Tensor] = None,
273 | output_attentions: Optional[bool] = False,
274 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
275 | """Input shape: Batch x Time x Channel"""
276 |
277 | bsz, tgt_len, embed_dim = hidden_states.size()
278 |
279 | # get query proj
280 | query_states = self.q_proj(hidden_states) * self.scale
281 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
282 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
283 |
284 | proj_shape = (bsz * self.num_heads, -1, self.head_dim)
285 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
286 | key_states = key_states.view(*proj_shape)
287 | value_states = value_states.view(*proj_shape)
288 |
289 | src_len = key_states.size(1)
290 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
291 |
292 | if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
293 | raise ValueError(
294 | f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
295 | f" {attn_weights.size()}"
296 | )
297 |
298 | # apply the causal_attention_mask first
299 | if causal_attention_mask is not None:
300 | if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
301 | raise ValueError(
302 | f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
303 | f" {causal_attention_mask.size()}"
304 | )
305 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
306 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
307 |
308 | if attention_mask is not None:
309 | if attention_mask.size() != (bsz, 1, tgt_len, src_len):
310 | raise ValueError(
311 | f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
312 | )
313 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
314 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
315 |
316 | attn_weights = nn.functional.softmax(attn_weights, dim=-1)
317 |
318 | if output_attentions:
319 | # this operation is a bit akward, but it's required to
320 | # make sure that attn_weights keeps its gradient.
321 | # In order to do so, attn_weights have to reshaped
322 | # twice and have to be reused in the following
323 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
324 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
325 | else:
326 | attn_weights_reshaped = None
327 |
328 | attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
329 |
330 | attn_output = torch.bmm(attn_probs, value_states)
331 |
332 | if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
333 | raise ValueError(
334 | f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
335 | f" {attn_output.size()}"
336 | )
337 |
338 | attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
339 | attn_output = attn_output.transpose(1, 2)
340 | attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
341 |
342 | attn_output = self.out_proj(attn_output)
343 |
344 | return attn_output, attn_weights_reshaped
345 |
346 |
347 | class CLIPFlashAttention2(CLIPAttention):
348 | """
349 | CLIPAttention flash attention module. This module inherits from `CLIPAttention` as the weights of the module stays
350 | untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
351 | flash attention and deal with padding tokens in case the input contains any of them.
352 | """
353 |
354 | # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
355 | def __init__(self, *args, **kwargs):
356 | super().__init__(*args, **kwargs)
357 |
358 | # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
359 | # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
360 | # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
361 | self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
362 |
363 | # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
364 | def forward(
365 | self,
366 | hidden_states: torch.Tensor,
367 | attention_mask: Optional[torch.Tensor] = None,
368 | causal_attention_mask: Optional[torch.Tensor] = None,
369 | output_attentions: Optional[bool] = False,
370 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
371 | output_attentions = False
372 |
373 | batch_size, q_len, _ = hidden_states.size()
374 |
375 | query_states = self.q_proj(hidden_states)
376 | key_states = self.k_proj(hidden_states)
377 | value_states = self.v_proj(hidden_states)
378 |
379 | # Flash attention requires the input to have the shape
380 | # batch_size x seq_length x head_dim x hidden_dim
381 | # therefore we just need to keep the original shape
382 | query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
383 | key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
384 | value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)
385 |
386 | dropout_rate = self.dropout if self.training else 0.0
387 |
388 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons
389 | # therefore the input hidden states gets silently casted in float32. Hence, we need
390 | # cast them back in the correct dtype just to be sure everything works as expected.
391 | # This might slowdown training & inference so it is recommended to not cast the LayerNorms
392 | # in fp32.
393 |
394 | input_dtype = query_states.dtype
395 | if input_dtype == torch.float32:
396 | if torch.is_autocast_enabled():
397 | target_dtype = torch.get_autocast_gpu_dtype()
398 | # Handle the case where the model is quantized
399 | elif hasattr(self.config, "_pre_quantization_dtype"):
400 | target_dtype = self.config._pre_quantization_dtype
401 | else:
402 | target_dtype = self.q_proj.weight.dtype
403 |
404 | logger.warning_once(
405 | f"The input hidden states seems to be silently casted in float32, this might be related to"
406 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
407 | f" {target_dtype}."
408 | )
409 |
410 | query_states = query_states.to(target_dtype)
411 | key_states = key_states.to(target_dtype)
412 | value_states = value_states.to(target_dtype)
413 |
414 | attn_output = _flash_attention_forward(
415 | query_states,
416 | key_states,
417 | value_states,
418 | attention_mask,
419 | q_len,
420 | dropout=dropout_rate,
421 | is_causal=causal_attention_mask is not None,
422 | use_top_left_mask=self._flash_attn_uses_top_left_mask,
423 | )
424 |
425 | attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
426 | attn_output = self.out_proj(attn_output)
427 |
428 | if not output_attentions:
429 | attn_weights = None
430 |
431 | return attn_output, attn_weights
432 |
433 |
434 | class CLIPSdpaAttention(CLIPAttention):
435 | """
436 | SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
437 | `CLIPAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
438 | SDPA API.
439 | """
440 |
441 | # Adapted from CLIPAttention.forward
442 | def forward(
443 | self,
444 | hidden_states: torch.Tensor,
445 | attention_mask: Optional[torch.Tensor] = None,
446 | causal_attention_mask: Optional[torch.Tensor] = None,
447 | output_attentions: Optional[bool] = False,
448 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
449 | if output_attentions:
450 | # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
451 | logger.warning_once(
452 | "CLIPModel is using CLIPSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
453 | "support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
454 | "the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
455 | 'be removed using the argument `attn_implementation="eager"` when loading the model.'
456 | )
457 | return super().forward(
458 | hidden_states=hidden_states,
459 | attention_mask=attention_mask,
460 | causal_attention_mask=causal_attention_mask,
461 | output_attentions=output_attentions,
462 | )
463 |
464 | # CLIP text model uses both `causal_attention_mask` and `attention_mask`
465 | if attention_mask is not None and causal_attention_mask is not None:
466 | attn_mask = attention_mask + causal_attention_mask
467 | elif causal_attention_mask is not None:
468 | attn_mask = causal_attention_mask
469 | else:
470 | attn_mask = attention_mask
471 |
472 | bsz, tgt_len, embed_dim = hidden_states.size()
473 |
474 | query_states = self.q_proj(hidden_states)
475 | key_states = self.k_proj(hidden_states)
476 | value_states = self.v_proj(hidden_states)
477 |
478 | query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
479 | key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
480 | value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
481 |
482 | # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
483 | # Reference: https://github.com/pytorch/pytorch/issues/112577.
484 | if not is_torch_greater_or_equal_than_2_2 and query_states.device.type == "cuda" and attn_mask is not None:
485 | query_states = query_states.contiguous()
486 | key_states = key_states.contiguous()
487 | value_states = value_states.contiguous()
488 |
489 | # CLIP text model uses both `causal_attention_mask` and `attention_mask` sequentially.
490 | attn_output = torch.nn.functional.scaled_dot_product_attention(
491 | query_states,
492 | key_states,
493 | value_states,
494 | attn_mask=attn_mask,
495 | dropout_p=self.dropout if self.training else 0.0,
496 | scale=self.scale,
497 | )
498 |
499 | attn_output = attn_output.transpose(1, 2)
500 | attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
501 |
502 | attn_output = self.out_proj(attn_output)
503 |
504 | return attn_output, None
505 |
506 |
507 | CLIP_ATTENTION_CLASSES = {
508 | "eager": CLIPAttention,
509 | "sdpa": CLIPSdpaAttention,
510 | "flash_attention_2": CLIPFlashAttention2,
511 | }
512 |
513 |
514 | class CLIPMLP(nn.Module):
515 | def __init__(self, config):
516 | super().__init__()
517 | self.config = config
518 | self.activation_fn = ACT2FN[config.hidden_act]
519 | self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
520 | self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
521 |
522 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
523 | hidden_states = self.fc1(hidden_states)
524 | hidden_states = self.activation_fn(hidden_states)
525 | hidden_states = self.fc2(hidden_states)
526 | return hidden_states
527 |
528 |
529 | class CLIPEncoderLayer(nn.Module):
530 | def __init__(self, config: CLIPConfig):
531 | super().__init__()
532 | self.embed_dim = config.hidden_size
533 | self.self_attn = CLIP_ATTENTION_CLASSES[config._attn_implementation](config)
534 | self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
535 | self.mlp = CLIPMLP(config)
536 | self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
537 |
538 | def forward(
539 | self,
540 | hidden_states: torch.Tensor,
541 | attention_mask: torch.Tensor,
542 | causal_attention_mask: torch.Tensor,
543 | output_attentions: Optional[bool] = False,
544 | ) -> Tuple[torch.FloatTensor]:
545 | """
546 | Args:
547 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
548 | attention_mask (`torch.FloatTensor`): attention mask of size
549 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
550 | `(config.encoder_attention_heads,)`.
551 | output_attentions (`bool`, *optional*):
552 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under
553 | returned tensors for more detail.
554 | """
555 | residual = hidden_states
556 |
557 | hidden_states = self.layer_norm1(hidden_states)
558 | hidden_states, attn_weights = self.self_attn(
559 | hidden_states=hidden_states,
560 | attention_mask=attention_mask,
561 | causal_attention_mask=causal_attention_mask,
562 | output_attentions=output_attentions,
563 | )
564 | hidden_states = residual + hidden_states
565 |
566 | residual = hidden_states
567 | hidden_states = self.layer_norm2(hidden_states)
568 | hidden_states = self.mlp(hidden_states)
569 | hidden_states = residual + hidden_states
570 |
571 | outputs = (hidden_states,)
572 |
573 | if output_attentions:
574 | outputs += (attn_weights,)
575 |
576 | return outputs
577 |
578 |
579 | class CLIPPreTrainedModel(PreTrainedModel):
580 | """
581 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
582 | models.
583 | """
584 |
585 | config_class = CLIPConfig
586 | base_model_prefix = "clip"
587 | supports_gradient_checkpointing = True
588 | _supports_sdpa = True
589 | _supports_flash_attn_2 = True
590 |
591 | def _init_weights(self, module):
592 | """Initialize the weights"""
593 | factor = self.config.initializer_factor
594 | if isinstance(module, CLIPTextEmbeddings):
595 | module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
596 | module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
597 | elif isinstance(module, CLIPVisionEmbeddings):
598 | factor = self.config.initializer_factor
599 | nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
600 | nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
601 | nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
602 | elif isinstance(module, CLIPAttention):
603 | factor = self.config.initializer_factor
604 | in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
605 | out_proj_std = (module.embed_dim**-0.5) * factor
606 | nn.init.normal_(module.q_proj.weight, std=in_proj_std)
607 | nn.init.normal_(module.k_proj.weight, std=in_proj_std)
608 | nn.init.normal_(module.v_proj.weight, std=in_proj_std)
609 | nn.init.normal_(module.out_proj.weight, std=out_proj_std)
610 | elif isinstance(module, CLIPMLP):
611 | factor = self.config.initializer_factor
612 | in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
613 | fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
614 | nn.init.normal_(module.fc1.weight, std=fc_std)
615 | nn.init.normal_(module.fc2.weight, std=in_proj_std)
616 | elif isinstance(module, CLIPModel):
617 | nn.init.normal_(
618 | module.text_projection.weight,
619 | std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
620 | )
621 | nn.init.normal_(
622 | module.visual_projection.weight,
623 | std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
624 | )
625 | elif isinstance(module, CLIPVisionModelWithProjection):
626 | nn.init.normal_(
627 | module.visual_projection.weight,
628 | std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
629 | )
630 | elif isinstance(module, CLIPTextModelWithProjection):
631 | nn.init.normal_(
632 | module.text_projection.weight,
633 | std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
634 | )
635 | elif isinstance(module, CLIPForImageClassification):
636 | nn.init.normal_(
637 | module.classifier.weight,
638 | std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor,
639 | )
640 |
641 | if isinstance(module, nn.LayerNorm):
642 | module.bias.data.zero_()
643 | module.weight.data.fill_(1.0)
644 | if isinstance(module, nn.Linear) and module.bias is not None:
645 | module.bias.data.zero_()
646 |
647 |
648 | CLIP_START_DOCSTRING = r"""
649 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
650 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
651 | etc.)
652 |
653 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
654 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
655 | and behavior.
656 |
657 | Parameters:
658 | config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
659 | Initializing with a config file does not load the weights associated with the model, only the
660 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
661 | """
662 |
663 | CLIP_TEXT_INPUTS_DOCSTRING = r"""
664 | Args:
665 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
666 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
667 | it.
668 |
669 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
670 | [`PreTrainedTokenizer.__call__`] for details.
671 |
672 | [What are input IDs?](../glossary#input-ids)
673 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
674 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
675 |
676 | - 1 for tokens that are **not masked**,
677 | - 0 for tokens that are **masked**.
678 |
679 | [What are attention masks?](../glossary#attention-mask)
680 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
681 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
682 | config.max_position_embeddings - 1]`.
683 |
684 | [What are position IDs?](../glossary#position-ids)
685 | output_attentions (`bool`, *optional*):
686 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
687 | tensors for more detail.
688 | output_hidden_states (`bool`, *optional*):
689 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
690 | more detail.
691 | return_dict (`bool`, *optional*):
692 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
693 | """
694 |
695 | CLIP_VISION_INPUTS_DOCSTRING = r"""
696 | Args:
697 | pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
698 | Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
699 | [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
700 | output_attentions (`bool`, *optional*):
701 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
702 | tensors for more detail.
703 | output_hidden_states (`bool`, *optional*):
704 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
705 | more detail.
706 | return_dict (`bool`, *optional*):
707 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
708 | """
709 |
710 | CLIP_INPUTS_DOCSTRING = r"""
711 | Args:
712 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
713 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
714 | it.
715 |
716 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
717 | [`PreTrainedTokenizer.__call__`] for details.
718 |
719 | [What are input IDs?](../glossary#input-ids)
720 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
721 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
722 |
723 | - 1 for tokens that are **not masked**,
724 | - 0 for tokens that are **masked**.
725 |
726 | [What are attention masks?](../glossary#attention-mask)
727 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
728 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
729 | config.max_position_embeddings - 1]`.
730 |
731 | [What are position IDs?](../glossary#position-ids)
732 | pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
733 | Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
734 | [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
735 | return_loss (`bool`, *optional*):
736 | Whether or not to return the contrastive loss.
737 | output_attentions (`bool`, *optional*):
738 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
739 | tensors for more detail.
740 | output_hidden_states (`bool`, *optional*):
741 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
742 | more detail.
743 | return_dict (`bool`, *optional*):
744 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
745 | """
746 |
747 |
748 | class CLIPEncoder(nn.Module):
749 | """
750 | Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
751 | [`CLIPEncoderLayer`].
752 |
753 | Args:
754 | config: CLIPConfig
755 | """
756 |
757 | def __init__(self, config: CLIPConfig):
758 | super().__init__()
759 | self.config = config
760 | self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
761 | self.gradient_checkpointing = False
762 |
763 | def forward(
764 | self,
765 | inputs_embeds,
766 | attention_mask: Optional[torch.Tensor] = None,
767 | causal_attention_mask: Optional[torch.Tensor] = None,
768 | output_attentions: Optional[bool] = None,
769 | output_hidden_states: Optional[bool] = None,
770 | return_dict: Optional[bool] = None,
771 | ) -> Union[Tuple, BaseModelOutput]:
772 | r"""
773 | Args:
774 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
775 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
776 | This is useful if you want more control over how to convert `input_ids` indices into associated vectors
777 | than the model's internal embedding lookup matrix.
778 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
779 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
780 |
781 | - 1 for tokens that are **not masked**,
782 | - 0 for tokens that are **masked**.
783 |
784 | [What are attention masks?](../glossary#attention-mask)
785 | causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
786 | Causal mask for the text model. Mask values selected in `[0, 1]`:
787 |
788 | - 1 for tokens that are **not masked**,
789 | - 0 for tokens that are **masked**.
790 |
791 | [What are attention masks?](../glossary#attention-mask)
792 | output_attentions (`bool`, *optional*):
793 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under
794 | returned tensors for more detail.
795 | output_hidden_states (`bool`, *optional*):
796 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
797 | for more detail.
798 | return_dict (`bool`, *optional*):
799 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
800 | """
801 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
802 | output_hidden_states = (
803 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
804 | )
805 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
806 |
807 | encoder_states = () if output_hidden_states else None
808 | all_attentions = () if output_attentions else None
809 |
810 | hidden_states = inputs_embeds
811 | for idx, encoder_layer in enumerate(self.layers):
812 | if output_hidden_states:
813 | encoder_states = encoder_states + (hidden_states,)
814 | if self.gradient_checkpointing and self.training:
815 | layer_outputs = self._gradient_checkpointing_func(
816 | encoder_layer.__call__,
817 | hidden_states,
818 | attention_mask,
819 | causal_attention_mask,
820 | output_attentions,
821 | )
822 | else:
823 | layer_outputs = encoder_layer(
824 | hidden_states,
825 | attention_mask,
826 | causal_attention_mask,
827 | output_attentions=output_attentions,
828 | )
829 |
830 | hidden_states = layer_outputs[0]
831 |
832 | if output_attentions:
833 | all_attentions = all_attentions + (layer_outputs[1],)
834 |
835 | if output_hidden_states:
836 | encoder_states = encoder_states + (hidden_states,)
837 |
838 | if not return_dict:
839 | return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
840 | return BaseModelOutput(
841 | last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
842 | )
843 |
844 |
845 | class CLIPTextTransformer(nn.Module):
846 | def __init__(self, config: CLIPTextConfig):
847 | super().__init__()
848 | self.config = config
849 | embed_dim = config.hidden_size
850 | self.embeddings = CLIPTextEmbeddings(config)
851 | self.encoder = CLIPEncoder(config)
852 | self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
853 |
854 | # For `pooled_output` computation
855 | self.eos_token_id = config.eos_token_id
856 |
857 | # For attention mask, it differs between `flash_attention_2` and other attention implementations
858 | self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
859 |
860 | @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
861 | @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
862 | def forward(
863 | self,
864 | input_ids: Optional[torch.Tensor] = None,
865 | attention_mask: Optional[torch.Tensor] = None,
866 | position_ids: Optional[torch.Tensor] = None,
867 | output_attentions: Optional[bool] = None,
868 | output_hidden_states: Optional[bool] = None,
869 | return_dict: Optional[bool] = None,
870 | ) -> Union[Tuple, BaseModelOutputWithPooling]:
871 | r"""
872 | Returns:
873 |
874 | """
875 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
876 | output_hidden_states = (
877 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
878 | )
879 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
880 |
881 | if input_ids is None:
882 | raise ValueError("You have to specify input_ids")
883 |
884 | input_shape = input_ids.size()
885 | input_ids = input_ids.view(-1, input_shape[-1])
886 |
887 | hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
888 |
889 | # CLIP's text model uses causal mask, prepare it here.
890 | # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
891 | causal_attention_mask = _create_4d_causal_attention_mask(
892 | input_shape, hidden_states.dtype, device=hidden_states.device
893 | )
894 |
895 | # expand attention_mask
896 | if attention_mask is not None and not self._use_flash_attention_2:
897 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
898 | attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
899 |
900 | encoder_outputs = self.encoder(
901 | inputs_embeds=hidden_states,
902 | attention_mask=attention_mask,
903 | causal_attention_mask=causal_attention_mask,
904 | output_attentions=output_attentions,
905 | output_hidden_states=output_hidden_states,
906 | return_dict=return_dict,
907 | )
908 |
909 | last_hidden_state = encoder_outputs[0]
910 | last_hidden_state = self.final_layer_norm(last_hidden_state)
911 |
912 | if self.eos_token_id == 2:
913 | # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
914 | # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
915 | # ------------------------------------------------------------
916 | # text_embeds.shape = [batch_size, sequence_length, transformer.width]
917 | # take features from the eot embedding (eot_token is the highest number in each sequence)
918 | # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
919 | pooled_output = last_hidden_state[
920 | torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
921 | input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
922 | ]
923 | else:
924 | # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
925 | pooled_output = last_hidden_state[
926 | torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
927 | # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
928 | # Note: we assume each sequence (along batch dim.) contains an `eos_token_id` (e.g. prepared by the tokenizer)
929 | (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
930 | .int()
931 | .argmax(dim=-1),
932 | ]
933 |
934 | if not return_dict:
935 | return (last_hidden_state, pooled_output) + encoder_outputs[1:]
936 |
937 | return BaseModelOutputWithPooling(
938 | last_hidden_state=last_hidden_state,
939 | pooler_output=pooled_output,
940 | hidden_states=encoder_outputs.hidden_states,
941 | attentions=encoder_outputs.attentions,
942 | )
943 |
944 |
945 | @add_start_docstrings(
946 | """The text model from CLIP without any head or projection on top.""",
947 | CLIP_START_DOCSTRING,
948 | )
949 | class CLIPTextModel(CLIPPreTrainedModel):
950 | config_class = CLIPTextConfig
951 |
952 | _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
953 |
954 | def __init__(self, config: CLIPTextConfig):
955 | super().__init__(config)
956 | self.text_model = CLIPTextTransformer(config)
957 | # Initialize weights and apply final processing
958 | self.post_init()
959 |
960 | def get_input_embeddings(self) -> nn.Module:
961 | return self.text_model.embeddings.token_embedding
962 |
963 | def set_input_embeddings(self, value):
964 | self.text_model.embeddings.token_embedding = value
965 |
966 | @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
967 | @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
968 | def forward(
969 | self,
970 | input_ids: Optional[torch.Tensor] = None,
971 | attention_mask: Optional[torch.Tensor] = None,
972 | position_ids: Optional[torch.Tensor] = None,
973 | output_attentions: Optional[bool] = None,
974 | output_hidden_states: Optional[bool] = None,
975 | return_dict: Optional[bool] = None,
976 | ) -> Union[Tuple, BaseModelOutputWithPooling]:
977 | r"""
978 | Returns:
979 |
980 | Examples:
981 |
982 | ```python
983 | >>> from transformers import AutoTokenizer, CLIPTextModel
984 |
985 | >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
986 | >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
987 |
988 | >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
989 |
990 | >>> outputs = model(**inputs)
991 | >>> last_hidden_state = outputs.last_hidden_state
992 | >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
993 | ```"""
994 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
995 |
996 | return self.text_model(
997 | input_ids=input_ids,
998 | attention_mask=attention_mask,
999 | position_ids=position_ids,
1000 | output_attentions=output_attentions,
1001 | output_hidden_states=output_hidden_states,
1002 | return_dict=return_dict,
1003 | )
1004 |
1005 |
1006 | class CLIPVisionTransformer(nn.Module):
1007 | def __init__(self, config: CLIPVisionConfig):
1008 | super().__init__()
1009 | self.config = config
1010 | embed_dim = config.hidden_size
1011 |
1012 | self.embeddings = CLIPVisionEmbeddings(config)
1013 | self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
1014 | self.encoder = CLIPEncoder(config)
1015 | self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
1016 |
1017 | @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
1018 | @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
1019 | def forward(
1020 | self,
1021 | pixel_values: Optional[torch.FloatTensor] = None,
1022 | output_attentions: Optional[bool] = None,
1023 | output_hidden_states: Optional[bool] = None,
1024 | return_dict: Optional[bool] = None,
1025 | ) -> Union[Tuple, BaseModelOutputWithPooling]:
1026 | r"""
1027 | Returns:
1028 |
1029 | """
1030 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1031 | output_hidden_states = (
1032 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1033 | )
1034 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1035 |
1036 | if pixel_values is None:
1037 | raise ValueError("You have to specify pixel_values")
1038 |
1039 | hidden_states = self.embeddings(pixel_values)
1040 | hidden_states = self.pre_layrnorm(hidden_states)
1041 |
1042 | encoder_outputs = self.encoder(
1043 | inputs_embeds=hidden_states,
1044 | output_attentions=output_attentions,
1045 | output_hidden_states=output_hidden_states,
1046 | return_dict=return_dict,
1047 | )
1048 |
1049 | last_hidden_state = encoder_outputs[0]
1050 | pooled_output = last_hidden_state[:, 0, :]
1051 | pooled_output = self.post_layernorm(pooled_output)
1052 |
1053 | if not return_dict:
1054 | return (last_hidden_state, pooled_output) + encoder_outputs[1:]
1055 |
1056 | return BaseModelOutputWithPooling(
1057 | last_hidden_state=last_hidden_state,
1058 | pooler_output=pooled_output,
1059 | hidden_states=encoder_outputs.hidden_states,
1060 | attentions=encoder_outputs.attentions,
1061 | )
1062 |
1063 |
1064 | @add_start_docstrings(
1065 | """The vision model from CLIP without any head or projection on top.""",
1066 | CLIP_START_DOCSTRING,
1067 | )
1068 | class CLIPVisionModel(CLIPPreTrainedModel):
1069 | config_class = CLIPVisionConfig
1070 | main_input_name = "pixel_values"
1071 | _no_split_modules = ["CLIPEncoderLayer"]
1072 |
1073 | def __init__(self, config: CLIPVisionConfig):
1074 | super().__init__(config)
1075 | self.vision_model = CLIPVisionTransformer(config)
1076 | # Initialize weights and apply final processing
1077 | self.post_init()
1078 |
1079 | def get_input_embeddings(self) -> nn.Module:
1080 | return self.vision_model.embeddings.patch_embedding
1081 |
1082 | @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
1083 | @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
1084 | def forward(
1085 | self,
1086 | pixel_values: Optional[torch.FloatTensor] = None,
1087 | output_attentions: Optional[bool] = None,
1088 | output_hidden_states: Optional[bool] = None,
1089 | return_dict: Optional[bool] = None,
1090 | ) -> Union[Tuple, BaseModelOutputWithPooling]:
1091 | r"""
1092 | Returns:
1093 |
1094 | Examples:
1095 |
1096 | ```python
1097 | >>> from PIL import Image
1098 | >>> import requests
1099 | >>> from transformers import AutoProcessor, CLIPVisionModel
1100 |
1101 | >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
1102 | >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1103 |
1104 | >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1105 | >>> image = Image.open(requests.get(url, stream=True).raw)
1106 |
1107 | >>> inputs = processor(images=image, return_tensors="pt")
1108 |
1109 | >>> outputs = model(**inputs)
1110 | >>> last_hidden_state = outputs.last_hidden_state
1111 | >>> pooled_output = outputs.pooler_output # pooled CLS states
1112 | ```"""
1113 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1114 |
1115 | return self.vision_model(
1116 | pixel_values=pixel_values,
1117 | output_attentions=output_attentions,
1118 | output_hidden_states=output_hidden_states,
1119 | return_dict=return_dict,
1120 | )
1121 |
1122 |
1123 | @add_start_docstrings(CLIP_START_DOCSTRING)
1124 | class CLIPModel(CLIPPreTrainedModel):
1125 | config_class = CLIPConfig
1126 | _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer", "CLIPVisionEmbeddings"]
1127 |
1128 | def __init__(self, config: CLIPConfig):
1129 | super().__init__(config)
1130 |
1131 | if not isinstance(config.text_config, CLIPTextConfig):
1132 | raise TypeError(
1133 | "config.text_config is expected to be of type CLIPTextConfig but is of type"
1134 | f" {type(config.text_config)}."
1135 | )
1136 |
1137 | if not isinstance(config.vision_config, CLIPVisionConfig):
1138 | raise TypeError(
1139 | "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
1140 | f" {type(config.vision_config)}."
1141 | )
1142 |
1143 | text_config = config.text_config
1144 | vision_config = config.vision_config
1145 |
1146 | self.projection_dim = config.projection_dim
1147 | self.text_embed_dim = text_config.hidden_size
1148 | self.vision_embed_dim = vision_config.hidden_size
1149 |
1150 | text_model = CLIPTextModel._from_config(text_config, attn_implementation=config._attn_implementation)
1151 | self.text_model = text_model.text_model
1152 |
1153 | vision_model = CLIPVisionModel._from_config(vision_config, attn_implementation=config._attn_implementation)
1154 | self.vision_model = vision_model.vision_model
1155 |
1156 | self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
1157 | self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
1158 | self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
1159 |
1160 | # Initialize weights and apply final processing
1161 | self.post_init()
1162 |
1163 | def set_processor(self, model_name):
1164 | self.processor = CLIPProcessor.from_pretrained(model_name)
1165 |
1166 | @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
1167 | def get_text_features(
1168 | self,
1169 | input_ids: Optional[torch.Tensor] = None,
1170 | attention_mask: Optional[torch.Tensor] = None,
1171 | position_ids: Optional[torch.Tensor] = None,
1172 | output_attentions: Optional[bool] = None,
1173 | output_hidden_states: Optional[bool] = None,
1174 | return_dict: Optional[bool] = None,
1175 | ) -> torch.FloatTensor:
1176 | r"""
1177 | Returns:
1178 | text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1179 | applying the projection layer to the pooled output of [`CLIPTextModel`].
1180 |
1181 | Examples:
1182 |
1183 | ```python
1184 | >>> from transformers import AutoTokenizer, CLIPModel
1185 |
1186 | >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
1187 | >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
1188 |
1189 | >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
1190 | >>> text_features = model.get_text_features(**inputs)
1191 | ```"""
1192 | # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
1193 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1194 | output_hidden_states = (
1195 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1196 | )
1197 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1198 |
1199 | text_outputs = self.text_model(
1200 | input_ids=input_ids,
1201 | attention_mask=attention_mask,
1202 | position_ids=position_ids,
1203 | output_attentions=output_attentions,
1204 | output_hidden_states=output_hidden_states,
1205 | return_dict=return_dict,
1206 | )
1207 |
1208 | pooled_output = text_outputs[1]
1209 | text_features = self.text_projection(pooled_output)
1210 |
1211 | return text_features
1212 |
1213 | @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
1214 | def get_image_features(
1215 | self,
1216 | pixel_values: Optional[torch.FloatTensor] = None,
1217 | output_attentions: Optional[bool] = None,
1218 | output_hidden_states: Optional[bool] = None,
1219 | return_dict: Optional[bool] = None,
1220 | ) -> torch.FloatTensor:
1221 | r"""
1222 | Returns:
1223 | image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1224 | applying the projection layer to the pooled output of [`CLIPVisionModel`].
1225 |
1226 | Examples:
1227 |
1228 | ```python
1229 | >>> from PIL import Image
1230 | >>> import requests
1231 | >>> from transformers import AutoProcessor, CLIPModel
1232 |
1233 | >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
1234 | >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1235 |
1236 | >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1237 | >>> image = Image.open(requests.get(url, stream=True).raw)
1238 |
1239 | >>> inputs = processor(images=image, return_tensors="pt")
1240 |
1241 | >>> image_features = model.get_image_features(**inputs)
1242 | ```"""
1243 | # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
1244 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1245 | output_hidden_states = (
1246 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1247 | )
1248 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1249 |
1250 | vision_outputs = self.vision_model(
1251 | pixel_values=pixel_values,
1252 | output_attentions=output_attentions,
1253 | output_hidden_states=output_hidden_states,
1254 | return_dict=return_dict,
1255 | )
1256 |
1257 | pooled_output = vision_outputs[1] # pooled_output
1258 | image_features = self.visual_projection(pooled_output)
1259 |
1260 | return image_features
1261 |
1262 |
1263 | def encode_image(self, images):
1264 | embeddings = self.get_image_features(images)
1265 | embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
1266 | return embeddings
1267 |
1268 | def encode_text(self, text):
1269 | embeddings = self.get_text_features(**text)
1270 | embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
1271 | return embeddings
1272 |
1273 | def encode_multimodal(self, images, text):
1274 | text_embeddings = self.get_text_features(**text)
1275 | image_embeddings = self.get_image_features(images)
1276 |
1277 | embeddings = text_embeddings + image_embeddings
1278 | embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
1279 |
1280 | return embeddings.contiguous()
1281 |
1282 | def data_process(self, images=None, text=None):
1283 | if images is None and text is not None:
1284 | text = self.processor(text=text, return_tensors="pt", padding=True).to(self.device)
1285 |
1286 | return images, text, "text"
1287 | elif images is not None and text is None:
1288 | if isinstance(images, str):
1289 | images = Image.open(images).convert("RGB")
1290 | elif isinstance(images, list):
1291 | images = [Image.open(image).convert("RGB") for image in images]
1292 | images = self.processor(images=images, return_tensors="pt").to(self.device)
1293 | images = images["pixel_values"]
1294 | return images, text, "images"
1295 | elif images is not None and text is not None:
1296 | assert type(images) == type(text), "images and text must be the same type: list or str"
1297 | if isinstance(images, str):
1298 | images = Image.open(images).convert("RGB")
1299 | elif isinstance(images, list):
1300 | assert len(images) == len(text), "images and text must be lists of the same length when use list"
1301 | images = [Image.open(image).convert("RGB") for image in images]
1302 | images = self.processor(images=images, return_tensors="pt").to(self.device)
1303 | images = images["pixel_values"]
1304 | text = self.processor(text=text, return_tensors="pt", padding=True).to(self.device)
1305 | return images, text, "multimodal"
1306 | else:
1307 | raise ValueError("images and text cannot both be None")
1308 |
1309 | def encode(self, images=None, text=None):
1310 | images, text, data_type = self.data_process(images, text)
1311 | if data_type == "images":
1312 | return self.encode_image(images)
1313 | elif data_type == "text":
1314 | return self.encode_text(text)
1315 | elif data_type == "multimodal":
1316 | return self.encode_multimodal(images, text)
1317 |
1318 |
1319 | @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING)
1320 | @replace_return_docstrings(output_type=CLIPOutput, config_class=CLIPConfig)
1321 | def forward(
1322 | self,
1323 | input_ids: Optional[torch.LongTensor] = None,
1324 | pixel_values: Optional[torch.FloatTensor] = None,
1325 | attention_mask: Optional[torch.Tensor] = None,
1326 | position_ids: Optional[torch.LongTensor] = None,
1327 | return_loss: Optional[bool] = None,
1328 | output_attentions: Optional[bool] = None,
1329 | output_hidden_states: Optional[bool] = None,
1330 | return_dict: Optional[bool] = None,
1331 | ) -> Union[Tuple, CLIPOutput]:
1332 | r"""
1333 | Returns:
1334 |
1335 | Examples:
1336 |
1337 | ```python
1338 | >>> from PIL import Image
1339 | >>> import requests
1340 | >>> from transformers import AutoProcessor, CLIPModel
1341 |
1342 | >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
1343 | >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1344 |
1345 | >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1346 | >>> image = Image.open(requests.get(url, stream=True).raw)
1347 |
1348 | >>> inputs = processor(
1349 | ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
1350 | ... )
1351 |
1352 | >>> outputs = model(**inputs)
1353 | >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
1354 | >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
1355 | ```"""
1356 | # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
1357 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1358 | output_hidden_states = (
1359 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1360 | )
1361 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1362 |
1363 | vision_outputs = self.vision_model(
1364 | pixel_values=pixel_values,
1365 | output_attentions=output_attentions,
1366 | output_hidden_states=output_hidden_states,
1367 | return_dict=return_dict,
1368 | )
1369 |
1370 | text_outputs = self.text_model(
1371 | input_ids=input_ids,
1372 | attention_mask=attention_mask,
1373 | position_ids=position_ids,
1374 | output_attentions=output_attentions,
1375 | output_hidden_states=output_hidden_states,
1376 | return_dict=return_dict,
1377 | )
1378 |
1379 | image_embeds = vision_outputs[1]
1380 | image_embeds = self.visual_projection(image_embeds)
1381 |
1382 | text_embeds = text_outputs[1]
1383 | text_embeds = self.text_projection(text_embeds)
1384 |
1385 | # normalized features
1386 | image_embeds = image_embeds / _get_vector_norm(image_embeds)
1387 | text_embeds = text_embeds / _get_vector_norm(text_embeds)
1388 |
1389 | # cosine similarity as logits
1390 | logit_scale = self.logit_scale.exp()
1391 | logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) * logit_scale.to(
1392 | text_embeds.device
1393 | )
1394 | logits_per_image = logits_per_text.t()
1395 |
1396 | loss = None
1397 | if return_loss:
1398 | loss = clip_loss(logits_per_text)
1399 |
1400 | if not return_dict:
1401 | output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1402 | return ((loss,) + output) if loss is not None else output
1403 |
1404 | return CLIPOutput(
1405 | loss=loss,
1406 | logits_per_image=logits_per_image,
1407 | logits_per_text=logits_per_text,
1408 | text_embeds=text_embeds,
1409 | image_embeds=image_embeds,
1410 | text_model_output=text_outputs,
1411 | vision_model_output=vision_outputs,
1412 | )
1413 |
1414 |
1415 | @add_start_docstrings(
1416 | """
1417 | CLIP Text Model with a projection layer on top (a linear layer on top of the pooled output).
1418 | """,
1419 | CLIP_START_DOCSTRING,
1420 | )
1421 | class CLIPTextModelWithProjection(CLIPPreTrainedModel):
1422 | config_class = CLIPTextConfig
1423 |
1424 | _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
1425 |
1426 | def __init__(self, config: CLIPTextConfig):
1427 | super().__init__(config)
1428 |
1429 | text_model = CLIPTextModel._from_config(config, attn_implementation=config._attn_implementation)
1430 | self.text_model = text_model.text_model
1431 |
1432 | self.text_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
1433 |
1434 | # Initialize weights and apply final processing
1435 | self.post_init()
1436 |
1437 | def get_input_embeddings(self) -> nn.Module:
1438 | return self.text_model.embeddings.token_embedding
1439 |
1440 | def set_input_embeddings(self, value):
1441 | self.text_model.embeddings.token_embedding = value
1442 |
1443 | @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
1444 | @replace_return_docstrings(output_type=CLIPTextModelOutput, config_class=CLIPTextConfig)
1445 | def forward(
1446 | self,
1447 | input_ids: Optional[torch.Tensor] = None,
1448 | attention_mask: Optional[torch.Tensor] = None,
1449 | position_ids: Optional[torch.Tensor] = None,
1450 | output_attentions: Optional[bool] = None,
1451 | output_hidden_states: Optional[bool] = None,
1452 | return_dict: Optional[bool] = None,
1453 | ) -> Union[Tuple, CLIPTextModelOutput]:
1454 | r"""
1455 | Returns:
1456 |
1457 | Examples:
1458 |
1459 | ```python
1460 | >>> from transformers import AutoTokenizer, CLIPTextModelWithProjection
1461 |
1462 | >>> model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
1463 | >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
1464 |
1465 | >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
1466 |
1467 | >>> outputs = model(**inputs)
1468 | >>> text_embeds = outputs.text_embeds
1469 | ```"""
1470 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1471 |
1472 | text_outputs = self.text_model(
1473 | input_ids=input_ids,
1474 | attention_mask=attention_mask,
1475 | position_ids=position_ids,
1476 | output_attentions=output_attentions,
1477 | output_hidden_states=output_hidden_states,
1478 | return_dict=return_dict,
1479 | )
1480 |
1481 | pooled_output = text_outputs[1]
1482 |
1483 | text_embeds = self.text_projection(pooled_output)
1484 |
1485 | if not return_dict:
1486 | outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
1487 | return tuple(output for output in outputs if output is not None)
1488 |
1489 | return CLIPTextModelOutput(
1490 | text_embeds=text_embeds,
1491 | last_hidden_state=text_outputs.last_hidden_state,
1492 | hidden_states=text_outputs.hidden_states,
1493 | attentions=text_outputs.attentions,
1494 | )
1495 |
1496 |
1497 | @add_start_docstrings(
1498 | """
1499 | CLIP Vision Model with a projection layer on top (a linear layer on top of the pooled output).
1500 | """,
1501 | CLIP_START_DOCSTRING,
1502 | )
1503 | class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
1504 | config_class = CLIPVisionConfig
1505 | main_input_name = "pixel_values"
1506 |
1507 | def __init__(self, config: CLIPVisionConfig):
1508 | super().__init__(config)
1509 |
1510 | vision_model = CLIPVisionModel._from_config(config, attn_implementation=config._attn_implementation)
1511 | self.vision_model = vision_model.vision_model
1512 |
1513 | self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
1514 |
1515 | # Initialize weights and apply final processing
1516 | self.post_init()
1517 |
1518 | def get_input_embeddings(self) -> nn.Module:
1519 | return self.vision_model.embeddings.patch_embedding
1520 |
1521 | @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
1522 | @replace_return_docstrings(output_type=CLIPVisionModelOutput, config_class=CLIPVisionConfig)
1523 | def forward(
1524 | self,
1525 | pixel_values: Optional[torch.FloatTensor] = None,
1526 | output_attentions: Optional[bool] = None,
1527 | output_hidden_states: Optional[bool] = None,
1528 | return_dict: Optional[bool] = None,
1529 | ) -> Union[Tuple, CLIPVisionModelOutput]:
1530 | r"""
1531 | Returns:
1532 |
1533 | Examples:
1534 |
1535 | ```python
1536 | >>> from PIL import Image
1537 | >>> import requests
1538 | >>> from transformers import AutoProcessor, CLIPVisionModelWithProjection
1539 |
1540 | >>> model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
1541 | >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1542 |
1543 | >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1544 | >>> image = Image.open(requests.get(url, stream=True).raw)
1545 |
1546 | >>> inputs = processor(images=image, return_tensors="pt")
1547 |
1548 | >>> outputs = model(**inputs)
1549 | >>> image_embeds = outputs.image_embeds
1550 | ```"""
1551 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1552 |
1553 | vision_outputs = self.vision_model(
1554 | pixel_values=pixel_values,
1555 | output_attentions=output_attentions,
1556 | output_hidden_states=output_hidden_states,
1557 | return_dict=return_dict,
1558 | )
1559 |
1560 | pooled_output = vision_outputs[1] # pooled_output
1561 |
1562 | image_embeds = self.visual_projection(pooled_output)
1563 |
1564 | if not return_dict:
1565 | outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:]
1566 | return tuple(output for output in outputs if output is not None)
1567 |
1568 | return CLIPVisionModelOutput(
1569 | image_embeds=image_embeds,
1570 | last_hidden_state=vision_outputs.last_hidden_state,
1571 | hidden_states=vision_outputs.hidden_states,
1572 | attentions=vision_outputs.attentions,
1573 | )
1574 |
1575 |
1576 | @add_start_docstrings(
1577 | """
1578 | CLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
1579 | the patch tokens) e.g. for ImageNet.
1580 | """,
1581 | CLIP_START_DOCSTRING,
1582 | )
1583 | class CLIPForImageClassification(CLIPPreTrainedModel):
1584 | main_input_name = "pixel_values"
1585 |
1586 | def __init__(self, config: CLIPConfig) -> None:
1587 | super().__init__(config)
1588 |
1589 | self.num_labels = config.num_labels
1590 | vision_model = CLIPVisionModel._from_config(
1591 | config.vision_config, attn_implementation=config._attn_implementation
1592 | )
1593 | self.vision_model = vision_model.vision_model
1594 |
1595 | # Classifier head
1596 | self.classifier = (
1597 | nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
1598 | )
1599 |
1600 | # Initialize weights and apply final processing
1601 | self.post_init()
1602 |
1603 | @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING)
1604 | @add_code_sample_docstrings(
1605 | checkpoint=_IMAGE_CLASS_CHECKPOINT,
1606 | output_type=ImageClassifierOutput,
1607 | config_class=_CONFIG_FOR_DOC,
1608 | expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
1609 | )
1610 | def forward(
1611 | self,
1612 | pixel_values: Optional[torch.Tensor] = None,
1613 | labels: Optional[torch.Tensor] = None,
1614 | output_attentions: Optional[bool] = None,
1615 | output_hidden_states: Optional[bool] = None,
1616 | return_dict: Optional[bool] = None,
1617 | ) -> Union[tuple, ImageClassifierOutput]:
1618 | r"""
1619 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1620 | Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
1621 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1622 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1623 | """
1624 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1625 | output_hidden_states = (
1626 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1627 | )
1628 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1629 |
1630 | outputs = self.vision_model(
1631 | pixel_values,
1632 | output_attentions=output_attentions,
1633 | output_hidden_states=output_hidden_states,
1634 | return_dict=return_dict,
1635 | )
1636 |
1637 | sequence_output = outputs[0]
1638 |
1639 | # average pool the patch tokens
1640 | sequence_output = torch.mean(sequence_output[:, 1:, :], dim=1)
1641 | # apply classifier
1642 | logits = self.classifier(sequence_output)
1643 |
1644 | loss = None
1645 | if labels is not None:
1646 | # move labels to correct device to enable model parallelism
1647 | labels = labels.to(logits.device)
1648 | if self.config.problem_type is None:
1649 | if self.num_labels == 1:
1650 | self.config.problem_type = "regression"
1651 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1652 | self.config.problem_type = "single_label_classification"
1653 | else:
1654 | self.config.problem_type = "multi_label_classification"
1655 |
1656 | if self.config.problem_type == "regression":
1657 | loss_fct = MSELoss()
1658 | if self.num_labels == 1:
1659 | loss = loss_fct(logits.squeeze(), labels.squeeze())
1660 | else:
1661 | loss = loss_fct(logits, labels)
1662 | elif self.config.problem_type == "single_label_classification":
1663 | loss_fct = CrossEntropyLoss()
1664 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1665 | elif self.config.problem_type == "multi_label_classification":
1666 | loss_fct = BCEWithLogitsLoss()
1667 | loss = loss_fct(logits, labels)
1668 |
1669 | if not return_dict:
1670 | output = (logits,) + outputs[2:]
1671 | return ((loss,) + output) if loss is not None else output
1672 |
1673 | return ImageClassifierOutput(
1674 | loss=loss,
1675 | logits=logits,
1676 | hidden_states=outputs.hidden_states,
1677 | attentions=outputs.attentions,
1678 | )
1679 |
--------------------------------------------------------------------------------