├── assets ├── dummy.txt ├── cala_workflow.jpg └── cala_workflow.pdf ├── cala_workflow.jpg ├── LICENSE ├── src ├── twin_attention_compositor_clip.py ├── twin_attention_compositor_blip2.py ├── artemis.py ├── hinge_based_cross_attention_blip2.py ├── hinge_based_cross_attention_clip.py ├── cirr_test_submission_blip2.py ├── cirr_test_submission.py ├── cirr_artemis_submission_blip2.py ├── utils.py ├── data_utils.py ├── validate.py └── blip_fine_tune.py └── README.md /assets/dummy.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /cala_workflow.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chiangsonw/CaLa/HEAD/cala_workflow.jpg -------------------------------------------------------------------------------- /assets/cala_workflow.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chiangsonw/CaLa/HEAD/assets/cala_workflow.jpg -------------------------------------------------------------------------------- /assets/cala_workflow.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chiangsonw/CaLa/HEAD/assets/cala_workflow.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 MadChiang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/twin_attention_compositor_clip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import sys 5 | from utils import device 6 | 7 | class TwinAttentionCompositorCLIP(nn.Module): 8 | def __init__(self) -> None: 9 | super().__init__() 10 | 11 | self.fc = nn.Linear(2560,640) 12 | self.relu = nn.ReLU(inplace=True) 13 | self.reference_as_query_attention = torch.nn.MultiheadAttention(embed_dim=640, num_heads=1, dropout=0.0, batch_first=True) 14 | self.target_as_query_attention = torch.nn.MultiheadAttention(embed_dim=640, num_heads=1, dropout=0.0, batch_first=True) 15 | 16 | def forward(self, reference_embeddings:torch.tensor, target_embeddings:torch.tensor): 17 | bs, hi, h, w = reference_embeddings.size() 18 | #embeddings to tokens bs x length x hidden bs 81 2560 19 | reference_embeddings = reference_embeddings.view(bs,h*w,hi) 20 | target_embeddings = target_embeddings.view(bs,h*w,hi) 21 | #dim compact bs 81 640 linear降维 22 | reference_tokens = self.relu(self.fc(reference_embeddings)) 23 | target_tokens =self.relu(self.fc(target_embeddings)) 24 | cls_token = torch.randn(bs, 1, 640).to(device, non_blocking=True) 25 | #cat cls token bs 82 640 26 | reference_tokens = torch.cat([cls_token, reference_tokens], dim=1) 27 | target_tokens = torch.cat([cls_token, target_tokens], dim=1) 28 | 29 | # 4 layers 30 | output1, _ = self.reference_as_query_attention(query=reference_tokens, key=target_tokens, value=target_tokens) 31 | output1, _ = self.reference_as_query_attention(query=reference_tokens, key=output1, value=output1) 32 | output1, _ = self.reference_as_query_attention(query=reference_tokens, key=output1, value=output1) 33 | output1, _ = self.reference_as_query_attention(query=reference_tokens, key=output1, value=output1) 34 | 35 | #4 layers 36 | output2, _ = self.target_as_query_attention(query=target_tokens, key=reference_tokens, value=reference_tokens) 37 | output2, _ = self.target_as_query_attention(query=target_tokens, key=output2, value=output2) 38 | output2, _ = self.target_as_query_attention(query=target_tokens, key=output2, value=output2) 39 | output2, _ = self.target_as_query_attention(query=target_tokens, key=output2, value=output2) 40 | 41 | 42 | output1_features = output1[:,0,:] 43 | output2_features = output2[:,0,:] 44 | output_features = (output1_features + output2_features) / 2 45 | return output_features 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /src/twin_attention_compositor_blip2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from utils import device 5 | 6 | class TwinAttentionCompositorBLIP2(nn.Module): 7 | def __init__(self, embedding_dim) -> None: 8 | super().__init__() 9 | # self.fusion = nn.Linear(512,256) 10 | # self.relu1 = nn.ReLU(inplace=True) 11 | self.reference_as_query_attention = torch.nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=1, dropout=0.0, batch_first=True) 12 | self.target_as_query_attention = torch.nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=1, dropout=0.0, batch_first=True) 13 | 14 | def forward(self, reference_embeddings:torch.tensor, target_embeddings:torch.tensor): 15 | #embeddings to tokens bs x length x hidden bs x 32 x 256 16 | 17 | # 4 layers of attention 18 | output1, _ = self.reference_as_query_attention(query=reference_embeddings, key=target_embeddings, value=target_embeddings) 19 | output1, _ = self.reference_as_query_attention(query=reference_embeddings, key=output1, value=output1) 20 | output1, _ = self.reference_as_query_attention(query=reference_embeddings, key=output1, value=output1) 21 | output1, _ = self.reference_as_query_attention(query=reference_embeddings, key=output1, value=output1) 22 | 23 | # 4 layers of attention 24 | output2, _ = self.target_as_query_attention(query=target_embeddings, key=reference_embeddings, value=reference_embeddings) 25 | output2, _ = self.target_as_query_attention(query=target_embeddings, key=output2, value=output2) 26 | output2, _ = self.target_as_query_attention(query=target_embeddings, key=output2, value=output2) 27 | output2, _ = self.target_as_query_attention(query=target_embeddings, key=output2, value=output2) 28 | 29 | # share weight 30 | # output2, _ = self.reference_as_query_attention(query=target_embeddings, key=reference_embeddings, value=reference_embeddings) 31 | # output2, _ = self.reference_as_query_attention(query=target_embeddings, key=output2, value=output2) 32 | # output2, _ = self.reference_as_query_attention(query=target_embeddings, key=output2, value=output2) 33 | # output2, _ = self.reference_as_query_attention(query=target_embeddings, key=output2, value=output2) 34 | 35 | # use 0 token 作为 features bs x 256 两个features平均 36 | output1_features = output1[:,0,:] 37 | output2_features = output2[:,0,:] 38 | output_features = (output1_features + output2_features) / 2 39 | return output_features 40 | -------------------------------------------------------------------------------- /src/artemis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from utils import l2norm 5 | 6 | class Artemis(nn.Module): 7 | 8 | def __init__(self, clip_feature_dim) -> None: 9 | super().__init__() 10 | self.embed_dim = clip_feature_dim 11 | self.Transform_m = nn.Sequential(nn.Linear(self.embed_dim, self.embed_dim), L2Module()) 12 | self.Attention_EM = AttentionMechanism(self.embed_dim) 13 | self.Attention_IS = AttentionMechanism(self.embed_dim) 14 | self.temperature = nn.Parameter(torch.FloatTensor((2.65926,))) 15 | 16 | def apply_attention(self, a, x): 17 | return l2norm(a * x) 18 | 19 | def compute_score_artemis(self, r, m, t, store_intermediary=False): 20 | EM = self.compute_score_EM(r, m, t, store_intermediary) 21 | IS = self.compute_score_IS(r, m, t, store_intermediary) 22 | if store_intermediary: 23 | self.hold_results["EM"] = EM 24 | self.hold_results["IS"] = IS 25 | return EM + IS 26 | 27 | def compute_score_broadcast_artemis(self, r, m, t): 28 | return self.compute_score_broadcast_EM(r, m, t) + self.compute_score_broadcast_IS(r, m, t) 29 | 30 | def compute_score_EM(self, r, m, t, store_intermediary=False): 31 | Tr_m = self.Transform_m(m) 32 | A_EM_t = self.apply_attention(self.Attention_EM(m), t) 33 | if store_intermediary: 34 | self.hold_results["Tr_m"] = Tr_m 35 | self.hold_results["A_EM_t"] = A_EM_t 36 | return (Tr_m * A_EM_t).sum(-1) 37 | 38 | def compute_score_broadcast_EM(self, r, m, t): 39 | batch_size = r.size(0) 40 | A_EM = self.Attention_EM(m) # shape (Bq, d) 41 | Tr_m = self.Transform_m(m) # shape (Bq, d) 42 | # apply each query attention mechanism to all targets 43 | A_EM_all_t = self.apply_attention(A_EM.view(batch_size, 1, self.embed_dim), t.view(1, batch_size, self.embed_dim)) # shape (Bq, Bt, d) 44 | EM_score = (Tr_m.view(batch_size, 1, self.embed_dim) * A_EM_all_t).sum(-1) # shape (Bq, Bt) ; coefficient (i,j) is the IS score between query i and target j 45 | return EM_score 46 | 47 | def compute_score_IS(self, r, m, t, store_intermediary=False): 48 | A_IS_r = self.apply_attention(self.Attention_IS(m), r) 49 | A_IS_t = self.apply_attention(self.Attention_IS(m), t) 50 | if store_intermediary: 51 | self.hold_results["A_IS_r"] = A_IS_r 52 | self.hold_results["A_IS_t"] = A_IS_t 53 | return (A_IS_r * A_IS_t).sum(-1) 54 | 55 | def compute_score_broadcast_IS(self, r, m, t): 56 | batch_size = r.size(0) 57 | A_IS = self.Attention_IS(m) # shape (Bq, d) 58 | A_IS_r = self.apply_attention(A_IS, r) # shape (Bq, d) 59 | # apply each query attention mechanism to all targets 60 | A_IS_all_t = self.apply_attention(A_IS.view(batch_size, 1, self.embed_dim), t.view(1, batch_size, self.embed_dim)) # shape (Bq, Bt, d) 61 | IS_score = (A_IS_r.view(batch_size, 1, self.embed_dim) * A_IS_all_t).sum(-1) # shape (Bq, Bt) ; coefficient (i,j) is the IS score between query i and target j 62 | return IS_score 63 | 64 | 65 | class L2Module(nn.Module): 66 | 67 | def __init__(self): 68 | super(L2Module, self).__init__() 69 | 70 | def forward(self, x): 71 | x = l2norm(x) 72 | return x 73 | 74 | class AttentionMechanism(nn.Module): 75 | """ 76 | Module defining the architecture of the attention mechanisms in ARTEMIS. 77 | """ 78 | 79 | def __init__(self, embed_dim): 80 | super(AttentionMechanism, self).__init__() 81 | 82 | self.embed_dim = embed_dim 83 | input_dim = self.embed_dim 84 | 85 | self.attention = nn.Sequential( 86 | nn.Linear(input_dim, self.embed_dim), 87 | nn.ReLU(), 88 | nn.Linear(self.embed_dim, self.embed_dim), 89 | nn.Softmax(dim=1) 90 | ) 91 | 92 | def forward(self, x): 93 | return self.attention(x) -------------------------------------------------------------------------------- /src/hinge_based_cross_attention_blip2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from utils import device 5 | import math 6 | 7 | class HingebasedCrossAttentionBLIP2(nn.Module): 8 | def __init__(self, embed_dim) -> None: 9 | super().__init__() 10 | # attention proj 11 | self.query_ref1 = nn.Linear(embed_dim,embed_dim) 12 | self.key_text1 = nn.Linear(embed_dim,embed_dim) 13 | # self.query_text1 = nn.Linear(embed_dim,embed_dim) 14 | self.key_tar1 = nn.Linear(embed_dim,embed_dim) 15 | self.value1 = nn.Linear(embed_dim,embed_dim) 16 | self.dropout1 = nn.Dropout(0.1) 17 | 18 | self.query_ref2 = nn.Linear(embed_dim,embed_dim) 19 | self.key_text2 = nn.Linear(embed_dim,embed_dim) 20 | self.key_tar2 = nn.Linear(embed_dim,embed_dim) 21 | self.value2 = nn.Linear(embed_dim,embed_dim) 22 | self.dropout2 = nn.Dropout(0.1) 23 | 24 | def forward(self, reference_embeds, caption_embeds, target_embeds): 25 | psudo_T = self.hca_T_share_text(reference_embeds, caption_embeds, target_embeds) 26 | return psudo_T 27 | 28 | def hca_T_share_text(self, reference_embeds, caption_embeds, target_embeds): 29 | bs , len_r , dim = reference_embeds.shape 30 | attA = self.multiply(self.query_ref1(reference_embeds), self.key_text1(caption_embeds)) / math.sqrt(dim) 31 | attB = self.multiply(self.key_text1(caption_embeds), self.key_tar1(target_embeds)) / math.sqrt(dim) 32 | attC = self.dropout1(F.softmax(torch.matmul(attA, attB), dim=-1)) 33 | psudo_T = torch.matmul(attC , self.value1(target_embeds)) 34 | return psudo_T[:,0,:] 35 | 36 | def hca_T_R_share_text(self, reference_embeds, caption_embeds, target_embeds): 37 | bs , len_r , dim = reference_embeds.shape 38 | attA1 = self.multiply(self.query_ref1(reference_embeds), self.key_text1(caption_embeds)) / math.sqrt(dim) 39 | attB1 = self.multiply(self.key_text1(caption_embeds), self.key_tar1(target_embeds)) / math.sqrt(dim) 40 | attC1 = self.dropout1(F.softmax(torch.matmul(attA1, attB1), dim=-1)) 41 | psudo_T = torch.matmul(attC1 , self.value1(target_embeds)) 42 | 43 | attA2 = self.multiply(self.query_ref2(target_embeds), self.key_text2(caption_embeds)) / math.sqrt(dim) 44 | attB2 = self.multiply(self.key_text2(caption_embeds), self.key_tar2(reference_embeds)) / math.sqrt(dim) 45 | attC2 = self.dropout2(F.softmax(torch.matmul(attA2, attB2), dim=-1)) 46 | psudo_R = torch.matmul(attC2 , self.value2(reference_embeds)) 47 | 48 | return psudo_T[:,0,:], psudo_R[:,0,:] 49 | 50 | def hca_T_multihead_4layer(self, reference_embeds, caption_embeds, target_embeds): 51 | bs , len_r , dim = reference_embeds.shape 52 | attA1 = self.multiply(self.query_ref1(reference_embeds), self.key_text1(caption_embeds)) / math.sqrt(dim) 53 | attB1 = self.multiply(self.key_text1(caption_embeds), self.key_tar1(target_embeds)) / math.sqrt(dim) 54 | attC1 = self.dropout1(F.softmax(torch.matmul(attA1, attB1), dim=-1)) 55 | psudo_T1 = torch.matmul(attC1 , self.value1(target_embeds)) 56 | 57 | attA2 = self.multiply(self.query_ref2(reference_embeds), self.key_text2(caption_embeds)) / math.sqrt(dim) 58 | attB2 = self.multiply(self.key_text2(caption_embeds), self.key_tar2(target_embeds)) / math.sqrt(dim) 59 | attC2 = self.dropout2(F.softmax(torch.matmul(attA2, attB2), dim=-1)) 60 | psudo_T2 = torch.matmul(attC2 , self.value2(target_embeds)) 61 | 62 | attA3 = self.multiply(self.query_ref3(reference_embeds), self.key_text3(caption_embeds)) / math.sqrt(dim) 63 | attB3 = self.multiply(self.key_text3(caption_embeds), self.key_tar3(target_embeds)) / math.sqrt(dim) 64 | attC3 = self.dropout3(F.softmax(torch.matmul(attA3, attB3), dim=-1)) 65 | psudo_T3 = torch.matmul(attC3 , self.value3(target_embeds)) 66 | 67 | attA4 = self.multiply(self.query_ref4(reference_embeds), self.key_text4(caption_embeds)) / math.sqrt(dim) 68 | attB4 = self.multiply(self.key_text4(caption_embeds), self.key_tar4(target_embeds)) / math.sqrt(dim) 69 | attC4 = self.dropout4(F.softmax(torch.matmul(attA4, attB4), dim=-1)) 70 | psudo_T4 = torch.matmul(attC4 , self.value4(target_embeds)) 71 | 72 | return (psudo_T1[:,0,:] + psudo_T2[:,0,:] + psudo_T3[:,0,:] + psudo_T4[:,0,:]) / 4 73 | 74 | def hca_T_multihead_2layer(self, reference_embeds, caption_embeds, target_embeds): 75 | bs , len_r , dim = reference_embeds.shape 76 | attA1 = self.multiply(self.query_ref1(reference_embeds), self.key_text1(caption_embeds)) / math.sqrt(dim) 77 | attB1 = self.multiply(self.key_text1(caption_embeds), self.key_tar1(target_embeds)) / math.sqrt(dim) 78 | attC1 = self.dropout1(F.softmax(torch.matmul(attA1, attB1), dim=-1)) 79 | psudo_T1 = torch.matmul(attC1 , self.value1(target_embeds)) 80 | 81 | attA2 = self.multiply(self.query_ref2(reference_embeds), self.key_text2(caption_embeds)) / math.sqrt(dim) 82 | attB2 = self.multiply(self.key_text2(caption_embeds), self.key_tar2(target_embeds)) / math.sqrt(dim) 83 | attC2 = self.dropout2(F.softmax(torch.matmul(attA2, attB2), dim=-1)) 84 | psudo_T2 = torch.matmul(attC2 , self.value2(target_embeds)) 85 | 86 | return (psudo_T1[:,0,:] + psudo_T2[:,0,:]) / 2 87 | 88 | def multiply(self, embedsA, embedsB): 89 | bs, len_a , dim = embedsA.shape 90 | bs, len_b , dim = embedsB.shape 91 | 92 | # 扁平化 93 | embedsA = embedsA.view(bs, -1, dim) # 形状为 bs x (length_a * dim) 94 | embedsB = embedsB.view(bs, -1, dim) # 形状为 bs x (length_b * dim) 95 | 96 | # 点积计算 97 | attention_scores_flat = torch.matmul(embedsA, embedsB.transpose(-1, -2)) # 转置 Key 的维度 98 | 99 | # 还原形状 100 | attention_scores = attention_scores_flat.view(bs, len_a, len_b) 101 | 102 | return attention_scores 103 | 104 | 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CaLa 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/cala-complementary-association-learning-for/image-retrieval-on-cirr)](https://paperswithcode.com/sota/image-retrieval-on-cirr?p=cala-complementary-association-learning-for) 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/cala-complementary-association-learning-for/image-retrieval-on-fashion-iq)](https://paperswithcode.com/sota/image-retrieval-on-fashion-iq?p=cala-complementary-association-learning-for) 5 | 6 | **CaLa(ACM SIGIR 2024)** is a new composed image retrieval framework, considering two complementary associations in the task. CaLa presents TBIA(text-based image alignment) and CTR(complementary text reasoning) for augmenting composed image retrieval. 7 | 8 | We highlight the contributions of this paper as follows: 9 | 10 | • We present a new thinking of composed image retrieval,the annotated triplet is viewed as a graph node, and two complementary association clues are disclosed to enhance the composed image retrieval. 11 | 12 | • A hinge-based attention and twin-attention-based visual compositor are proposed to effectively impose the new associations into the network learning. 13 | 14 | • Competitive Performance on CIRR and FashionIQ benchmarks. CaLa can benefit several baselines with different backbones and architectures, revealing it is a widely beneficial module for composed image retrieval. 15 | 16 | More details can be found at our paper: [CaLa: Complementary Association Learning for Augmenting Composed Image Retrieval](https://arxiv.org/pdf/2405.19149) 17 | 18 | This is the workflow of our CaLa framework. 19 |
20 | 21 | ## News 22 | 23 | ## Models and Weights 24 | 25 | ## Usage 26 | 27 | ### Prerequisites 28 | 29 | The following commands will create a local Anaconda environment with the necessary packages installed. 30 | 31 | ```bash 32 | conda create -n cala -y python=3.8 33 | conda activate cala 34 | conda install -y -c pytorch pytorch=1.11.0 torchvision=0.12.0 35 | conda install -y -c anaconda pandas=1.4.2 36 | pip install comet-ml==3.21.0 37 | pip install git+https://github.com/openai/CLIP.git 38 | pip install salesforce-lavis 39 | ``` 40 | 41 | ### Data Preparation 42 | 43 | To properly work with the codebase FashionIQ and CIRR datasets should have the following structure: 44 | 45 | ``` 46 | project_base_path 47 | └─── CaLa 48 | └─── src 49 | | blip_fine_tune.py 50 | | data_utils.py 51 | | utils.py 52 | | ... 53 | 54 | └─── fashionIQ_dataset 55 | └─── captions 56 | | cap.dress.test.json 57 | | cap.dress.train.json 58 | | cap.dress.val.json 59 | | ... 60 | 61 | └─── images 62 | | B00006M009.jpg 63 | | B00006M00B.jpg 64 | | B00006M6IH.jpg 65 | | ... 66 | 67 | └─── image_splits 68 | | split.dress.test.json 69 | | split.dress.train.json 70 | | split.dress.val.json 71 | | ... 72 | 73 | └─── cirr_dataset 74 | └─── train 75 | └─── 0 76 | | train-10108-0-img0.png 77 | | train-10108-0-img1.png 78 | | train-10108-1-img0.png 79 | | ... 80 | 81 | └─── 1 82 | | train-10056-0-img0.png 83 | | train-10056-0-img1.png 84 | | train-10056-1-img0.png 85 | | ... 86 | 87 | ... 88 | 89 | └─── dev 90 | | dev-0-0-img0.png 91 | | dev-0-0-img1.png 92 | | dev-0-1-img0.png 93 | | ... 94 | 95 | └─── test1 96 | | test1-0-0-img0.png 97 | | test1-0-0-img1.png 98 | | test1-0-1-img0.png 99 | | ... 100 | 101 | └─── cirr 102 | └─── captions 103 | | cap.rc2.test1.json 104 | | cap.rc2.train.json 105 | | cap.rc2.val.json 106 | 107 | └─── image_splits 108 | | split.rc2.test1.json 109 | | split.rc2.train.json 110 | | split.rc2.val.json 111 | ``` 112 | 113 | 114 | ### Adjustments for dependencies 115 | 116 | For finetuning blip2 encoderds, you need to comment out this code in lavis within your conda enviroment. 117 | ```python 118 | # In lavis/models/blip2_models/blip2_qformer.py line 367 119 | # @torch.no_grad() # commemt out this line. 120 | ``` 121 | Comment out this code to calculate the gradient of the blip2-model to update the parameters. 122 | 123 | For finetuning clip encoders, you need to replace with these codes in the clip packages, thus RN50x4 features can interact with Qformers. 124 | ```python 125 | # Replace CLIP/clip/models.py line 152-154 with the following codes. 126 | 152# x = self.attnpool(x) 127 | 153# 128 | 154# return x 129 | 130 | 152# y=x 131 | 153# x = self.attnpool(x) 132 | 154# 133 | 155# return x,y 134 | 135 | # Replace CLIP/clip/models.py line 343-356 with the following codes. Before get the cls token, get the feature sequence of text as text global features. 136 | 346# x = x + self.positional_embedding.type(self.dtype) 137 | 347# x = x.permute(1, 0, 2) # NLD -> LND 138 | 348# x = self.transformer(x) 139 | 349# x = x.permute(1, 0, 2) # LND -> NLD 140 | 350# x = self.ln_final(x).type(self.dtype) 141 | 351# 142 | 352# y = x 143 | 353# # x.shape = [batch_size, n_ctx, transformer.width] 144 | 354# # take features from the eot embedding (eot_token is the highest number in each sequence) 145 | 355# x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 146 | 147 | 356# return x,y 148 | 149 | ``` 150 | 151 | ### Training 152 | 153 | 154 | ```shell 155 | # cala finetune 156 | CUDA_VISIBLE_DEVICES='GPU_IDs' python src/blip_fine_tune.py --dataset {'CIRR' or 'FashionIQ'} \ 157 | --num-epochs 30 --batch-size 64 \ 158 | --max-epoch 15 --min-lr 0 \ 159 | --learning-rate 5e-6 \ 160 | --transform targetpad --target-ratio 1.25 \ 161 | --save-training --save-best --validation-frequency 1 \ 162 | --encoder {'both' or 'text' or 'multi'} \ 163 | --encoder-arch {clip or blip2} \ 164 | --cir-frame {sum or artemis} \ 165 | --tac-weight 0.45 \ 166 | --hca-weight 0.1 \ 167 | --embeds-dim {640 for clip and 768 for blip2} \ 168 | --model-name {RN50x4 for clip and None for blip} \ 169 | --api-key {Comet-api-key} \ 170 | --workspace {Comet-workspace} \ 171 | --experiment-name {Comet-experiment-name} \ 172 | ``` 173 | 174 | 175 | ### CIRR Testing 176 | 177 | 178 | ```shell 179 | CUDA_VISIBLE_DEVICES='GPU_IDs' python src/cirr_test_submission_blip2.py --submission-name {cirr_submission} \ 180 | --combining-function {sum or artemis} \ 181 | --blip2-textual-path {saved_blip2_textual.pt} \ 182 | --blip2-multimodal-path {saved_blip2_multimodal.pt} \ 183 | --blip2-visual-path {saved_blip2_visual.pt} 184 | 185 | ``` 186 | 187 | ```shell 188 | python src/validate.py 189 | --dataset {'CIRR' or 'FashionIQ'} \ 190 | --combining-function {'combiner' or 'sum'} \ 191 | --combiner-path {path to trained Combiner} \ 192 | --projection-dim 2560 \ 193 | --hidden-dim 5120 \ 194 | --clip-model-name RN50x4 \ 195 | --clip-model-path {path-to-fine-tuned-CLIP} \ 196 | --target-ratio 1.25 \ 197 | --transform targetpad 198 | ``` 199 | 200 | 201 | ## Reference 202 | If you use CaLa in your research, please cite it by the following BibTeX entry: 203 | 204 | ```bibtex 205 | @article{jiang2024cala, 206 | title={CaLa: Complementary Association Learning for Augmenting Composed Image Retrieval}, 207 | author={Jiang, Xintong and Wang, Yaxiong and Li, Mengjian and Wu, Yujiao and Hu, Bingwen and Qian, Xueming}, 208 | journal={arXiv preprint arXiv:2405.19149}, 209 | year={2024} 210 | } 211 | ``` 212 | 213 | ## Acknowledgement 214 | Our implementation is based on [CLIP4Cir](https://github.com/ABaldrati/CLIP4Cir) and [LAVIS](https://github.com/salesforce/LAVIS). 215 | -------------------------------------------------------------------------------- /src/hinge_based_cross_attention_clip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import sys 5 | from utils import device 6 | import math 7 | 8 | class HingebasedCrossAttentionCLIP(nn.Module): 9 | def __init__(self, embed_dim) -> None: 10 | super().__init__() 11 | # attention proj 12 | self.query_ref1 = nn.Linear(embed_dim,embed_dim) 13 | self.key_text1 = nn.Linear(embed_dim,embed_dim) 14 | # self.query_text1 = nn.Linear(embed_dim,embed_dim) 15 | self.key_tar1 = nn.Linear(embed_dim,embed_dim) 16 | self.value1 = nn.Linear(embed_dim,embed_dim) 17 | self.dropout1 = nn.Dropout(0.1) 18 | 19 | self.query_ref2 = nn.Linear(embed_dim,embed_dim) 20 | self.key_text2 = nn.Linear(embed_dim,embed_dim) 21 | self.key_tar2 = nn.Linear(embed_dim,embed_dim) 22 | self.value2 = nn.Linear(embed_dim,embed_dim) 23 | self.dropout2 = nn.Dropout(0.1) 24 | 25 | self.fc1 = nn.Linear(2560,640) 26 | self.relu1 = nn.ReLU(inplace=True) 27 | 28 | 29 | def forward(self, reference_embeds, caption_embeds, target_embeds): 30 | psudo_T = self.hca_T_share_text(reference_embeds, caption_embeds, target_embeds) 31 | return psudo_T 32 | 33 | def hca_T_share_text(self, reference_embeds, caption_embeds, target_embeds): 34 | 35 | bs, hi, h, w = reference_embeds.size() 36 | #embeddings to tokens bs x length x hidden bs 81 2560 37 | reference_embeds = reference_embeds.view(bs,h*w,hi) 38 | target_embeds = target_embeds.view(bs,h*w,hi) 39 | #dim compact bs 81 640 linear降维 40 | reference_embeds = self.relu1(self.fc1(reference_embeds)) 41 | target_embeds = self.relu1(self.fc1(target_embeds)) 42 | 43 | attA = self.multiply(self.query_ref1(reference_embeds), self.key_text1(caption_embeds)) / math.sqrt(640) 44 | attB = self.multiply(self.key_text1(caption_embeds), self.key_tar1(target_embeds)) / math.sqrt(640) 45 | attC = self.dropout1(F.softmax(torch.matmul(attA, attB), dim=-1)) 46 | psudo_T = torch.matmul(attC , self.value1(target_embeds)) 47 | return psudo_T[:,0,:] 48 | 49 | def hca_T_R_share_text(self, reference_embeds, caption_embeds, target_embeds): 50 | bs , len_r , dim = reference_embeds.shape 51 | attA1 = self.multiply(self.query_ref1(reference_embeds), self.key_text1(caption_embeds)) / math.sqrt(dim) 52 | attB1 = self.multiply(self.key_text1(caption_embeds), self.key_tar1(target_embeds)) / math.sqrt(dim) 53 | attC1 = self.dropout1(F.softmax(torch.matmul(attA1, attB1), dim=-1)) 54 | psudo_T = torch.matmul(attC1 , self.value1(target_embeds)) 55 | 56 | attA2 = self.multiply(self.query_ref2(target_embeds), self.key_text2(caption_embeds)) / math.sqrt(dim) 57 | attB2 = self.multiply(self.key_text2(caption_embeds), self.key_tar2(reference_embeds)) / math.sqrt(dim) 58 | attC2 = self.dropout2(F.softmax(torch.matmul(attA2, attB2), dim=-1)) 59 | psudo_R = torch.matmul(attC2 , self.value2(reference_embeds)) 60 | 61 | return psudo_T[:,0,:], psudo_R[:,0,:] 62 | 63 | def hca_T_multihead_4(self, reference_embeds, caption_embeds, target_embeds): 64 | bs , len_r , dim = reference_embeds.shape 65 | attA1 = self.multiply(self.query_ref1(reference_embeds), self.key_text1(caption_embeds)) / math.sqrt(dim) 66 | attB1 = self.multiply(self.key_text1(caption_embeds), self.key_tar1(target_embeds)) / math.sqrt(dim) 67 | attC1 = self.dropout1(F.softmax(torch.matmul(attA1, attB1), dim=-1)) 68 | psudo_T1 = torch.matmul(attC1 , self.value1(target_embeds)) 69 | 70 | attA2 = self.multiply(self.query_ref2(reference_embeds), self.key_text2(caption_embeds)) / math.sqrt(dim) 71 | attB2 = self.multiply(self.key_text2(caption_embeds), self.key_tar2(target_embeds)) / math.sqrt(dim) 72 | attC2 = self.dropout2(F.softmax(torch.matmul(attA2, attB2), dim=-1)) 73 | psudo_T2 = torch.matmul(attC2 , self.value2(target_embeds)) 74 | 75 | attA3 = self.multiply(self.query_ref3(reference_embeds), self.key_text3(caption_embeds)) / math.sqrt(dim) 76 | attB3 = self.multiply(self.key_text3(caption_embeds), self.key_tar3(target_embeds)) / math.sqrt(dim) 77 | attC3 = self.dropout3(F.softmax(torch.matmul(attA3, attB3), dim=-1)) 78 | psudo_T3 = torch.matmul(attC3 , self.value3(target_embeds)) 79 | 80 | attA4 = self.multiply(self.query_ref4(reference_embeds), self.key_text4(caption_embeds)) / math.sqrt(dim) 81 | attB4 = self.multiply(self.key_text4(caption_embeds), self.key_tar4(target_embeds)) / math.sqrt(dim) 82 | attC4 = self.dropout4(F.softmax(torch.matmul(attA4, attB4), dim=-1)) 83 | psudo_T4 = torch.matmul(attC4 , self.value4(target_embeds)) 84 | 85 | return (psudo_T1[:,0,:] + psudo_T2[:,0,:] + psudo_T3[:,0,:] + psudo_T4[:,0,:]) / 4 86 | 87 | def hca_T_multihead_2(self, reference_embeds, caption_embeds, target_embeds): 88 | bs , len_r , dim = reference_embeds.shape 89 | attA1 = self.multiply(self.query_ref1(reference_embeds), self.key_text1(caption_embeds)) / math.sqrt(dim) 90 | attB1 = self.multiply(self.key_text1(caption_embeds), self.key_tar1(target_embeds)) / math.sqrt(dim) 91 | attC1 = self.dropout1(F.softmax(torch.matmul(attA1, attB1), dim=-1)) 92 | psudo_T1 = torch.matmul(attC1 , self.value1(target_embeds)) 93 | 94 | attA2 = self.multiply(self.query_ref2(reference_embeds), self.key_text2(caption_embeds)) / math.sqrt(dim) 95 | attB2 = self.multiply(self.key_text2(caption_embeds), self.key_tar2(target_embeds)) / math.sqrt(dim) 96 | attC2 = self.dropout2(F.softmax(torch.matmul(attA2, attB2), dim=-1)) 97 | psudo_T2 = torch.matmul(attC2 , self.value2(target_embeds)) 98 | 99 | return (psudo_T1[:,0,:] + psudo_T2[:,0,:]) / 2 100 | 101 | # def rct_block_R(self, reference_embeds, caption_embeds, target_embeds): 102 | # bs , len_r , dim = reference_embeds.shape 103 | # attA = self.multiply(self.query1(target_embeds), self.key1(caption_embeds)) / math.sqrt(dim) 104 | # attB = self.multiply(self.query2(caption_embeds), self.key2(reference_embeds)) / math.sqrt(dim) 105 | 106 | # attC = self.dropout(F.softmax(torch.matmul(attA, attB), dim=-1)) 107 | # psudo_R = torch.matmul(attC , self.value(reference_embeds)) 108 | # return psudo_R 109 | 110 | # def rct_block_cap(self, reference_embeds, caption_embeds, target_embeds): 111 | # bs , len_r , dim = reference_embeds.shape 112 | # attA = self.multiply(self.query1(reference_embeds), self.key1(caption_embeds)) / math.sqrt(dim) 113 | # attB = self.multiply(self.query2(target_embeds), self.key2(caption_embeds)) / math.sqrt(dim) 114 | 115 | # attC = self.dropout(F.softmax(attA * attB, dim=-1)) 116 | # psudo_C = torch.matmul(attC , self.value(caption_embeds)) 117 | # return psudo_C 118 | 119 | 120 | # def rct_block_no_linear(self, reference_embeds, caption_embeds, target_embeds): 121 | # bs , len_r , dim = reference_embeds.shape 122 | # attA = self.multiply(reference_embeds, caption_embeds) / math.sqrt(dim) 123 | # attB = self.multiply(caption_embeds, target_embeds) / math.sqrt(dim) 124 | 125 | # attC = self.dropout(F.softmax(torch.matmul(attA, attB), dim=-1)) 126 | # psudo_T = torch.matmul(attC , (target_embeds)) 127 | # return psudo_T 128 | 129 | def multiply(self, embedsA, embedsB): 130 | bs, len_a , dim = embedsA.shape 131 | bs, len_b , dim = embedsB.shape 132 | 133 | # 扁平化 134 | embedsA = embedsA.view(bs, -1, dim) # 形状为 bs x (length_a * dim) 135 | embedsB = embedsB.view(bs, -1, dim) # 形状为 bs x (length_b * dim) 136 | 137 | # 点积计算 138 | attention_scores_flat = torch.matmul(embedsA, embedsB.transpose(-1, -2)) # 转置 Key 的维度 139 | 140 | # 还原形状 141 | attention_scores = attention_scores_flat.view(bs, len_a, len_b) 142 | 143 | return attention_scores 144 | 145 | 146 | -------------------------------------------------------------------------------- /src/cirr_test_submission_blip2.py: -------------------------------------------------------------------------------- 1 | import json 2 | import multiprocessing 3 | from argparse import ArgumentParser 4 | from operator import itemgetter 5 | from pathlib import Path 6 | from typing import List, Tuple, Dict 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.data import DataLoader 12 | from tqdm import tqdm 13 | 14 | from data_utils import CIRRDataset, targetpad_transform, squarepad_transform, base_path 15 | from utils import element_wise_sum, device, extract_index_features_blip2 16 | from lavis.models import load_model 17 | 18 | 19 | def generate_cirr_test_submissions(combining_function: callable, file_name: str, blip_textual, blip_multimodal, 20 | blip_visual, preprocess: callable): 21 | """ 22 | Generate and save CIRR test submission files to be submitted to evaluation server 23 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined 24 | features 25 | :param file_name: file_name of the submission 26 | :param clip_model: CLIP model 27 | :param preprocess: preprocess pipeline 28 | """ 29 | blip_textual = blip_textual.float().eval() 30 | blip_multimodal = blip_multimodal.float().eval() 31 | blip_visual = blip_visual.float().eval() 32 | 33 | # Define the dataset and extract index features 34 | classic_test_dataset = CIRRDataset('test1', 'classic', preprocess) 35 | index_features, index_names = extract_index_features_blip2(classic_test_dataset, blip_visual) 36 | relative_test_dataset = CIRRDataset('test1', 'relative', preprocess) 37 | 38 | # Generate test prediction dicts for CIRR 39 | pairid_to_predictions, pairid_to_group_predictions = generate_cirr_test_dicts(relative_test_dataset, blip_textual, blip_multimodal, 40 | index_features, index_names, 41 | combining_function) 42 | 43 | submission = { 44 | 'version': 'rc2', 45 | 'metric': 'recall' 46 | } 47 | group_submission = { 48 | 'version': 'rc2', 49 | 'metric': 'recall_subset' 50 | } 51 | 52 | submission.update(pairid_to_predictions) 53 | group_submission.update(pairid_to_group_predictions) 54 | 55 | # Define submission path 56 | submissions_folder_path = base_path / "submission" / 'CIRR' 57 | submissions_folder_path.mkdir(exist_ok=True, parents=True) 58 | 59 | print(f"Saving CIRR test predictions") 60 | with open(submissions_folder_path / f"recall_submission_{file_name}.json", 'w+') as file: 61 | json.dump(submission, file, sort_keys=True) 62 | 63 | with open(submissions_folder_path / f"recall_subset_submission_{file_name}.json", 'w+') as file: 64 | json.dump(group_submission, file, sort_keys=True) 65 | 66 | 67 | def generate_cirr_test_dicts(relative_test_dataset: CIRRDataset, blip_textual, blip_multimodal, index_features: torch.tensor, 68 | index_names: List[str], combining_function: callable) \ 69 | -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]: 70 | """ 71 | Compute test prediction dicts for CIRR dataset 72 | :param relative_test_dataset: CIRR test dataset in relative mode 73 | :param clip_model: CLIP model 74 | :param index_features: test index features 75 | :param index_names: test index names 76 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined 77 | features 78 | :return: Top50 global and Top3 subset prediction for each query (reference_name, caption) 79 | """ 80 | 81 | # Generate predictions 82 | predicted_features, reference_names, group_members, pairs_id = \ 83 | generate_cirr_test_predictions(blip_textual, blip_multimodal, relative_test_dataset, combining_function, index_names, 84 | index_features) 85 | 86 | print(f"Compute CIRR prediction dicts") 87 | 88 | # Normalize the index features 89 | index_features = F.normalize(index_features, dim=-1).float() 90 | 91 | # Compute the distances and sort the results 92 | distances = 1 - predicted_features @ index_features.T 93 | sorted_indices = torch.argsort(distances, dim=-1).cpu() 94 | sorted_index_names = np.array(index_names)[sorted_indices] 95 | 96 | # Delete the reference image from the results 97 | reference_mask = torch.tensor( 98 | sorted_index_names != np.repeat(np.array(reference_names), len(index_names)).reshape(len(sorted_index_names), 99 | -1)) 100 | sorted_index_names = sorted_index_names[reference_mask].reshape(sorted_index_names.shape[0], 101 | sorted_index_names.shape[1] - 1) 102 | # Compute the subset predictions 103 | group_members = np.array(group_members) 104 | group_mask = (sorted_index_names[..., None] == group_members[:, None, :]).sum(-1).astype(bool) 105 | sorted_group_names = sorted_index_names[group_mask].reshape(sorted_index_names.shape[0], -1) 106 | 107 | # Generate prediction dicts 108 | pairid_to_predictions = {str(int(pair_id)): prediction[:50].tolist() for (pair_id, prediction) in 109 | zip(pairs_id, sorted_index_names)} 110 | pairid_to_group_predictions = {str(int(pair_id)): prediction[:3].tolist() for (pair_id, prediction) in 111 | zip(pairs_id, sorted_group_names)} 112 | 113 | return pairid_to_predictions, pairid_to_group_predictions 114 | 115 | 116 | def generate_cirr_test_predictions(blip_textual, blip_multimodal, relative_test_dataset: CIRRDataset, combining_function: callable, 117 | index_names: List[str], index_features: torch.tensor) -> \ 118 | Tuple[torch.tensor, List[str], List[List[str]], List[str]]: 119 | """ 120 | Compute CIRR predictions on the test set 121 | :param clip_model: CLIP model 122 | :param relative_test_dataset: CIRR test dataset in relative mode 123 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined 124 | features 125 | :param index_features: test index features 126 | :param index_names: test index names 127 | 128 | :return: predicted_features, reference_names, group_members and pairs_id 129 | """ 130 | print(f"Compute CIRR test predictions") 131 | 132 | relative_test_loader = DataLoader(dataset=relative_test_dataset, batch_size=32, 133 | num_workers=multiprocessing.cpu_count(), pin_memory=True) 134 | 135 | # Get a mapping from index names to index features 136 | # name_to_feat = dict(zip(index_names, index_features)) 137 | 138 | # Initialize pairs_id, predicted_features, group_members and reference_names 139 | pairs_id = [] 140 | predicted_features = torch.empty((0, 768)).to(device, non_blocking=True) 141 | group_members = [] 142 | reference_names = [] 143 | 144 | for batch_pairs_id, batch_reference_names, reference_images, captions, batch_group_members in tqdm( 145 | relative_test_loader): # Load data 146 | reference_images = reference_images.to(device) 147 | batch_group_members = np.array(batch_group_members).T.tolist() 148 | # Compute the predicted features 149 | with torch.no_grad(): 150 | text_feats = blip_textual.extract_features({"text_input":captions}, 151 | mode="text").text_embeds[:,0,:] 152 | reference_feats = blip_multimodal.extract_features({"image":reference_images, 153 | "text_input":captions}).multimodal_embeds[:,0,:] 154 | batch_predicted_features = combining_function(text_feats, reference_feats) 155 | 156 | predicted_features = torch.vstack((predicted_features, F.normalize(batch_predicted_features, dim=-1))) 157 | group_members.extend(batch_group_members) 158 | reference_names.extend(batch_reference_names) 159 | pairs_id.extend(batch_pairs_id) 160 | 161 | return predicted_features, reference_names, group_members, pairs_id 162 | 163 | 164 | def main(): 165 | parser = ArgumentParser() 166 | parser.add_argument("--submission-name", type=str, required=True, help="submission file name") 167 | parser.add_argument("--combining-function", type=str, required=True, 168 | help="Which combining function use, should be in ['combiner', 'sum']") 169 | parser.add_argument("--blip2-textual-path", type=str, help="Path to the fine-tuned BLIP2 model") 170 | parser.add_argument("--blip2-visual-path", type=str, help="Path to the fine-tuned BLIP2 model") 171 | parser.add_argument("--blip2-multimodal-path", type=str, help="Path to the fine-tuned BLIP2 model") 172 | parser.add_argument("--target-ratio", default=1.25, type=float, help="TargetPad target ratio") 173 | parser.add_argument("--transform", default="targetpad", type=str, 174 | help="Preprocess pipeline, should be in ['clip', 'squarepad', 'targetpad'] ") 175 | args = parser.parse_args() 176 | 177 | blip_textual_path = args.blip2_textual_path 178 | blip_visual_path = args.blip2_visual_path 179 | blip_multimodal_path = args.blip2_multimodal_path 180 | 181 | blip_textual = load_model(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device) 182 | x_saved_state_dict = torch.load(blip_textual_path, map_location=device) 183 | 184 | blip_textual.load_state_dict(x_saved_state_dict["Blip2Qformer"]) 185 | 186 | blip_multimodal = load_model(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device) 187 | r_saved_state_dict = torch.load(blip_multimodal_path, map_location=device) 188 | blip_multimodal.load_state_dict(r_saved_state_dict["Blip2Qformer"]) 189 | 190 | blip_visual = load_model(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device) 191 | t_saved_state_dict = torch.load(blip_visual_path, map_location=device) 192 | blip_visual.load_state_dict(t_saved_state_dict["Blip2Qformer"]) 193 | input_dim = 224 194 | 195 | if args.transform == 'targetpad': 196 | print('Target pad preprocess pipeline is used') 197 | preprocess = targetpad_transform(args.target_ratio, input_dim) 198 | elif args.transform == 'squarepad': 199 | print('Square pad preprocess pipeline is used') 200 | preprocess = squarepad_transform(input_dim) 201 | 202 | if args.combining_function.lower() == 'sum': 203 | combining_function = element_wise_sum 204 | else: 205 | raise ValueError("combiner_path should be in ['sum', 'combiner']") 206 | 207 | generate_cirr_test_submissions(combining_function, args.submission_name, blip_textual, blip_multimodal, blip_visual, preprocess) 208 | 209 | 210 | if __name__ == '__main__': 211 | main() 212 | -------------------------------------------------------------------------------- /src/cirr_test_submission.py: -------------------------------------------------------------------------------- 1 | import json 2 | import multiprocessing 3 | from argparse import ArgumentParser 4 | from operator import itemgetter 5 | from pathlib import Path 6 | from typing import List, Tuple, Dict 7 | 8 | import clip 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | from clip.model import CLIP 13 | from torch.utils.data import DataLoader 14 | from tqdm import tqdm 15 | 16 | from combiner_train import extract_index_features 17 | from data_utils import CIRRDataset, targetpad_transform, squarepad_transform, base_path 18 | from combiner import Combiner 19 | from utils import element_wise_sum, device 20 | 21 | 22 | def generate_cirr_test_submissions(combining_function: callable, file_name: str, clip_model: CLIP, 23 | preprocess: callable): 24 | """ 25 | Generate and save CIRR test submission files to be submitted to evaluation server 26 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined 27 | features 28 | :param file_name: file_name of the submission 29 | :param clip_model: CLIP model 30 | :param preprocess: preprocess pipeline 31 | """ 32 | 33 | clip_model = clip_model.float().eval() 34 | 35 | # Define the dataset and extract index features 36 | classic_test_dataset = CIRRDataset('test1', 'classic', preprocess) 37 | index_features, index_names = extract_index_features(classic_test_dataset, clip_model) 38 | relative_test_dataset = CIRRDataset('test1', 'relative', preprocess) 39 | 40 | # Generate test prediction dicts for CIRR 41 | pairid_to_predictions, pairid_to_group_predictions = generate_cirr_test_dicts(relative_test_dataset, clip_model, 42 | index_features, index_names, 43 | combining_function) 44 | 45 | submission = { 46 | 'version': 'rc2', 47 | 'metric': 'recall' 48 | } 49 | group_submission = { 50 | 'version': 'rc2', 51 | 'metric': 'recall_subset' 52 | } 53 | 54 | submission.update(pairid_to_predictions) 55 | group_submission.update(pairid_to_group_predictions) 56 | 57 | # Define submission path 58 | submissions_folder_path = base_path / "submission" / 'CIRR' 59 | submissions_folder_path.mkdir(exist_ok=True, parents=True) 60 | 61 | print(f"Saving CIRR test predictions") 62 | with open(submissions_folder_path / f"recall_submission_{file_name}.json", 'w+') as file: 63 | json.dump(submission, file, sort_keys=True) 64 | 65 | with open(submissions_folder_path / f"recall_subset_submission_{file_name}.json", 'w+') as file: 66 | json.dump(group_submission, file, sort_keys=True) 67 | 68 | 69 | def generate_cirr_test_dicts(relative_test_dataset: CIRRDataset, clip_model: CLIP, index_features: torch.tensor, 70 | index_names: List[str], combining_function: callable) \ 71 | -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]: 72 | """ 73 | Compute test prediction dicts for CIRR dataset 74 | :param relative_test_dataset: CIRR test dataset in relative mode 75 | :param clip_model: CLIP model 76 | :param index_features: test index features 77 | :param index_names: test index names 78 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined 79 | features 80 | :return: Top50 global and Top3 subset prediction for each query (reference_name, caption) 81 | """ 82 | 83 | # Generate predictions 84 | predicted_features, reference_names, group_members, pairs_id = \ 85 | generate_cirr_test_predictions(clip_model, relative_test_dataset, combining_function, index_names, 86 | index_features) 87 | 88 | print(f"Compute CIRR prediction dicts") 89 | 90 | # Normalize the index features 91 | index_features = F.normalize(index_features, dim=-1).float() 92 | 93 | # Compute the distances and sort the results 94 | distances = 1 - predicted_features @ index_features.T 95 | sorted_indices = torch.argsort(distances, dim=-1).cpu() 96 | sorted_index_names = np.array(index_names)[sorted_indices] 97 | 98 | # Delete the reference image from the results 99 | reference_mask = torch.tensor( 100 | sorted_index_names != np.repeat(np.array(reference_names), len(index_names)).reshape(len(sorted_index_names), 101 | -1)) 102 | sorted_index_names = sorted_index_names[reference_mask].reshape(sorted_index_names.shape[0], 103 | sorted_index_names.shape[1] - 1) 104 | # Compute the subset predictions 105 | group_members = np.array(group_members) 106 | group_mask = (sorted_index_names[..., None] == group_members[:, None, :]).sum(-1).astype(bool) 107 | sorted_group_names = sorted_index_names[group_mask].reshape(sorted_index_names.shape[0], -1) 108 | 109 | # Generate prediction dicts 110 | pairid_to_predictions = {str(int(pair_id)): prediction[:50].tolist() for (pair_id, prediction) in 111 | zip(pairs_id, sorted_index_names)} 112 | pairid_to_group_predictions = {str(int(pair_id)): prediction[:3].tolist() for (pair_id, prediction) in 113 | zip(pairs_id, sorted_group_names)} 114 | 115 | return pairid_to_predictions, pairid_to_group_predictions 116 | 117 | 118 | def generate_cirr_test_predictions(clip_model: CLIP, relative_test_dataset: CIRRDataset, combining_function: callable, 119 | index_names: List[str], index_features: torch.tensor) -> \ 120 | Tuple[torch.tensor, List[str], List[List[str]], List[str]]: 121 | """ 122 | Compute CIRR predictions on the test set 123 | :param clip_model: CLIP model 124 | :param relative_test_dataset: CIRR test dataset in relative mode 125 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined 126 | features 127 | :param index_features: test index features 128 | :param index_names: test index names 129 | 130 | :return: predicted_features, reference_names, group_members and pairs_id 131 | """ 132 | print(f"Compute CIRR test predictions") 133 | 134 | relative_test_loader = DataLoader(dataset=relative_test_dataset, batch_size=32, 135 | num_workers=multiprocessing.cpu_count(), pin_memory=True) 136 | 137 | # Get a mapping from index names to index features 138 | name_to_feat = dict(zip(index_names, index_features)) 139 | 140 | # Initialize pairs_id, predicted_features, group_members and reference_names 141 | pairs_id = [] 142 | predicted_features = torch.empty((0, clip_model.visual.output_dim)).to(device, non_blocking=True) 143 | group_members = [] 144 | reference_names = [] 145 | 146 | for batch_pairs_id, batch_reference_names, captions, batch_group_members in tqdm( 147 | relative_test_loader): # Load data 148 | text_inputs = clip.tokenize(captions, context_length=77).to(device) 149 | batch_group_members = np.array(batch_group_members).T.tolist() 150 | 151 | # Compute the predicted features 152 | with torch.no_grad(): 153 | text_features = clip_model.encode_text(text_inputs) 154 | # Check whether a single element is in the batch due to the exception raised by torch.stack when used with 155 | # a single tensor 156 | if text_features.shape[0] == 1: 157 | reference_image_features = itemgetter(*batch_reference_names)(name_to_feat).unqueeze(0) 158 | else: 159 | reference_image_features = torch.stack(itemgetter(*batch_reference_names)( 160 | name_to_feat)) # To avoid unnecessary computation retrieve the reference image features directly from the index features 161 | batch_predicted_features = combining_function(reference_image_features, text_features) 162 | 163 | predicted_features = torch.vstack((predicted_features, F.normalize(batch_predicted_features, dim=-1))) 164 | group_members.extend(batch_group_members) 165 | reference_names.extend(batch_reference_names) 166 | pairs_id.extend(batch_pairs_id) 167 | 168 | return predicted_features, reference_names, group_members, pairs_id 169 | 170 | 171 | def main(): 172 | parser = ArgumentParser() 173 | parser.add_argument("--submission-name", type=str, required=True, help="submission file name") 174 | parser.add_argument("--combining-function", type=str, required=True, 175 | help="Which combining function use, should be in ['combiner', 'sum']") 176 | parser.add_argument("--combiner-path", type=str, help="path to trained Combiner") 177 | parser.add_argument("--projection-dim", default=640 * 4, type=int, help='Combiner projection dim') 178 | parser.add_argument("--hidden-dim", default=640 * 8, type=int, help="Combiner hidden dim") 179 | parser.add_argument("--clip-model-name", default="RN50x4", type=str, help="CLIP model to use, e.g 'RN50', 'RN50x4'") 180 | parser.add_argument("--clip-model-path", type=Path, help="Path to the fine-tuned CLIP model") 181 | parser.add_argument("--target-ratio", default=1.25, type=float, help="TargetPad target ratio") 182 | parser.add_argument("--transform", default="targetpad", type=str, 183 | help="Preprocess pipeline, should be in ['clip', 'squarepad', 'targetpad'] ") 184 | args = parser.parse_args() 185 | clip_model, clip_preprocess = clip.load(args.clip_model_name, device=device, jit=False) 186 | input_dim = clip_model.visual.input_resolution 187 | feature_dim = clip_model.visual.output_dim 188 | 189 | if args.clip_model_path: 190 | print('Trying to load the CLIP model') 191 | saved_state_dict = torch.load(args.clip_model_path, map_location=device) 192 | clip_model.load_state_dict(saved_state_dict["CLIP"]) 193 | print('CLIP model loaded successfully') 194 | 195 | if args.transform == 'targetpad': 196 | print('Target pad preprocess pipeline is used') 197 | preprocess = targetpad_transform(args.target_ratio, input_dim) 198 | elif args.transform == 'squarepad': 199 | print('Square pad preprocess pipeline is used') 200 | preprocess = squarepad_transform(input_dim) 201 | else: 202 | print('CLIP default preprocess pipeline is used') 203 | preprocess = clip_preprocess 204 | 205 | if args.combining_function.lower() == 'sum': 206 | if args.combiner_path: 207 | print("Be careful, you are using the element-wise sum as combining_function but you have also passed a path" 208 | " to a trained Combiner. Such Combiner will not be used") 209 | combining_function = element_wise_sum 210 | elif args.combining_function.lower() == 'combiner': 211 | combiner = Combiner(feature_dim, args.projection_dim, args.hidden_dim).to(device) 212 | saved_state_dict = torch.load(args.combiner_path, map_location=device) 213 | combiner.load_state_dict(saved_state_dict["Combiner"]) 214 | combiner.eval() 215 | combining_function = combiner.combine_features 216 | else: 217 | raise ValueError("combiner_path should be in ['sum', 'combiner']") 218 | 219 | generate_cirr_test_submissions(combining_function, args.submission_name, clip_model, preprocess) 220 | 221 | 222 | if __name__ == '__main__': 223 | main() 224 | -------------------------------------------------------------------------------- /src/cirr_artemis_submission_blip2.py: -------------------------------------------------------------------------------- 1 | import json 2 | import multiprocessing 3 | from argparse import ArgumentParser 4 | from operator import itemgetter 5 | from pathlib import Path 6 | from typing import List, Tuple, Dict 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.data import DataLoader 12 | from tqdm import tqdm 13 | 14 | from data_utils import CIRRDataset_val_submission, targetpad_transform, squarepad_transform, base_path 15 | from utils import element_wise_sum, device, extract_index_features_blip2 16 | from lavis.models import load_model 17 | from artemis import Artemis 18 | 19 | 20 | def generate_cirr_test_submissions(artemis, file_name: str, blip_textual, blip_multimodal, 21 | blip_visual, preprocess: callable): 22 | """ 23 | Generate and save CIRR test submission files to be submitted to evaluation server 24 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined 25 | features 26 | :param file_name: file_name of the submission 27 | :param clip_model: CLIP model 28 | :param preprocess: preprocess pipeline 29 | """ 30 | blip_textual = blip_textual.float().eval() 31 | blip_multimodal = blip_multimodal.float().eval() 32 | blip_visual = blip_visual.float().eval() 33 | artemis = artemis.float().eval() 34 | 35 | # Define the dataset and extract index features 36 | classic_test_dataset = CIRRDataset_val_submission('test1', 'classic', preprocess) 37 | index_features, index_names = extract_index_features_blip2(classic_test_dataset, blip_visual) 38 | relative_test_dataset = CIRRDataset_val_submission('test1', 'relative', preprocess) 39 | 40 | # Generate test prediction dicts for CIRR 41 | pairid_to_predictions, pairid_to_group_predictions = generate_cirr_test_dicts(relative_test_dataset, blip_textual, blip_multimodal, 42 | index_features, index_names, 43 | artemis) 44 | 45 | submission = { 46 | 'version': 'rc2', 47 | 'metric': 'recall' 48 | } 49 | group_submission = { 50 | 'version': 'rc2', 51 | 'metric': 'recall_subset' 52 | } 53 | 54 | submission.update(pairid_to_predictions) 55 | group_submission.update(pairid_to_group_predictions) 56 | 57 | # Define submission path 58 | submissions_folder_path = base_path / "submission" / 'CIRR' 59 | submissions_folder_path.mkdir(exist_ok=True, parents=True) 60 | 61 | print(f"Saving CIRR test predictions") 62 | with open(submissions_folder_path / f"retrieval_{file_name}.json", 'w+') as file: 63 | json.dump(submission, file, sort_keys=True) 64 | 65 | with open(submissions_folder_path / f"retrieval_subset_submission_{file_name}.json", 'w+') as file: 66 | json.dump(group_submission, file, sort_keys=True) 67 | 68 | 69 | def generate_cirr_test_dicts(relative_test_dataset: CIRRDataset_val_submission, blip_textual, blip_multimodal, index_features: torch.tensor, 70 | index_names: List[str], artemis) \ 71 | -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]: 72 | """ 73 | Compute test prediction dicts for CIRR dataset 74 | :param relative_test_dataset: CIRR test dataset in relative mode 75 | :param clip_model: CLIP model 76 | :param index_features: test index features 77 | :param index_names: test index names 78 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined 79 | features 80 | :return: Top50 global and Top3 subset prediction for each query (reference_name, caption) 81 | """ 82 | 83 | # Generate predictions 84 | artemis_scores, reference_names, group_members, pairs_id = \ 85 | generate_cirr_test_predictions(blip_textual, blip_multimodal, relative_test_dataset, artemis, index_names, 86 | index_features) 87 | 88 | print(f"Compute CIRR prediction dicts") 89 | 90 | # Normalize the index features 91 | index_features = F.normalize(index_features, dim=-1).float() 92 | 93 | # Compute the distances and sort the results 94 | distances = 1 - artemis_scores 95 | sorted_indices = torch.argsort(distances, dim=-1).cpu() 96 | sorted_index_names = np.array(index_names)[sorted_indices] 97 | 98 | # Delete the reference image from the results 99 | reference_mask = torch.tensor( 100 | sorted_index_names != np.repeat(np.array(reference_names), len(index_names)).reshape(len(sorted_index_names), 101 | -1)) 102 | sorted_index_names = sorted_index_names[reference_mask].reshape(sorted_index_names.shape[0], 103 | sorted_index_names.shape[1] - 1) 104 | # Compute the subset predictions 105 | group_members = np.array(group_members) 106 | group_mask = (sorted_index_names[..., None] == group_members[:, None, :]).sum(-1).astype(bool) 107 | sorted_group_names = sorted_index_names[group_mask].reshape(sorted_index_names.shape[0], -1) 108 | 109 | # Generate prediction dicts for test split 110 | pairid_to_predictions = {str(int(pair_id)): prediction[:50].tolist() for (pair_id, prediction) in 111 | zip(pairs_id, sorted_index_names)} 112 | pairid_to_group_predictions = {str(int(pair_id)): prediction[:3].tolist() for (pair_id, prediction) in 113 | zip(pairs_id, sorted_group_names)} 114 | 115 | return pairid_to_predictions, pairid_to_group_predictions 116 | 117 | 118 | def generate_cirr_test_predictions(blip_textual, blip_multimodal, relative_test_dataset: CIRRDataset_val_submission, artemis, 119 | index_names: List[str], index_features: torch.tensor) -> \ 120 | Tuple[torch.tensor, List[str], List[List[str]], List[str]]: 121 | """ 122 | Compute CIRR predictions on the test set 123 | :param clip_model: CLIP model 124 | :param relative_test_dataset: CIRR test dataset in relative mode 125 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined 126 | features 127 | :param index_features: test index features 128 | :param index_names: test index names 129 | 130 | :return: predicted_features, reference_names, group_members and pairs_id 131 | """ 132 | print(f"Compute CIRR test predictions") 133 | 134 | relative_test_loader = DataLoader(dataset=relative_test_dataset, batch_size=32, 135 | num_workers=multiprocessing.cpu_count(), pin_memory=True) 136 | 137 | # Get a mapping from index names to index features 138 | # name_to_feat = dict(zip(index_names, index_features)) 139 | 140 | # Initialize pairs_id, predicted_features, group_members and reference_names 141 | pairs_id = [] 142 | artemis_scores = torch.empty((0, len(index_names))).to(device, non_blocking=True) 143 | group_members = [] 144 | reference_names = [] 145 | 146 | for batch_pairs_id, batch_reference_names, reference_images, captions, batch_group_members in tqdm( 147 | relative_test_loader): # Load data 148 | reference_images = reference_images.to(device) 149 | batch_group_members = np.array(batch_group_members).T.tolist() 150 | # Compute the predicted features 151 | with torch.no_grad(): 152 | text_feats = blip_textual.extract_features({"text_input":captions}, 153 | mode="text").text_embeds[:,0,:] 154 | reference_feats = blip_multimodal.extract_features({"image":reference_images, 155 | "text_input":captions}).multimodal_embeds[:,0,:] 156 | batch_artemis_score = torch.empty((0, len(index_names))).to(device, non_blocking=True) 157 | for i in range(reference_feats.shape[0]): 158 | one_artemis_score = artemis.compute_score_artemis(reference_feats[i].unsqueeze(0), text_feats[i].unsqueeze(0), index_features) 159 | batch_artemis_score = torch.vstack((batch_artemis_score, one_artemis_score)) 160 | 161 | artemis_scores = torch.vstack((artemis_scores, F.normalize(batch_artemis_score, dim=-1))) 162 | group_members.extend(batch_group_members) 163 | reference_names.extend(batch_reference_names) 164 | pairs_id.extend(batch_pairs_id) 165 | # target_hard_names.extend(target_hard_name) 166 | # captions_list.extend(captions) 167 | 168 | return artemis_scores, reference_names, group_members, pairs_id 169 | #return artemis_scores, reference_names, group_members, pairs_id, target_hard_names, captions_list 170 | 171 | 172 | def main(): 173 | parser = ArgumentParser() 174 | parser.add_argument("--submission-name", type=str, required=True, help="submission file name") 175 | parser.add_argument("--combining-function", type=str, required=True, 176 | help="Which combining function use, should be in ['combiner', 'sum']") 177 | parser.add_argument("--blip2-textual-path", type=str, help="Path to the fine-tuned BLIP2 model") 178 | parser.add_argument("--blip2-visual-path", type=str, help="Path to the fine-tuned BLIP2 model") 179 | parser.add_argument("--blip2-multimodal-path", type=str, help="Path to the fine-tuned BLIP2 model") 180 | parser.add_argument("--artemis-path", type=str, help="Path to the fine-tuned BLIP2 model") 181 | parser.add_argument("--target-ratio", default=1.25, type=float, help="TargetPad target ratio") 182 | parser.add_argument("--transform", default="targetpad", type=str, 183 | help="Preprocess pipeline, should be in ['clip', 'squarepad', 'targetpad'] ") 184 | args = parser.parse_args() 185 | 186 | blip_textual_path = args.blip2_textual_path 187 | blip_visual_path = args.blip2_visual_path 188 | blip_multimodal_path = args.blip2_multimodal_path 189 | artemis_path = args.artemis_path 190 | 191 | blip_textual = load_model(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device) 192 | x_saved_state_dict = torch.load(blip_textual_path, map_location=device) 193 | 194 | blip_textual.load_state_dict(x_saved_state_dict["Blip2Qformer"]) 195 | 196 | blip_multimodal = load_model(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device) 197 | r_saved_state_dict = torch.load(blip_multimodal_path, map_location=device) 198 | blip_multimodal.load_state_dict(r_saved_state_dict["Blip2Qformer"]) 199 | 200 | blip_visual = load_model(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device) 201 | t_saved_state_dict = torch.load(blip_visual_path, map_location=device) 202 | blip_visual.load_state_dict(t_saved_state_dict["Blip2Qformer"]) 203 | 204 | artemis = Artemis(768).to(device) 205 | a_saved_state_dict = torch.load(artemis_path, map_location=device) 206 | artemis.load_state_dict(a_saved_state_dict["Artemis"]) 207 | 208 | 209 | input_dim = 224 210 | if args.transform == 'targetpad': 211 | print('Target pad preprocess pipeline is used') 212 | preprocess = targetpad_transform(args.target_ratio, input_dim) 213 | elif args.transform == 'squarepad': 214 | print('Square pad preprocess pipeline is used') 215 | preprocess = squarepad_transform(input_dim) 216 | 217 | if args.combining_function.lower() == 'sum': 218 | combining_function = element_wise_sum 219 | else: 220 | raise ValueError("combiner_path should be in ['sum', 'combiner']") 221 | 222 | generate_cirr_test_submissions(artemis, args.submission_name, blip_textual, blip_multimodal, blip_visual, preprocess) 223 | 224 | 225 | if __name__ == '__main__': 226 | main() 227 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import random 3 | from pathlib import Path 4 | from typing import Union, Tuple, List 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from clip.model import CLIP 9 | from torch import nn 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from data_utils import CIRRDataset, FashionIQDataset 13 | 14 | if torch.cuda.is_available(): 15 | device = torch.device("cuda") 16 | else: 17 | device = torch.device("cpu") 18 | 19 | 20 | def extract_features_fusion_blip(blip_textual, multimodal_embeds, captions): 21 | self = blip_textual.Qformer.bert 22 | embeddings = self.embeddings 23 | encoder = self.encoder 24 | # pooler = blip_textual.Qformer.bert.pooler 25 | 26 | text_tokens = blip_textual.tokenizer( 27 | captions, 28 | padding="max_length", 29 | truncation=True, 30 | max_length=77, 31 | return_tensors="pt", 32 | ).to(device) 33 | 34 | text_atts = text_tokens.attention_mask 35 | query_atts = torch.ones(multimodal_embeds.size()[:-1], dtype=torch.long).to(device) 36 | # print("query_atts:", query_atts.shape) 37 | # print("text_atts:", text_atts.shape) 38 | attention_mask = torch.cat([query_atts, text_atts], dim=1) 39 | # print("attention_mask:",attention_mask.shape) 40 | # head_mask = blip_textual.Qformer.bert.get_head_mask(head_mask, 41 | # blip_textual.Qformer.bert.config.num_hidden_layers) 42 | 43 | embedding_output = embeddings( 44 | input_ids=text_tokens.input_ids, 45 | query_embeds=multimodal_embeds,) 46 | 47 | input_shape = embedding_output.size()[:-1] 48 | extended_attention_mask = self.get_extended_attention_mask( 49 | attention_mask, input_shape, device, False 50 | ) 51 | # print(extended_attention_mask.shape) 52 | head_mask = self.get_head_mask(None, self.config.num_hidden_layers) 53 | 54 | encoder_outputs = encoder( 55 | embedding_output, 56 | attention_mask=extended_attention_mask, 57 | head_mask=head_mask, 58 | return_dict=True, 59 | ) 60 | sequence_output = encoder_outputs[0] 61 | return sequence_output 62 | 63 | 64 | # def extract_index_features_blip(dataset: Union[CIRRDataset, FashionIQDataset], model: CLIP, vision_layer) -> \ 65 | def extract_index_features_blip2(dataset: Union[CIRRDataset, FashionIQDataset], blip_model) -> \ 66 | Tuple[torch.tensor, List[str]]: 67 | """ 68 | Extract FashionIQ or CIRR index features 69 | :param dataset: FashionIQ or CIRR dataset in 'classic' mode 70 | :param model: CLIP model 71 | :return: a tensor of features and a list of images 72 | """ 73 | feature_dim = 768 74 | classic_val_loader = DataLoader(dataset=dataset, batch_size=32, num_workers=multiprocessing.cpu_count(), 75 | pin_memory=True, collate_fn=collate_fn) 76 | index_features = torch.empty((0,feature_dim)).to(device, non_blocking=True) 77 | index_names = [] 78 | if isinstance(dataset, CIRRDataset): 79 | print(f"extracting CIRR {dataset.split} index features") 80 | elif isinstance(dataset, FashionIQDataset): 81 | print(f"extracting fashionIQ {dataset.dress_types} - {dataset.split} index features") 82 | 83 | for names, images in tqdm(classic_val_loader): 84 | images = images.to(device, non_blocking=True) 85 | with torch.no_grad(): 86 | batch_features = blip_model.extract_features({"image":images}, mode="image").image_embeds[:,0,:] 87 | index_features = torch.vstack((index_features, batch_features)) 88 | index_names.extend(names) 89 | return index_features, index_names 90 | 91 | def extract_index_features_blip1(dataset: Union[CIRRDataset, FashionIQDataset], blip_model) -> \ 92 | Tuple[torch.tensor, List[str]]: 93 | """ 94 | Extract FashionIQ or CIRR index features 95 | :param dataset: FashionIQ or CIRR dataset in 'classic' mode 96 | :param model: CLIP model 97 | :return: a tensor of features and a list of images 98 | """ 99 | feature_dim = 256 100 | classic_val_loader = DataLoader(dataset=dataset, batch_size=32, num_workers=multiprocessing.cpu_count(), 101 | pin_memory=True, collate_fn=collate_fn) 102 | index_features = torch.empty((0,feature_dim)).to(device, non_blocking=True) 103 | index_names = [] 104 | if isinstance(dataset, CIRRDataset): 105 | print(f"extracting CIRR {dataset.split} index features") 106 | elif isinstance(dataset, FashionIQDataset): 107 | print(f"extracting fashionIQ {dataset.dress_types} - {dataset.split} index features") 108 | 109 | for names, images in tqdm(classic_val_loader): 110 | images = images.to(device, non_blocking=True) 111 | with torch.no_grad(): 112 | batch_features = blip_model(images)[:,0,:] 113 | index_features = torch.vstack((index_features, batch_features)) 114 | index_names.extend(names) 115 | return index_features, index_names 116 | 117 | def extract_index_features_blip_feature_extractor(dataset: Union[CIRRDataset, FashionIQDataset], model) -> \ 118 | Tuple[torch.tensor, List[str]]: 119 | """ 120 | Extract FashionIQ or CIRR index features 121 | :param dataset: FashionIQ or CIRR dataset in 'classic' mode 122 | :param model: CLIP model 123 | :return: a tensor of features and a list of images 124 | """ 125 | # feature_dim = model.visual.output_dim 126 | feature_dim = 256 127 | classic_val_loader = DataLoader(dataset=dataset, batch_size=32, num_workers=multiprocessing.cpu_count(), 128 | pin_memory=True, collate_fn=collate_fn) 129 | index_features = torch.empty((0,feature_dim)).to(device, non_blocking=True) 130 | index_names = [] 131 | if isinstance(dataset, CIRRDataset): 132 | print(f"extracting CIRR {dataset.split} index features") 133 | elif isinstance(dataset, FashionIQDataset): 134 | print(f"extracting fashionIQ {dataset.dress_types} - {dataset.split} index features") 135 | for names, images in tqdm(classic_val_loader): 136 | images = images.to(device, non_blocking=True) 137 | with torch.no_grad(): 138 | # visual_encoder = model.visual_encoder.float() 139 | # batch_vision_embeds = visual_encoder(images) 140 | # batch_vision_embeds = model.ln_vision(batch_vision_embeds).float() 141 | # batch_features = model.extract_features({"image":images}, mode="image").image_embeds 142 | 143 | batch_features = model.extract_features({"image":images}, mode="image").image_embeds_proj[:,0,:] 144 | index_features = torch.vstack((index_features, batch_features)) 145 | index_names.extend(names) 146 | return index_features, index_names 147 | 148 | def extract_index_features_clip(dataset: Union[CIRRDataset, FashionIQDataset], clip_model: CLIP) -> \ 149 | Tuple[torch.tensor, List[str]]: 150 | """ 151 | Extract FashionIQ or CIRR index features 152 | :param dataset: FashionIQ or CIRR dataset in 'classic' mode 153 | :param clip_model: CLIP model 154 | :return: a tensor of features and a list of images 155 | """ 156 | feature_dim = clip_model.visual.output_dim 157 | classic_val_loader = DataLoader(dataset=dataset, batch_size=32, num_workers=multiprocessing.cpu_count(), 158 | pin_memory=True, collate_fn=collate_fn) 159 | index_features = torch.empty((0, feature_dim)).to(device, non_blocking=True) 160 | index_names = [] 161 | if isinstance(dataset, CIRRDataset): 162 | print(f"extracting CIRR {dataset.split} index features") 163 | elif isinstance(dataset, FashionIQDataset): 164 | print(f"extracting fashionIQ {dataset.dress_types} - {dataset.split} index features") 165 | for names, images in tqdm(classic_val_loader): 166 | images = images.to(device, non_blocking=True) 167 | with torch.no_grad(): 168 | batch_features = clip_model.encode_image(images)[0] 169 | # print(clip_model.encode_image(images)[1].shape) 170 | index_features = torch.vstack((index_features, batch_features)) 171 | index_names.extend(names) 172 | return index_features, index_names 173 | 174 | def element_wise_sum(image_features: torch.tensor, text_features: torch.tensor) -> torch.tensor: 175 | """ 176 | Normalized element-wise sum of image features and text features 177 | :param image_features: non-normalized image features 178 | :param text_features: non-normalized text features 179 | :return: normalized element-wise sum of image and text features 180 | """ 181 | return F.normalize(image_features + text_features, dim=-1) 182 | 183 | 184 | def generate_randomized_fiq_caption(flattened_captions: List[str]) -> List[str]: 185 | """ 186 | Function which randomize the FashionIQ training captions in four way: (a) cap1 and cap2 (b) cap2 and cap1 (c) cap1 187 | (d) cap2 188 | :param flattened_captions: the list of caption to randomize, note that the length of such list is 2*batch_size since 189 | to each triplet are associated two captions 190 | :return: the randomized caption list (with length = batch_size) 191 | """ 192 | captions = [] 193 | for i in range(0, len(flattened_captions), 2): 194 | random_num = random.random() 195 | if random_num < 0.25: 196 | captions.append( 197 | f"{flattened_captions[i].strip('.?, ').capitalize()} and {flattened_captions[i + 1].strip('.?, ')}") 198 | elif 0.25 < random_num < 0.5: 199 | captions.append( 200 | f"{flattened_captions[i + 1].strip('.?, ').capitalize()} and {flattened_captions[i].strip('.?, ')}") 201 | elif 0.5 < random_num < 0.75: 202 | captions.append(f"{flattened_captions[i].strip('.?, ').capitalize()}") 203 | else: 204 | captions.append(f"{flattened_captions[i + 1].strip('.?, ').capitalize()}") 205 | return captions 206 | def generate_randomized_fiq_caption_blip(flattened_captions: List[str],txt_processors:callable) -> List[str]: 207 | """ 208 | Function which randomize the FashionIQ training captions in four way: (a) cap1 and cap2 (b) cap2 and cap1 (c) cap1 209 | (d) cap2 210 | :param flattened_captions: the list of caption to randomize, note that the length of such list is 2*batch_size since 211 | to each triplet are associated two captions 212 | :return: the randomized caption list (with length = batch_size) 213 | """ 214 | captions = [] 215 | for i in range(0, len(flattened_captions), 2): 216 | random_num = random.random() 217 | caption ='' 218 | if random_num < 0.25: 219 | caption=f"{flattened_captions[i].strip('.?, ').capitalize()} and {flattened_captions[i + 1].strip('.?, ')}" 220 | elif 0.25 < random_num < 0.5: 221 | caption=f"{flattened_captions[i + 1].strip('.?, ').capitalize()} and {flattened_captions[i].strip('.?, ')}" 222 | elif 0.5 < random_num < 0.75: 223 | caption=f"{flattened_captions[i].strip('.?, ').capitalize()}" 224 | else: 225 | caption=f"{flattened_captions[i + 1].strip('.?, ').capitalize()}" 226 | captions.append(txt_processors(caption)) 227 | return captions 228 | 229 | def collate_fn(batch: list): 230 | """ 231 | Discard None images in a batch when using torch DataLoader 232 | :param batch: input_batch 233 | :return: output_batch = input_batch - None_values 234 | """ 235 | batch = list(filter(lambda x: x is not None, batch)) 236 | return torch.utils.data.dataloader.default_collate(batch) 237 | 238 | 239 | def update_train_running_results(train_running_results: dict, loss: torch.tensor, images_in_batch: int): 240 | """ 241 | Update `train_running_results` dict during training 242 | :param train_running_results: logging training dict 243 | :param loss: computed loss for batch 244 | :param images_in_batch: num images in the batch 245 | """ 246 | train_running_results['accumulated_train_loss'] += loss.to('cpu', 247 | non_blocking=True).detach().item() * images_in_batch 248 | train_running_results["images_in_epoch"] += images_in_batch 249 | 250 | 251 | def set_train_bar_description(train_bar, epoch: int, num_epochs: int, train_running_results: dict): 252 | """ 253 | Update tqdm train bar during training 254 | :param train_bar: tqdm training bar 255 | :param epoch: current epoch 256 | :param num_epochs: numbers of epochs 257 | :param train_running_results: logging training dict 258 | """ 259 | train_bar.set_description( 260 | desc=f"[{epoch}/{num_epochs}] " 261 | f"train loss: {train_running_results['accumulated_train_loss'] / train_running_results['images_in_epoch']:.3f} " 262 | ) 263 | 264 | 265 | def save_model(name: str, cur_epoch: int, model_to_save: nn.Module, training_path: Path): 266 | """ 267 | Save the weights of the model during training 268 | :param name: name of the file 269 | :param cur_epoch: current epoch 270 | :param model_to_save: pytorch model to be saved 271 | :param training_path: path associated with the training run 272 | """ 273 | models_path = training_path / "saved_models" 274 | models_path.mkdir(exist_ok=True, parents=True) 275 | model_name = model_to_save.__class__.__name__ 276 | torch.save({ 277 | 'epoch': cur_epoch, 278 | model_name: model_to_save.state_dict(), 279 | }, str(models_path / f'{name}.pt')) 280 | 281 | import math 282 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr, onlyGroup0=False): 283 | """Decay the learning rate""" 284 | lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr 285 | param_group_count = 0 286 | for param_group in optimizer.param_groups: 287 | param_group_count += 1 288 | if param_group_count <= 1 and onlyGroup0: # only vary group0 parameters' learning rate, i.e., exclude the text_proj layer 289 | param_group['lr'] = lr 290 | 291 | @torch.no_grad() 292 | def _momentum_update(model_pairs, momentum): 293 | for model_pair in model_pairs: 294 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 295 | param_m.data = param_m.data * momentum + param.data * (1. - momentum) 296 | 297 | @torch.no_grad() 298 | def _dequeue_and_enqueue(model, target_feats, idx, queue_size): 299 | # gather keys before updating queue 300 | batch_size = target_feats.shape[0] 301 | 302 | ptr = int(model.queue_ptr) 303 | assert queue_size % batch_size == 0 # for simplicity 304 | 305 | # replace the keys at ptr (dequeue and enqueue) 306 | model.target_queue[:, ptr:ptr + batch_size] = target_feats.T 307 | model.idx_queue[:, ptr:ptr + batch_size] = idx.T 308 | ptr = (ptr + batch_size) % queue_size # move pointer 309 | 310 | model.queue_ptr[0] = ptr 311 | 312 | 313 | def l2norm(x): 314 | """L2-normalize each row of x""" 315 | norm = torch.pow(x, 2).sum(dim=-1, keepdim=True).sqrt() 316 | return torch.div(x, norm) 317 | -------------------------------------------------------------------------------- /src/data_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import List 4 | 5 | import PIL 6 | import PIL.Image 7 | import torchvision.transforms.functional as F 8 | from torch.utils.data import Dataset 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | 11 | PIL.Image.MAX_IMAGE_PIXELS=None 12 | base_path = Path(__file__).absolute().parents[2].absolute() 13 | 14 | 15 | def _convert_image_to_rgb(image): 16 | return image.convert("RGB") 17 | 18 | 19 | class SquarePad: 20 | """ 21 | Square pad the input image with zero padding 22 | """ 23 | 24 | def __init__(self, size: int): 25 | """ 26 | For having a consistent preprocess pipeline with CLIP we need to have the preprocessing output dimension as 27 | a parameter 28 | :param size: preprocessing output dimension 29 | """ 30 | self.size = size 31 | 32 | def __call__(self, image): 33 | w, h = image.size 34 | max_wh = max(w, h) 35 | hp = int((max_wh - w) / 2) 36 | vp = int((max_wh - h) / 2) 37 | padding = [hp, vp, hp, vp] 38 | return F.pad(image, padding, 0, 'constant') 39 | 40 | 41 | class TargetPad: 42 | """ 43 | Pad the image if its aspect ratio is above a target ratio. 44 | Pad the image to match such target ratio 45 | """ 46 | 47 | def __init__(self, target_ratio: float, size: int): 48 | """ 49 | :param target_ratio: target ratio 50 | :param size: preprocessing output dimension 51 | """ 52 | self.size = size 53 | self.target_ratio = target_ratio 54 | 55 | def __call__(self, image): 56 | w, h = image.size 57 | actual_ratio = max(w, h) / min(w, h) 58 | if actual_ratio < self.target_ratio: # check if the ratio is above or below the target ratio 59 | return image 60 | scaled_max_wh = max(w, h) / self.target_ratio # rescale the pad to match the target ratio 61 | hp = max(int((scaled_max_wh - w) / 2), 0) 62 | vp = max(int((scaled_max_wh - h) / 2), 0) 63 | padding = [hp, vp, hp, vp] 64 | return F.pad(image, padding, 0, 'constant') 65 | 66 | 67 | def squarepad_transform(dim: int): 68 | """ 69 | CLIP-like preprocessing transform on a square padded image 70 | :param dim: image output dimension 71 | :return: CLIP-like torchvision Compose transform 72 | """ 73 | return Compose([ 74 | SquarePad(dim), 75 | Resize(dim, interpolation=PIL.Image.BICUBIC), 76 | CenterCrop(dim), 77 | _convert_image_to_rgb, 78 | ToTensor(), 79 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 80 | ]) 81 | 82 | 83 | def targetpad_transform(target_ratio: float, dim: int): 84 | """ 85 | CLIP-like preprocessing transform computed after using TargetPad pad 86 | :param target_ratio: target ratio for TargetPad 87 | :param dim: image output dimension 88 | :return: CLIP-like torchvision Compose transform 89 | """ 90 | return Compose([ 91 | TargetPad(target_ratio, dim), 92 | Resize(dim, interpolation=PIL.Image.BICUBIC), 93 | CenterCrop(dim), 94 | _convert_image_to_rgb, 95 | ToTensor(), 96 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 97 | ]) 98 | 99 | 100 | class FashionIQDataset(Dataset): 101 | """ 102 | FashionIQ dataset class which manage FashionIQ data. 103 | The dataset can be used in 'relative' or 'classic' mode: 104 | - In 'classic' mode the dataset yield tuples made of (image_name, image) 105 | - In 'relative' mode the dataset yield tuples made of: 106 | - (reference_image, target_image, image_captions) when split == train 107 | - (reference_name, target_name, image_captions) when split == val 108 | - (reference_name, reference_image, image_captions) when split == test 109 | The dataset manage an arbitrary numbers of FashionIQ category, e.g. only dress, dress+toptee+shirt, dress+shirt... 110 | """ 111 | 112 | def __init__(self, split: str, dress_types: List[str], mode: str, preprocess: callable): 113 | """ 114 | :param split: dataset split, should be in ['test', 'train', 'val'] 115 | :param dress_types: list of fashionIQ category 116 | :param mode: dataset mode, should be in ['relative', 'classic']: 117 | - In 'classic' mode the dataset yield tuples made of (image_name, image) 118 | - In 'relative' mode the dataset yield tuples made of: 119 | - (reference_image, target_image, image_captions) when split == train 120 | - (reference_name, target_name, image_captions) when split == val 121 | - (reference_name, reference_image, image_captions) when split == test 122 | :param preprocess: function which preprocesses the image 123 | """ 124 | self.mode = mode 125 | self.dress_types = dress_types 126 | self.split = split 127 | 128 | if mode not in ['relative', 'classic']: 129 | raise ValueError("mode should be in ['relative', 'classic']") 130 | if split not in ['test', 'train', 'val']: 131 | raise ValueError("split should be in ['test', 'train', 'val']") 132 | for dress_type in dress_types: 133 | if dress_type not in ['dress', 'shirt', 'toptee']: 134 | raise ValueError("dress_type should be in ['dress', 'shirt', 'toptee']") 135 | 136 | self.preprocess = preprocess 137 | 138 | 139 | # get triplets made by (reference_image, target_image, a pair of relative captions) 140 | self.triplets: List[dict] = [] 141 | for dress_type in dress_types: 142 | with open(base_path / 'fashionIQ_dataset' / 'captions' / f'cap.{dress_type}.{split}.json') as f: 143 | self.triplets.extend(json.load(f)) 144 | 145 | # get the image names 146 | self.image_names: list = [] 147 | for dress_type in dress_types: 148 | with open(base_path / 'fashionIQ_dataset' / 'image_splits' / f'split.{dress_type}.{split}.json') as f: 149 | self.image_names.extend(json.load(f)) 150 | 151 | print(f"FashionIQ {split} - {dress_types} dataset in {mode} mode initialized") 152 | for image_name in self.image_names[:]: 153 | file_path = base_path / 'fashionIQ_dataset' / 'image_data' / f"{image_name}.jpg" 154 | 155 | # 使用 Path.exists() 方法检查文件是否存在 156 | if not file_path.exists(): 157 | # 如果文件不存在,从列表中删除该项 158 | self.image_names.remove(image_name) 159 | 160 | def __getitem__(self, index): 161 | try: 162 | if self.mode == 'relative': 163 | image_captions = self.triplets[index]['captions'] 164 | reference_name = self.triplets[index]['candidate'].replace(".jpg", "") 165 | 166 | if self.split == 'train': 167 | reference_image_path = base_path / 'fashionIQ_dataset' / 'image_data' / f"{reference_name}.jpg" 168 | reference_image = self.preprocess(PIL.Image.open(reference_image_path).convert("RGB")) 169 | target_name = self.triplets[index]['target'].replace(".jpg", "") 170 | target_image_path = base_path / 'fashionIQ_dataset' / 'image_data' / f"{target_name}.jpg" 171 | target_image = self.preprocess(PIL.Image.open(target_image_path).convert("RGB")) 172 | return reference_image, target_image, image_captions 173 | 174 | elif self.split == 'val': 175 | reference_image_path = base_path / 'fashionIQ_dataset' / 'image_data' / f"{reference_name}.jpg" 176 | reference_image = self.preprocess(PIL.Image.open(reference_image_path).convert("RGB")) 177 | target_name = self.triplets[index]['target'].replace(".jpg", "") 178 | return reference_image, reference_name, target_name, image_captions 179 | 180 | elif self.split == 'test': 181 | reference_image_path = base_path / 'fashionIQ_dataset' / 'image_data' / f"{reference_name}.jpg" 182 | reference_image = self.preprocess(PIL.Image.open(reference_image_path).convert("RGB")) 183 | return reference_name, reference_image, image_captions 184 | 185 | elif self.mode == 'classic': 186 | image_name = self.image_names[index] 187 | image_path = base_path / 'fashionIQ_dataset' / 'image_data' / f"{image_name}.jpg" 188 | image = self.preprocess(PIL.Image.open(image_path).convert("RGB")) 189 | return image_name, image 190 | 191 | else: 192 | raise ValueError("mode should be in ['relative', 'classic']") 193 | except Exception as e: 194 | print(f"Exception: {e}") 195 | 196 | def __len__(self): 197 | if self.mode == 'relative': 198 | return len(self.triplets) 199 | elif self.mode == 'classic': 200 | return len(self.image_names) 201 | else: 202 | raise ValueError("mode should be in ['relative', 'classic']") 203 | 204 | 205 | class CIRRDataset(Dataset): 206 | """ 207 | CIRR dataset class which manage CIRR data 208 | The dataset can be used in 'relative' or 'classic' mode: 209 | - In 'classic' mode the dataset yield tuples made of (image_name, image) 210 | - In 'relative' mode the dataset yield tuples made of: 211 | - (reference_image, target_image, rel_caption) when split == train 212 | - (reference_name, target_name, rel_caption, group_members) when split == val 213 | - (pair_id, reference_name, rel_caption, group_members) when split == test1 214 | """ 215 | 216 | def __init__(self, split: str, mode: str, preprocess: callable): 217 | """ 218 | :param split: dataset split, should be in ['test', 'train', 'val'] 219 | :param mode: dataset mode, should be in ['relative', 'classic']: 220 | - In 'classic' mode the dataset yield tuples made of (image_name, image) 221 | - In 'relative' mode the dataset yield tuples made of: 222 | - (reference_image, target_image, rel_caption) when split == train 223 | - (reference_name, target_name, rel_caption, group_members) when split == val 224 | - (pair_id, reference_name, rel_caption, group_members) when split == test1 225 | :param preprocess: function which preprocesses the image 226 | """ 227 | self.preprocess = preprocess 228 | # self.preprocess_384 = preprocess_384 229 | self.mode = mode 230 | self.split = split 231 | 232 | if split not in ['test1', 'train', 'val']: 233 | raise ValueError("split should be in ['test1', 'train', 'val']") 234 | if mode not in ['relative', 'classic']: 235 | raise ValueError("mode should be in ['relative', 'classic']") 236 | 237 | # get triplets made by (reference_image, target_image, relative caption) 238 | with open(base_path / 'cirr_datasets' / 'cirr' / 'captions' / f'cap.rc2.{split}.json') as f: 239 | self.triplets = json.load(f) 240 | 241 | # get a mapping from image name to relative path 242 | with open(base_path / 'cirr_datasets' / 'cirr' / 'image_splits' / f'split.rc2.{split}.json') as f: 243 | self.name_to_relpath = json.load(f) 244 | 245 | print(f"CIRR {split} dataset in {mode} mode initialized") 246 | 247 | def __getitem__(self, index): 248 | try: 249 | if self.mode == 'relative': 250 | group_members = self.triplets[index]['img_set']['members'] 251 | reference_name = self.triplets[index]['reference'] 252 | rel_caption = self.triplets[index]['caption'] 253 | 254 | if self.split == 'train': 255 | reference_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[reference_name] 256 | reference_image = self.preprocess(PIL.Image.open(reference_image_path)) 257 | target_hard_name = self.triplets[index]['target_hard'] 258 | target_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[target_hard_name] 259 | target_image = self.preprocess(PIL.Image.open(target_image_path)) 260 | return reference_image, target_image, rel_caption 261 | 262 | elif self.split == 'val': 263 | reference_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[reference_name] 264 | reference_image = self.preprocess(PIL.Image.open(reference_image_path)) 265 | target_hard_name = self.triplets[index]['target_hard'] 266 | # return reference_name, target_hard_name, rel_caption, group_members 267 | return reference_image, reference_name, target_hard_name, rel_caption, group_members 268 | 269 | elif self.split == 'test1': 270 | pair_id = self.triplets[index]['pairid'] 271 | reference_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[reference_name] 272 | reference_image = self.preprocess(PIL.Image.open(reference_image_path)) 273 | return pair_id, reference_name, reference_image, rel_caption, group_members 274 | 275 | elif self.mode == 'classic': 276 | image_name = list(self.name_to_relpath.keys())[index] 277 | image_path = base_path / 'cirr_datasets' / self.name_to_relpath[image_name] 278 | im = PIL.Image.open(image_path) 279 | image = self.preprocess(im) 280 | return image_name, image 281 | 282 | else: 283 | raise ValueError("mode should be in ['relative', 'classic']") 284 | 285 | except Exception as e: 286 | print(f"Exception: {e}") 287 | 288 | def __len__(self): 289 | if self.mode == 'relative': 290 | return len(self.triplets) 291 | elif self.mode == 'classic': 292 | return len(self.name_to_relpath) 293 | else: 294 | raise ValueError("mode should be in ['relative', 'classic']") 295 | 296 | 297 | class CIRRDataset_val_submission(Dataset): 298 | """ 299 | CIRR dataset class which manage CIRR data 300 | The dataset can be used in 'relative' or 'classic' mode: 301 | - In 'classic' mode the dataset yield tuples made of (image_name, image) 302 | - In 'relative' mode the dataset yield tuples made of: 303 | - (reference_image, target_image, rel_caption) when split == train 304 | - (reference_name, target_name, rel_caption, group_members) when split == val 305 | - (pair_id, reference_name, rel_caption, group_members) when split == test1 306 | """ 307 | 308 | def __init__(self, split: str, mode: str, preprocess: callable): 309 | """ 310 | :param split: dataset split, should be in ['test', 'train', 'val'] 311 | :param mode: dataset mode, should be in ['relative', 'classic']: 312 | - In 'classic' mode the dataset yield tuples made of (image_name, image) 313 | - In 'relative' mode the dataset yield tuples made of: 314 | - (reference_image, target_image, rel_caption) when split == train 315 | - (reference_name, target_name, rel_caption, group_members) when split == val 316 | - (pair_id, reference_name, rel_caption, group_members) when split == test1 317 | :param preprocess: function which preprocesses the image 318 | """ 319 | self.preprocess = preprocess 320 | # self.preprocess_384 = preprocess_384 321 | self.mode = mode 322 | self.split = split 323 | 324 | if split not in ['test1', 'train', 'val']: 325 | raise ValueError("split should be in ['test1', 'train', 'val']") 326 | if mode not in ['relative', 'classic']: 327 | raise ValueError("mode should be in ['relative', 'classic']") 328 | 329 | # get triplets made by (reference_image, target_image, relative caption) 330 | with open(base_path / 'cirr_datasets' / 'cirr' / 'captions' / f'cap.rc2.{split}.json') as f: 331 | self.triplets = json.load(f) 332 | 333 | # get a mapping from image name to relative path 334 | with open(base_path / 'cirr_datasets' / 'cirr' / 'image_splits' / f'split.rc2.{split}.json') as f: 335 | self.name_to_relpath = json.load(f) 336 | 337 | print(f"CIRR {split} dataset in {mode} mode initialized") 338 | 339 | def __getitem__(self, index): 340 | try: 341 | if self.mode == 'relative': 342 | group_members = self.triplets[index]['img_set']['members'] 343 | reference_name = self.triplets[index]['reference'] 344 | rel_caption = self.triplets[index]['caption'] 345 | 346 | if self.split == 'train': 347 | reference_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[reference_name] 348 | reference_image = self.preprocess(PIL.Image.open(reference_image_path)) 349 | target_hard_name = self.triplets[index]['target_hard'] 350 | target_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[target_hard_name] 351 | target_image = self.preprocess(PIL.Image.open(target_image_path)) 352 | return reference_image, target_image, rel_caption 353 | 354 | elif self.split == 'val': 355 | pair_id = self.triplets[index]['pairid'] 356 | reference_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[reference_name] 357 | reference_image = self.preprocess(PIL.Image.open(reference_image_path)) 358 | target_hard_name = self.triplets[index]['target_hard'] 359 | # return reference_name, target_hard_name, rel_caption, group_members 360 | return pair_id, reference_name, reference_image, rel_caption, group_members, target_hard_name 361 | 362 | elif self.split == 'test1': 363 | pair_id = self.triplets[index]['pairid'] 364 | reference_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[reference_name] 365 | reference_image = self.preprocess(PIL.Image.open(reference_image_path)) 366 | return pair_id, reference_name, reference_image, rel_caption, group_members 367 | 368 | elif self.mode == 'classic': 369 | image_name = list(self.name_to_relpath.keys())[index] 370 | image_path = base_path / 'cirr_datasets' / self.name_to_relpath[image_name] 371 | im = PIL.Image.open(image_path) 372 | image = self.preprocess(im) 373 | return image_name, image 374 | 375 | else: 376 | raise ValueError("mode should be in ['relative', 'classic']") 377 | 378 | except Exception as e: 379 | print(f"Exception: {e}") 380 | 381 | def __len__(self): 382 | if self.mode == 'relative': 383 | return len(self.triplets) 384 | elif self.mode == 'classic': 385 | return len(self.name_to_relpath) 386 | else: 387 | raise ValueError("mode should be in ['relative', 'classic']") 388 | 389 | 390 | class CIRRDataset_hinge(Dataset): 391 | """ 392 | CIRR dataset class which manage CIRR data 393 | The dataset can be used in 'relative' or 'classic' mode: 394 | - In 'classic' mode the dataset yield tuples made of (image_name, image) 395 | - In 'relative' mode the dataset yield tuples made of: 396 | - (reference_image, target_image, rel_caption) when split == train 397 | - (reference_name, target_name, rel_caption, group_members) when split == val 398 | - (pair_id, reference_name, rel_caption, group_members) when split == test1 399 | """ 400 | 401 | def __init__(self, split: str, mode: str, preprocess: callable): 402 | """ 403 | :param split: dataset split, should be in ['test', 'train', 'val'] 404 | :param mode: dataset mode, should be in ['relative', 'classic']: 405 | - In 'classic' mode the dataset yield tuples made of (image_name, image) 406 | - In 'relative' mode the dataset yield tuples made of: 407 | - (reference_image, target_image, rel_caption) when split == train 408 | - (reference_name, target_name, rel_caption, group_members) when split == val 409 | - (pair_id, reference_name, rel_caption, group_members) when split == test1 410 | :param preprocess: function which preprocesses the image 411 | """ 412 | self.preprocess = preprocess 413 | # self.preprocess_384 = preprocess_384 414 | self.mode = mode 415 | self.split = split 416 | 417 | if split not in ['test1', 'train', 'val']: 418 | raise ValueError("split should be in ['test1', 'train', 'val']") 419 | if mode not in ['relative', 'classic']: 420 | raise ValueError("mode should be in ['relative', 'classic']") 421 | 422 | # get triplets made by (reference_image, target_image, relative caption) 423 | with open(base_path / 'cirr_datasets' / 'cirr' / 'captions' / f'cap.rc2.{split}.json') as f: 424 | self.triplets = json.load(f) 425 | 426 | # get a mapping from image name to relative path 427 | with open(base_path / 'cirr_datasets' / 'cirr' / 'image_splits' / f'split.rc2.{split}.json') as f: 428 | self.name_to_relpath = json.load(f) 429 | 430 | print(f"CIRR {split} dataset in {mode} mode initialized") 431 | 432 | def __getitem__(self, index): 433 | try: 434 | if self.mode == 'relative': 435 | group_members = self.triplets[index]['img_set']['members'] 436 | reference_name = self.triplets[index]['reference'] 437 | rel_caption = self.triplets[index]['caption'] 438 | 439 | if self.split == 'train': 440 | reference_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[reference_name] 441 | reference_image = self.preprocess(PIL.Image.open(reference_image_path)) 442 | target_hard_name = self.triplets[index]['target_hard'] 443 | target_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[target_hard_name] 444 | target_image = self.preprocess(PIL.Image.open(target_image_path)) 445 | return reference_image, target_image, rel_caption 446 | 447 | elif self.split == 'val': 448 | reference_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[reference_name] 449 | reference_image = self.preprocess(PIL.Image.open(reference_image_path)) 450 | target_hard_name = self.triplets[index]['target_hard'] 451 | target_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[target_hard_name] 452 | target_image = self.preprocess(PIL.Image.open(target_image_path)) 453 | # return reference_name, target_hard_name, rel_caption, group_members 454 | return reference_image, reference_name, target_hard_name, rel_caption, group_members, target_image 455 | 456 | elif self.split == 'test1': 457 | pair_id = self.triplets[index]['pairid'] 458 | reference_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[reference_name] 459 | reference_image = self.preprocess(PIL.Image.open(reference_image_path)) 460 | return pair_id, reference_name, reference_image, rel_caption, group_members 461 | 462 | elif self.mode == 'classic': 463 | image_name = list(self.name_to_relpath.keys())[index] 464 | image_path = base_path / 'cirr_datasets' / self.name_to_relpath[image_name] 465 | im = PIL.Image.open(image_path) 466 | image = self.preprocess(im) 467 | return image_name, image 468 | 469 | else: 470 | raise ValueError("mode should be in ['relative', 'classic']") 471 | 472 | except Exception as e: 473 | print(f"Exception: {e}") 474 | 475 | def __len__(self): 476 | if self.mode == 'relative': 477 | return len(self.triplets) 478 | elif self.mode == 'classic': 479 | return len(self.name_to_relpath) 480 | else: 481 | raise ValueError("mode should be in ['relative', 'classic']") 482 | -------------------------------------------------------------------------------- /src/validate.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | from argparse import ArgumentParser 3 | from operator import itemgetter 4 | from pathlib import Path 5 | from statistics import mean 6 | from typing import List, Tuple 7 | 8 | import clip 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | from clip.model import CLIP 13 | from torch.utils.data import DataLoader 14 | from tqdm import tqdm 15 | 16 | from data_utils import squarepad_transform, FashionIQDataset, targetpad_transform, CIRRDataset 17 | from utils import extract_index_features_clip, collate_fn, element_wise_sum, device , extract_index_features_blip2 18 | from lavis.models import load_model 19 | 20 | 21 | def compute_fiq_val_metrics_blip2(relative_val_dataset: FashionIQDataset, blip_textual,blip_multimodal, index_features: torch.tensor, 22 | index_names: List[str], combining_function: callable) -> Tuple[float, float]: 23 | """ 24 | Compute validation metrics on FashionIQ dataset 25 | :param relative_val_dataset: FashionIQ validation dataset in relative mode 26 | :param clip_model: CLIP model 27 | :param index_features: validation index features 28 | :param index_names: validation index names 29 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined 30 | features 31 | :return: the computed validation metrics 32 | """ 33 | 34 | # Generate predictions 35 | predicted_features, target_names = generate_fiq_val_predictions_blip2(blip_textual, blip_multimodal, relative_val_dataset, 36 | combining_function, index_names, index_features) 37 | 38 | print(f"Compute FashionIQ {relative_val_dataset.dress_types} validation metrics") 39 | 40 | # Normalize the index features 41 | index_features = F.normalize(index_features, dim=-1).float() 42 | 43 | # Compute the distances and sort the results 44 | distances = 1 - predicted_features @ index_features.T 45 | sorted_indices = torch.argsort(distances, dim=-1).cpu() 46 | sorted_index_names = np.array(index_names)[sorted_indices] 47 | 48 | # Compute the ground-truth labels wrt the predictions 49 | labels = torch.tensor( 50 | sorted_index_names == np.repeat(np.array(target_names), len(index_names)).reshape(len(target_names), -1)) 51 | assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int()) 52 | 53 | # Compute the metrics 54 | recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100 55 | recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100 56 | 57 | return recall_at10, recall_at50 58 | 59 | def compute_fiq_val_metrics_clip(relative_val_dataset: FashionIQDataset, clip_model: CLIP, index_features: torch.tensor, 60 | index_names: List[str], combining_function: callable) -> Tuple[float, float]: 61 | """ 62 | Compute validation metrics on FashionIQ dataset 63 | :param relative_val_dataset: FashionIQ validation dataset in relative mode 64 | :param clip_model: CLIP model 65 | :param index_features: validation index features 66 | :param index_names: validation index names 67 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined 68 | features 69 | :return: the computed validation metrics 70 | """ 71 | 72 | # Generate predictions 73 | predicted_features, target_names = generate_fiq_val_predictions_clip(clip_model, relative_val_dataset, 74 | combining_function, index_names, index_features) 75 | 76 | print(f"Compute FashionIQ {relative_val_dataset.dress_types} validation metrics") 77 | 78 | # Normalize the index features 79 | index_features = F.normalize(index_features, dim=-1).float() 80 | 81 | # Compute the distances and sort the results 82 | distances = 1 - predicted_features @ index_features.T 83 | sorted_indices = torch.argsort(distances, dim=-1).cpu() 84 | sorted_index_names = np.array(index_names)[sorted_indices] 85 | 86 | # Compute the ground-truth labels wrt the predictions 87 | labels = torch.tensor( 88 | sorted_index_names == np.repeat(np.array(target_names), len(index_names)).reshape(len(target_names), -1)) 89 | assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int()) 90 | 91 | # Compute the metrics 92 | recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100 93 | recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100 94 | 95 | return recall_at10, recall_at50 96 | 97 | 98 | def artemis_compute_fiq_val_metrics(relative_val_dataset: FashionIQDataset, blip_textual,blip_multimodal, index_features: torch.tensor, 99 | index_names: List[str], artemis) -> Tuple[float, float]: 100 | """ 101 | Compute validation metrics on FashionIQ dataset 102 | :param relative_val_dataset: FashionIQ validation dataset in relative mode 103 | :param clip_model: CLIP model 104 | :param index_features: validation index features 105 | :param index_names: validation index names 106 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined 107 | features 108 | :return: the computed validation metrics 109 | """ 110 | 111 | print(f"Compute FashionIQ {relative_val_dataset.dress_types} validation metrics") 112 | 113 | # Normalize the index features 114 | index_features = F.normalize(index_features, dim=-1).float() 115 | print("Compute FashionIQ validation predictions") 116 | 117 | relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32, num_workers=8, 118 | pin_memory=True, collate_fn=collate_fn) 119 | feature_dim = 768 120 | # Initialize predicted features, target_names, group_members and reference_names 121 | artemis.eval() 122 | artemis_scores = torch.empty((0, len(index_names))).to(device, non_blocking=True) 123 | target_names = [] 124 | reference_names = [] 125 | 126 | for reference_images, batch_reference_names, batch_target_names, captions in tqdm( 127 | relative_val_loader): # Load data 128 | reference_images = reference_images.to(device) 129 | 130 | flattened_captions: list = np.array(captions).T.flatten().tolist() 131 | input_captions = [ 132 | f"{flattened_captions[i].strip('.?, ').capitalize()} and {flattened_captions[i + 1].strip('.?, ')}" 133 | for i in range(0, len(flattened_captions), 2)] 134 | 135 | # Compute the predicted features 136 | with torch.no_grad(): 137 | text_feats = blip_textual.extract_features({"text_input":input_captions}, 138 | mode="text").text_embeds[:,0,:] 139 | reference_feats = blip_multimodal.extract_features({"image":reference_images, 140 | "text_input":input_captions}).multimodal_embeds[:,0,:] 141 | batch_artemis_score = torch.empty((0, len(index_names))).to(device, non_blocking=True) 142 | for i in range(reference_feats.shape[0]): 143 | one_artemis_score = artemis.compute_score_artemis(reference_feats[i].unsqueeze(0), text_feats[i].unsqueeze(0), index_features) 144 | batch_artemis_score = torch.vstack((batch_artemis_score, one_artemis_score)) 145 | artemis_scores = torch.vstack((artemis_scores, F.normalize(batch_artemis_score, dim=-1))) 146 | target_names.extend(batch_target_names) 147 | reference_names.extend(batch_reference_names) 148 | 149 | # Compute the distances and sort the results 150 | distances = 1 - artemis_scores 151 | sorted_indices = torch.argsort(distances, dim=-1).cpu() 152 | sorted_index_names = np.array(index_names)[sorted_indices] 153 | 154 | # Compute the ground-truth labels wrt the predictions 155 | labels = torch.tensor( 156 | sorted_index_names == np.repeat(np.array(target_names), len(index_names)).reshape(len(target_names), -1)) 157 | assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int()) 158 | 159 | # Compute the metrics 160 | recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100 161 | recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100 162 | 163 | return recall_at10, recall_at50 164 | 165 | 166 | def generate_fiq_val_predictions_blip2(blip_textual, blip_multimodal, relative_val_dataset: FashionIQDataset, 167 | combining_function: callable, index_names: List[str], index_features: torch.tensor) -> \ 168 | Tuple[torch.tensor, List[str]]: 169 | """ 170 | Compute FashionIQ predictions on the validation set 171 | :param clip_model: CLIP model 172 | :param relative_val_dataset: FashionIQ validation dataset in relative mode 173 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined 174 | features 175 | :param index_features: validation index features 176 | :param index_names: validation index names 177 | :return: predicted features and target names 178 | """ 179 | print(f"Compute FashionIQ {relative_val_dataset.dress_types} validation predictions") 180 | 181 | relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32, 182 | num_workers=multiprocessing.cpu_count(), pin_memory=True, collate_fn=collate_fn, 183 | shuffle=False) 184 | 185 | # Get a mapping from index names to index features 186 | # name_to_feat = dict(zip(index_names, index_features)) 187 | feature_dim = 768 188 | # Initialize predicted features and target names 189 | predicted_features = torch.empty((0, feature_dim)).to(device, non_blocking=True) 190 | target_names = [] 191 | 192 | for reference_images, batch_target_names, captions in tqdm(relative_val_loader): # Load data 193 | reference_images = reference_images.to(device) 194 | # Concatenate the captions in a deterministic way 195 | flattened_captions: list = np.array(captions).T.flatten().tolist() 196 | input_captions = [ 197 | f"{flattened_captions[i].strip('.?, ').capitalize()} and {flattened_captions[i + 1].strip('.?, ')}" 198 | for i in range(0, len(flattened_captions), 2)] 199 | 200 | # Compute the predicted features 201 | with torch.no_grad(): 202 | text_feats = blip_textual.extract_features({"text_input":input_captions}, 203 | mode="text").text_embeds[:,0,:] 204 | reference_feats = blip_multimodal.extract_features({"image":reference_images, 205 | "text_input":input_captions}).multimodal_embeds[:,0,:] 206 | 207 | batch_predicted_features = combining_function(reference_feats, text_feats) 208 | 209 | predicted_features = torch.vstack((predicted_features, F.normalize(batch_predicted_features, dim=-1))) 210 | target_names.extend(batch_target_names) 211 | 212 | return predicted_features, target_names 213 | 214 | 215 | def generate_fiq_val_predictions_clip(clip_model: CLIP, relative_val_dataset: FashionIQDataset, 216 | combining_function: callable, index_names: List[str], index_features: torch.tensor) -> \ 217 | Tuple[torch.tensor, List[str]]: 218 | """ 219 | Compute FashionIQ predictions on the validation set 220 | :param clip_model: CLIP model 221 | :param relative_val_dataset: FashionIQ validation dataset in relative mode 222 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined 223 | features 224 | :param index_features: validation index features 225 | :param index_names: validation index names 226 | :return: predicted features and target names 227 | """ 228 | print(f"Compute FashionIQ {relative_val_dataset.dress_types} validation predictions") 229 | 230 | relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32, 231 | num_workers=multiprocessing.cpu_count(), pin_memory=True, collate_fn=collate_fn, 232 | shuffle=False) 233 | 234 | # Get a mapping from index names to index features 235 | name_to_feat = dict(zip(index_names, index_features)) 236 | 237 | # Initialize predicted features and target names 238 | predicted_features = torch.empty((0, clip_model.visual.output_dim)).to(device, non_blocking=True) 239 | target_names = [] 240 | 241 | for reference_image, reference_names, batch_target_names, captions in tqdm(relative_val_loader): # Load data 242 | 243 | # Concatenate the captions in a deterministic way 244 | flattened_captions: list = np.array(captions).T.flatten().tolist() 245 | input_captions = [ 246 | f"{flattened_captions[i].strip('.?, ').capitalize()} and {flattened_captions[i + 1].strip('.?, ')}" for 247 | i in range(0, len(flattened_captions), 2)] 248 | text_inputs = clip.tokenize(input_captions, context_length=77).to(device, non_blocking=True) 249 | 250 | # Compute the predicted features 251 | with torch.no_grad(): 252 | text_features = clip_model.encode_text(text_inputs)[0] 253 | # Check whether a single element is in the batch due to the exception raised by torch.stack when used with 254 | # a single tensor 255 | if text_features.shape[0] == 1: 256 | reference_image_features = itemgetter(*reference_names)(name_to_feat).unsqueeze(0) 257 | else: 258 | reference_image_features = torch.stack(itemgetter(*reference_names)( 259 | name_to_feat)) # To avoid unnecessary computation retrieve the reference image features directly from the index features 260 | batch_predicted_features = combining_function(reference_image_features, text_features) 261 | 262 | predicted_features = torch.vstack((predicted_features, F.normalize(batch_predicted_features, dim=-1))) 263 | target_names.extend(batch_target_names) 264 | 265 | return predicted_features, target_names 266 | 267 | 268 | def fashioniq_val_retrieval(dress_type: str, combining_function: callable, clip_model: CLIP, preprocess: callable): 269 | """ 270 | Perform retrieval on FashionIQ validation set computing the metrics. To combine the features the `combining_function` 271 | is used 272 | :param dress_type: FashionIQ category on which perform the retrieval 273 | :param combining_function:function which takes as input (image_features, text_features) and outputs the combined 274 | features 275 | :param clip_model: CLIP model 276 | :param preprocess: preprocess pipeline 277 | """ 278 | 279 | clip_model = clip_model.float().eval() 280 | 281 | # Define the validation datasets and extract the index features 282 | classic_val_dataset = FashionIQDataset('val', [dress_type], 'classic', preprocess) 283 | index_features, index_names = extract_index_features(classic_val_dataset, clip_model) 284 | relative_val_dataset = FashionIQDataset('val', [dress_type], 'relative', preprocess) 285 | 286 | return compute_fiq_val_metrics(relative_val_dataset, clip_model, index_features, index_names, 287 | combining_function) 288 | 289 | def artemis_compute_cirr_val_metrics(relative_val_dataset: CIRRDataset, blip_textual, blip_multimodal, index_features: torch.tensor, 290 | index_names: List[str], artemis) -> Tuple[ 291 | float, float, float, float, float, float, float]: 292 | """ 293 | Compute validation metrics on CIRR dataset 294 | :param relative_val_dataset: CIRR validation dataset in relative mode 295 | :param clip_model: CLIP model 296 | :param index_features: validation index features 297 | :param index_names: validation index names 298 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined 299 | features 300 | :return: the computed validation metrics 301 | """ 302 | # Generate predictions 303 | 304 | print("Compute CIRR validation metrics") 305 | 306 | # Normalize the index features 307 | index_features = F.normalize(index_features, dim=-1).float() 308 | 309 | print("Compute CIRR validation predictions") 310 | relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32, num_workers=8, 311 | pin_memory=True, collate_fn=collate_fn) 312 | feature_dim = 768 313 | # Initialize predicted features, target_names, group_members and reference_names 314 | artemis.eval() 315 | 316 | artemis_scores = torch.empty((0, len(index_names))).to(device, non_blocking=True) 317 | target_names = [] 318 | group_members = [] 319 | reference_names = [] 320 | 321 | for reference_images, batch_reference_names, batch_target_names, captions, batch_group_members in tqdm( 322 | relative_val_loader): # Load data 323 | reference_images = reference_images.to(device) 324 | batch_group_members = np.array(batch_group_members).T.tolist() 325 | # Compute the predicted features 326 | with torch.no_grad(): 327 | text_feats = blip_textual.extract_features({"text_input":captions}, 328 | mode="text").text_embeds[:,0,:] 329 | reference_feats = blip_multimodal.extract_features({"image":reference_images, 330 | "text_input":captions}).multimodal_embeds[:,0,:] 331 | batch_artemis_score = torch.empty((0, len(index_names))).to(device, non_blocking=True) 332 | for i in range(reference_feats.shape[0]): 333 | one_artemis_score = artemis.compute_score_artemis(reference_feats[i].unsqueeze(0), text_feats[i].unsqueeze(0), index_features) 334 | batch_artemis_score = torch.vstack((batch_artemis_score, one_artemis_score)) 335 | 336 | artemis_scores = torch.vstack((artemis_scores, F.normalize(batch_artemis_score, dim=-1))) 337 | target_names.extend(batch_target_names) 338 | group_members.extend(batch_group_members) 339 | reference_names.extend(batch_reference_names) 340 | 341 | 342 | # Compute the distances and sort the results 343 | distances = 1 - artemis_scores 344 | sorted_indices = torch.argsort(distances, dim=-1).cpu() 345 | sorted_index_names = np.array(index_names)[sorted_indices] 346 | 347 | # Delete the reference image from the results 348 | reference_mask = torch.tensor( 349 | sorted_index_names != np.repeat(np.array(reference_names), len(index_names)).reshape(len(target_names), -1)) 350 | sorted_index_names = sorted_index_names[reference_mask].reshape(sorted_index_names.shape[0], 351 | sorted_index_names.shape[1] - 1) 352 | # Compute the ground-truth labels wrt the predictions 353 | labels = torch.tensor( 354 | sorted_index_names == np.repeat(np.array(target_names), len(index_names) - 1).reshape(len(target_names), -1)) 355 | 356 | # Compute the subset predictions and ground-truth labels 357 | group_members = np.array(group_members) 358 | group_mask = (sorted_index_names[..., None] == group_members[:, None, :]).sum(-1).astype(bool) 359 | group_labels = labels[group_mask].reshape(labels.shape[0], -1) 360 | 361 | assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int()) 362 | assert torch.equal(torch.sum(group_labels, dim=-1).int(), torch.ones(len(target_names)).int()) 363 | 364 | # Compute the metrics 365 | recall_at1 = (torch.sum(labels[:, :1]) / len(labels)).item() * 100 366 | recall_at5 = (torch.sum(labels[:, :5]) / len(labels)).item() * 100 367 | recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100 368 | recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100 369 | group_recall_at1 = (torch.sum(group_labels[:, :1]) / len(group_labels)).item() * 100 370 | group_recall_at2 = (torch.sum(group_labels[:, :2]) / len(group_labels)).item() * 100 371 | group_recall_at3 = (torch.sum(group_labels[:, :3]) / len(group_labels)).item() * 100 372 | 373 | return group_recall_at1, group_recall_at2, group_recall_at3, recall_at1, recall_at5, recall_at10, recall_at50 374 | 375 | def compute_cirr_val_metrics_clip(relative_val_dataset: CIRRDataset, clip_model: CLIP, index_features: torch.tensor, 376 | index_names: List[str], combining_function: callable) -> Tuple[ 377 | float, float, float, float, float, float, float]: 378 | """ 379 | Compute validation metrics on CIRR dataset 380 | :param relative_val_dataset: CIRR validation dataset in relative mode 381 | :param clip_model: CLIP model 382 | :param index_features: validation index features 383 | :param index_names: validation index names 384 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined 385 | features 386 | :return: the computed validation metrics 387 | """ 388 | # Generate predictions 389 | predicted_features, reference_names, target_names, group_members = \ 390 | generate_cirr_val_predictions_clip(clip_model, relative_val_dataset, combining_function, index_names, index_features) 391 | 392 | print("Compute CIRR validation metrics") 393 | 394 | # Normalize the index features 395 | index_features = F.normalize(index_features, dim=-1).float() 396 | 397 | # Compute the distances and sort the results 398 | distances = 1 - predicted_features @ index_features.T 399 | sorted_indices = torch.argsort(distances, dim=-1).cpu() 400 | sorted_index_names = np.array(index_names)[sorted_indices] 401 | 402 | # Delete the reference image from the results 403 | reference_mask = torch.tensor( 404 | sorted_index_names != np.repeat(np.array(reference_names), len(index_names)).reshape(len(target_names), -1)) 405 | sorted_index_names = sorted_index_names[reference_mask].reshape(sorted_index_names.shape[0], 406 | sorted_index_names.shape[1] - 1) 407 | # Compute the ground-truth labels wrt the predictions 408 | labels = torch.tensor( 409 | sorted_index_names == np.repeat(np.array(target_names), len(index_names) - 1).reshape(len(target_names), -1)) 410 | 411 | # Compute the subset predictions and ground-truth labels 412 | group_members = np.array(group_members) 413 | group_mask = (sorted_index_names[..., None] == group_members[:, None, :]).sum(-1).astype(bool) 414 | group_labels = labels[group_mask].reshape(labels.shape[0], -1) 415 | 416 | assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int()) 417 | assert torch.equal(torch.sum(group_labels, dim=-1).int(), torch.ones(len(target_names)).int()) 418 | 419 | # Compute the metrics 420 | recall_at1 = (torch.sum(labels[:, :1]) / len(labels)).item() * 100 421 | recall_at5 = (torch.sum(labels[:, :5]) / len(labels)).item() * 100 422 | recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100 423 | recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100 424 | group_recall_at1 = (torch.sum(group_labels[:, :1]) / len(group_labels)).item() * 100 425 | group_recall_at2 = (torch.sum(group_labels[:, :2]) / len(group_labels)).item() * 100 426 | group_recall_at3 = (torch.sum(group_labels[:, :3]) / len(group_labels)).item() * 100 427 | 428 | return group_recall_at1, group_recall_at2, group_recall_at3, recall_at1, recall_at5, recall_at10, recall_at50 429 | 430 | def generate_cirr_val_predictions_clip(clip_model: CLIP, relative_val_dataset: CIRRDataset, 431 | combining_function: callable, index_names: List[str], index_features: torch.tensor) -> \ 432 | Tuple[torch.tensor, List[str], List[str], List[List[str]]]: 433 | """ 434 | Compute CIRR predictions on the validation set 435 | :param clip_model: CLIP model 436 | :param relative_val_dataset: CIRR validation dataset in relative mode 437 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined 438 | features 439 | :param index_features: validation index features 440 | :param index_names: validation index names 441 | :return: predicted features, reference names, target names and group members 442 | """ 443 | print("Compute CIRR validation predictions") 444 | relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32, num_workers=8, 445 | pin_memory=True, collate_fn=collate_fn) 446 | 447 | # Get a mapping from index names to index features 448 | name_to_feat = dict(zip(index_names, index_features)) 449 | 450 | # Initialize predicted features, target_names, group_members and reference_names 451 | predicted_features = torch.empty((0, clip_model.visual.output_dim)).to(device, non_blocking=True) 452 | target_names = [] 453 | group_members = [] 454 | reference_names = [] 455 | 456 | for batch_reference_names, batch_target_names, captions, batch_group_members in tqdm( 457 | relative_val_loader): # Load data 458 | text_inputs = clip.tokenize(captions).to(device, non_blocking=True) 459 | batch_group_members = np.array(batch_group_members).T.tolist() 460 | 461 | # Compute the predicted features 462 | with torch.no_grad(): 463 | text_features = clip_model.encode_text(text_inputs)[0] 464 | # Check whether a single element is in the batch due to the exception raised by torch.stack when used with 465 | # a single tensor 466 | if text_features.shape[0] == 1: 467 | reference_image_features = itemgetter(*batch_reference_names)(name_to_feat).unsqueeze(0) 468 | else: 469 | reference_image_features = torch.stack(itemgetter(*batch_reference_names)( 470 | name_to_feat)) # To avoid unnecessary computation retrieve the reference image features directly from the index features 471 | batch_predicted_features = combining_function(reference_image_features, text_features) 472 | 473 | predicted_features = torch.vstack((predicted_features, F.normalize(batch_predicted_features, dim=-1))) 474 | target_names.extend(batch_target_names) 475 | group_members.extend(batch_group_members) 476 | reference_names.extend(batch_reference_names) 477 | 478 | return predicted_features, reference_names, target_names, group_members 479 | 480 | 481 | def compute_cirr_val_metrics_blip2(relative_val_dataset: CIRRDataset, blip_textual, blip_multimodal, index_features: torch.tensor, 482 | index_names: List[str], combining_function: callable) -> Tuple[ 483 | float, float, float, float, float, float, float]: 484 | """ 485 | Compute validation metrics on CIRR dataset 486 | :param relative_val_dataset: CIRR validation dataset in relative mode 487 | :param clip_model: CLIP model 488 | :param index_features: validation index features 489 | :param index_names: validation index names 490 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined 491 | features 492 | :return: the computed validation metrics 493 | """ 494 | # Generate predictions 495 | predicted_features, reference_names, target_names, group_members = \ 496 | generate_cirr_val_predictions_blip2(blip_textual, blip_multimodal, relative_val_dataset, combining_function, index_names, index_features) 497 | 498 | print("Compute CIRR validation metrics") 499 | 500 | # Normalize the index features 501 | index_features = F.normalize(index_features, dim=-1).float() 502 | 503 | # Compute the distances and sort the results 504 | distances = 1 - predicted_features @ index_features.T 505 | sorted_indices = torch.argsort(distances, dim=-1).cpu() 506 | sorted_index_names = np.array(index_names)[sorted_indices] 507 | 508 | # Delete the reference image from the results 509 | reference_mask = torch.tensor( 510 | sorted_index_names != np.repeat(np.array(reference_names), len(index_names)).reshape(len(target_names), -1)) 511 | sorted_index_names = sorted_index_names[reference_mask].reshape(sorted_index_names.shape[0], 512 | sorted_index_names.shape[1] - 1) 513 | # Compute the ground-truth labels wrt the predictions 514 | labels = torch.tensor( 515 | sorted_index_names == np.repeat(np.array(target_names), len(index_names) - 1).reshape(len(target_names), -1)) 516 | 517 | # Compute the subset predictions and ground-truth labels 518 | group_members = np.array(group_members) 519 | group_mask = (sorted_index_names[..., None] == group_members[:, None, :]).sum(-1).astype(bool) 520 | group_labels = labels[group_mask].reshape(labels.shape[0], -1) 521 | 522 | assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int()) 523 | assert torch.equal(torch.sum(group_labels, dim=-1).int(), torch.ones(len(target_names)).int()) 524 | 525 | # Compute the metrics 526 | recall_at1 = (torch.sum(labels[:, :1]) / len(labels)).item() * 100 527 | recall_at5 = (torch.sum(labels[:, :5]) / len(labels)).item() * 100 528 | recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100 529 | recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100 530 | group_recall_at1 = (torch.sum(group_labels[:, :1]) / len(group_labels)).item() * 100 531 | group_recall_at2 = (torch.sum(group_labels[:, :2]) / len(group_labels)).item() * 100 532 | group_recall_at3 = (torch.sum(group_labels[:, :3]) / len(group_labels)).item() * 100 533 | 534 | return group_recall_at1, group_recall_at2, group_recall_at3, recall_at1, recall_at5, recall_at10, recall_at50 535 | 536 | 537 | def generate_cirr_val_predictions_blip2(blip_textual, blip_multimodal, relative_val_dataset: CIRRDataset, 538 | combining_function: callable, index_names: List[str], index_features: torch.tensor) -> \ 539 | Tuple[torch.tensor, List[str], List[str], List[List[str]]]: 540 | """ 541 | Compute CIRR predictions on the validation set 542 | :param clip_model: CLIP model 543 | :param relative_val_dataset: CIRR validation dataset in relative mode 544 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined 545 | features 546 | :param index_features: validation index features 547 | :param index_names: validation index names 548 | :return: predicted features, reference names, target names and group members 549 | """ 550 | print("Compute CIRR validation predictions") 551 | relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32, num_workers=8, 552 | pin_memory=True, collate_fn=collate_fn) 553 | 554 | # Get a mapping from index names to index features 555 | # name_to_feat = dict(zip(index_names, index_features)) 556 | 557 | feature_dim = 768 558 | # Initialize predicted features, target_names, group_members and reference_names 559 | predicted_features = torch.empty((0, feature_dim)).to(device, non_blocking=True) 560 | target_names = [] 561 | group_members = [] 562 | reference_names = [] 563 | 564 | for reference_images, batch_reference_names, batch_target_names, captions, batch_group_members in tqdm( 565 | relative_val_loader): # Load data 566 | reference_images = reference_images.to(device) 567 | batch_group_members = np.array(batch_group_members).T.tolist() 568 | # Compute the predicted features 569 | with torch.no_grad(): 570 | text_feats = blip_textual.extract_features({"text_input":captions}, 571 | mode="text").text_embeds[:,0,:] 572 | reference_feats = blip_multimodal.extract_features({"image":reference_images, 573 | "text_input":captions}).multimodal_embeds[:,0,:] 574 | 575 | batch_predicted_features = combining_function(text_feats, reference_feats) 576 | predicted_features = torch.vstack((predicted_features, F.normalize(batch_predicted_features, dim=-1))) 577 | target_names.extend(batch_target_names) 578 | group_members.extend(batch_group_members) 579 | reference_names.extend(batch_reference_names) 580 | 581 | return predicted_features, reference_names, target_names, group_members 582 | 583 | 584 | def compute_cirr_val_metrics_blip1(relative_val_dataset: CIRRDataset, blip_textual, index_features: torch.tensor, 585 | index_names: List[str], combining_function: callable) -> Tuple[ 586 | float, float, float, float, float, float, float]: 587 | """ 588 | Compute validation metrics on CIRR dataset 589 | :param relative_val_dataset: CIRR validation dataset in relative mode 590 | :param clip_model: CLIP model 591 | :param index_features: validation index features 592 | :param index_names: validation index names 593 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined 594 | features 595 | :return: the computed validation metrics 596 | """ 597 | # Generate predictions 598 | predicted_features, reference_names, target_names, group_members = \ 599 | generate_cirr_val_predictions_blip1(blip_textual, relative_val_dataset, combining_function, index_names, index_features) 600 | 601 | print("Compute CIRR validation metrics") 602 | 603 | # Normalize the index features 604 | index_features = F.normalize(index_features, dim=-1).float() 605 | 606 | # Compute the distances and sort the results 607 | distances = 1 - predicted_features @ index_features.T 608 | sorted_indices = torch.argsort(distances, dim=-1).cpu() 609 | sorted_index_names = np.array(index_names)[sorted_indices] 610 | 611 | # Delete the reference image from the results 612 | reference_mask = torch.tensor( 613 | sorted_index_names != np.repeat(np.array(reference_names), len(index_names)).reshape(len(target_names), -1)) 614 | sorted_index_names = sorted_index_names[reference_mask].reshape(sorted_index_names.shape[0], 615 | sorted_index_names.shape[1] - 1) 616 | # Compute the ground-truth labels wrt the predictions 617 | labels = torch.tensor( 618 | sorted_index_names == np.repeat(np.array(target_names), len(index_names) - 1).reshape(len(target_names), -1)) 619 | 620 | # Compute the subset predictions and ground-truth labels 621 | group_members = np.array(group_members) 622 | group_mask = (sorted_index_names[..., None] == group_members[:, None, :]).sum(-1).astype(bool) 623 | group_labels = labels[group_mask].reshape(labels.shape[0], -1) 624 | 625 | assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int()) 626 | assert torch.equal(torch.sum(group_labels, dim=-1).int(), torch.ones(len(target_names)).int()) 627 | 628 | # Compute the metrics 629 | recall_at1 = (torch.sum(labels[:, :1]) / len(labels)).item() * 100 630 | recall_at5 = (torch.sum(labels[:, :5]) / len(labels)).item() * 100 631 | recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100 632 | recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100 633 | group_recall_at1 = (torch.sum(group_labels[:, :1]) / len(group_labels)).item() * 100 634 | group_recall_at2 = (torch.sum(group_labels[:, :2]) / len(group_labels)).item() * 100 635 | group_recall_at3 = (torch.sum(group_labels[:, :3]) / len(group_labels)).item() * 100 636 | 637 | return group_recall_at1, group_recall_at2, group_recall_at3, recall_at1, recall_at5, recall_at10, recall_at50 638 | 639 | 640 | def generate_cirr_val_predictions_blip1(blip_textual, relative_val_dataset: CIRRDataset, 641 | combining_function: callable, index_names: List[str], index_features: torch.tensor) -> \ 642 | Tuple[torch.tensor, List[str], List[str], List[List[str]]]: 643 | """ 644 | Compute CIRR predictions on the validation set 645 | :param clip_model: CLIP model 646 | :param relative_val_dataset: CIRR validation dataset in relative mode 647 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined 648 | features 649 | :param index_features: validation index features 650 | :param index_names: validation index names 651 | :return: predicted features, reference names, target names and group members 652 | """ 653 | print("Compute CIRR validation predictions") 654 | relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32, num_workers=8, 655 | pin_memory=True, collate_fn=collate_fn) 656 | 657 | # Get a mapping from index names to index features 658 | name_to_feat = dict(zip(index_names, index_features)) 659 | 660 | feature_dim = 256 661 | # Initialize predicted features, target_names, group_members and reference_names 662 | predicted_features = torch.empty((0, feature_dim)).to(device, non_blocking=True) 663 | target_names = [] 664 | group_members = [] 665 | reference_names = [] 666 | 667 | for batch_reference_images, batch_reference_names, batch_target_names, captions, batch_group_members in tqdm( 668 | relative_val_loader): # Load data 669 | 670 | batch_group_members = np.array(batch_group_members).T.tolist() 671 | # Compute the predicted features 672 | with torch.no_grad(): 673 | text_features = blip_textual(captions, max_length=77, device=device)[:,0,:] 674 | if text_features.shape[0] == 1: 675 | reference_image_features = itemgetter(*batch_reference_names)(name_to_feat).unsqueeze(0) 676 | else: 677 | reference_image_features = torch.stack(itemgetter(*batch_reference_names)( 678 | name_to_feat)) # To avoid unnecessary computation retrieve the reference image features directly from the index features 679 | batch_predicted_features = combining_function(reference_image_features, text_features) 680 | 681 | predicted_features = torch.vstack((predicted_features, F.normalize(batch_predicted_features, dim=-1))) 682 | target_names.extend(batch_target_names) 683 | group_members.extend(batch_group_members) 684 | reference_names.extend(batch_reference_names) 685 | 686 | return predicted_features, reference_names, target_names, group_members 687 | 688 | 689 | 690 | 691 | def cirr_val_retrieval(combining_function: callable, blip_visual, blip_textual, blip_multimodal, preprocess): 692 | """ 693 | Perform retrieval on CIRR validation set computing the metrics. To combine the features the `combining_function` 694 | is used 695 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined 696 | features 697 | :param clip_model: CLIP model 698 | :param preprocess: preprocess pipeline 699 | """ 700 | 701 | blip_visual = blip_visual.float().eval() 702 | blip_textual = blip_textual.float().eval() 703 | blip_multimodal = blip_multimodal.float().eval() 704 | # Define the validation datasets and extract the index features 705 | classic_val_dataset = CIRRDataset('val', 'classic', preprocess) 706 | index_features, index_names = extract_index_features_blip2(classic_val_dataset, blip_visual) 707 | relative_val_dataset = CIRRDataset('val', 'relative', preprocess) 708 | 709 | return compute_cirr_val_metrics_blip2(relative_val_dataset, blip_textual, blip_multimodal, index_features, index_names, 710 | combining_function) 711 | 712 | 713 | def main(): 714 | parser = ArgumentParser() 715 | parser.add_argument("--dataset", type=str, required=True, help="should be either 'CIRR' or 'fashionIQ'") 716 | parser.add_argument("--combining-function", type=str, required=True, 717 | help="Which combining function use, should be in ['combiner', 'sum']") 718 | # parser.add_argument("--combiner-path", type=Path, help="path to trained Combiner") 719 | # parser.add_argument("--projection-dim", default=640 * 4, type=int, help='Combiner projection dim') 720 | # parser.add_argument("--hidden-dim", default=640 * 8, type=int, help="Combiner hidden dim") 721 | # parser.add_argument("--clip-model-name", default="RN50x4", type=str, help="CLIP model to use, e.g 'RN50', 'RN50x4'") 722 | # parser.add_argument("--clip-model-path", type=Path, help="Path to the fine-tuned CLIP model") 723 | 724 | parser.add_argument("--blip2-textual-path", type=str, help="Path to the fine-tuned BLIP2 model") 725 | parser.add_argument("--blip2-visual-path", type=str, help="Path to the fine-tuned BLIP2 model") 726 | parser.add_argument("--blip2-multimodal-path", type=str, help="Path to the fine-tuned BLIP2 model") 727 | 728 | parser.add_argument("--target-ratio", default=1.25, type=float, help="TargetPad target ratio") 729 | parser.add_argument("--transform", default="targetpad", type=str, 730 | help="Preprocess pipeline, should be in ['clip', 'squarepad', 'targetpad'] ") 731 | 732 | 733 | args = parser.parse_args() 734 | 735 | # clip_model, clip_preprocess = clip.load(args.clip_model_name, device=device, jit=False) 736 | # input_dim = clip_model.visual.input_resolution 737 | # feature_dim = clip_model.visual.output_dim 738 | 739 | blip_textual_path = args.blip2_textual_path 740 | blip_visual_path = args.blip2_visual_path 741 | blip_multimodal_path = args.blip2_multimodal_path 742 | 743 | blip_textual = load_model(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device) 744 | x_saved_state_dict = torch.load(blip_textual_path, map_location=device) 745 | # print(x_saved_state_dict.keys()) 746 | blip_textual.load_state_dict(x_saved_state_dict["Blip2Qformer"]) 747 | 748 | blip_multimodal = load_model(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device) 749 | r_saved_state_dict = torch.load(blip_multimodal_path, map_location=device) 750 | blip_multimodal.load_state_dict(r_saved_state_dict["Blip2Qformer"]) 751 | 752 | blip_visual = load_model(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device) 753 | t_saved_state_dict = torch.load(blip_visual_path, map_location=device) 754 | blip_visual.load_state_dict(t_saved_state_dict["Blip2Qformer"]) 755 | 756 | 757 | # blip_textual = load_model(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device) 758 | 759 | # blip_multimodal = load_model(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device) 760 | 761 | # blip_visual = load_model(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device) 762 | 763 | # if args.clip_model_path: 764 | # print('Trying to load the CLIP model') 765 | # saved_state_dict = torch.load(args.clip_model_path, map_location=device) 766 | # clip_model.load_state_dict(saved_state_dict["CLIP"]) 767 | # print('CLIP model loaded successfully') 768 | input_dim = 224 769 | if args.transform == 'targetpad': 770 | print('Target pad preprocess pipeline is used') 771 | preprocess = targetpad_transform(args.target_ratio, input_dim) 772 | elif args.transform == 'squarepad': 773 | print('Square pad preprocess pipeline is used') 774 | preprocess = squarepad_transform(input_dim) 775 | # else: 776 | # print('CLIP default preprocess pipeline is used') 777 | # preprocess = clip_preprocess 778 | 779 | if args.combining_function.lower() == 'sum': 780 | combining_function = element_wise_sum 781 | # elif args.combining_function.lower() == 'combiner': 782 | # combiner = Combiner(feature_dim, args.projection_dim, args.hidden_dim).to(device, non_blocking=True) 783 | # state_dict = torch.load(args.combiner_path, map_location=device) 784 | # combiner.load_state_dict(state_dict["Combiner"]) 785 | # combiner.eval() 786 | # combining_function = combiner.combine_features 787 | else: 788 | raise ValueError("combiner_path should be in ['sum', 'combiner']") 789 | 790 | if args.dataset.lower() == 'cirr': 791 | group_recall_at1, group_recall_at2, group_recall_at3, recall_at1, recall_at5, recall_at10, recall_at50 = \ 792 | cirr_val_retrieval(combining_function, blip_visual, blip_textual, blip_multimodal , preprocess) 793 | 794 | print(f"{group_recall_at1 = }") 795 | print(f"{group_recall_at2 = }") 796 | print(f"{group_recall_at3 = }") 797 | print(f"{recall_at1 = }") 798 | print(f"{recall_at5 = }") 799 | print(f"{recall_at10 = }") 800 | print(f"{recall_at50 = }") 801 | 802 | elif args.dataset.lower() == 'fashioniq': 803 | average_recall10_list = [] 804 | average_recall50_list = [] 805 | 806 | shirt_recallat10, shirt_recallat50 = fashioniq_val_retrieval('shirt', combining_function, clip_model, 807 | preprocess) 808 | average_recall10_list.append(shirt_recallat10) 809 | average_recall50_list.append(shirt_recallat50) 810 | 811 | dress_recallat10, dress_recallat50 = fashioniq_val_retrieval('dress', combining_function, clip_model, 812 | preprocess) 813 | average_recall10_list.append(dress_recallat10) 814 | average_recall50_list.append(dress_recallat50) 815 | 816 | toptee_recallat10, toptee_recallat50 = fashioniq_val_retrieval('toptee', combining_function, clip_model, 817 | preprocess) 818 | average_recall10_list.append(toptee_recallat10) 819 | average_recall50_list.append(toptee_recallat50) 820 | 821 | print(f"\n{shirt_recallat10 = }") 822 | print(f"{shirt_recallat50 = }") 823 | 824 | print(f"{dress_recallat10 = }") 825 | print(f"{dress_recallat50 = }") 826 | 827 | print(f"{toptee_recallat10 = }") 828 | print(f"{toptee_recallat50 = }") 829 | 830 | print(f"average recall10 = {mean(average_recall10_list)}") 831 | print(f"average recall50 = {mean(average_recall50_list)}") 832 | else: 833 | raise ValueError("Dataset should be either 'CIRR' or 'FashionIQ") 834 | 835 | 836 | if __name__ == '__main__': 837 | main() 838 | -------------------------------------------------------------------------------- /src/blip_fine_tune.py: -------------------------------------------------------------------------------- 1 | from comet_ml import Experiment 2 | import json 3 | import multiprocessing 4 | from argparse import ArgumentParser 5 | from datetime import datetime 6 | from pathlib import Path 7 | from statistics import mean, geometric_mean, harmonic_mean 8 | from typing import List 9 | 10 | import numpy as np 11 | import pandas as pd 12 | import torch 13 | import torch.nn.functional as F 14 | from torch import optim, nn 15 | from torch.utils.data import DataLoader 16 | from tqdm import tqdm 17 | from data_utils import base_path, squarepad_transform, targetpad_transform, CIRRDataset, FashionIQDataset 18 | import clip 19 | from lavis.models import load_model 20 | from utils import collate_fn, update_train_running_results, set_train_bar_description, save_model, \ 21 | extract_index_features_blip2, generate_randomized_fiq_caption, extract_index_features_clip, \ 22 | device, element_wise_sum, cosine_lr_schedule 23 | from validate import compute_cirr_val_metrics_blip2, compute_fiq_val_metrics_blip2, compute_fiq_val_metrics_clip, \ 24 | artemis_compute_cirr_val_metrics ,compute_cirr_val_metrics_clip, artemis_compute_fiq_val_metrics 25 | from twin_attention_compositor_blip2 import TwinAttentionCompositorBLIP2 26 | from hinge_based_cross_attention_blip2 import HingebasedCrossAttentionBLIP2 27 | from twin_attention_compositor_clip import TwinAttentionCompositorCLIP 28 | from hinge_based_cross_attention_clip import HingebasedCrossAttentionCLIP 29 | import random 30 | from artemis import Artemis 31 | import ssl 32 | 33 | base_path = Path(__file__).absolute().parents[1].absolute() 34 | 35 | def blip_finetune_fiq(train_dress_types: List[str], val_dress_types: List[str], 36 | num_epochs: int, batch_size: int, 37 | validation_frequency: int, transform: str, save_training: bool, save_best: bool, 38 | **kwargs): 39 | """ 40 | Fine-tune blip on the FashionIQ dataset using as combining function the image-text element-wise sum 41 | :param train_dress_types: FashionIQ categories to train on 42 | :param val_dress_types: FashionIQ categories to validate on 43 | :param num_epochs: number of epochs 44 | :param blip_model_name: blip model you want to use: "RN50", "RN101", "RN50x4"... 45 | :param learning_rate: fine-tuning leanring rate 46 | :param batch_size: batch size 47 | :param validation_frequency: validation frequency expressed in epoch 48 | :param transform: preprocess transform you want to use. Should be in ['blip', 'squarepad', 'targetpad']. When 49 | targetpad is also required to provide `target_ratio` kwarg. 50 | :param save_training: when True save the weights of the fine-tuned blip model 51 | :param encoder: which blip encoder to fine-tune, should be in ['both', 'text', 'image'] 52 | :param save_best: when True save only the weights of the best blip model wrt the average_recall metric 53 | :param kwargs: if you use the `targetpad` transform you should prove `target_ratio` as kwarg 54 | """ 55 | 56 | experiment_name = kwargs["experiment_name"] 57 | # training_start = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") 58 | training_path: Path = Path( 59 | base_path / f"models/blip_cirr_{experiment_name}") 60 | training_path.mkdir(exist_ok=False, parents=True) 61 | 62 | # Save all the hyperparameters on a file 63 | with open(training_path / "training_hyperparameters.json", 'w+') as file: 64 | json.dump(training_hyper_params, file, sort_keys=True, indent=4) 65 | 66 | # initialize encoders 67 | encoder_arch = kwargs["encoder_arch"] 68 | model_name = kwargs["model_name"] 69 | 70 | if encoder_arch == "blip2": 71 | blip_textual = load_model(name=model_name, model_type="pretrain_vitL", is_eval=True, device=device) 72 | blip_multimodal = load_model(name=model_name, model_type="pretrain_vitL", is_eval=True, device=device) 73 | blip_visual = load_model(name=model_name, model_type="pretrain_vitL", is_eval=True, device=device) 74 | 75 | elif encoder_arch == "clip": 76 | clip_model, clip_preprocess = clip.load(model_name, device=device, jit=False) 77 | clip_model.eval().float() 78 | 79 | # initialize support modules 80 | embeds_dim = kwargs["embeds_dim"] 81 | if encoder_arch == "blip2": 82 | tac = TwinAttentionCompositorBLIP2(embeds_dim).to(device) 83 | hca = HingebasedCrossAttentionBLIP2(embeds_dim).to(device) 84 | 85 | elif encoder_arch == "clip": 86 | tac = TwinAttentionCompositorCLIP().to(device) 87 | hca = HingebasedCrossAttentionCLIP(embeds_dim).to(device) 88 | 89 | cir_frame = kwargs["cir_frame"] 90 | artemis = Artemis(embeds_dim).to(device) 91 | 92 | 93 | # define the combining func 94 | combining_function = element_wise_sum 95 | # preprocess 96 | if encoder_arch == "blip2": 97 | input_dim = 224 98 | elif encoder_arch == "clip": 99 | input_dim = clip_model.visual.input_resolution 100 | 101 | 102 | if transform == "clip": 103 | preprocess = clip_preprocess 104 | print('CLIP default preprocess pipeline is used') 105 | elif transform == "squarepad": 106 | preprocess = squarepad_transform(input_dim) 107 | print('Square pad preprocess pipeline is used') 108 | elif transform == "targetpad": 109 | target_ratio = kwargs['target_ratio'] 110 | preprocess = targetpad_transform(target_ratio, input_dim) 111 | print(f'Target pad with {target_ratio = } preprocess pipeline is used') 112 | else: 113 | raise ValueError("Preprocess transform should be in ['blip', 'squarepad', 'targetpad']") 114 | 115 | idx_to_dress_mapping = {} 116 | relative_val_datasets = [] 117 | classic_val_datasets = [] 118 | 119 | # When fine-tuning only the text encoder we can precompute the index features since they do not change over 120 | # the epochs 121 | encoder = kwargs["encoder"] 122 | if encoder == 'text': 123 | index_features_list = [] 124 | index_names_list = [] 125 | 126 | # Define the validation datasets 127 | for idx, dress_type in enumerate(val_dress_types): 128 | idx_to_dress_mapping[idx] = dress_type 129 | relative_val_dataset = FashionIQDataset('val', [dress_type], 'relative', preprocess, ) 130 | relative_val_datasets.append(relative_val_dataset) 131 | classic_val_dataset = FashionIQDataset('val', [dress_type], 'classic', preprocess, ) 132 | classic_val_datasets.append(classic_val_dataset) 133 | if encoder == 'text': 134 | if encoder_arch == "blip2": 135 | index_features_and_names = extract_index_features_blip2(classic_val_dataset, clip_model) 136 | index_features_list.append(index_features_and_names[0]) 137 | index_names_list.append(index_features_and_names[1]) 138 | 139 | if encoder_arch == "clip": 140 | index_features_and_names = extract_index_features_clip(classic_val_dataset, clip_model) 141 | index_features_list.append(index_features_and_names[0]) 142 | index_names_list.append(index_features_and_names[1]) 143 | 144 | 145 | # Define the train datasets and the combining function 146 | relative_train_dataset = FashionIQDataset('train', train_dress_types, 'relative', preprocess) 147 | relative_train_loader = DataLoader(dataset=relative_train_dataset, batch_size=batch_size, 148 | num_workers=0, pin_memory=False, collate_fn=collate_fn, 149 | drop_last=True, shuffle=True) 150 | 151 | # Define the optimizer, the loss and the grad scaler 152 | learning_rate = kwargs["learning_rate"] 153 | min_lr = kwargs["min_lr"] 154 | max_epoch = kwargs["max_epoch"] 155 | 156 | 157 | # Define the optimizer, the loss and the grad scaler 158 | if encoder_arch == "blip2": 159 | # # blip2_encoder 160 | if encoder == 'multi': # only finetuning text_encoder 161 | optimizer = optim.AdamW([ # param in blip_multimodal 162 | {'params': [param for param in blip_multimodal.Qformer.bert.parameters()], 163 | 'lr': learning_rate, 164 | 'weight_decay': 0.05}, 165 | {'params': blip_multimodal.query_tokens, 166 | 'lr': learning_rate, 167 | 'weight_decay': 0.05}, 168 | # param in blip_textual 169 | {'params': [param for param in blip_textual.Qformer.bert.parameters()], 170 | 'lr': learning_rate, 171 | 'weight_decay': 0.05}, 172 | 173 | # params in support modules 174 | {'params': [param for param in tac.parameters()], 175 | 'lr': learning_rate * 2, 176 | 'weight_decay': 0.05}, 177 | 178 | {'params': [param for param in hca.parameters()], 179 | 'lr': learning_rate * 2, 180 | 'weight_decay': 0.05}, 181 | 182 | {'params': [param for param in artemis.parameters()], 183 | 'lr': learning_rate * 10, 184 | 'weight_decay': 0.05}, 185 | ]) 186 | 187 | elif encoder == 'both': # finetuning textual and visual concurrently 188 | optimizer = optim.AdamW([ 189 | # param in blip_multimodal 190 | {'params': [param for param in blip_multimodal.Qformer.bert.parameters()], 191 | 'lr': learning_rate, 192 | 'weight_decay': 0.05}, 193 | {'params': blip_multimodal.query_tokens, 194 | 'lr': learning_rate, 195 | 'weight_decay': 0.05}, 196 | 197 | # param in blip_textual 198 | {'params': [param for param in blip_textual.Qformer.bert.parameters()], 199 | 'lr': learning_rate, 200 | 'weight_decay': 0.05}, 201 | 202 | # param in blip_visual 203 | {'params': [param for param in blip_visual.Qformer.bert.parameters()], 204 | 'lr': learning_rate, 205 | 'weight_decay': 0.05}, 206 | {'params': blip_visual.query_tokens, 207 | 'lr': learning_rate, 208 | 'weight_decay': 0.05}, 209 | 210 | # params in support modules 211 | 212 | {'params': [param for param in tac.parameters()], 213 | 'lr': learning_rate * 2, 214 | 'weight_decay': 0.05}, 215 | 216 | {'params': [param for param in hca.parameters()], 217 | 'lr': learning_rate * 2, 218 | 'weight_decay': 0.05}, 219 | 220 | {'params': [param for param in artemis.parameters()], 221 | 'lr': learning_rate * 10, 222 | 'weight_decay': 0.05}, 223 | ]) 224 | 225 | else: 226 | raise ValueError("encoders to finetune must be 'multi' or 'both'") 227 | elif encoder_arch == "clip": 228 | # clip_encoder 229 | optimizer = optim.AdamW([ # param in blip_multimodal 230 | {'params': [param for name, param in clip_model.named_parameters() 231 | if 'visual' not in name], 232 | 'lr': learning_rate, 233 | 'weight_decay': 0.05}, 234 | 235 | # params in support modules 236 | {'params': [param for param in tac.parameters()], 237 | 'lr': learning_rate * 2, 238 | 'weight_decay': 0.05}, 239 | 240 | {'params': [param for param in hca.parameters()], 241 | 'lr': learning_rate * 2, 242 | 'weight_decay': 0.05}, 243 | ]) 244 | 245 | crossentropy_criterion = nn.CrossEntropyLoss() 246 | scaler = torch.cuda.amp.GradScaler() 247 | 248 | # When save_best == True initialize the best result to zero 249 | if save_best: 250 | best_avg_recall = 0 251 | 252 | # Define dataframes for CSV logging 253 | training_log_frame = pd.DataFrame() 254 | validation_log_frame = pd.DataFrame() 255 | 256 | # define weights for different modules 257 | tac_weight = kwargs["tac_weight"] 258 | hca_weight = kwargs["hca_weight"] 259 | 260 | # Start with the training loop 261 | print('Training loop started') 262 | for epoch in range(num_epochs): 263 | with experiment.train(): 264 | # encoder = "text" or "both" 265 | # set models to train mode 266 | if encoder_arch == "blip2": 267 | # # blip2_encoder 268 | blip_multimodal.Qformer.bert.train() 269 | # blip_multimodal.vision_proj.train() 270 | blip_multimodal.query_tokens.requires_grad = True 271 | blip_textual.Qformer.bert.train() 272 | # blip_textual.text_proj.train() 273 | 274 | # both adds param in visual_encoder 275 | # blip2_encoder 276 | if encoder == "both": 277 | blip_visual.Qformer.bert.train() 278 | # blip_visual.vision_proj.train() 279 | blip_visual.query_tokens.requires_grad = True 280 | 281 | elif encoder_arch == "clip": 282 | # clip_encoder 283 | clip_model.train() 284 | 285 | # support modules 286 | if tac_weight > 0: 287 | tac.train() 288 | if hca_weight > 0: 289 | hca.train() 290 | if cir_frame == "artemis": 291 | artemis.train() 292 | 293 | 294 | train_running_results = {'images_in_epoch': 0, 'accumulated_train_loss': 0} 295 | train_bar = tqdm(relative_train_loader, ncols=150) 296 | 297 | # adjust learning rate 298 | cosine_lr_schedule(optimizer, epoch, max_epoch, learning_rate, min_lr, onlyGroup0=True) 299 | 300 | for idx, (reference_images, target_images, captions) in enumerate(train_bar): 301 | images_in_batch = reference_images.size(0) 302 | step = len(train_bar) * epoch + idx 303 | optimizer.zero_grad() 304 | 305 | # move ref and tar img to device 306 | reference_images = reference_images.to(device, non_blocking=True) 307 | target_images = target_images.to(device, non_blocking=True) 308 | 309 | # Randomize the training caption in four way: (a) cap1 and cap2 (b) cap2 and cap1 (c) cap1 (d) cap2 310 | flattened_captions: list = np.array(captions).T.flatten().tolist() 311 | input_captions = generate_randomized_fiq_caption(flattened_captions) 312 | 313 | # Extract the features, compute the logits and the loss 314 | 315 | with torch.cuda.amp.autocast(): 316 | if encoder_arch == "blip2": 317 | # text 318 | text_embeds = blip_textual.extract_features({"text_input":input_captions}, 319 | mode="text").text_embeds 320 | text_feats = text_embeds[:,0,:] 321 | 322 | # target 323 | target_embeds = blip_visual.extract_features({"image":target_images}, 324 | mode="image").image_embeds 325 | target_feats = F.normalize(target_embeds[:,0,:], dim=-1) 326 | 327 | # reference 328 | reference_embeds = blip_multimodal.extract_features({"image":reference_images, 329 | "text_input":input_captions}).multimodal_embeds 330 | reference_feats = reference_embeds[:,0,:] 331 | 332 | # embeds encoded with visual_encoder 333 | reference_embeds_for_tac = blip_visual.extract_features({"image":reference_images}, 334 | mode="image").image_embeds 335 | elif encoder_arch == "clip": 336 | # reference 337 | reference_feats, reference_embeds = clip_model.encode_image(reference_images) 338 | reference_embeds_for_tac = reference_embeds 339 | 340 | # text 341 | text_inputs = clip.tokenize(input_captions, context_length=77, truncate=True).to(device, non_blocking=True) 342 | text_feats, text_embeds = clip_model.encode_text(text_inputs) 343 | 344 | # target 345 | target_feats, target_embeds = clip_model.encode_image(target_images) 346 | target_feats = F.normalize(target_feats) 347 | 348 | # ============ Query-Target Contrastive =========== 349 | 350 | if cir_frame == "artemis": 351 | # artemis 352 | artemis_scores = artemis.compute_score_broadcast_artemis(reference_feats, text_feats, target_feats) 353 | # artemis_logits = artemis.temperature.exp() * artemis_scores 354 | elif cir_frame == "sum": 355 | # sum_predicted 356 | predicted_feats = combining_function(reference_feats, text_feats) 357 | matching_logits = 100 * predicted_feats @ target_feats.T 358 | # ============ Query-Target Align =========== 359 | 360 | # align(tac) 361 | if tac_weight > 0: 362 | visual_gap_feats = tac(reference_embeds_for_tac, target_embeds) 363 | aligning_logits = 10 * text_feats @ visual_gap_feats.T 364 | 365 | # ============ Reference-Caption-Target Contrastive =========== 366 | if hca_weight > 0: 367 | psudo_T = hca(reference_embeds = reference_embeds, 368 | caption_embeds = text_embeds, 369 | target_embeds = target_embeds) 370 | 371 | reasoning_logits = 10 * psudo_T @ reference_feats.T 372 | 373 | 374 | # ============ LOSS =========== 375 | # align_loss / tac 376 | # align_loss = crossentropy_criterion(align_logits, ground_truth) 377 | 378 | # hca_loss 379 | # hca_tcr_loss = crossentropy_criterion(hca_logits, ground_truth) 380 | 381 | if cir_frame == "artemis": 382 | # artemis_loss 383 | # artemis_loss = crossentropy_criterion(artemis_logits, ground_truth) 384 | if tac_weight > 0: 385 | contrastive_logits = tac_weight * aligning_logits + \ 386 | (1 - tac_weight) * artemis_scores 387 | else: 388 | contrastive_logits = artemis_scores 389 | elif cir_frame == "sum": 390 | # contrastive loss 391 | # contrastive_loss = crossentropy_criterion(contrast_logits, ground_truth) 392 | if tac_weight > 0: 393 | contrastive_logits = tac_weight * aligning_logits + \ 394 | (1 - tac_weight) * matching_logits 395 | else: 396 | contrastive_logits = matching_logits 397 | 398 | # ========== Sum_Loss =============== 399 | # hca_loss 400 | ground_truth = torch.arange(images_in_batch, dtype=torch.long, device=device) 401 | contrastive_loss = crossentropy_criterion(contrastive_logits, ground_truth) 402 | if hca_weight > 0: 403 | reasoning_loss = crossentropy_criterion(reasoning_logits, ground_truth) 404 | loss = hca_weight * reasoning_loss + (1 - hca_weight) * contrastive_loss 405 | else: 406 | loss = contrastive_loss 407 | 408 | # Backpropagate and update the weights 409 | scaler.scale(loss).backward() 410 | scaler.step(optimizer) 411 | scaler.update() 412 | 413 | experiment.log_metric('step_loss', loss.detach().cpu().item(), step=step) 414 | update_train_running_results(train_running_results, loss, images_in_batch) 415 | set_train_bar_description(train_bar, epoch, num_epochs, train_running_results) 416 | 417 | train_epoch_loss = float( 418 | train_running_results['accumulated_train_loss'] / train_running_results['images_in_epoch']) 419 | experiment.log_metric('epoch_loss', train_epoch_loss, epoch=epoch) 420 | 421 | # Training CSV logging 422 | training_log_frame = pd.concat( 423 | [training_log_frame, 424 | pd.DataFrame(data={'epoch': epoch, 'train_epoch_loss': train_epoch_loss}, index=[0])]) 425 | training_log_frame.to_csv(str(training_path / 'train_metrics.csv'), index=False) 426 | 427 | if epoch % validation_frequency == 0: 428 | with experiment.validate(): 429 | recalls_at10 = [] 430 | recalls_at50 = [] 431 | 432 | # Compute and log validation metrics for each validation dataset (which corresponds to a different 433 | # FashionIQ category) 434 | 435 | for relative_val_dataset, classic_val_dataset, idx in zip(relative_val_datasets, classic_val_datasets, 436 | idx_to_dress_mapping): 437 | if encoder == 'text': 438 | index_features, index_names = index_features_list[idx], index_names_list[idx] 439 | else: 440 | if encoder_arch == "blip2": 441 | index_features, index_names = extract_index_features_blip2(classic_val_dataset, blip_visual) 442 | else: 443 | index_features, index_names = extract_index_features_clip(classic_val_dataset, clip_model) 444 | 445 | if cir_frame == "sum": 446 | if encoder_arch == "blip2": 447 | 448 | recall_at10, recall_at50 = compute_fiq_val_metrics_blip2(relative_val_dataset, 449 | blip_textual, 450 | blip_multimodal, 451 | index_features, 452 | index_names, 453 | combining_function) 454 | else: 455 | recall_at10, recall_at50 = compute_fiq_val_metrics_clip(relative_val_dataset, 456 | blip_textual, 457 | blip_multimodal, 458 | index_features, 459 | index_names, 460 | combining_function) 461 | elif cir_frame == "artemis": 462 | recall_at10, recall_at50 = artemis_compute_fiq_val_metrics(relative_val_dataset, 463 | blip_textual, 464 | blip_multimodal, 465 | index_features, 466 | index_names, 467 | artemis) 468 | 469 | recalls_at10.append(recall_at10) 470 | recalls_at50.append(recall_at50) 471 | 472 | results_dict = {} 473 | for i in range(len(recalls_at10)): 474 | results_dict[f'{idx_to_dress_mapping[i]}_recall_at10'] = recalls_at10[i] 475 | results_dict[f'{idx_to_dress_mapping[i]}_recall_at50'] = recalls_at50[i] 476 | results_dict.update({ 477 | f'average_recall_at10': mean(recalls_at10), 478 | f'average_recall_at50': mean(recalls_at50), 479 | f'average_recall': (mean(recalls_at50) + mean(recalls_at10)) / 2 480 | }) 481 | 482 | print(json.dumps(results_dict, indent=4)) 483 | experiment.log_metrics( 484 | results_dict, 485 | epoch=epoch 486 | ) 487 | 488 | # Validation CSV logging 489 | log_dict = {'epoch': epoch} 490 | log_dict.update(results_dict) 491 | validation_log_frame = pd.concat([validation_log_frame, pd.DataFrame(data=log_dict, index=[0])]) 492 | validation_log_frame.to_csv(str(training_path / 'validation_metrics.csv'), index=False) 493 | 494 | if save_training: 495 | if save_best and results_dict['average_recall'] > best_avg_recall: 496 | best_avg_recall = results_dict['average_recall'] 497 | if encoder_arch == "blip2": 498 | save_model('tuned_blip_text_arithmetic', epoch, blip_textual, training_path) 499 | save_model('tuned_blip_multi_arithmetic', epoch, blip_multimodal, training_path) 500 | save_model('tuned_blip_visual_arithmetic', epoch, blip_visual, training_path) 501 | elif encoder_arch == "clip": 502 | save_model('tuned_clip_arithmetic', epoch, clip_model, training_path) 503 | 504 | save_model('tuned_tac', epoch, tac, training_path) 505 | save_model('tuned_hca', epoch, hca, training_path) 506 | if not save_best: 507 | print("Warning!!!! Now you don't save any models, please set save_best==True") 508 | 509 | 510 | def blip_finetune_cirr(num_epochs: int, batch_size: int, 511 | validation_frequency: int, transform: str, save_training: bool, save_best: bool, 512 | **kwargs): 513 | """ 514 | Fine-tune blip on the CIRR dataset using as combining function the image-text element-wise sum 515 | :param num_epochs: number of epochs 516 | :param blip_model_name: blip model you want to use: "RN50", "RN101", "RN50x4"... 517 | :param learning_rate: fine-tuning learning rate 518 | :param batch_size: batch size 519 | :param validation_frequency: validation frequency expressed in epoch 520 | :param transform: preprocess transform you want to use. Should be in ['blip', 'squarepad', 'targetpad']. When 521 | targetpad is also required to provide `target_ratio` kwarg. 522 | :param save_training: when True save the weights of the Combiner network 523 | :param encoder: which blip encoder to fine-tune, should be in ['both', 'text', 'image'] 524 | :param save_best: when True save only the weights of the best Combiner wrt three different averages of the metrics 525 | :param kwargs: if you use the `targetpad` transform you should prove `target_ratio` :return: 526 | """ 527 | experiment_name = kwargs["experiment_name"] 528 | # training_start = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") 529 | training_path: Path = Path( 530 | base_path / f"models/blip_cirr_{experiment_name}") 531 | training_path.mkdir(exist_ok=False, parents=True) 532 | 533 | # Save all the hyperparameters on a file 534 | with open(training_path / "training_hyperparameters.json", 'w+') as file: 535 | json.dump(training_hyper_params, file, sort_keys=True, indent=4) 536 | 537 | # initialize encoders 538 | encoder_arch = kwargs["encoder_arch"] 539 | 540 | # initialize the encoders with different arch 541 | model_name = kwargs["model_name"] 542 | if encoder_arch == "blip2": 543 | # blip2_encoders 544 | blip_textual = load_model(name=model_name, model_type="pretrain_vitL", is_eval=True, device=device) 545 | blip_multimodal = load_model(name=model_name, model_type="pretrain_vitL", is_eval=True, device=device) 546 | blip_visual = load_model(name=model_name, model_type="pretrain_vitL", is_eval=True, device=device) 547 | 548 | elif encoder_arch == "clip": 549 | clip_model, clip_preprocess = clip.load(model_name, device=device, jit=False) 550 | clip_model.eval().float() 551 | # initialize support modules 552 | embeds_dim = kwargs["embeds_dim"] 553 | if encoder_arch == "blip2": 554 | tac = TwinAttentionCompositorBLIP2(embeds_dim).to(device) 555 | hca = HingebasedCrossAttentionBLIP2(embeds_dim).to(device) 556 | 557 | elif encoder_arch == "clip": 558 | tac = TwinAttentionCompositorCLIP().to(device) 559 | hca = HingebasedCrossAttentionCLIP(embeds_dim).to(device) 560 | 561 | cir_frame = kwargs["cir_frame"] 562 | artemis = Artemis(embeds_dim).to(device) 563 | 564 | # defined the combining func 565 | combining_function = element_wise_sum 566 | 567 | # preprocess 568 | if encoder_arch == "blip2": 569 | input_dim = 224 570 | elif encoder_arch == "clip": 571 | input_dim = clip_model.visual.input_resolution 572 | 573 | if transform == "clip": 574 | preprocess = clip_preprocess 575 | print('CLIP default preprocess pipeline is used') 576 | elif transform == "squarepad": 577 | preprocess = squarepad_transform(input_dim) 578 | print('Square pad preprocess pipeline is used') 579 | elif transform == "targetpad": 580 | target_ratio = kwargs['target_ratio'] 581 | preprocess = targetpad_transform(target_ratio, input_dim) 582 | print(f'Target pad with {target_ratio = } preprocess pipeline is used') 583 | else: 584 | raise ValueError("Preprocess transform should be in ['squarepad', 'targetpad']") 585 | 586 | # Define the validation datasets 587 | relative_val_dataset = CIRRDataset('val', 'relative', preprocess) 588 | classic_val_dataset = CIRRDataset('val', 'classic', preprocess) 589 | 590 | # When fine-tuning only the text encoder we can precompute the index features since they do not change over 591 | # the epochs 592 | encoder = kwargs["encoder"] 593 | 594 | if encoder_arch == "blip2": 595 | val_index_features, val_index_names = extract_index_features_blip2(classic_val_dataset, blip_visual) 596 | if encoder_arch == "clip": 597 | val_index_features, val_index_names = extract_index_features_clip(classic_val_dataset, clip_model) 598 | if encoder == 'text': 599 | if encoder_arch == "blip2": 600 | val_index_features, val_index_names = extract_index_features_blip2(classic_val_dataset, blip_visual) 601 | if encoder_arch == "clip": 602 | val_index_features, val_index_names = extract_index_features_clip(classic_val_dataset, clip_model) 603 | 604 | # debug for validation 605 | if encoder_arch == "blip2": 606 | results = artemis_compute_cirr_val_metrics(relative_val_dataset, 607 | blip_textual, 608 | blip_multimodal, 609 | val_index_features, 610 | val_index_names, 611 | artemis) 612 | else: 613 | results = compute_cirr_val_metrics_clip(relative_val_dataset, clip_model, val_index_features, 614 | val_index_names, combining_function) 615 | 616 | 617 | # Define the train dataset and the combining function 618 | relative_train_dataset = CIRRDataset('train', 'relative', preprocess) 619 | relative_train_loader = DataLoader(dataset=relative_train_dataset, batch_size=batch_size, 620 | num_workers=0, pin_memory=False, collate_fn=collate_fn, 621 | drop_last=True, shuffle=True) 622 | 623 | # Define the optimizer, the loss and the grad scaler 624 | learning_rate = kwargs["learning_rate"] 625 | min_lr = kwargs["min_lr"] 626 | max_epoch = kwargs["max_epoch"] 627 | 628 | # Define the optimizer, the loss and the grad scaler 629 | if encoder_arch == "blip2": 630 | # # blip2_encoder 631 | if encoder == 'multi': # only finetuning text_encoder 632 | optimizer = optim.AdamW([ # param in blip_multimodal 633 | {'params': [param for param in blip_multimodal.Qformer.bert.parameters()], 634 | 'lr': learning_rate, 635 | 'weight_decay': 0.05}, 636 | {'params': blip_multimodal.query_tokens, 637 | 'lr': learning_rate, 638 | 'weight_decay': 0.05}, 639 | # param in blip_textual 640 | {'params': [param for param in blip_textual.Qformer.bert.parameters()], 641 | 'lr': learning_rate, 642 | 'weight_decay': 0.05}, 643 | 644 | # params in support modules 645 | {'params': [param for param in tac.parameters()], 646 | 'lr': learning_rate * 2, 647 | 'weight_decay': 0.05}, 648 | 649 | {'params': [param for param in hca.parameters()], 650 | 'lr': learning_rate * 2, 651 | 'weight_decay': 0.05}, 652 | 653 | {'params': [param for param in artemis.parameters()], 654 | 'lr': learning_rate * 10, 655 | 'weight_decay': 0.05}, 656 | ]) 657 | 658 | elif encoder == 'both': # finetuning textual and visual concurrently 659 | optimizer = optim.AdamW([ 660 | # param in blip_multimodal 661 | {'params': [param for param in blip_multimodal.Qformer.bert.parameters()], 662 | 'lr': learning_rate, 663 | 'weight_decay': 0.05}, 664 | {'params': blip_multimodal.query_tokens, 665 | 'lr': learning_rate, 666 | 'weight_decay': 0.05}, 667 | 668 | # param in blip_textual 669 | {'params': [param for param in blip_textual.Qformer.bert.parameters()], 670 | 'lr': learning_rate, 671 | 'weight_decay': 0.05}, 672 | 673 | # param in blip_visual 674 | {'params': [param for param in blip_visual.Qformer.bert.parameters()], 675 | 'lr': learning_rate, 676 | 'weight_decay': 0.05}, 677 | {'params': blip_visual.query_tokens, 678 | 'lr': learning_rate, 679 | 'weight_decay': 0.05}, 680 | 681 | # params in support modules 682 | 683 | {'params': [param for param in tac.parameters()], 684 | 'lr': learning_rate * 2, 685 | 'weight_decay': 0.05}, 686 | 687 | {'params': [param for param in hca.parameters()], 688 | 'lr': learning_rate * 2, 689 | 'weight_decay': 0.05}, 690 | 691 | {'params': [param for param in artemis.parameters()], 692 | 'lr': learning_rate * 10, 693 | 'weight_decay': 0.05}, 694 | ]) 695 | 696 | else: 697 | raise ValueError("encoders to finetune must be 'multi' or 'both'") 698 | elif encoder_arch == "clip": 699 | # clip_encoder 700 | optimizer = optim.AdamW([ # param in blip_multimodal 701 | {'params': [param for name, param in clip_model.named_parameters() 702 | if 'visual' not in name], 703 | 'lr': learning_rate, 704 | 'weight_decay': 0.05}, 705 | 706 | # params in support modules 707 | {'params': [param for param in tac.parameters()], 708 | 'lr': learning_rate * 2, 709 | 'weight_decay': 0.05}, 710 | 711 | {'params': [param for param in hca.parameters()], 712 | 'lr': learning_rate * 2, 713 | 'weight_decay': 0.05}, 714 | ]) 715 | 716 | # define loss function and scaler 717 | crossentropy_criterion = nn.CrossEntropyLoss() 718 | scaler = torch.cuda.amp.GradScaler() 719 | 720 | # When save_best == True initialize the best results to zero 721 | if save_best: 722 | # best_harmonic = 0 723 | # best_geometric = 0 724 | best_arithmetic = 0 725 | 726 | # Define dataframes for CSV logging 727 | training_log_frame = pd.DataFrame() 728 | validation_log_frame = pd.DataFrame() 729 | 730 | # define weights for different modules 731 | tac_weight = kwargs["tac_weight"] 732 | hca_weight = kwargs["hca_weight"] 733 | 734 | # epoch loop 735 | for epoch in range(num_epochs): 736 | with experiment.train(): 737 | # encoder = "text" or "both" 738 | # set models to train mode 739 | if encoder_arch == "blip2": 740 | # # blip2_encoder 741 | blip_multimodal.Qformer.bert.train() 742 | # blip_multimodal.vision_proj.train() 743 | blip_multimodal.query_tokens.requires_grad = True 744 | blip_textual.Qformer.bert.train() 745 | # blip_textual.text_proj.train() 746 | 747 | # both adds param in visual_encoder 748 | # blip2_encoder 749 | if encoder == "both": 750 | blip_visual.Qformer.bert.train() 751 | # blip_visual.vision_proj.train() 752 | blip_visual.query_tokens.requires_grad = True 753 | 754 | elif encoder_arch == "clip": 755 | # clip_encoder 756 | clip_model.train() 757 | 758 | # support modules 759 | if tac_weight > 0: 760 | tac.train() 761 | if hca_weight > 0: 762 | hca.train() 763 | if cir_frame == "artemis": 764 | artemis.train() 765 | 766 | train_running_results = {'images_in_epoch': 0, 'accumulated_train_loss': 0} 767 | train_bar = tqdm(relative_train_loader, ncols=150) 768 | 769 | # adjust learning rate in every epoch 770 | cosine_lr_schedule(optimizer, epoch, max_epoch, learning_rate, min_lr, onlyGroup0=True) 771 | 772 | # iteration loop 773 | for idx, (reference_images, target_images, captions) in enumerate(train_bar): 774 | images_in_batch = reference_images.size(0) 775 | step = len(train_bar) * epoch + idx 776 | optimizer.zero_grad() 777 | 778 | # move ref and tar img to device 779 | reference_images = reference_images.to(device, non_blocking=True) 780 | target_images = target_images.to(device, non_blocking=True) 781 | 782 | # Extract the features, compute the logits and the loss 783 | 784 | with torch.cuda.amp.autocast(): 785 | if encoder_arch == "blip2": 786 | # text 787 | text_embeds = blip_textual.extract_features({"text_input":captions}, 788 | mode="text").text_embeds 789 | text_feats = text_embeds[:,0,:] 790 | 791 | # target 792 | target_embeds = blip_visual.extract_features({"image":target_images}, 793 | mode="image").image_embeds 794 | target_feats = F.normalize(target_embeds[:,0,:], dim=-1) 795 | 796 | # reference 797 | reference_embeds = blip_multimodal.extract_features({"image":reference_images, 798 | "text_input":captions}).multimodal_embeds 799 | reference_feats = reference_embeds[:,0,:] 800 | 801 | # embeds encoded with visual_encoder 802 | reference_embeds_for_tac = blip_visual.extract_features({"image":reference_images}, 803 | mode="image").image_embeds 804 | elif encoder_arch == "clip": 805 | # reference 806 | reference_feats, reference_embeds = clip_model.encode_image(reference_images) 807 | reference_embeds_for_tac = reference_embeds 808 | 809 | # text 810 | text_inputs = clip.tokenize(captions, context_length=77, truncate=True).to(device,non_blocking=True) 811 | text_feats, text_embeds = clip_model.encode_text(text_inputs) 812 | 813 | # target 814 | target_feats, target_embeds = clip_model.encode_image(target_images) 815 | target_feats = F.normalize(target_feats) 816 | 817 | 818 | # ============ Query-Target Contrastive =========== 819 | 820 | if cir_frame == "artemis": 821 | # artemis 822 | artemis_scores = artemis.compute_score_broadcast_artemis(reference_feats, text_feats, target_feats) 823 | # artemis_logits = artemis.temperature.exp() * artemis_scores 824 | elif cir_frame == "sum": 825 | # sum_predicted 826 | predicted_feats = combining_function(reference_feats, text_feats) 827 | matching_logits = 100 * predicted_feats @ target_feats.T 828 | 829 | # ============ Query-Target Align =========== 830 | 831 | # align(tac) 832 | if tac_weight > 0: 833 | visual_gap_feats = tac(reference_embeds_for_tac, target_embeds) 834 | aligning_logits = 10 * text_feats @ visual_gap_feats.T 835 | 836 | # ============ Reference-Caption-Target Contrastive =========== 837 | if hca_weight > 0: 838 | psudo_T = hca(reference_embeds = reference_embeds, 839 | caption_embeds = text_embeds, 840 | target_embeds = target_embeds) 841 | 842 | reasoning_logits = 10 * psudo_T @ reference_feats.T 843 | 844 | # ============ LOSS =========== 845 | # align_loss / tac 846 | # align_loss = crossentropy_criterion(align_logits, ground_truth) 847 | 848 | # hca_loss 849 | # hca_tcr_loss = crossentropy_criterion(hca_logits, ground_truth) 850 | 851 | if cir_frame == "artemis": 852 | # artemis_loss 853 | # artemis_loss = crossentropy_criterion(artemis_logits, ground_truth) 854 | if tac_weight > 0: 855 | contrastive_logits = tac_weight * aligning_logits + \ 856 | (1 - tac_weight) * artemis_scores 857 | else: 858 | contrastive_logits = artemis_scores 859 | elif cir_frame == "sum": 860 | # contrastive loss 861 | # contrastive_loss = crossentropy_criterion(contrast_logits, ground_truth) 862 | if tac_weight > 0: 863 | contrastive_logits = tac_weight * aligning_logits + \ 864 | (1 - tac_weight) * matching_logits 865 | else: 866 | contrastive_logits = matching_logits 867 | 868 | # ========== Sum_Loss =============== 869 | # hca_loss 870 | ground_truth = torch.arange(images_in_batch, dtype=torch.long, device=device) 871 | contrastive_loss = crossentropy_criterion(contrastive_logits, ground_truth) 872 | if hca_weight > 0: 873 | reasoning_loss = crossentropy_criterion(reasoning_logits, ground_truth) 874 | loss = hca_weight * reasoning_loss + (1 - hca_weight) * contrastive_loss 875 | else: 876 | loss = contrastive_loss 877 | 878 | # Backpropagate and update the weights 879 | scaler.scale(loss).backward() 880 | scaler.step(optimizer) 881 | scaler.update() 882 | 883 | experiment.log_metric('step_loss', loss.detach().cpu().item(), step=step) 884 | update_train_running_results(train_running_results, loss, images_in_batch) 885 | set_train_bar_description(train_bar, epoch, num_epochs, train_running_results) 886 | 887 | train_epoch_loss = float( 888 | train_running_results['accumulated_train_loss'] / train_running_results['images_in_epoch']) 889 | experiment.log_metric('epoch_loss', train_epoch_loss, epoch=epoch) 890 | 891 | # Training CSV logging 892 | training_log_frame = pd.concat( 893 | [training_log_frame, 894 | pd.DataFrame(data={'epoch': epoch, 'train_epoch_loss': train_epoch_loss}, index=[0])]) 895 | training_log_frame.to_csv(str(training_path / 'train_metrics.csv'), index=False) 896 | 897 | if epoch % validation_frequency == 0: 898 | with experiment.validate(): 899 | if encoder_arch == "blip2": 900 | val_index_features, val_index_names = extract_index_features_blip2(classic_val_dataset, blip_visual) 901 | if cir_frame == "sum": 902 | results = compute_cirr_val_metrics_blip2(relative_val_dataset, 903 | blip_textual, 904 | blip_multimodal, 905 | val_index_features, 906 | val_index_names, 907 | combining_function) 908 | elif cir_frame == "artemis": 909 | results = artemis_compute_cirr_val_metrics(relative_val_dataset, 910 | blip_textual, 911 | blip_multimodal, 912 | val_index_features, 913 | val_index_names, 914 | artemis) 915 | elif encoder_arch == "clip": 916 | results = compute_cirr_val_metrics_clip(relative_val_dataset, clip_model, val_index_features, 917 | val_index_names, combining_function) 918 | 919 | 920 | group_recall_at1, group_recall_at2, group_recall_at3, recall_at1, recall_at5, recall_at10, recall_at50 = results 921 | 922 | results_dict = { 923 | 'group_recall_at1': group_recall_at1, 924 | 'group_recall_at2': group_recall_at2, 925 | 'group_recall_at3': group_recall_at3, 926 | 'recall_at1': recall_at1, 927 | 'recall_at5': recall_at5, 928 | 'recall_at10': recall_at10, 929 | 'recall_at50': recall_at50, 930 | 'mean(R@5+R_s@1)': (group_recall_at1 + recall_at5) / 2, 931 | 'arithmetic_mean': mean(results), 932 | 'harmonic_mean': harmonic_mean(results), 933 | 'geometric_mean': geometric_mean(results) 934 | } 935 | print(json.dumps(results_dict, indent=4)) 936 | 937 | experiment.log_metrics( 938 | results_dict, 939 | epoch=epoch 940 | ) 941 | 942 | # Validation CSV logging 943 | log_dict = {'epoch': epoch} 944 | log_dict.update(results_dict) 945 | validation_log_frame = pd.concat([validation_log_frame, pd.DataFrame(data=log_dict, index=[0])]) 946 | validation_log_frame.to_csv(str(training_path / 'validation_metrics.csv'), index=False) 947 | 948 | if save_training: 949 | if save_best and results_dict['arithmetic_mean'] > best_arithmetic: 950 | best_arithmetic = results_dict['arithmetic_mean'] 951 | # save encoders 952 | if encoder_arch == "blip2": 953 | save_model('tuned_blip_text_arithmetic', epoch, blip_textual, training_path) 954 | save_model('tuned_blip_multi_arithmetic', epoch, blip_multimodal, training_path) 955 | save_model('tuned_blip_visual_arithmetic', epoch, blip_visual, training_path) 956 | elif encoder_arch == "clip": 957 | save_model('tuned_clip_arithmetic', epoch, clip_model, training_path) 958 | # save support modules anyway 959 | save_model('tuned_tac_arithmetic', epoch, tac, training_path) 960 | save_model('tuned_hca_arithmetic', epoch, hca, training_path) 961 | # save artemis modules 962 | if cir_frame == "artemis": 963 | save_model('tuned_artemis_arithmetic', epoch, artemis, training_path) 964 | if not save_best: 965 | print("Warning!!!! Now you don't save any models, please set save_best==True") 966 | 967 | 968 | if __name__ == '__main__': 969 | parser = ArgumentParser() 970 | # dataset 971 | parser.add_argument("--dataset", type=str, required=True, help="should be either 'CIRR' or 'fashionIQ'") 972 | # comet enviroment 973 | parser.add_argument("--api-key", type=str, help="api for Comet logging") 974 | parser.add_argument("--workspace", type=str, help="workspace of Comet logging") 975 | parser.add_argument("--experiment-name", type=str, help="name of the experiment on Comet") 976 | # fine_tune_encoder_modal 977 | parser.add_argument("--encoder", type=str, default="text", help="the encoder that needs to be finetuned") 978 | parser.add_argument("--encoder-arch", type=str, default="blip2", help="the encoder architecture") 979 | parser.add_argument("--model-name", type=str, default="blip2_feature_extractor", help="the model used for encoder blip2_feature_extractor for blip2 or RN50x4 for clip") 980 | # training args 981 | parser.add_argument("--num-epochs", default=300, type=int, help="number training epochs") 982 | parser.add_argument("--learning-rate", default=1e-5, type=float, help="Learning rate") 983 | parser.add_argument("--batch-size", default=512, type=int, help="Batch size") 984 | # cosin learning rate scheduler 985 | parser.add_argument("--min-lr", default=0, type=float, help="Cos Learning Rate Scheduler min learning rate") 986 | parser.add_argument("--max-epoch", default=10, type=int, help="Cos Learning Rate Scheduler max epoch") 987 | #i mage preprocessing 988 | parser.add_argument("--target-ratio", default=1.25, type=float, help="TargetPad target ratio") 989 | parser.add_argument("--transform", default="targetpad", type=str, 990 | help="Preprocess pipeline, should be in ['blip', 'squarepad', 'targetpad'] ") 991 | # training settings 992 | parser.add_argument("--validation-frequency", default=1, type=int, help="Validation frequency expressed in epochs") 993 | parser.add_argument("--save-training", dest="save_training", action='store_true', 994 | help="Whether save the training model") 995 | parser.add_argument("--save-best", dest="save_best", action='store_true', 996 | help="Save only the best model during training") 997 | parser.add_argument("--cir-frame", default="sum", type=str, help="frame loss") 998 | parser.add_argument("--tac-weight", default=0.1, type=float, help="tac_loss weight") 999 | parser.add_argument("--hca-weight", default=0.1, type=float, help="hca-loss weight") 1000 | parser.add_argument("--hca-temperature", default=0.92, type=float, help="hca_temperature") 1001 | parser.add_argument("--tac-temperature", default=2.3, type=float, help="tac_temperature") 1002 | parser.add_argument("--embeds-dim", default=768, type=int, help="") 1003 | 1004 | # fix seed for stable results 1005 | seed = 42 1006 | random.seed(seed) 1007 | np.random.seed(seed) 1008 | torch.manual_seed(seed) 1009 | torch.cuda.manual_seed(seed) 1010 | torch.backends.cudnn.deterministic=True 1011 | torch.backends.cudnn.benchmark=True 1012 | 1013 | args = parser.parse_args() 1014 | if args.dataset.lower() not in ['fashioniq', 'cirr']: 1015 | raise ValueError("Dataset should be either 'CIRR' or 'FashionIQ") 1016 | 1017 | training_hyper_params = { 1018 | "num_epochs": args.num_epochs, 1019 | "learning_rate": args.learning_rate, 1020 | "max_epoch": args.max_epoch, 1021 | "min_lr": args.min_lr, 1022 | "batch_size": args.batch_size, 1023 | "validation_frequency": args.validation_frequency, 1024 | "transform": args.transform, 1025 | "target_ratio": args.target_ratio, 1026 | "save_training": args.save_training, 1027 | "save_best": args.save_best, 1028 | "experiment_name":args.experiment_name, 1029 | "encoder":args.encoder, 1030 | "encoder_arch":args.encoder_arch, 1031 | "model_name":args.model_name, 1032 | "cir_frame":args.cir_frame, 1033 | "tac_weight": args.tac_weight, 1034 | "hca_weight": args.hca_weight, 1035 | "hca_temperature": args.hca_temperature, 1036 | "tac_temperature": args.tac_temperature, 1037 | "embeds_dim": args.embeds_dim 1038 | } 1039 | if args.api_key and args.workspace: 1040 | print("Comet logging ENABLED") 1041 | experiment = Experiment( 1042 | api_key=args.api_key, 1043 | project_name=f"{args.dataset} blip fine-tuning", 1044 | workspace=args.workspace, 1045 | disabled=False 1046 | ) 1047 | if args.experiment_name: 1048 | experiment.set_name(args.experiment_name) 1049 | else: 1050 | print("Comet loging DISABLED, in order to enable it you need to provide an api key and a workspace") 1051 | experiment = Experiment( 1052 | api_key="", 1053 | project_name="", 1054 | workspace="", 1055 | disabled=True 1056 | ) 1057 | 1058 | experiment.log_code(folder=str(base_path / 'src')) 1059 | experiment.log_parameters(training_hyper_params) 1060 | 1061 | ssl._create_default_https_context = ssl._create_unverified_context 1062 | 1063 | if args.dataset.lower() == 'cirr': 1064 | blip_finetune_cirr(**training_hyper_params) 1065 | elif args.dataset.lower() == 'fashioniq': 1066 | training_hyper_params.update( 1067 | {'train_dress_types': ['dress', 'toptee', 'shirt'], 'val_dress_types': ['dress', 'toptee', 'shirt']}) 1068 | blip_finetune_fiq(**training_hyper_params) 1069 | --------------------------------------------------------------------------------