├── assets
├── dummy.txt
├── cala_workflow.jpg
└── cala_workflow.pdf
├── cala_workflow.jpg
├── LICENSE
├── src
├── twin_attention_compositor_clip.py
├── twin_attention_compositor_blip2.py
├── artemis.py
├── hinge_based_cross_attention_blip2.py
├── hinge_based_cross_attention_clip.py
├── cirr_test_submission_blip2.py
├── cirr_test_submission.py
├── cirr_artemis_submission_blip2.py
├── utils.py
├── data_utils.py
├── validate.py
└── blip_fine_tune.py
└── README.md
/assets/dummy.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/cala_workflow.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chiangsonw/CaLa/HEAD/cala_workflow.jpg
--------------------------------------------------------------------------------
/assets/cala_workflow.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chiangsonw/CaLa/HEAD/assets/cala_workflow.jpg
--------------------------------------------------------------------------------
/assets/cala_workflow.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chiangsonw/CaLa/HEAD/assets/cala_workflow.pdf
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 MadChiang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/twin_attention_compositor_clip.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import sys
5 | from utils import device
6 |
7 | class TwinAttentionCompositorCLIP(nn.Module):
8 | def __init__(self) -> None:
9 | super().__init__()
10 |
11 | self.fc = nn.Linear(2560,640)
12 | self.relu = nn.ReLU(inplace=True)
13 | self.reference_as_query_attention = torch.nn.MultiheadAttention(embed_dim=640, num_heads=1, dropout=0.0, batch_first=True)
14 | self.target_as_query_attention = torch.nn.MultiheadAttention(embed_dim=640, num_heads=1, dropout=0.0, batch_first=True)
15 |
16 | def forward(self, reference_embeddings:torch.tensor, target_embeddings:torch.tensor):
17 | bs, hi, h, w = reference_embeddings.size()
18 | #embeddings to tokens bs x length x hidden bs 81 2560
19 | reference_embeddings = reference_embeddings.view(bs,h*w,hi)
20 | target_embeddings = target_embeddings.view(bs,h*w,hi)
21 | #dim compact bs 81 640 linear降维
22 | reference_tokens = self.relu(self.fc(reference_embeddings))
23 | target_tokens =self.relu(self.fc(target_embeddings))
24 | cls_token = torch.randn(bs, 1, 640).to(device, non_blocking=True)
25 | #cat cls token bs 82 640
26 | reference_tokens = torch.cat([cls_token, reference_tokens], dim=1)
27 | target_tokens = torch.cat([cls_token, target_tokens], dim=1)
28 |
29 | # 4 layers
30 | output1, _ = self.reference_as_query_attention(query=reference_tokens, key=target_tokens, value=target_tokens)
31 | output1, _ = self.reference_as_query_attention(query=reference_tokens, key=output1, value=output1)
32 | output1, _ = self.reference_as_query_attention(query=reference_tokens, key=output1, value=output1)
33 | output1, _ = self.reference_as_query_attention(query=reference_tokens, key=output1, value=output1)
34 |
35 | #4 layers
36 | output2, _ = self.target_as_query_attention(query=target_tokens, key=reference_tokens, value=reference_tokens)
37 | output2, _ = self.target_as_query_attention(query=target_tokens, key=output2, value=output2)
38 | output2, _ = self.target_as_query_attention(query=target_tokens, key=output2, value=output2)
39 | output2, _ = self.target_as_query_attention(query=target_tokens, key=output2, value=output2)
40 |
41 |
42 | output1_features = output1[:,0,:]
43 | output2_features = output2[:,0,:]
44 | output_features = (output1_features + output2_features) / 2
45 | return output_features
46 |
47 |
48 |
49 |
--------------------------------------------------------------------------------
/src/twin_attention_compositor_blip2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from utils import device
5 |
6 | class TwinAttentionCompositorBLIP2(nn.Module):
7 | def __init__(self, embedding_dim) -> None:
8 | super().__init__()
9 | # self.fusion = nn.Linear(512,256)
10 | # self.relu1 = nn.ReLU(inplace=True)
11 | self.reference_as_query_attention = torch.nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=1, dropout=0.0, batch_first=True)
12 | self.target_as_query_attention = torch.nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=1, dropout=0.0, batch_first=True)
13 |
14 | def forward(self, reference_embeddings:torch.tensor, target_embeddings:torch.tensor):
15 | #embeddings to tokens bs x length x hidden bs x 32 x 256
16 |
17 | # 4 layers of attention
18 | output1, _ = self.reference_as_query_attention(query=reference_embeddings, key=target_embeddings, value=target_embeddings)
19 | output1, _ = self.reference_as_query_attention(query=reference_embeddings, key=output1, value=output1)
20 | output1, _ = self.reference_as_query_attention(query=reference_embeddings, key=output1, value=output1)
21 | output1, _ = self.reference_as_query_attention(query=reference_embeddings, key=output1, value=output1)
22 |
23 | # 4 layers of attention
24 | output2, _ = self.target_as_query_attention(query=target_embeddings, key=reference_embeddings, value=reference_embeddings)
25 | output2, _ = self.target_as_query_attention(query=target_embeddings, key=output2, value=output2)
26 | output2, _ = self.target_as_query_attention(query=target_embeddings, key=output2, value=output2)
27 | output2, _ = self.target_as_query_attention(query=target_embeddings, key=output2, value=output2)
28 |
29 | # share weight
30 | # output2, _ = self.reference_as_query_attention(query=target_embeddings, key=reference_embeddings, value=reference_embeddings)
31 | # output2, _ = self.reference_as_query_attention(query=target_embeddings, key=output2, value=output2)
32 | # output2, _ = self.reference_as_query_attention(query=target_embeddings, key=output2, value=output2)
33 | # output2, _ = self.reference_as_query_attention(query=target_embeddings, key=output2, value=output2)
34 |
35 | # use 0 token 作为 features bs x 256 两个features平均
36 | output1_features = output1[:,0,:]
37 | output2_features = output2[:,0,:]
38 | output_features = (output1_features + output2_features) / 2
39 | return output_features
40 |
--------------------------------------------------------------------------------
/src/artemis.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from utils import l2norm
5 |
6 | class Artemis(nn.Module):
7 |
8 | def __init__(self, clip_feature_dim) -> None:
9 | super().__init__()
10 | self.embed_dim = clip_feature_dim
11 | self.Transform_m = nn.Sequential(nn.Linear(self.embed_dim, self.embed_dim), L2Module())
12 | self.Attention_EM = AttentionMechanism(self.embed_dim)
13 | self.Attention_IS = AttentionMechanism(self.embed_dim)
14 | self.temperature = nn.Parameter(torch.FloatTensor((2.65926,)))
15 |
16 | def apply_attention(self, a, x):
17 | return l2norm(a * x)
18 |
19 | def compute_score_artemis(self, r, m, t, store_intermediary=False):
20 | EM = self.compute_score_EM(r, m, t, store_intermediary)
21 | IS = self.compute_score_IS(r, m, t, store_intermediary)
22 | if store_intermediary:
23 | self.hold_results["EM"] = EM
24 | self.hold_results["IS"] = IS
25 | return EM + IS
26 |
27 | def compute_score_broadcast_artemis(self, r, m, t):
28 | return self.compute_score_broadcast_EM(r, m, t) + self.compute_score_broadcast_IS(r, m, t)
29 |
30 | def compute_score_EM(self, r, m, t, store_intermediary=False):
31 | Tr_m = self.Transform_m(m)
32 | A_EM_t = self.apply_attention(self.Attention_EM(m), t)
33 | if store_intermediary:
34 | self.hold_results["Tr_m"] = Tr_m
35 | self.hold_results["A_EM_t"] = A_EM_t
36 | return (Tr_m * A_EM_t).sum(-1)
37 |
38 | def compute_score_broadcast_EM(self, r, m, t):
39 | batch_size = r.size(0)
40 | A_EM = self.Attention_EM(m) # shape (Bq, d)
41 | Tr_m = self.Transform_m(m) # shape (Bq, d)
42 | # apply each query attention mechanism to all targets
43 | A_EM_all_t = self.apply_attention(A_EM.view(batch_size, 1, self.embed_dim), t.view(1, batch_size, self.embed_dim)) # shape (Bq, Bt, d)
44 | EM_score = (Tr_m.view(batch_size, 1, self.embed_dim) * A_EM_all_t).sum(-1) # shape (Bq, Bt) ; coefficient (i,j) is the IS score between query i and target j
45 | return EM_score
46 |
47 | def compute_score_IS(self, r, m, t, store_intermediary=False):
48 | A_IS_r = self.apply_attention(self.Attention_IS(m), r)
49 | A_IS_t = self.apply_attention(self.Attention_IS(m), t)
50 | if store_intermediary:
51 | self.hold_results["A_IS_r"] = A_IS_r
52 | self.hold_results["A_IS_t"] = A_IS_t
53 | return (A_IS_r * A_IS_t).sum(-1)
54 |
55 | def compute_score_broadcast_IS(self, r, m, t):
56 | batch_size = r.size(0)
57 | A_IS = self.Attention_IS(m) # shape (Bq, d)
58 | A_IS_r = self.apply_attention(A_IS, r) # shape (Bq, d)
59 | # apply each query attention mechanism to all targets
60 | A_IS_all_t = self.apply_attention(A_IS.view(batch_size, 1, self.embed_dim), t.view(1, batch_size, self.embed_dim)) # shape (Bq, Bt, d)
61 | IS_score = (A_IS_r.view(batch_size, 1, self.embed_dim) * A_IS_all_t).sum(-1) # shape (Bq, Bt) ; coefficient (i,j) is the IS score between query i and target j
62 | return IS_score
63 |
64 |
65 | class L2Module(nn.Module):
66 |
67 | def __init__(self):
68 | super(L2Module, self).__init__()
69 |
70 | def forward(self, x):
71 | x = l2norm(x)
72 | return x
73 |
74 | class AttentionMechanism(nn.Module):
75 | """
76 | Module defining the architecture of the attention mechanisms in ARTEMIS.
77 | """
78 |
79 | def __init__(self, embed_dim):
80 | super(AttentionMechanism, self).__init__()
81 |
82 | self.embed_dim = embed_dim
83 | input_dim = self.embed_dim
84 |
85 | self.attention = nn.Sequential(
86 | nn.Linear(input_dim, self.embed_dim),
87 | nn.ReLU(),
88 | nn.Linear(self.embed_dim, self.embed_dim),
89 | nn.Softmax(dim=1)
90 | )
91 |
92 | def forward(self, x):
93 | return self.attention(x)
--------------------------------------------------------------------------------
/src/hinge_based_cross_attention_blip2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from utils import device
5 | import math
6 |
7 | class HingebasedCrossAttentionBLIP2(nn.Module):
8 | def __init__(self, embed_dim) -> None:
9 | super().__init__()
10 | # attention proj
11 | self.query_ref1 = nn.Linear(embed_dim,embed_dim)
12 | self.key_text1 = nn.Linear(embed_dim,embed_dim)
13 | # self.query_text1 = nn.Linear(embed_dim,embed_dim)
14 | self.key_tar1 = nn.Linear(embed_dim,embed_dim)
15 | self.value1 = nn.Linear(embed_dim,embed_dim)
16 | self.dropout1 = nn.Dropout(0.1)
17 |
18 | self.query_ref2 = nn.Linear(embed_dim,embed_dim)
19 | self.key_text2 = nn.Linear(embed_dim,embed_dim)
20 | self.key_tar2 = nn.Linear(embed_dim,embed_dim)
21 | self.value2 = nn.Linear(embed_dim,embed_dim)
22 | self.dropout2 = nn.Dropout(0.1)
23 |
24 | def forward(self, reference_embeds, caption_embeds, target_embeds):
25 | psudo_T = self.hca_T_share_text(reference_embeds, caption_embeds, target_embeds)
26 | return psudo_T
27 |
28 | def hca_T_share_text(self, reference_embeds, caption_embeds, target_embeds):
29 | bs , len_r , dim = reference_embeds.shape
30 | attA = self.multiply(self.query_ref1(reference_embeds), self.key_text1(caption_embeds)) / math.sqrt(dim)
31 | attB = self.multiply(self.key_text1(caption_embeds), self.key_tar1(target_embeds)) / math.sqrt(dim)
32 | attC = self.dropout1(F.softmax(torch.matmul(attA, attB), dim=-1))
33 | psudo_T = torch.matmul(attC , self.value1(target_embeds))
34 | return psudo_T[:,0,:]
35 |
36 | def hca_T_R_share_text(self, reference_embeds, caption_embeds, target_embeds):
37 | bs , len_r , dim = reference_embeds.shape
38 | attA1 = self.multiply(self.query_ref1(reference_embeds), self.key_text1(caption_embeds)) / math.sqrt(dim)
39 | attB1 = self.multiply(self.key_text1(caption_embeds), self.key_tar1(target_embeds)) / math.sqrt(dim)
40 | attC1 = self.dropout1(F.softmax(torch.matmul(attA1, attB1), dim=-1))
41 | psudo_T = torch.matmul(attC1 , self.value1(target_embeds))
42 |
43 | attA2 = self.multiply(self.query_ref2(target_embeds), self.key_text2(caption_embeds)) / math.sqrt(dim)
44 | attB2 = self.multiply(self.key_text2(caption_embeds), self.key_tar2(reference_embeds)) / math.sqrt(dim)
45 | attC2 = self.dropout2(F.softmax(torch.matmul(attA2, attB2), dim=-1))
46 | psudo_R = torch.matmul(attC2 , self.value2(reference_embeds))
47 |
48 | return psudo_T[:,0,:], psudo_R[:,0,:]
49 |
50 | def hca_T_multihead_4layer(self, reference_embeds, caption_embeds, target_embeds):
51 | bs , len_r , dim = reference_embeds.shape
52 | attA1 = self.multiply(self.query_ref1(reference_embeds), self.key_text1(caption_embeds)) / math.sqrt(dim)
53 | attB1 = self.multiply(self.key_text1(caption_embeds), self.key_tar1(target_embeds)) / math.sqrt(dim)
54 | attC1 = self.dropout1(F.softmax(torch.matmul(attA1, attB1), dim=-1))
55 | psudo_T1 = torch.matmul(attC1 , self.value1(target_embeds))
56 |
57 | attA2 = self.multiply(self.query_ref2(reference_embeds), self.key_text2(caption_embeds)) / math.sqrt(dim)
58 | attB2 = self.multiply(self.key_text2(caption_embeds), self.key_tar2(target_embeds)) / math.sqrt(dim)
59 | attC2 = self.dropout2(F.softmax(torch.matmul(attA2, attB2), dim=-1))
60 | psudo_T2 = torch.matmul(attC2 , self.value2(target_embeds))
61 |
62 | attA3 = self.multiply(self.query_ref3(reference_embeds), self.key_text3(caption_embeds)) / math.sqrt(dim)
63 | attB3 = self.multiply(self.key_text3(caption_embeds), self.key_tar3(target_embeds)) / math.sqrt(dim)
64 | attC3 = self.dropout3(F.softmax(torch.matmul(attA3, attB3), dim=-1))
65 | psudo_T3 = torch.matmul(attC3 , self.value3(target_embeds))
66 |
67 | attA4 = self.multiply(self.query_ref4(reference_embeds), self.key_text4(caption_embeds)) / math.sqrt(dim)
68 | attB4 = self.multiply(self.key_text4(caption_embeds), self.key_tar4(target_embeds)) / math.sqrt(dim)
69 | attC4 = self.dropout4(F.softmax(torch.matmul(attA4, attB4), dim=-1))
70 | psudo_T4 = torch.matmul(attC4 , self.value4(target_embeds))
71 |
72 | return (psudo_T1[:,0,:] + psudo_T2[:,0,:] + psudo_T3[:,0,:] + psudo_T4[:,0,:]) / 4
73 |
74 | def hca_T_multihead_2layer(self, reference_embeds, caption_embeds, target_embeds):
75 | bs , len_r , dim = reference_embeds.shape
76 | attA1 = self.multiply(self.query_ref1(reference_embeds), self.key_text1(caption_embeds)) / math.sqrt(dim)
77 | attB1 = self.multiply(self.key_text1(caption_embeds), self.key_tar1(target_embeds)) / math.sqrt(dim)
78 | attC1 = self.dropout1(F.softmax(torch.matmul(attA1, attB1), dim=-1))
79 | psudo_T1 = torch.matmul(attC1 , self.value1(target_embeds))
80 |
81 | attA2 = self.multiply(self.query_ref2(reference_embeds), self.key_text2(caption_embeds)) / math.sqrt(dim)
82 | attB2 = self.multiply(self.key_text2(caption_embeds), self.key_tar2(target_embeds)) / math.sqrt(dim)
83 | attC2 = self.dropout2(F.softmax(torch.matmul(attA2, attB2), dim=-1))
84 | psudo_T2 = torch.matmul(attC2 , self.value2(target_embeds))
85 |
86 | return (psudo_T1[:,0,:] + psudo_T2[:,0,:]) / 2
87 |
88 | def multiply(self, embedsA, embedsB):
89 | bs, len_a , dim = embedsA.shape
90 | bs, len_b , dim = embedsB.shape
91 |
92 | # 扁平化
93 | embedsA = embedsA.view(bs, -1, dim) # 形状为 bs x (length_a * dim)
94 | embedsB = embedsB.view(bs, -1, dim) # 形状为 bs x (length_b * dim)
95 |
96 | # 点积计算
97 | attention_scores_flat = torch.matmul(embedsA, embedsB.transpose(-1, -2)) # 转置 Key 的维度
98 |
99 | # 还原形状
100 | attention_scores = attention_scores_flat.view(bs, len_a, len_b)
101 |
102 | return attention_scores
103 |
104 |
105 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CaLa
2 |
3 | [](https://paperswithcode.com/sota/image-retrieval-on-cirr?p=cala-complementary-association-learning-for)
4 | [](https://paperswithcode.com/sota/image-retrieval-on-fashion-iq?p=cala-complementary-association-learning-for)
5 |
6 | **CaLa(ACM SIGIR 2024)** is a new composed image retrieval framework, considering two complementary associations in the task. CaLa presents TBIA(text-based image alignment) and CTR(complementary text reasoning) for augmenting composed image retrieval.
7 |
8 | We highlight the contributions of this paper as follows:
9 |
10 | • We present a new thinking of composed image retrieval,the annotated triplet is viewed as a graph node, and two complementary association clues are disclosed to enhance the composed image retrieval.
11 |
12 | • A hinge-based attention and twin-attention-based visual compositor are proposed to effectively impose the new associations into the network learning.
13 |
14 | • Competitive Performance on CIRR and FashionIQ benchmarks. CaLa can benefit several baselines with different backbones and architectures, revealing it is a widely beneficial module for composed image retrieval.
15 |
16 | More details can be found at our paper: [CaLa: Complementary Association Learning for Augmenting Composed Image Retrieval](https://arxiv.org/pdf/2405.19149)
17 |
18 | This is the workflow of our CaLa framework.
19 |

20 |
21 | ## News
22 |
23 | ## Models and Weights
24 |
25 | ## Usage
26 |
27 | ### Prerequisites
28 |
29 | The following commands will create a local Anaconda environment with the necessary packages installed.
30 |
31 | ```bash
32 | conda create -n cala -y python=3.8
33 | conda activate cala
34 | conda install -y -c pytorch pytorch=1.11.0 torchvision=0.12.0
35 | conda install -y -c anaconda pandas=1.4.2
36 | pip install comet-ml==3.21.0
37 | pip install git+https://github.com/openai/CLIP.git
38 | pip install salesforce-lavis
39 | ```
40 |
41 | ### Data Preparation
42 |
43 | To properly work with the codebase FashionIQ and CIRR datasets should have the following structure:
44 |
45 | ```
46 | project_base_path
47 | └─── CaLa
48 | └─── src
49 | | blip_fine_tune.py
50 | | data_utils.py
51 | | utils.py
52 | | ...
53 |
54 | └─── fashionIQ_dataset
55 | └─── captions
56 | | cap.dress.test.json
57 | | cap.dress.train.json
58 | | cap.dress.val.json
59 | | ...
60 |
61 | └─── images
62 | | B00006M009.jpg
63 | | B00006M00B.jpg
64 | | B00006M6IH.jpg
65 | | ...
66 |
67 | └─── image_splits
68 | | split.dress.test.json
69 | | split.dress.train.json
70 | | split.dress.val.json
71 | | ...
72 |
73 | └─── cirr_dataset
74 | └─── train
75 | └─── 0
76 | | train-10108-0-img0.png
77 | | train-10108-0-img1.png
78 | | train-10108-1-img0.png
79 | | ...
80 |
81 | └─── 1
82 | | train-10056-0-img0.png
83 | | train-10056-0-img1.png
84 | | train-10056-1-img0.png
85 | | ...
86 |
87 | ...
88 |
89 | └─── dev
90 | | dev-0-0-img0.png
91 | | dev-0-0-img1.png
92 | | dev-0-1-img0.png
93 | | ...
94 |
95 | └─── test1
96 | | test1-0-0-img0.png
97 | | test1-0-0-img1.png
98 | | test1-0-1-img0.png
99 | | ...
100 |
101 | └─── cirr
102 | └─── captions
103 | | cap.rc2.test1.json
104 | | cap.rc2.train.json
105 | | cap.rc2.val.json
106 |
107 | └─── image_splits
108 | | split.rc2.test1.json
109 | | split.rc2.train.json
110 | | split.rc2.val.json
111 | ```
112 |
113 |
114 | ### Adjustments for dependencies
115 |
116 | For finetuning blip2 encoderds, you need to comment out this code in lavis within your conda enviroment.
117 | ```python
118 | # In lavis/models/blip2_models/blip2_qformer.py line 367
119 | # @torch.no_grad() # commemt out this line.
120 | ```
121 | Comment out this code to calculate the gradient of the blip2-model to update the parameters.
122 |
123 | For finetuning clip encoders, you need to replace with these codes in the clip packages, thus RN50x4 features can interact with Qformers.
124 | ```python
125 | # Replace CLIP/clip/models.py line 152-154 with the following codes.
126 | 152# x = self.attnpool(x)
127 | 153#
128 | 154# return x
129 |
130 | 152# y=x
131 | 153# x = self.attnpool(x)
132 | 154#
133 | 155# return x,y
134 |
135 | # Replace CLIP/clip/models.py line 343-356 with the following codes. Before get the cls token, get the feature sequence of text as text global features.
136 | 346# x = x + self.positional_embedding.type(self.dtype)
137 | 347# x = x.permute(1, 0, 2) # NLD -> LND
138 | 348# x = self.transformer(x)
139 | 349# x = x.permute(1, 0, 2) # LND -> NLD
140 | 350# x = self.ln_final(x).type(self.dtype)
141 | 351#
142 | 352# y = x
143 | 353# # x.shape = [batch_size, n_ctx, transformer.width]
144 | 354# # take features from the eot embedding (eot_token is the highest number in each sequence)
145 | 355# x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
146 |
147 | 356# return x,y
148 |
149 | ```
150 |
151 | ### Training
152 |
153 |
154 | ```shell
155 | # cala finetune
156 | CUDA_VISIBLE_DEVICES='GPU_IDs' python src/blip_fine_tune.py --dataset {'CIRR' or 'FashionIQ'} \
157 | --num-epochs 30 --batch-size 64 \
158 | --max-epoch 15 --min-lr 0 \
159 | --learning-rate 5e-6 \
160 | --transform targetpad --target-ratio 1.25 \
161 | --save-training --save-best --validation-frequency 1 \
162 | --encoder {'both' or 'text' or 'multi'} \
163 | --encoder-arch {clip or blip2} \
164 | --cir-frame {sum or artemis} \
165 | --tac-weight 0.45 \
166 | --hca-weight 0.1 \
167 | --embeds-dim {640 for clip and 768 for blip2} \
168 | --model-name {RN50x4 for clip and None for blip} \
169 | --api-key {Comet-api-key} \
170 | --workspace {Comet-workspace} \
171 | --experiment-name {Comet-experiment-name} \
172 | ```
173 |
174 |
175 | ### CIRR Testing
176 |
177 |
178 | ```shell
179 | CUDA_VISIBLE_DEVICES='GPU_IDs' python src/cirr_test_submission_blip2.py --submission-name {cirr_submission} \
180 | --combining-function {sum or artemis} \
181 | --blip2-textual-path {saved_blip2_textual.pt} \
182 | --blip2-multimodal-path {saved_blip2_multimodal.pt} \
183 | --blip2-visual-path {saved_blip2_visual.pt}
184 |
185 | ```
186 |
187 | ```shell
188 | python src/validate.py
189 | --dataset {'CIRR' or 'FashionIQ'} \
190 | --combining-function {'combiner' or 'sum'} \
191 | --combiner-path {path to trained Combiner} \
192 | --projection-dim 2560 \
193 | --hidden-dim 5120 \
194 | --clip-model-name RN50x4 \
195 | --clip-model-path {path-to-fine-tuned-CLIP} \
196 | --target-ratio 1.25 \
197 | --transform targetpad
198 | ```
199 |
200 |
201 | ## Reference
202 | If you use CaLa in your research, please cite it by the following BibTeX entry:
203 |
204 | ```bibtex
205 | @article{jiang2024cala,
206 | title={CaLa: Complementary Association Learning for Augmenting Composed Image Retrieval},
207 | author={Jiang, Xintong and Wang, Yaxiong and Li, Mengjian and Wu, Yujiao and Hu, Bingwen and Qian, Xueming},
208 | journal={arXiv preprint arXiv:2405.19149},
209 | year={2024}
210 | }
211 | ```
212 |
213 | ## Acknowledgement
214 | Our implementation is based on [CLIP4Cir](https://github.com/ABaldrati/CLIP4Cir) and [LAVIS](https://github.com/salesforce/LAVIS).
215 |
--------------------------------------------------------------------------------
/src/hinge_based_cross_attention_clip.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import sys
5 | from utils import device
6 | import math
7 |
8 | class HingebasedCrossAttentionCLIP(nn.Module):
9 | def __init__(self, embed_dim) -> None:
10 | super().__init__()
11 | # attention proj
12 | self.query_ref1 = nn.Linear(embed_dim,embed_dim)
13 | self.key_text1 = nn.Linear(embed_dim,embed_dim)
14 | # self.query_text1 = nn.Linear(embed_dim,embed_dim)
15 | self.key_tar1 = nn.Linear(embed_dim,embed_dim)
16 | self.value1 = nn.Linear(embed_dim,embed_dim)
17 | self.dropout1 = nn.Dropout(0.1)
18 |
19 | self.query_ref2 = nn.Linear(embed_dim,embed_dim)
20 | self.key_text2 = nn.Linear(embed_dim,embed_dim)
21 | self.key_tar2 = nn.Linear(embed_dim,embed_dim)
22 | self.value2 = nn.Linear(embed_dim,embed_dim)
23 | self.dropout2 = nn.Dropout(0.1)
24 |
25 | self.fc1 = nn.Linear(2560,640)
26 | self.relu1 = nn.ReLU(inplace=True)
27 |
28 |
29 | def forward(self, reference_embeds, caption_embeds, target_embeds):
30 | psudo_T = self.hca_T_share_text(reference_embeds, caption_embeds, target_embeds)
31 | return psudo_T
32 |
33 | def hca_T_share_text(self, reference_embeds, caption_embeds, target_embeds):
34 |
35 | bs, hi, h, w = reference_embeds.size()
36 | #embeddings to tokens bs x length x hidden bs 81 2560
37 | reference_embeds = reference_embeds.view(bs,h*w,hi)
38 | target_embeds = target_embeds.view(bs,h*w,hi)
39 | #dim compact bs 81 640 linear降维
40 | reference_embeds = self.relu1(self.fc1(reference_embeds))
41 | target_embeds = self.relu1(self.fc1(target_embeds))
42 |
43 | attA = self.multiply(self.query_ref1(reference_embeds), self.key_text1(caption_embeds)) / math.sqrt(640)
44 | attB = self.multiply(self.key_text1(caption_embeds), self.key_tar1(target_embeds)) / math.sqrt(640)
45 | attC = self.dropout1(F.softmax(torch.matmul(attA, attB), dim=-1))
46 | psudo_T = torch.matmul(attC , self.value1(target_embeds))
47 | return psudo_T[:,0,:]
48 |
49 | def hca_T_R_share_text(self, reference_embeds, caption_embeds, target_embeds):
50 | bs , len_r , dim = reference_embeds.shape
51 | attA1 = self.multiply(self.query_ref1(reference_embeds), self.key_text1(caption_embeds)) / math.sqrt(dim)
52 | attB1 = self.multiply(self.key_text1(caption_embeds), self.key_tar1(target_embeds)) / math.sqrt(dim)
53 | attC1 = self.dropout1(F.softmax(torch.matmul(attA1, attB1), dim=-1))
54 | psudo_T = torch.matmul(attC1 , self.value1(target_embeds))
55 |
56 | attA2 = self.multiply(self.query_ref2(target_embeds), self.key_text2(caption_embeds)) / math.sqrt(dim)
57 | attB2 = self.multiply(self.key_text2(caption_embeds), self.key_tar2(reference_embeds)) / math.sqrt(dim)
58 | attC2 = self.dropout2(F.softmax(torch.matmul(attA2, attB2), dim=-1))
59 | psudo_R = torch.matmul(attC2 , self.value2(reference_embeds))
60 |
61 | return psudo_T[:,0,:], psudo_R[:,0,:]
62 |
63 | def hca_T_multihead_4(self, reference_embeds, caption_embeds, target_embeds):
64 | bs , len_r , dim = reference_embeds.shape
65 | attA1 = self.multiply(self.query_ref1(reference_embeds), self.key_text1(caption_embeds)) / math.sqrt(dim)
66 | attB1 = self.multiply(self.key_text1(caption_embeds), self.key_tar1(target_embeds)) / math.sqrt(dim)
67 | attC1 = self.dropout1(F.softmax(torch.matmul(attA1, attB1), dim=-1))
68 | psudo_T1 = torch.matmul(attC1 , self.value1(target_embeds))
69 |
70 | attA2 = self.multiply(self.query_ref2(reference_embeds), self.key_text2(caption_embeds)) / math.sqrt(dim)
71 | attB2 = self.multiply(self.key_text2(caption_embeds), self.key_tar2(target_embeds)) / math.sqrt(dim)
72 | attC2 = self.dropout2(F.softmax(torch.matmul(attA2, attB2), dim=-1))
73 | psudo_T2 = torch.matmul(attC2 , self.value2(target_embeds))
74 |
75 | attA3 = self.multiply(self.query_ref3(reference_embeds), self.key_text3(caption_embeds)) / math.sqrt(dim)
76 | attB3 = self.multiply(self.key_text3(caption_embeds), self.key_tar3(target_embeds)) / math.sqrt(dim)
77 | attC3 = self.dropout3(F.softmax(torch.matmul(attA3, attB3), dim=-1))
78 | psudo_T3 = torch.matmul(attC3 , self.value3(target_embeds))
79 |
80 | attA4 = self.multiply(self.query_ref4(reference_embeds), self.key_text4(caption_embeds)) / math.sqrt(dim)
81 | attB4 = self.multiply(self.key_text4(caption_embeds), self.key_tar4(target_embeds)) / math.sqrt(dim)
82 | attC4 = self.dropout4(F.softmax(torch.matmul(attA4, attB4), dim=-1))
83 | psudo_T4 = torch.matmul(attC4 , self.value4(target_embeds))
84 |
85 | return (psudo_T1[:,0,:] + psudo_T2[:,0,:] + psudo_T3[:,0,:] + psudo_T4[:,0,:]) / 4
86 |
87 | def hca_T_multihead_2(self, reference_embeds, caption_embeds, target_embeds):
88 | bs , len_r , dim = reference_embeds.shape
89 | attA1 = self.multiply(self.query_ref1(reference_embeds), self.key_text1(caption_embeds)) / math.sqrt(dim)
90 | attB1 = self.multiply(self.key_text1(caption_embeds), self.key_tar1(target_embeds)) / math.sqrt(dim)
91 | attC1 = self.dropout1(F.softmax(torch.matmul(attA1, attB1), dim=-1))
92 | psudo_T1 = torch.matmul(attC1 , self.value1(target_embeds))
93 |
94 | attA2 = self.multiply(self.query_ref2(reference_embeds), self.key_text2(caption_embeds)) / math.sqrt(dim)
95 | attB2 = self.multiply(self.key_text2(caption_embeds), self.key_tar2(target_embeds)) / math.sqrt(dim)
96 | attC2 = self.dropout2(F.softmax(torch.matmul(attA2, attB2), dim=-1))
97 | psudo_T2 = torch.matmul(attC2 , self.value2(target_embeds))
98 |
99 | return (psudo_T1[:,0,:] + psudo_T2[:,0,:]) / 2
100 |
101 | # def rct_block_R(self, reference_embeds, caption_embeds, target_embeds):
102 | # bs , len_r , dim = reference_embeds.shape
103 | # attA = self.multiply(self.query1(target_embeds), self.key1(caption_embeds)) / math.sqrt(dim)
104 | # attB = self.multiply(self.query2(caption_embeds), self.key2(reference_embeds)) / math.sqrt(dim)
105 |
106 | # attC = self.dropout(F.softmax(torch.matmul(attA, attB), dim=-1))
107 | # psudo_R = torch.matmul(attC , self.value(reference_embeds))
108 | # return psudo_R
109 |
110 | # def rct_block_cap(self, reference_embeds, caption_embeds, target_embeds):
111 | # bs , len_r , dim = reference_embeds.shape
112 | # attA = self.multiply(self.query1(reference_embeds), self.key1(caption_embeds)) / math.sqrt(dim)
113 | # attB = self.multiply(self.query2(target_embeds), self.key2(caption_embeds)) / math.sqrt(dim)
114 |
115 | # attC = self.dropout(F.softmax(attA * attB, dim=-1))
116 | # psudo_C = torch.matmul(attC , self.value(caption_embeds))
117 | # return psudo_C
118 |
119 |
120 | # def rct_block_no_linear(self, reference_embeds, caption_embeds, target_embeds):
121 | # bs , len_r , dim = reference_embeds.shape
122 | # attA = self.multiply(reference_embeds, caption_embeds) / math.sqrt(dim)
123 | # attB = self.multiply(caption_embeds, target_embeds) / math.sqrt(dim)
124 |
125 | # attC = self.dropout(F.softmax(torch.matmul(attA, attB), dim=-1))
126 | # psudo_T = torch.matmul(attC , (target_embeds))
127 | # return psudo_T
128 |
129 | def multiply(self, embedsA, embedsB):
130 | bs, len_a , dim = embedsA.shape
131 | bs, len_b , dim = embedsB.shape
132 |
133 | # 扁平化
134 | embedsA = embedsA.view(bs, -1, dim) # 形状为 bs x (length_a * dim)
135 | embedsB = embedsB.view(bs, -1, dim) # 形状为 bs x (length_b * dim)
136 |
137 | # 点积计算
138 | attention_scores_flat = torch.matmul(embedsA, embedsB.transpose(-1, -2)) # 转置 Key 的维度
139 |
140 | # 还原形状
141 | attention_scores = attention_scores_flat.view(bs, len_a, len_b)
142 |
143 | return attention_scores
144 |
145 |
146 |
--------------------------------------------------------------------------------
/src/cirr_test_submission_blip2.py:
--------------------------------------------------------------------------------
1 | import json
2 | import multiprocessing
3 | from argparse import ArgumentParser
4 | from operator import itemgetter
5 | from pathlib import Path
6 | from typing import List, Tuple, Dict
7 |
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 | from torch.utils.data import DataLoader
12 | from tqdm import tqdm
13 |
14 | from data_utils import CIRRDataset, targetpad_transform, squarepad_transform, base_path
15 | from utils import element_wise_sum, device, extract_index_features_blip2
16 | from lavis.models import load_model
17 |
18 |
19 | def generate_cirr_test_submissions(combining_function: callable, file_name: str, blip_textual, blip_multimodal,
20 | blip_visual, preprocess: callable):
21 | """
22 | Generate and save CIRR test submission files to be submitted to evaluation server
23 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined
24 | features
25 | :param file_name: file_name of the submission
26 | :param clip_model: CLIP model
27 | :param preprocess: preprocess pipeline
28 | """
29 | blip_textual = blip_textual.float().eval()
30 | blip_multimodal = blip_multimodal.float().eval()
31 | blip_visual = blip_visual.float().eval()
32 |
33 | # Define the dataset and extract index features
34 | classic_test_dataset = CIRRDataset('test1', 'classic', preprocess)
35 | index_features, index_names = extract_index_features_blip2(classic_test_dataset, blip_visual)
36 | relative_test_dataset = CIRRDataset('test1', 'relative', preprocess)
37 |
38 | # Generate test prediction dicts for CIRR
39 | pairid_to_predictions, pairid_to_group_predictions = generate_cirr_test_dicts(relative_test_dataset, blip_textual, blip_multimodal,
40 | index_features, index_names,
41 | combining_function)
42 |
43 | submission = {
44 | 'version': 'rc2',
45 | 'metric': 'recall'
46 | }
47 | group_submission = {
48 | 'version': 'rc2',
49 | 'metric': 'recall_subset'
50 | }
51 |
52 | submission.update(pairid_to_predictions)
53 | group_submission.update(pairid_to_group_predictions)
54 |
55 | # Define submission path
56 | submissions_folder_path = base_path / "submission" / 'CIRR'
57 | submissions_folder_path.mkdir(exist_ok=True, parents=True)
58 |
59 | print(f"Saving CIRR test predictions")
60 | with open(submissions_folder_path / f"recall_submission_{file_name}.json", 'w+') as file:
61 | json.dump(submission, file, sort_keys=True)
62 |
63 | with open(submissions_folder_path / f"recall_subset_submission_{file_name}.json", 'w+') as file:
64 | json.dump(group_submission, file, sort_keys=True)
65 |
66 |
67 | def generate_cirr_test_dicts(relative_test_dataset: CIRRDataset, blip_textual, blip_multimodal, index_features: torch.tensor,
68 | index_names: List[str], combining_function: callable) \
69 | -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]:
70 | """
71 | Compute test prediction dicts for CIRR dataset
72 | :param relative_test_dataset: CIRR test dataset in relative mode
73 | :param clip_model: CLIP model
74 | :param index_features: test index features
75 | :param index_names: test index names
76 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined
77 | features
78 | :return: Top50 global and Top3 subset prediction for each query (reference_name, caption)
79 | """
80 |
81 | # Generate predictions
82 | predicted_features, reference_names, group_members, pairs_id = \
83 | generate_cirr_test_predictions(blip_textual, blip_multimodal, relative_test_dataset, combining_function, index_names,
84 | index_features)
85 |
86 | print(f"Compute CIRR prediction dicts")
87 |
88 | # Normalize the index features
89 | index_features = F.normalize(index_features, dim=-1).float()
90 |
91 | # Compute the distances and sort the results
92 | distances = 1 - predicted_features @ index_features.T
93 | sorted_indices = torch.argsort(distances, dim=-1).cpu()
94 | sorted_index_names = np.array(index_names)[sorted_indices]
95 |
96 | # Delete the reference image from the results
97 | reference_mask = torch.tensor(
98 | sorted_index_names != np.repeat(np.array(reference_names), len(index_names)).reshape(len(sorted_index_names),
99 | -1))
100 | sorted_index_names = sorted_index_names[reference_mask].reshape(sorted_index_names.shape[0],
101 | sorted_index_names.shape[1] - 1)
102 | # Compute the subset predictions
103 | group_members = np.array(group_members)
104 | group_mask = (sorted_index_names[..., None] == group_members[:, None, :]).sum(-1).astype(bool)
105 | sorted_group_names = sorted_index_names[group_mask].reshape(sorted_index_names.shape[0], -1)
106 |
107 | # Generate prediction dicts
108 | pairid_to_predictions = {str(int(pair_id)): prediction[:50].tolist() for (pair_id, prediction) in
109 | zip(pairs_id, sorted_index_names)}
110 | pairid_to_group_predictions = {str(int(pair_id)): prediction[:3].tolist() for (pair_id, prediction) in
111 | zip(pairs_id, sorted_group_names)}
112 |
113 | return pairid_to_predictions, pairid_to_group_predictions
114 |
115 |
116 | def generate_cirr_test_predictions(blip_textual, blip_multimodal, relative_test_dataset: CIRRDataset, combining_function: callable,
117 | index_names: List[str], index_features: torch.tensor) -> \
118 | Tuple[torch.tensor, List[str], List[List[str]], List[str]]:
119 | """
120 | Compute CIRR predictions on the test set
121 | :param clip_model: CLIP model
122 | :param relative_test_dataset: CIRR test dataset in relative mode
123 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined
124 | features
125 | :param index_features: test index features
126 | :param index_names: test index names
127 |
128 | :return: predicted_features, reference_names, group_members and pairs_id
129 | """
130 | print(f"Compute CIRR test predictions")
131 |
132 | relative_test_loader = DataLoader(dataset=relative_test_dataset, batch_size=32,
133 | num_workers=multiprocessing.cpu_count(), pin_memory=True)
134 |
135 | # Get a mapping from index names to index features
136 | # name_to_feat = dict(zip(index_names, index_features))
137 |
138 | # Initialize pairs_id, predicted_features, group_members and reference_names
139 | pairs_id = []
140 | predicted_features = torch.empty((0, 768)).to(device, non_blocking=True)
141 | group_members = []
142 | reference_names = []
143 |
144 | for batch_pairs_id, batch_reference_names, reference_images, captions, batch_group_members in tqdm(
145 | relative_test_loader): # Load data
146 | reference_images = reference_images.to(device)
147 | batch_group_members = np.array(batch_group_members).T.tolist()
148 | # Compute the predicted features
149 | with torch.no_grad():
150 | text_feats = blip_textual.extract_features({"text_input":captions},
151 | mode="text").text_embeds[:,0,:]
152 | reference_feats = blip_multimodal.extract_features({"image":reference_images,
153 | "text_input":captions}).multimodal_embeds[:,0,:]
154 | batch_predicted_features = combining_function(text_feats, reference_feats)
155 |
156 | predicted_features = torch.vstack((predicted_features, F.normalize(batch_predicted_features, dim=-1)))
157 | group_members.extend(batch_group_members)
158 | reference_names.extend(batch_reference_names)
159 | pairs_id.extend(batch_pairs_id)
160 |
161 | return predicted_features, reference_names, group_members, pairs_id
162 |
163 |
164 | def main():
165 | parser = ArgumentParser()
166 | parser.add_argument("--submission-name", type=str, required=True, help="submission file name")
167 | parser.add_argument("--combining-function", type=str, required=True,
168 | help="Which combining function use, should be in ['combiner', 'sum']")
169 | parser.add_argument("--blip2-textual-path", type=str, help="Path to the fine-tuned BLIP2 model")
170 | parser.add_argument("--blip2-visual-path", type=str, help="Path to the fine-tuned BLIP2 model")
171 | parser.add_argument("--blip2-multimodal-path", type=str, help="Path to the fine-tuned BLIP2 model")
172 | parser.add_argument("--target-ratio", default=1.25, type=float, help="TargetPad target ratio")
173 | parser.add_argument("--transform", default="targetpad", type=str,
174 | help="Preprocess pipeline, should be in ['clip', 'squarepad', 'targetpad'] ")
175 | args = parser.parse_args()
176 |
177 | blip_textual_path = args.blip2_textual_path
178 | blip_visual_path = args.blip2_visual_path
179 | blip_multimodal_path = args.blip2_multimodal_path
180 |
181 | blip_textual = load_model(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device)
182 | x_saved_state_dict = torch.load(blip_textual_path, map_location=device)
183 |
184 | blip_textual.load_state_dict(x_saved_state_dict["Blip2Qformer"])
185 |
186 | blip_multimodal = load_model(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device)
187 | r_saved_state_dict = torch.load(blip_multimodal_path, map_location=device)
188 | blip_multimodal.load_state_dict(r_saved_state_dict["Blip2Qformer"])
189 |
190 | blip_visual = load_model(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device)
191 | t_saved_state_dict = torch.load(blip_visual_path, map_location=device)
192 | blip_visual.load_state_dict(t_saved_state_dict["Blip2Qformer"])
193 | input_dim = 224
194 |
195 | if args.transform == 'targetpad':
196 | print('Target pad preprocess pipeline is used')
197 | preprocess = targetpad_transform(args.target_ratio, input_dim)
198 | elif args.transform == 'squarepad':
199 | print('Square pad preprocess pipeline is used')
200 | preprocess = squarepad_transform(input_dim)
201 |
202 | if args.combining_function.lower() == 'sum':
203 | combining_function = element_wise_sum
204 | else:
205 | raise ValueError("combiner_path should be in ['sum', 'combiner']")
206 |
207 | generate_cirr_test_submissions(combining_function, args.submission_name, blip_textual, blip_multimodal, blip_visual, preprocess)
208 |
209 |
210 | if __name__ == '__main__':
211 | main()
212 |
--------------------------------------------------------------------------------
/src/cirr_test_submission.py:
--------------------------------------------------------------------------------
1 | import json
2 | import multiprocessing
3 | from argparse import ArgumentParser
4 | from operator import itemgetter
5 | from pathlib import Path
6 | from typing import List, Tuple, Dict
7 |
8 | import clip
9 | import numpy as np
10 | import torch
11 | import torch.nn.functional as F
12 | from clip.model import CLIP
13 | from torch.utils.data import DataLoader
14 | from tqdm import tqdm
15 |
16 | from combiner_train import extract_index_features
17 | from data_utils import CIRRDataset, targetpad_transform, squarepad_transform, base_path
18 | from combiner import Combiner
19 | from utils import element_wise_sum, device
20 |
21 |
22 | def generate_cirr_test_submissions(combining_function: callable, file_name: str, clip_model: CLIP,
23 | preprocess: callable):
24 | """
25 | Generate and save CIRR test submission files to be submitted to evaluation server
26 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined
27 | features
28 | :param file_name: file_name of the submission
29 | :param clip_model: CLIP model
30 | :param preprocess: preprocess pipeline
31 | """
32 |
33 | clip_model = clip_model.float().eval()
34 |
35 | # Define the dataset and extract index features
36 | classic_test_dataset = CIRRDataset('test1', 'classic', preprocess)
37 | index_features, index_names = extract_index_features(classic_test_dataset, clip_model)
38 | relative_test_dataset = CIRRDataset('test1', 'relative', preprocess)
39 |
40 | # Generate test prediction dicts for CIRR
41 | pairid_to_predictions, pairid_to_group_predictions = generate_cirr_test_dicts(relative_test_dataset, clip_model,
42 | index_features, index_names,
43 | combining_function)
44 |
45 | submission = {
46 | 'version': 'rc2',
47 | 'metric': 'recall'
48 | }
49 | group_submission = {
50 | 'version': 'rc2',
51 | 'metric': 'recall_subset'
52 | }
53 |
54 | submission.update(pairid_to_predictions)
55 | group_submission.update(pairid_to_group_predictions)
56 |
57 | # Define submission path
58 | submissions_folder_path = base_path / "submission" / 'CIRR'
59 | submissions_folder_path.mkdir(exist_ok=True, parents=True)
60 |
61 | print(f"Saving CIRR test predictions")
62 | with open(submissions_folder_path / f"recall_submission_{file_name}.json", 'w+') as file:
63 | json.dump(submission, file, sort_keys=True)
64 |
65 | with open(submissions_folder_path / f"recall_subset_submission_{file_name}.json", 'w+') as file:
66 | json.dump(group_submission, file, sort_keys=True)
67 |
68 |
69 | def generate_cirr_test_dicts(relative_test_dataset: CIRRDataset, clip_model: CLIP, index_features: torch.tensor,
70 | index_names: List[str], combining_function: callable) \
71 | -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]:
72 | """
73 | Compute test prediction dicts for CIRR dataset
74 | :param relative_test_dataset: CIRR test dataset in relative mode
75 | :param clip_model: CLIP model
76 | :param index_features: test index features
77 | :param index_names: test index names
78 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined
79 | features
80 | :return: Top50 global and Top3 subset prediction for each query (reference_name, caption)
81 | """
82 |
83 | # Generate predictions
84 | predicted_features, reference_names, group_members, pairs_id = \
85 | generate_cirr_test_predictions(clip_model, relative_test_dataset, combining_function, index_names,
86 | index_features)
87 |
88 | print(f"Compute CIRR prediction dicts")
89 |
90 | # Normalize the index features
91 | index_features = F.normalize(index_features, dim=-1).float()
92 |
93 | # Compute the distances and sort the results
94 | distances = 1 - predicted_features @ index_features.T
95 | sorted_indices = torch.argsort(distances, dim=-1).cpu()
96 | sorted_index_names = np.array(index_names)[sorted_indices]
97 |
98 | # Delete the reference image from the results
99 | reference_mask = torch.tensor(
100 | sorted_index_names != np.repeat(np.array(reference_names), len(index_names)).reshape(len(sorted_index_names),
101 | -1))
102 | sorted_index_names = sorted_index_names[reference_mask].reshape(sorted_index_names.shape[0],
103 | sorted_index_names.shape[1] - 1)
104 | # Compute the subset predictions
105 | group_members = np.array(group_members)
106 | group_mask = (sorted_index_names[..., None] == group_members[:, None, :]).sum(-1).astype(bool)
107 | sorted_group_names = sorted_index_names[group_mask].reshape(sorted_index_names.shape[0], -1)
108 |
109 | # Generate prediction dicts
110 | pairid_to_predictions = {str(int(pair_id)): prediction[:50].tolist() for (pair_id, prediction) in
111 | zip(pairs_id, sorted_index_names)}
112 | pairid_to_group_predictions = {str(int(pair_id)): prediction[:3].tolist() for (pair_id, prediction) in
113 | zip(pairs_id, sorted_group_names)}
114 |
115 | return pairid_to_predictions, pairid_to_group_predictions
116 |
117 |
118 | def generate_cirr_test_predictions(clip_model: CLIP, relative_test_dataset: CIRRDataset, combining_function: callable,
119 | index_names: List[str], index_features: torch.tensor) -> \
120 | Tuple[torch.tensor, List[str], List[List[str]], List[str]]:
121 | """
122 | Compute CIRR predictions on the test set
123 | :param clip_model: CLIP model
124 | :param relative_test_dataset: CIRR test dataset in relative mode
125 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined
126 | features
127 | :param index_features: test index features
128 | :param index_names: test index names
129 |
130 | :return: predicted_features, reference_names, group_members and pairs_id
131 | """
132 | print(f"Compute CIRR test predictions")
133 |
134 | relative_test_loader = DataLoader(dataset=relative_test_dataset, batch_size=32,
135 | num_workers=multiprocessing.cpu_count(), pin_memory=True)
136 |
137 | # Get a mapping from index names to index features
138 | name_to_feat = dict(zip(index_names, index_features))
139 |
140 | # Initialize pairs_id, predicted_features, group_members and reference_names
141 | pairs_id = []
142 | predicted_features = torch.empty((0, clip_model.visual.output_dim)).to(device, non_blocking=True)
143 | group_members = []
144 | reference_names = []
145 |
146 | for batch_pairs_id, batch_reference_names, captions, batch_group_members in tqdm(
147 | relative_test_loader): # Load data
148 | text_inputs = clip.tokenize(captions, context_length=77).to(device)
149 | batch_group_members = np.array(batch_group_members).T.tolist()
150 |
151 | # Compute the predicted features
152 | with torch.no_grad():
153 | text_features = clip_model.encode_text(text_inputs)
154 | # Check whether a single element is in the batch due to the exception raised by torch.stack when used with
155 | # a single tensor
156 | if text_features.shape[0] == 1:
157 | reference_image_features = itemgetter(*batch_reference_names)(name_to_feat).unqueeze(0)
158 | else:
159 | reference_image_features = torch.stack(itemgetter(*batch_reference_names)(
160 | name_to_feat)) # To avoid unnecessary computation retrieve the reference image features directly from the index features
161 | batch_predicted_features = combining_function(reference_image_features, text_features)
162 |
163 | predicted_features = torch.vstack((predicted_features, F.normalize(batch_predicted_features, dim=-1)))
164 | group_members.extend(batch_group_members)
165 | reference_names.extend(batch_reference_names)
166 | pairs_id.extend(batch_pairs_id)
167 |
168 | return predicted_features, reference_names, group_members, pairs_id
169 |
170 |
171 | def main():
172 | parser = ArgumentParser()
173 | parser.add_argument("--submission-name", type=str, required=True, help="submission file name")
174 | parser.add_argument("--combining-function", type=str, required=True,
175 | help="Which combining function use, should be in ['combiner', 'sum']")
176 | parser.add_argument("--combiner-path", type=str, help="path to trained Combiner")
177 | parser.add_argument("--projection-dim", default=640 * 4, type=int, help='Combiner projection dim')
178 | parser.add_argument("--hidden-dim", default=640 * 8, type=int, help="Combiner hidden dim")
179 | parser.add_argument("--clip-model-name", default="RN50x4", type=str, help="CLIP model to use, e.g 'RN50', 'RN50x4'")
180 | parser.add_argument("--clip-model-path", type=Path, help="Path to the fine-tuned CLIP model")
181 | parser.add_argument("--target-ratio", default=1.25, type=float, help="TargetPad target ratio")
182 | parser.add_argument("--transform", default="targetpad", type=str,
183 | help="Preprocess pipeline, should be in ['clip', 'squarepad', 'targetpad'] ")
184 | args = parser.parse_args()
185 | clip_model, clip_preprocess = clip.load(args.clip_model_name, device=device, jit=False)
186 | input_dim = clip_model.visual.input_resolution
187 | feature_dim = clip_model.visual.output_dim
188 |
189 | if args.clip_model_path:
190 | print('Trying to load the CLIP model')
191 | saved_state_dict = torch.load(args.clip_model_path, map_location=device)
192 | clip_model.load_state_dict(saved_state_dict["CLIP"])
193 | print('CLIP model loaded successfully')
194 |
195 | if args.transform == 'targetpad':
196 | print('Target pad preprocess pipeline is used')
197 | preprocess = targetpad_transform(args.target_ratio, input_dim)
198 | elif args.transform == 'squarepad':
199 | print('Square pad preprocess pipeline is used')
200 | preprocess = squarepad_transform(input_dim)
201 | else:
202 | print('CLIP default preprocess pipeline is used')
203 | preprocess = clip_preprocess
204 |
205 | if args.combining_function.lower() == 'sum':
206 | if args.combiner_path:
207 | print("Be careful, you are using the element-wise sum as combining_function but you have also passed a path"
208 | " to a trained Combiner. Such Combiner will not be used")
209 | combining_function = element_wise_sum
210 | elif args.combining_function.lower() == 'combiner':
211 | combiner = Combiner(feature_dim, args.projection_dim, args.hidden_dim).to(device)
212 | saved_state_dict = torch.load(args.combiner_path, map_location=device)
213 | combiner.load_state_dict(saved_state_dict["Combiner"])
214 | combiner.eval()
215 | combining_function = combiner.combine_features
216 | else:
217 | raise ValueError("combiner_path should be in ['sum', 'combiner']")
218 |
219 | generate_cirr_test_submissions(combining_function, args.submission_name, clip_model, preprocess)
220 |
221 |
222 | if __name__ == '__main__':
223 | main()
224 |
--------------------------------------------------------------------------------
/src/cirr_artemis_submission_blip2.py:
--------------------------------------------------------------------------------
1 | import json
2 | import multiprocessing
3 | from argparse import ArgumentParser
4 | from operator import itemgetter
5 | from pathlib import Path
6 | from typing import List, Tuple, Dict
7 |
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 | from torch.utils.data import DataLoader
12 | from tqdm import tqdm
13 |
14 | from data_utils import CIRRDataset_val_submission, targetpad_transform, squarepad_transform, base_path
15 | from utils import element_wise_sum, device, extract_index_features_blip2
16 | from lavis.models import load_model
17 | from artemis import Artemis
18 |
19 |
20 | def generate_cirr_test_submissions(artemis, file_name: str, blip_textual, blip_multimodal,
21 | blip_visual, preprocess: callable):
22 | """
23 | Generate and save CIRR test submission files to be submitted to evaluation server
24 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined
25 | features
26 | :param file_name: file_name of the submission
27 | :param clip_model: CLIP model
28 | :param preprocess: preprocess pipeline
29 | """
30 | blip_textual = blip_textual.float().eval()
31 | blip_multimodal = blip_multimodal.float().eval()
32 | blip_visual = blip_visual.float().eval()
33 | artemis = artemis.float().eval()
34 |
35 | # Define the dataset and extract index features
36 | classic_test_dataset = CIRRDataset_val_submission('test1', 'classic', preprocess)
37 | index_features, index_names = extract_index_features_blip2(classic_test_dataset, blip_visual)
38 | relative_test_dataset = CIRRDataset_val_submission('test1', 'relative', preprocess)
39 |
40 | # Generate test prediction dicts for CIRR
41 | pairid_to_predictions, pairid_to_group_predictions = generate_cirr_test_dicts(relative_test_dataset, blip_textual, blip_multimodal,
42 | index_features, index_names,
43 | artemis)
44 |
45 | submission = {
46 | 'version': 'rc2',
47 | 'metric': 'recall'
48 | }
49 | group_submission = {
50 | 'version': 'rc2',
51 | 'metric': 'recall_subset'
52 | }
53 |
54 | submission.update(pairid_to_predictions)
55 | group_submission.update(pairid_to_group_predictions)
56 |
57 | # Define submission path
58 | submissions_folder_path = base_path / "submission" / 'CIRR'
59 | submissions_folder_path.mkdir(exist_ok=True, parents=True)
60 |
61 | print(f"Saving CIRR test predictions")
62 | with open(submissions_folder_path / f"retrieval_{file_name}.json", 'w+') as file:
63 | json.dump(submission, file, sort_keys=True)
64 |
65 | with open(submissions_folder_path / f"retrieval_subset_submission_{file_name}.json", 'w+') as file:
66 | json.dump(group_submission, file, sort_keys=True)
67 |
68 |
69 | def generate_cirr_test_dicts(relative_test_dataset: CIRRDataset_val_submission, blip_textual, blip_multimodal, index_features: torch.tensor,
70 | index_names: List[str], artemis) \
71 | -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]:
72 | """
73 | Compute test prediction dicts for CIRR dataset
74 | :param relative_test_dataset: CIRR test dataset in relative mode
75 | :param clip_model: CLIP model
76 | :param index_features: test index features
77 | :param index_names: test index names
78 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined
79 | features
80 | :return: Top50 global and Top3 subset prediction for each query (reference_name, caption)
81 | """
82 |
83 | # Generate predictions
84 | artemis_scores, reference_names, group_members, pairs_id = \
85 | generate_cirr_test_predictions(blip_textual, blip_multimodal, relative_test_dataset, artemis, index_names,
86 | index_features)
87 |
88 | print(f"Compute CIRR prediction dicts")
89 |
90 | # Normalize the index features
91 | index_features = F.normalize(index_features, dim=-1).float()
92 |
93 | # Compute the distances and sort the results
94 | distances = 1 - artemis_scores
95 | sorted_indices = torch.argsort(distances, dim=-1).cpu()
96 | sorted_index_names = np.array(index_names)[sorted_indices]
97 |
98 | # Delete the reference image from the results
99 | reference_mask = torch.tensor(
100 | sorted_index_names != np.repeat(np.array(reference_names), len(index_names)).reshape(len(sorted_index_names),
101 | -1))
102 | sorted_index_names = sorted_index_names[reference_mask].reshape(sorted_index_names.shape[0],
103 | sorted_index_names.shape[1] - 1)
104 | # Compute the subset predictions
105 | group_members = np.array(group_members)
106 | group_mask = (sorted_index_names[..., None] == group_members[:, None, :]).sum(-1).astype(bool)
107 | sorted_group_names = sorted_index_names[group_mask].reshape(sorted_index_names.shape[0], -1)
108 |
109 | # Generate prediction dicts for test split
110 | pairid_to_predictions = {str(int(pair_id)): prediction[:50].tolist() for (pair_id, prediction) in
111 | zip(pairs_id, sorted_index_names)}
112 | pairid_to_group_predictions = {str(int(pair_id)): prediction[:3].tolist() for (pair_id, prediction) in
113 | zip(pairs_id, sorted_group_names)}
114 |
115 | return pairid_to_predictions, pairid_to_group_predictions
116 |
117 |
118 | def generate_cirr_test_predictions(blip_textual, blip_multimodal, relative_test_dataset: CIRRDataset_val_submission, artemis,
119 | index_names: List[str], index_features: torch.tensor) -> \
120 | Tuple[torch.tensor, List[str], List[List[str]], List[str]]:
121 | """
122 | Compute CIRR predictions on the test set
123 | :param clip_model: CLIP model
124 | :param relative_test_dataset: CIRR test dataset in relative mode
125 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined
126 | features
127 | :param index_features: test index features
128 | :param index_names: test index names
129 |
130 | :return: predicted_features, reference_names, group_members and pairs_id
131 | """
132 | print(f"Compute CIRR test predictions")
133 |
134 | relative_test_loader = DataLoader(dataset=relative_test_dataset, batch_size=32,
135 | num_workers=multiprocessing.cpu_count(), pin_memory=True)
136 |
137 | # Get a mapping from index names to index features
138 | # name_to_feat = dict(zip(index_names, index_features))
139 |
140 | # Initialize pairs_id, predicted_features, group_members and reference_names
141 | pairs_id = []
142 | artemis_scores = torch.empty((0, len(index_names))).to(device, non_blocking=True)
143 | group_members = []
144 | reference_names = []
145 |
146 | for batch_pairs_id, batch_reference_names, reference_images, captions, batch_group_members in tqdm(
147 | relative_test_loader): # Load data
148 | reference_images = reference_images.to(device)
149 | batch_group_members = np.array(batch_group_members).T.tolist()
150 | # Compute the predicted features
151 | with torch.no_grad():
152 | text_feats = blip_textual.extract_features({"text_input":captions},
153 | mode="text").text_embeds[:,0,:]
154 | reference_feats = blip_multimodal.extract_features({"image":reference_images,
155 | "text_input":captions}).multimodal_embeds[:,0,:]
156 | batch_artemis_score = torch.empty((0, len(index_names))).to(device, non_blocking=True)
157 | for i in range(reference_feats.shape[0]):
158 | one_artemis_score = artemis.compute_score_artemis(reference_feats[i].unsqueeze(0), text_feats[i].unsqueeze(0), index_features)
159 | batch_artemis_score = torch.vstack((batch_artemis_score, one_artemis_score))
160 |
161 | artemis_scores = torch.vstack((artemis_scores, F.normalize(batch_artemis_score, dim=-1)))
162 | group_members.extend(batch_group_members)
163 | reference_names.extend(batch_reference_names)
164 | pairs_id.extend(batch_pairs_id)
165 | # target_hard_names.extend(target_hard_name)
166 | # captions_list.extend(captions)
167 |
168 | return artemis_scores, reference_names, group_members, pairs_id
169 | #return artemis_scores, reference_names, group_members, pairs_id, target_hard_names, captions_list
170 |
171 |
172 | def main():
173 | parser = ArgumentParser()
174 | parser.add_argument("--submission-name", type=str, required=True, help="submission file name")
175 | parser.add_argument("--combining-function", type=str, required=True,
176 | help="Which combining function use, should be in ['combiner', 'sum']")
177 | parser.add_argument("--blip2-textual-path", type=str, help="Path to the fine-tuned BLIP2 model")
178 | parser.add_argument("--blip2-visual-path", type=str, help="Path to the fine-tuned BLIP2 model")
179 | parser.add_argument("--blip2-multimodal-path", type=str, help="Path to the fine-tuned BLIP2 model")
180 | parser.add_argument("--artemis-path", type=str, help="Path to the fine-tuned BLIP2 model")
181 | parser.add_argument("--target-ratio", default=1.25, type=float, help="TargetPad target ratio")
182 | parser.add_argument("--transform", default="targetpad", type=str,
183 | help="Preprocess pipeline, should be in ['clip', 'squarepad', 'targetpad'] ")
184 | args = parser.parse_args()
185 |
186 | blip_textual_path = args.blip2_textual_path
187 | blip_visual_path = args.blip2_visual_path
188 | blip_multimodal_path = args.blip2_multimodal_path
189 | artemis_path = args.artemis_path
190 |
191 | blip_textual = load_model(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device)
192 | x_saved_state_dict = torch.load(blip_textual_path, map_location=device)
193 |
194 | blip_textual.load_state_dict(x_saved_state_dict["Blip2Qformer"])
195 |
196 | blip_multimodal = load_model(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device)
197 | r_saved_state_dict = torch.load(blip_multimodal_path, map_location=device)
198 | blip_multimodal.load_state_dict(r_saved_state_dict["Blip2Qformer"])
199 |
200 | blip_visual = load_model(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device)
201 | t_saved_state_dict = torch.load(blip_visual_path, map_location=device)
202 | blip_visual.load_state_dict(t_saved_state_dict["Blip2Qformer"])
203 |
204 | artemis = Artemis(768).to(device)
205 | a_saved_state_dict = torch.load(artemis_path, map_location=device)
206 | artemis.load_state_dict(a_saved_state_dict["Artemis"])
207 |
208 |
209 | input_dim = 224
210 | if args.transform == 'targetpad':
211 | print('Target pad preprocess pipeline is used')
212 | preprocess = targetpad_transform(args.target_ratio, input_dim)
213 | elif args.transform == 'squarepad':
214 | print('Square pad preprocess pipeline is used')
215 | preprocess = squarepad_transform(input_dim)
216 |
217 | if args.combining_function.lower() == 'sum':
218 | combining_function = element_wise_sum
219 | else:
220 | raise ValueError("combiner_path should be in ['sum', 'combiner']")
221 |
222 | generate_cirr_test_submissions(artemis, args.submission_name, blip_textual, blip_multimodal, blip_visual, preprocess)
223 |
224 |
225 | if __name__ == '__main__':
226 | main()
227 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | import random
3 | from pathlib import Path
4 | from typing import Union, Tuple, List
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from clip.model import CLIP
9 | from torch import nn
10 | from torch.utils.data import DataLoader
11 | from tqdm import tqdm
12 | from data_utils import CIRRDataset, FashionIQDataset
13 |
14 | if torch.cuda.is_available():
15 | device = torch.device("cuda")
16 | else:
17 | device = torch.device("cpu")
18 |
19 |
20 | def extract_features_fusion_blip(blip_textual, multimodal_embeds, captions):
21 | self = blip_textual.Qformer.bert
22 | embeddings = self.embeddings
23 | encoder = self.encoder
24 | # pooler = blip_textual.Qformer.bert.pooler
25 |
26 | text_tokens = blip_textual.tokenizer(
27 | captions,
28 | padding="max_length",
29 | truncation=True,
30 | max_length=77,
31 | return_tensors="pt",
32 | ).to(device)
33 |
34 | text_atts = text_tokens.attention_mask
35 | query_atts = torch.ones(multimodal_embeds.size()[:-1], dtype=torch.long).to(device)
36 | # print("query_atts:", query_atts.shape)
37 | # print("text_atts:", text_atts.shape)
38 | attention_mask = torch.cat([query_atts, text_atts], dim=1)
39 | # print("attention_mask:",attention_mask.shape)
40 | # head_mask = blip_textual.Qformer.bert.get_head_mask(head_mask,
41 | # blip_textual.Qformer.bert.config.num_hidden_layers)
42 |
43 | embedding_output = embeddings(
44 | input_ids=text_tokens.input_ids,
45 | query_embeds=multimodal_embeds,)
46 |
47 | input_shape = embedding_output.size()[:-1]
48 | extended_attention_mask = self.get_extended_attention_mask(
49 | attention_mask, input_shape, device, False
50 | )
51 | # print(extended_attention_mask.shape)
52 | head_mask = self.get_head_mask(None, self.config.num_hidden_layers)
53 |
54 | encoder_outputs = encoder(
55 | embedding_output,
56 | attention_mask=extended_attention_mask,
57 | head_mask=head_mask,
58 | return_dict=True,
59 | )
60 | sequence_output = encoder_outputs[0]
61 | return sequence_output
62 |
63 |
64 | # def extract_index_features_blip(dataset: Union[CIRRDataset, FashionIQDataset], model: CLIP, vision_layer) -> \
65 | def extract_index_features_blip2(dataset: Union[CIRRDataset, FashionIQDataset], blip_model) -> \
66 | Tuple[torch.tensor, List[str]]:
67 | """
68 | Extract FashionIQ or CIRR index features
69 | :param dataset: FashionIQ or CIRR dataset in 'classic' mode
70 | :param model: CLIP model
71 | :return: a tensor of features and a list of images
72 | """
73 | feature_dim = 768
74 | classic_val_loader = DataLoader(dataset=dataset, batch_size=32, num_workers=multiprocessing.cpu_count(),
75 | pin_memory=True, collate_fn=collate_fn)
76 | index_features = torch.empty((0,feature_dim)).to(device, non_blocking=True)
77 | index_names = []
78 | if isinstance(dataset, CIRRDataset):
79 | print(f"extracting CIRR {dataset.split} index features")
80 | elif isinstance(dataset, FashionIQDataset):
81 | print(f"extracting fashionIQ {dataset.dress_types} - {dataset.split} index features")
82 |
83 | for names, images in tqdm(classic_val_loader):
84 | images = images.to(device, non_blocking=True)
85 | with torch.no_grad():
86 | batch_features = blip_model.extract_features({"image":images}, mode="image").image_embeds[:,0,:]
87 | index_features = torch.vstack((index_features, batch_features))
88 | index_names.extend(names)
89 | return index_features, index_names
90 |
91 | def extract_index_features_blip1(dataset: Union[CIRRDataset, FashionIQDataset], blip_model) -> \
92 | Tuple[torch.tensor, List[str]]:
93 | """
94 | Extract FashionIQ or CIRR index features
95 | :param dataset: FashionIQ or CIRR dataset in 'classic' mode
96 | :param model: CLIP model
97 | :return: a tensor of features and a list of images
98 | """
99 | feature_dim = 256
100 | classic_val_loader = DataLoader(dataset=dataset, batch_size=32, num_workers=multiprocessing.cpu_count(),
101 | pin_memory=True, collate_fn=collate_fn)
102 | index_features = torch.empty((0,feature_dim)).to(device, non_blocking=True)
103 | index_names = []
104 | if isinstance(dataset, CIRRDataset):
105 | print(f"extracting CIRR {dataset.split} index features")
106 | elif isinstance(dataset, FashionIQDataset):
107 | print(f"extracting fashionIQ {dataset.dress_types} - {dataset.split} index features")
108 |
109 | for names, images in tqdm(classic_val_loader):
110 | images = images.to(device, non_blocking=True)
111 | with torch.no_grad():
112 | batch_features = blip_model(images)[:,0,:]
113 | index_features = torch.vstack((index_features, batch_features))
114 | index_names.extend(names)
115 | return index_features, index_names
116 |
117 | def extract_index_features_blip_feature_extractor(dataset: Union[CIRRDataset, FashionIQDataset], model) -> \
118 | Tuple[torch.tensor, List[str]]:
119 | """
120 | Extract FashionIQ or CIRR index features
121 | :param dataset: FashionIQ or CIRR dataset in 'classic' mode
122 | :param model: CLIP model
123 | :return: a tensor of features and a list of images
124 | """
125 | # feature_dim = model.visual.output_dim
126 | feature_dim = 256
127 | classic_val_loader = DataLoader(dataset=dataset, batch_size=32, num_workers=multiprocessing.cpu_count(),
128 | pin_memory=True, collate_fn=collate_fn)
129 | index_features = torch.empty((0,feature_dim)).to(device, non_blocking=True)
130 | index_names = []
131 | if isinstance(dataset, CIRRDataset):
132 | print(f"extracting CIRR {dataset.split} index features")
133 | elif isinstance(dataset, FashionIQDataset):
134 | print(f"extracting fashionIQ {dataset.dress_types} - {dataset.split} index features")
135 | for names, images in tqdm(classic_val_loader):
136 | images = images.to(device, non_blocking=True)
137 | with torch.no_grad():
138 | # visual_encoder = model.visual_encoder.float()
139 | # batch_vision_embeds = visual_encoder(images)
140 | # batch_vision_embeds = model.ln_vision(batch_vision_embeds).float()
141 | # batch_features = model.extract_features({"image":images}, mode="image").image_embeds
142 |
143 | batch_features = model.extract_features({"image":images}, mode="image").image_embeds_proj[:,0,:]
144 | index_features = torch.vstack((index_features, batch_features))
145 | index_names.extend(names)
146 | return index_features, index_names
147 |
148 | def extract_index_features_clip(dataset: Union[CIRRDataset, FashionIQDataset], clip_model: CLIP) -> \
149 | Tuple[torch.tensor, List[str]]:
150 | """
151 | Extract FashionIQ or CIRR index features
152 | :param dataset: FashionIQ or CIRR dataset in 'classic' mode
153 | :param clip_model: CLIP model
154 | :return: a tensor of features and a list of images
155 | """
156 | feature_dim = clip_model.visual.output_dim
157 | classic_val_loader = DataLoader(dataset=dataset, batch_size=32, num_workers=multiprocessing.cpu_count(),
158 | pin_memory=True, collate_fn=collate_fn)
159 | index_features = torch.empty((0, feature_dim)).to(device, non_blocking=True)
160 | index_names = []
161 | if isinstance(dataset, CIRRDataset):
162 | print(f"extracting CIRR {dataset.split} index features")
163 | elif isinstance(dataset, FashionIQDataset):
164 | print(f"extracting fashionIQ {dataset.dress_types} - {dataset.split} index features")
165 | for names, images in tqdm(classic_val_loader):
166 | images = images.to(device, non_blocking=True)
167 | with torch.no_grad():
168 | batch_features = clip_model.encode_image(images)[0]
169 | # print(clip_model.encode_image(images)[1].shape)
170 | index_features = torch.vstack((index_features, batch_features))
171 | index_names.extend(names)
172 | return index_features, index_names
173 |
174 | def element_wise_sum(image_features: torch.tensor, text_features: torch.tensor) -> torch.tensor:
175 | """
176 | Normalized element-wise sum of image features and text features
177 | :param image_features: non-normalized image features
178 | :param text_features: non-normalized text features
179 | :return: normalized element-wise sum of image and text features
180 | """
181 | return F.normalize(image_features + text_features, dim=-1)
182 |
183 |
184 | def generate_randomized_fiq_caption(flattened_captions: List[str]) -> List[str]:
185 | """
186 | Function which randomize the FashionIQ training captions in four way: (a) cap1 and cap2 (b) cap2 and cap1 (c) cap1
187 | (d) cap2
188 | :param flattened_captions: the list of caption to randomize, note that the length of such list is 2*batch_size since
189 | to each triplet are associated two captions
190 | :return: the randomized caption list (with length = batch_size)
191 | """
192 | captions = []
193 | for i in range(0, len(flattened_captions), 2):
194 | random_num = random.random()
195 | if random_num < 0.25:
196 | captions.append(
197 | f"{flattened_captions[i].strip('.?, ').capitalize()} and {flattened_captions[i + 1].strip('.?, ')}")
198 | elif 0.25 < random_num < 0.5:
199 | captions.append(
200 | f"{flattened_captions[i + 1].strip('.?, ').capitalize()} and {flattened_captions[i].strip('.?, ')}")
201 | elif 0.5 < random_num < 0.75:
202 | captions.append(f"{flattened_captions[i].strip('.?, ').capitalize()}")
203 | else:
204 | captions.append(f"{flattened_captions[i + 1].strip('.?, ').capitalize()}")
205 | return captions
206 | def generate_randomized_fiq_caption_blip(flattened_captions: List[str],txt_processors:callable) -> List[str]:
207 | """
208 | Function which randomize the FashionIQ training captions in four way: (a) cap1 and cap2 (b) cap2 and cap1 (c) cap1
209 | (d) cap2
210 | :param flattened_captions: the list of caption to randomize, note that the length of such list is 2*batch_size since
211 | to each triplet are associated two captions
212 | :return: the randomized caption list (with length = batch_size)
213 | """
214 | captions = []
215 | for i in range(0, len(flattened_captions), 2):
216 | random_num = random.random()
217 | caption =''
218 | if random_num < 0.25:
219 | caption=f"{flattened_captions[i].strip('.?, ').capitalize()} and {flattened_captions[i + 1].strip('.?, ')}"
220 | elif 0.25 < random_num < 0.5:
221 | caption=f"{flattened_captions[i + 1].strip('.?, ').capitalize()} and {flattened_captions[i].strip('.?, ')}"
222 | elif 0.5 < random_num < 0.75:
223 | caption=f"{flattened_captions[i].strip('.?, ').capitalize()}"
224 | else:
225 | caption=f"{flattened_captions[i + 1].strip('.?, ').capitalize()}"
226 | captions.append(txt_processors(caption))
227 | return captions
228 |
229 | def collate_fn(batch: list):
230 | """
231 | Discard None images in a batch when using torch DataLoader
232 | :param batch: input_batch
233 | :return: output_batch = input_batch - None_values
234 | """
235 | batch = list(filter(lambda x: x is not None, batch))
236 | return torch.utils.data.dataloader.default_collate(batch)
237 |
238 |
239 | def update_train_running_results(train_running_results: dict, loss: torch.tensor, images_in_batch: int):
240 | """
241 | Update `train_running_results` dict during training
242 | :param train_running_results: logging training dict
243 | :param loss: computed loss for batch
244 | :param images_in_batch: num images in the batch
245 | """
246 | train_running_results['accumulated_train_loss'] += loss.to('cpu',
247 | non_blocking=True).detach().item() * images_in_batch
248 | train_running_results["images_in_epoch"] += images_in_batch
249 |
250 |
251 | def set_train_bar_description(train_bar, epoch: int, num_epochs: int, train_running_results: dict):
252 | """
253 | Update tqdm train bar during training
254 | :param train_bar: tqdm training bar
255 | :param epoch: current epoch
256 | :param num_epochs: numbers of epochs
257 | :param train_running_results: logging training dict
258 | """
259 | train_bar.set_description(
260 | desc=f"[{epoch}/{num_epochs}] "
261 | f"train loss: {train_running_results['accumulated_train_loss'] / train_running_results['images_in_epoch']:.3f} "
262 | )
263 |
264 |
265 | def save_model(name: str, cur_epoch: int, model_to_save: nn.Module, training_path: Path):
266 | """
267 | Save the weights of the model during training
268 | :param name: name of the file
269 | :param cur_epoch: current epoch
270 | :param model_to_save: pytorch model to be saved
271 | :param training_path: path associated with the training run
272 | """
273 | models_path = training_path / "saved_models"
274 | models_path.mkdir(exist_ok=True, parents=True)
275 | model_name = model_to_save.__class__.__name__
276 | torch.save({
277 | 'epoch': cur_epoch,
278 | model_name: model_to_save.state_dict(),
279 | }, str(models_path / f'{name}.pt'))
280 |
281 | import math
282 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr, onlyGroup0=False):
283 | """Decay the learning rate"""
284 | lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr
285 | param_group_count = 0
286 | for param_group in optimizer.param_groups:
287 | param_group_count += 1
288 | if param_group_count <= 1 and onlyGroup0: # only vary group0 parameters' learning rate, i.e., exclude the text_proj layer
289 | param_group['lr'] = lr
290 |
291 | @torch.no_grad()
292 | def _momentum_update(model_pairs, momentum):
293 | for model_pair in model_pairs:
294 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
295 | param_m.data = param_m.data * momentum + param.data * (1. - momentum)
296 |
297 | @torch.no_grad()
298 | def _dequeue_and_enqueue(model, target_feats, idx, queue_size):
299 | # gather keys before updating queue
300 | batch_size = target_feats.shape[0]
301 |
302 | ptr = int(model.queue_ptr)
303 | assert queue_size % batch_size == 0 # for simplicity
304 |
305 | # replace the keys at ptr (dequeue and enqueue)
306 | model.target_queue[:, ptr:ptr + batch_size] = target_feats.T
307 | model.idx_queue[:, ptr:ptr + batch_size] = idx.T
308 | ptr = (ptr + batch_size) % queue_size # move pointer
309 |
310 | model.queue_ptr[0] = ptr
311 |
312 |
313 | def l2norm(x):
314 | """L2-normalize each row of x"""
315 | norm = torch.pow(x, 2).sum(dim=-1, keepdim=True).sqrt()
316 | return torch.div(x, norm)
317 |
--------------------------------------------------------------------------------
/src/data_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 | from typing import List
4 |
5 | import PIL
6 | import PIL.Image
7 | import torchvision.transforms.functional as F
8 | from torch.utils.data import Dataset
9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
10 |
11 | PIL.Image.MAX_IMAGE_PIXELS=None
12 | base_path = Path(__file__).absolute().parents[2].absolute()
13 |
14 |
15 | def _convert_image_to_rgb(image):
16 | return image.convert("RGB")
17 |
18 |
19 | class SquarePad:
20 | """
21 | Square pad the input image with zero padding
22 | """
23 |
24 | def __init__(self, size: int):
25 | """
26 | For having a consistent preprocess pipeline with CLIP we need to have the preprocessing output dimension as
27 | a parameter
28 | :param size: preprocessing output dimension
29 | """
30 | self.size = size
31 |
32 | def __call__(self, image):
33 | w, h = image.size
34 | max_wh = max(w, h)
35 | hp = int((max_wh - w) / 2)
36 | vp = int((max_wh - h) / 2)
37 | padding = [hp, vp, hp, vp]
38 | return F.pad(image, padding, 0, 'constant')
39 |
40 |
41 | class TargetPad:
42 | """
43 | Pad the image if its aspect ratio is above a target ratio.
44 | Pad the image to match such target ratio
45 | """
46 |
47 | def __init__(self, target_ratio: float, size: int):
48 | """
49 | :param target_ratio: target ratio
50 | :param size: preprocessing output dimension
51 | """
52 | self.size = size
53 | self.target_ratio = target_ratio
54 |
55 | def __call__(self, image):
56 | w, h = image.size
57 | actual_ratio = max(w, h) / min(w, h)
58 | if actual_ratio < self.target_ratio: # check if the ratio is above or below the target ratio
59 | return image
60 | scaled_max_wh = max(w, h) / self.target_ratio # rescale the pad to match the target ratio
61 | hp = max(int((scaled_max_wh - w) / 2), 0)
62 | vp = max(int((scaled_max_wh - h) / 2), 0)
63 | padding = [hp, vp, hp, vp]
64 | return F.pad(image, padding, 0, 'constant')
65 |
66 |
67 | def squarepad_transform(dim: int):
68 | """
69 | CLIP-like preprocessing transform on a square padded image
70 | :param dim: image output dimension
71 | :return: CLIP-like torchvision Compose transform
72 | """
73 | return Compose([
74 | SquarePad(dim),
75 | Resize(dim, interpolation=PIL.Image.BICUBIC),
76 | CenterCrop(dim),
77 | _convert_image_to_rgb,
78 | ToTensor(),
79 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
80 | ])
81 |
82 |
83 | def targetpad_transform(target_ratio: float, dim: int):
84 | """
85 | CLIP-like preprocessing transform computed after using TargetPad pad
86 | :param target_ratio: target ratio for TargetPad
87 | :param dim: image output dimension
88 | :return: CLIP-like torchvision Compose transform
89 | """
90 | return Compose([
91 | TargetPad(target_ratio, dim),
92 | Resize(dim, interpolation=PIL.Image.BICUBIC),
93 | CenterCrop(dim),
94 | _convert_image_to_rgb,
95 | ToTensor(),
96 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
97 | ])
98 |
99 |
100 | class FashionIQDataset(Dataset):
101 | """
102 | FashionIQ dataset class which manage FashionIQ data.
103 | The dataset can be used in 'relative' or 'classic' mode:
104 | - In 'classic' mode the dataset yield tuples made of (image_name, image)
105 | - In 'relative' mode the dataset yield tuples made of:
106 | - (reference_image, target_image, image_captions) when split == train
107 | - (reference_name, target_name, image_captions) when split == val
108 | - (reference_name, reference_image, image_captions) when split == test
109 | The dataset manage an arbitrary numbers of FashionIQ category, e.g. only dress, dress+toptee+shirt, dress+shirt...
110 | """
111 |
112 | def __init__(self, split: str, dress_types: List[str], mode: str, preprocess: callable):
113 | """
114 | :param split: dataset split, should be in ['test', 'train', 'val']
115 | :param dress_types: list of fashionIQ category
116 | :param mode: dataset mode, should be in ['relative', 'classic']:
117 | - In 'classic' mode the dataset yield tuples made of (image_name, image)
118 | - In 'relative' mode the dataset yield tuples made of:
119 | - (reference_image, target_image, image_captions) when split == train
120 | - (reference_name, target_name, image_captions) when split == val
121 | - (reference_name, reference_image, image_captions) when split == test
122 | :param preprocess: function which preprocesses the image
123 | """
124 | self.mode = mode
125 | self.dress_types = dress_types
126 | self.split = split
127 |
128 | if mode not in ['relative', 'classic']:
129 | raise ValueError("mode should be in ['relative', 'classic']")
130 | if split not in ['test', 'train', 'val']:
131 | raise ValueError("split should be in ['test', 'train', 'val']")
132 | for dress_type in dress_types:
133 | if dress_type not in ['dress', 'shirt', 'toptee']:
134 | raise ValueError("dress_type should be in ['dress', 'shirt', 'toptee']")
135 |
136 | self.preprocess = preprocess
137 |
138 |
139 | # get triplets made by (reference_image, target_image, a pair of relative captions)
140 | self.triplets: List[dict] = []
141 | for dress_type in dress_types:
142 | with open(base_path / 'fashionIQ_dataset' / 'captions' / f'cap.{dress_type}.{split}.json') as f:
143 | self.triplets.extend(json.load(f))
144 |
145 | # get the image names
146 | self.image_names: list = []
147 | for dress_type in dress_types:
148 | with open(base_path / 'fashionIQ_dataset' / 'image_splits' / f'split.{dress_type}.{split}.json') as f:
149 | self.image_names.extend(json.load(f))
150 |
151 | print(f"FashionIQ {split} - {dress_types} dataset in {mode} mode initialized")
152 | for image_name in self.image_names[:]:
153 | file_path = base_path / 'fashionIQ_dataset' / 'image_data' / f"{image_name}.jpg"
154 |
155 | # 使用 Path.exists() 方法检查文件是否存在
156 | if not file_path.exists():
157 | # 如果文件不存在,从列表中删除该项
158 | self.image_names.remove(image_name)
159 |
160 | def __getitem__(self, index):
161 | try:
162 | if self.mode == 'relative':
163 | image_captions = self.triplets[index]['captions']
164 | reference_name = self.triplets[index]['candidate'].replace(".jpg", "")
165 |
166 | if self.split == 'train':
167 | reference_image_path = base_path / 'fashionIQ_dataset' / 'image_data' / f"{reference_name}.jpg"
168 | reference_image = self.preprocess(PIL.Image.open(reference_image_path).convert("RGB"))
169 | target_name = self.triplets[index]['target'].replace(".jpg", "")
170 | target_image_path = base_path / 'fashionIQ_dataset' / 'image_data' / f"{target_name}.jpg"
171 | target_image = self.preprocess(PIL.Image.open(target_image_path).convert("RGB"))
172 | return reference_image, target_image, image_captions
173 |
174 | elif self.split == 'val':
175 | reference_image_path = base_path / 'fashionIQ_dataset' / 'image_data' / f"{reference_name}.jpg"
176 | reference_image = self.preprocess(PIL.Image.open(reference_image_path).convert("RGB"))
177 | target_name = self.triplets[index]['target'].replace(".jpg", "")
178 | return reference_image, reference_name, target_name, image_captions
179 |
180 | elif self.split == 'test':
181 | reference_image_path = base_path / 'fashionIQ_dataset' / 'image_data' / f"{reference_name}.jpg"
182 | reference_image = self.preprocess(PIL.Image.open(reference_image_path).convert("RGB"))
183 | return reference_name, reference_image, image_captions
184 |
185 | elif self.mode == 'classic':
186 | image_name = self.image_names[index]
187 | image_path = base_path / 'fashionIQ_dataset' / 'image_data' / f"{image_name}.jpg"
188 | image = self.preprocess(PIL.Image.open(image_path).convert("RGB"))
189 | return image_name, image
190 |
191 | else:
192 | raise ValueError("mode should be in ['relative', 'classic']")
193 | except Exception as e:
194 | print(f"Exception: {e}")
195 |
196 | def __len__(self):
197 | if self.mode == 'relative':
198 | return len(self.triplets)
199 | elif self.mode == 'classic':
200 | return len(self.image_names)
201 | else:
202 | raise ValueError("mode should be in ['relative', 'classic']")
203 |
204 |
205 | class CIRRDataset(Dataset):
206 | """
207 | CIRR dataset class which manage CIRR data
208 | The dataset can be used in 'relative' or 'classic' mode:
209 | - In 'classic' mode the dataset yield tuples made of (image_name, image)
210 | - In 'relative' mode the dataset yield tuples made of:
211 | - (reference_image, target_image, rel_caption) when split == train
212 | - (reference_name, target_name, rel_caption, group_members) when split == val
213 | - (pair_id, reference_name, rel_caption, group_members) when split == test1
214 | """
215 |
216 | def __init__(self, split: str, mode: str, preprocess: callable):
217 | """
218 | :param split: dataset split, should be in ['test', 'train', 'val']
219 | :param mode: dataset mode, should be in ['relative', 'classic']:
220 | - In 'classic' mode the dataset yield tuples made of (image_name, image)
221 | - In 'relative' mode the dataset yield tuples made of:
222 | - (reference_image, target_image, rel_caption) when split == train
223 | - (reference_name, target_name, rel_caption, group_members) when split == val
224 | - (pair_id, reference_name, rel_caption, group_members) when split == test1
225 | :param preprocess: function which preprocesses the image
226 | """
227 | self.preprocess = preprocess
228 | # self.preprocess_384 = preprocess_384
229 | self.mode = mode
230 | self.split = split
231 |
232 | if split not in ['test1', 'train', 'val']:
233 | raise ValueError("split should be in ['test1', 'train', 'val']")
234 | if mode not in ['relative', 'classic']:
235 | raise ValueError("mode should be in ['relative', 'classic']")
236 |
237 | # get triplets made by (reference_image, target_image, relative caption)
238 | with open(base_path / 'cirr_datasets' / 'cirr' / 'captions' / f'cap.rc2.{split}.json') as f:
239 | self.triplets = json.load(f)
240 |
241 | # get a mapping from image name to relative path
242 | with open(base_path / 'cirr_datasets' / 'cirr' / 'image_splits' / f'split.rc2.{split}.json') as f:
243 | self.name_to_relpath = json.load(f)
244 |
245 | print(f"CIRR {split} dataset in {mode} mode initialized")
246 |
247 | def __getitem__(self, index):
248 | try:
249 | if self.mode == 'relative':
250 | group_members = self.triplets[index]['img_set']['members']
251 | reference_name = self.triplets[index]['reference']
252 | rel_caption = self.triplets[index]['caption']
253 |
254 | if self.split == 'train':
255 | reference_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[reference_name]
256 | reference_image = self.preprocess(PIL.Image.open(reference_image_path))
257 | target_hard_name = self.triplets[index]['target_hard']
258 | target_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[target_hard_name]
259 | target_image = self.preprocess(PIL.Image.open(target_image_path))
260 | return reference_image, target_image, rel_caption
261 |
262 | elif self.split == 'val':
263 | reference_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[reference_name]
264 | reference_image = self.preprocess(PIL.Image.open(reference_image_path))
265 | target_hard_name = self.triplets[index]['target_hard']
266 | # return reference_name, target_hard_name, rel_caption, group_members
267 | return reference_image, reference_name, target_hard_name, rel_caption, group_members
268 |
269 | elif self.split == 'test1':
270 | pair_id = self.triplets[index]['pairid']
271 | reference_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[reference_name]
272 | reference_image = self.preprocess(PIL.Image.open(reference_image_path))
273 | return pair_id, reference_name, reference_image, rel_caption, group_members
274 |
275 | elif self.mode == 'classic':
276 | image_name = list(self.name_to_relpath.keys())[index]
277 | image_path = base_path / 'cirr_datasets' / self.name_to_relpath[image_name]
278 | im = PIL.Image.open(image_path)
279 | image = self.preprocess(im)
280 | return image_name, image
281 |
282 | else:
283 | raise ValueError("mode should be in ['relative', 'classic']")
284 |
285 | except Exception as e:
286 | print(f"Exception: {e}")
287 |
288 | def __len__(self):
289 | if self.mode == 'relative':
290 | return len(self.triplets)
291 | elif self.mode == 'classic':
292 | return len(self.name_to_relpath)
293 | else:
294 | raise ValueError("mode should be in ['relative', 'classic']")
295 |
296 |
297 | class CIRRDataset_val_submission(Dataset):
298 | """
299 | CIRR dataset class which manage CIRR data
300 | The dataset can be used in 'relative' or 'classic' mode:
301 | - In 'classic' mode the dataset yield tuples made of (image_name, image)
302 | - In 'relative' mode the dataset yield tuples made of:
303 | - (reference_image, target_image, rel_caption) when split == train
304 | - (reference_name, target_name, rel_caption, group_members) when split == val
305 | - (pair_id, reference_name, rel_caption, group_members) when split == test1
306 | """
307 |
308 | def __init__(self, split: str, mode: str, preprocess: callable):
309 | """
310 | :param split: dataset split, should be in ['test', 'train', 'val']
311 | :param mode: dataset mode, should be in ['relative', 'classic']:
312 | - In 'classic' mode the dataset yield tuples made of (image_name, image)
313 | - In 'relative' mode the dataset yield tuples made of:
314 | - (reference_image, target_image, rel_caption) when split == train
315 | - (reference_name, target_name, rel_caption, group_members) when split == val
316 | - (pair_id, reference_name, rel_caption, group_members) when split == test1
317 | :param preprocess: function which preprocesses the image
318 | """
319 | self.preprocess = preprocess
320 | # self.preprocess_384 = preprocess_384
321 | self.mode = mode
322 | self.split = split
323 |
324 | if split not in ['test1', 'train', 'val']:
325 | raise ValueError("split should be in ['test1', 'train', 'val']")
326 | if mode not in ['relative', 'classic']:
327 | raise ValueError("mode should be in ['relative', 'classic']")
328 |
329 | # get triplets made by (reference_image, target_image, relative caption)
330 | with open(base_path / 'cirr_datasets' / 'cirr' / 'captions' / f'cap.rc2.{split}.json') as f:
331 | self.triplets = json.load(f)
332 |
333 | # get a mapping from image name to relative path
334 | with open(base_path / 'cirr_datasets' / 'cirr' / 'image_splits' / f'split.rc2.{split}.json') as f:
335 | self.name_to_relpath = json.load(f)
336 |
337 | print(f"CIRR {split} dataset in {mode} mode initialized")
338 |
339 | def __getitem__(self, index):
340 | try:
341 | if self.mode == 'relative':
342 | group_members = self.triplets[index]['img_set']['members']
343 | reference_name = self.triplets[index]['reference']
344 | rel_caption = self.triplets[index]['caption']
345 |
346 | if self.split == 'train':
347 | reference_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[reference_name]
348 | reference_image = self.preprocess(PIL.Image.open(reference_image_path))
349 | target_hard_name = self.triplets[index]['target_hard']
350 | target_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[target_hard_name]
351 | target_image = self.preprocess(PIL.Image.open(target_image_path))
352 | return reference_image, target_image, rel_caption
353 |
354 | elif self.split == 'val':
355 | pair_id = self.triplets[index]['pairid']
356 | reference_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[reference_name]
357 | reference_image = self.preprocess(PIL.Image.open(reference_image_path))
358 | target_hard_name = self.triplets[index]['target_hard']
359 | # return reference_name, target_hard_name, rel_caption, group_members
360 | return pair_id, reference_name, reference_image, rel_caption, group_members, target_hard_name
361 |
362 | elif self.split == 'test1':
363 | pair_id = self.triplets[index]['pairid']
364 | reference_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[reference_name]
365 | reference_image = self.preprocess(PIL.Image.open(reference_image_path))
366 | return pair_id, reference_name, reference_image, rel_caption, group_members
367 |
368 | elif self.mode == 'classic':
369 | image_name = list(self.name_to_relpath.keys())[index]
370 | image_path = base_path / 'cirr_datasets' / self.name_to_relpath[image_name]
371 | im = PIL.Image.open(image_path)
372 | image = self.preprocess(im)
373 | return image_name, image
374 |
375 | else:
376 | raise ValueError("mode should be in ['relative', 'classic']")
377 |
378 | except Exception as e:
379 | print(f"Exception: {e}")
380 |
381 | def __len__(self):
382 | if self.mode == 'relative':
383 | return len(self.triplets)
384 | elif self.mode == 'classic':
385 | return len(self.name_to_relpath)
386 | else:
387 | raise ValueError("mode should be in ['relative', 'classic']")
388 |
389 |
390 | class CIRRDataset_hinge(Dataset):
391 | """
392 | CIRR dataset class which manage CIRR data
393 | The dataset can be used in 'relative' or 'classic' mode:
394 | - In 'classic' mode the dataset yield tuples made of (image_name, image)
395 | - In 'relative' mode the dataset yield tuples made of:
396 | - (reference_image, target_image, rel_caption) when split == train
397 | - (reference_name, target_name, rel_caption, group_members) when split == val
398 | - (pair_id, reference_name, rel_caption, group_members) when split == test1
399 | """
400 |
401 | def __init__(self, split: str, mode: str, preprocess: callable):
402 | """
403 | :param split: dataset split, should be in ['test', 'train', 'val']
404 | :param mode: dataset mode, should be in ['relative', 'classic']:
405 | - In 'classic' mode the dataset yield tuples made of (image_name, image)
406 | - In 'relative' mode the dataset yield tuples made of:
407 | - (reference_image, target_image, rel_caption) when split == train
408 | - (reference_name, target_name, rel_caption, group_members) when split == val
409 | - (pair_id, reference_name, rel_caption, group_members) when split == test1
410 | :param preprocess: function which preprocesses the image
411 | """
412 | self.preprocess = preprocess
413 | # self.preprocess_384 = preprocess_384
414 | self.mode = mode
415 | self.split = split
416 |
417 | if split not in ['test1', 'train', 'val']:
418 | raise ValueError("split should be in ['test1', 'train', 'val']")
419 | if mode not in ['relative', 'classic']:
420 | raise ValueError("mode should be in ['relative', 'classic']")
421 |
422 | # get triplets made by (reference_image, target_image, relative caption)
423 | with open(base_path / 'cirr_datasets' / 'cirr' / 'captions' / f'cap.rc2.{split}.json') as f:
424 | self.triplets = json.load(f)
425 |
426 | # get a mapping from image name to relative path
427 | with open(base_path / 'cirr_datasets' / 'cirr' / 'image_splits' / f'split.rc2.{split}.json') as f:
428 | self.name_to_relpath = json.load(f)
429 |
430 | print(f"CIRR {split} dataset in {mode} mode initialized")
431 |
432 | def __getitem__(self, index):
433 | try:
434 | if self.mode == 'relative':
435 | group_members = self.triplets[index]['img_set']['members']
436 | reference_name = self.triplets[index]['reference']
437 | rel_caption = self.triplets[index]['caption']
438 |
439 | if self.split == 'train':
440 | reference_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[reference_name]
441 | reference_image = self.preprocess(PIL.Image.open(reference_image_path))
442 | target_hard_name = self.triplets[index]['target_hard']
443 | target_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[target_hard_name]
444 | target_image = self.preprocess(PIL.Image.open(target_image_path))
445 | return reference_image, target_image, rel_caption
446 |
447 | elif self.split == 'val':
448 | reference_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[reference_name]
449 | reference_image = self.preprocess(PIL.Image.open(reference_image_path))
450 | target_hard_name = self.triplets[index]['target_hard']
451 | target_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[target_hard_name]
452 | target_image = self.preprocess(PIL.Image.open(target_image_path))
453 | # return reference_name, target_hard_name, rel_caption, group_members
454 | return reference_image, reference_name, target_hard_name, rel_caption, group_members, target_image
455 |
456 | elif self.split == 'test1':
457 | pair_id = self.triplets[index]['pairid']
458 | reference_image_path = base_path / 'cirr_datasets' / self.name_to_relpath[reference_name]
459 | reference_image = self.preprocess(PIL.Image.open(reference_image_path))
460 | return pair_id, reference_name, reference_image, rel_caption, group_members
461 |
462 | elif self.mode == 'classic':
463 | image_name = list(self.name_to_relpath.keys())[index]
464 | image_path = base_path / 'cirr_datasets' / self.name_to_relpath[image_name]
465 | im = PIL.Image.open(image_path)
466 | image = self.preprocess(im)
467 | return image_name, image
468 |
469 | else:
470 | raise ValueError("mode should be in ['relative', 'classic']")
471 |
472 | except Exception as e:
473 | print(f"Exception: {e}")
474 |
475 | def __len__(self):
476 | if self.mode == 'relative':
477 | return len(self.triplets)
478 | elif self.mode == 'classic':
479 | return len(self.name_to_relpath)
480 | else:
481 | raise ValueError("mode should be in ['relative', 'classic']")
482 |
--------------------------------------------------------------------------------
/src/validate.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | from argparse import ArgumentParser
3 | from operator import itemgetter
4 | from pathlib import Path
5 | from statistics import mean
6 | from typing import List, Tuple
7 |
8 | import clip
9 | import numpy as np
10 | import torch
11 | import torch.nn.functional as F
12 | from clip.model import CLIP
13 | from torch.utils.data import DataLoader
14 | from tqdm import tqdm
15 |
16 | from data_utils import squarepad_transform, FashionIQDataset, targetpad_transform, CIRRDataset
17 | from utils import extract_index_features_clip, collate_fn, element_wise_sum, device , extract_index_features_blip2
18 | from lavis.models import load_model
19 |
20 |
21 | def compute_fiq_val_metrics_blip2(relative_val_dataset: FashionIQDataset, blip_textual,blip_multimodal, index_features: torch.tensor,
22 | index_names: List[str], combining_function: callable) -> Tuple[float, float]:
23 | """
24 | Compute validation metrics on FashionIQ dataset
25 | :param relative_val_dataset: FashionIQ validation dataset in relative mode
26 | :param clip_model: CLIP model
27 | :param index_features: validation index features
28 | :param index_names: validation index names
29 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined
30 | features
31 | :return: the computed validation metrics
32 | """
33 |
34 | # Generate predictions
35 | predicted_features, target_names = generate_fiq_val_predictions_blip2(blip_textual, blip_multimodal, relative_val_dataset,
36 | combining_function, index_names, index_features)
37 |
38 | print(f"Compute FashionIQ {relative_val_dataset.dress_types} validation metrics")
39 |
40 | # Normalize the index features
41 | index_features = F.normalize(index_features, dim=-1).float()
42 |
43 | # Compute the distances and sort the results
44 | distances = 1 - predicted_features @ index_features.T
45 | sorted_indices = torch.argsort(distances, dim=-1).cpu()
46 | sorted_index_names = np.array(index_names)[sorted_indices]
47 |
48 | # Compute the ground-truth labels wrt the predictions
49 | labels = torch.tensor(
50 | sorted_index_names == np.repeat(np.array(target_names), len(index_names)).reshape(len(target_names), -1))
51 | assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int())
52 |
53 | # Compute the metrics
54 | recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100
55 | recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100
56 |
57 | return recall_at10, recall_at50
58 |
59 | def compute_fiq_val_metrics_clip(relative_val_dataset: FashionIQDataset, clip_model: CLIP, index_features: torch.tensor,
60 | index_names: List[str], combining_function: callable) -> Tuple[float, float]:
61 | """
62 | Compute validation metrics on FashionIQ dataset
63 | :param relative_val_dataset: FashionIQ validation dataset in relative mode
64 | :param clip_model: CLIP model
65 | :param index_features: validation index features
66 | :param index_names: validation index names
67 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined
68 | features
69 | :return: the computed validation metrics
70 | """
71 |
72 | # Generate predictions
73 | predicted_features, target_names = generate_fiq_val_predictions_clip(clip_model, relative_val_dataset,
74 | combining_function, index_names, index_features)
75 |
76 | print(f"Compute FashionIQ {relative_val_dataset.dress_types} validation metrics")
77 |
78 | # Normalize the index features
79 | index_features = F.normalize(index_features, dim=-1).float()
80 |
81 | # Compute the distances and sort the results
82 | distances = 1 - predicted_features @ index_features.T
83 | sorted_indices = torch.argsort(distances, dim=-1).cpu()
84 | sorted_index_names = np.array(index_names)[sorted_indices]
85 |
86 | # Compute the ground-truth labels wrt the predictions
87 | labels = torch.tensor(
88 | sorted_index_names == np.repeat(np.array(target_names), len(index_names)).reshape(len(target_names), -1))
89 | assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int())
90 |
91 | # Compute the metrics
92 | recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100
93 | recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100
94 |
95 | return recall_at10, recall_at50
96 |
97 |
98 | def artemis_compute_fiq_val_metrics(relative_val_dataset: FashionIQDataset, blip_textual,blip_multimodal, index_features: torch.tensor,
99 | index_names: List[str], artemis) -> Tuple[float, float]:
100 | """
101 | Compute validation metrics on FashionIQ dataset
102 | :param relative_val_dataset: FashionIQ validation dataset in relative mode
103 | :param clip_model: CLIP model
104 | :param index_features: validation index features
105 | :param index_names: validation index names
106 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined
107 | features
108 | :return: the computed validation metrics
109 | """
110 |
111 | print(f"Compute FashionIQ {relative_val_dataset.dress_types} validation metrics")
112 |
113 | # Normalize the index features
114 | index_features = F.normalize(index_features, dim=-1).float()
115 | print("Compute FashionIQ validation predictions")
116 |
117 | relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32, num_workers=8,
118 | pin_memory=True, collate_fn=collate_fn)
119 | feature_dim = 768
120 | # Initialize predicted features, target_names, group_members and reference_names
121 | artemis.eval()
122 | artemis_scores = torch.empty((0, len(index_names))).to(device, non_blocking=True)
123 | target_names = []
124 | reference_names = []
125 |
126 | for reference_images, batch_reference_names, batch_target_names, captions in tqdm(
127 | relative_val_loader): # Load data
128 | reference_images = reference_images.to(device)
129 |
130 | flattened_captions: list = np.array(captions).T.flatten().tolist()
131 | input_captions = [
132 | f"{flattened_captions[i].strip('.?, ').capitalize()} and {flattened_captions[i + 1].strip('.?, ')}"
133 | for i in range(0, len(flattened_captions), 2)]
134 |
135 | # Compute the predicted features
136 | with torch.no_grad():
137 | text_feats = blip_textual.extract_features({"text_input":input_captions},
138 | mode="text").text_embeds[:,0,:]
139 | reference_feats = blip_multimodal.extract_features({"image":reference_images,
140 | "text_input":input_captions}).multimodal_embeds[:,0,:]
141 | batch_artemis_score = torch.empty((0, len(index_names))).to(device, non_blocking=True)
142 | for i in range(reference_feats.shape[0]):
143 | one_artemis_score = artemis.compute_score_artemis(reference_feats[i].unsqueeze(0), text_feats[i].unsqueeze(0), index_features)
144 | batch_artemis_score = torch.vstack((batch_artemis_score, one_artemis_score))
145 | artemis_scores = torch.vstack((artemis_scores, F.normalize(batch_artemis_score, dim=-1)))
146 | target_names.extend(batch_target_names)
147 | reference_names.extend(batch_reference_names)
148 |
149 | # Compute the distances and sort the results
150 | distances = 1 - artemis_scores
151 | sorted_indices = torch.argsort(distances, dim=-1).cpu()
152 | sorted_index_names = np.array(index_names)[sorted_indices]
153 |
154 | # Compute the ground-truth labels wrt the predictions
155 | labels = torch.tensor(
156 | sorted_index_names == np.repeat(np.array(target_names), len(index_names)).reshape(len(target_names), -1))
157 | assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int())
158 |
159 | # Compute the metrics
160 | recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100
161 | recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100
162 |
163 | return recall_at10, recall_at50
164 |
165 |
166 | def generate_fiq_val_predictions_blip2(blip_textual, blip_multimodal, relative_val_dataset: FashionIQDataset,
167 | combining_function: callable, index_names: List[str], index_features: torch.tensor) -> \
168 | Tuple[torch.tensor, List[str]]:
169 | """
170 | Compute FashionIQ predictions on the validation set
171 | :param clip_model: CLIP model
172 | :param relative_val_dataset: FashionIQ validation dataset in relative mode
173 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined
174 | features
175 | :param index_features: validation index features
176 | :param index_names: validation index names
177 | :return: predicted features and target names
178 | """
179 | print(f"Compute FashionIQ {relative_val_dataset.dress_types} validation predictions")
180 |
181 | relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32,
182 | num_workers=multiprocessing.cpu_count(), pin_memory=True, collate_fn=collate_fn,
183 | shuffle=False)
184 |
185 | # Get a mapping from index names to index features
186 | # name_to_feat = dict(zip(index_names, index_features))
187 | feature_dim = 768
188 | # Initialize predicted features and target names
189 | predicted_features = torch.empty((0, feature_dim)).to(device, non_blocking=True)
190 | target_names = []
191 |
192 | for reference_images, batch_target_names, captions in tqdm(relative_val_loader): # Load data
193 | reference_images = reference_images.to(device)
194 | # Concatenate the captions in a deterministic way
195 | flattened_captions: list = np.array(captions).T.flatten().tolist()
196 | input_captions = [
197 | f"{flattened_captions[i].strip('.?, ').capitalize()} and {flattened_captions[i + 1].strip('.?, ')}"
198 | for i in range(0, len(flattened_captions), 2)]
199 |
200 | # Compute the predicted features
201 | with torch.no_grad():
202 | text_feats = blip_textual.extract_features({"text_input":input_captions},
203 | mode="text").text_embeds[:,0,:]
204 | reference_feats = blip_multimodal.extract_features({"image":reference_images,
205 | "text_input":input_captions}).multimodal_embeds[:,0,:]
206 |
207 | batch_predicted_features = combining_function(reference_feats, text_feats)
208 |
209 | predicted_features = torch.vstack((predicted_features, F.normalize(batch_predicted_features, dim=-1)))
210 | target_names.extend(batch_target_names)
211 |
212 | return predicted_features, target_names
213 |
214 |
215 | def generate_fiq_val_predictions_clip(clip_model: CLIP, relative_val_dataset: FashionIQDataset,
216 | combining_function: callable, index_names: List[str], index_features: torch.tensor) -> \
217 | Tuple[torch.tensor, List[str]]:
218 | """
219 | Compute FashionIQ predictions on the validation set
220 | :param clip_model: CLIP model
221 | :param relative_val_dataset: FashionIQ validation dataset in relative mode
222 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined
223 | features
224 | :param index_features: validation index features
225 | :param index_names: validation index names
226 | :return: predicted features and target names
227 | """
228 | print(f"Compute FashionIQ {relative_val_dataset.dress_types} validation predictions")
229 |
230 | relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32,
231 | num_workers=multiprocessing.cpu_count(), pin_memory=True, collate_fn=collate_fn,
232 | shuffle=False)
233 |
234 | # Get a mapping from index names to index features
235 | name_to_feat = dict(zip(index_names, index_features))
236 |
237 | # Initialize predicted features and target names
238 | predicted_features = torch.empty((0, clip_model.visual.output_dim)).to(device, non_blocking=True)
239 | target_names = []
240 |
241 | for reference_image, reference_names, batch_target_names, captions in tqdm(relative_val_loader): # Load data
242 |
243 | # Concatenate the captions in a deterministic way
244 | flattened_captions: list = np.array(captions).T.flatten().tolist()
245 | input_captions = [
246 | f"{flattened_captions[i].strip('.?, ').capitalize()} and {flattened_captions[i + 1].strip('.?, ')}" for
247 | i in range(0, len(flattened_captions), 2)]
248 | text_inputs = clip.tokenize(input_captions, context_length=77).to(device, non_blocking=True)
249 |
250 | # Compute the predicted features
251 | with torch.no_grad():
252 | text_features = clip_model.encode_text(text_inputs)[0]
253 | # Check whether a single element is in the batch due to the exception raised by torch.stack when used with
254 | # a single tensor
255 | if text_features.shape[0] == 1:
256 | reference_image_features = itemgetter(*reference_names)(name_to_feat).unsqueeze(0)
257 | else:
258 | reference_image_features = torch.stack(itemgetter(*reference_names)(
259 | name_to_feat)) # To avoid unnecessary computation retrieve the reference image features directly from the index features
260 | batch_predicted_features = combining_function(reference_image_features, text_features)
261 |
262 | predicted_features = torch.vstack((predicted_features, F.normalize(batch_predicted_features, dim=-1)))
263 | target_names.extend(batch_target_names)
264 |
265 | return predicted_features, target_names
266 |
267 |
268 | def fashioniq_val_retrieval(dress_type: str, combining_function: callable, clip_model: CLIP, preprocess: callable):
269 | """
270 | Perform retrieval on FashionIQ validation set computing the metrics. To combine the features the `combining_function`
271 | is used
272 | :param dress_type: FashionIQ category on which perform the retrieval
273 | :param combining_function:function which takes as input (image_features, text_features) and outputs the combined
274 | features
275 | :param clip_model: CLIP model
276 | :param preprocess: preprocess pipeline
277 | """
278 |
279 | clip_model = clip_model.float().eval()
280 |
281 | # Define the validation datasets and extract the index features
282 | classic_val_dataset = FashionIQDataset('val', [dress_type], 'classic', preprocess)
283 | index_features, index_names = extract_index_features(classic_val_dataset, clip_model)
284 | relative_val_dataset = FashionIQDataset('val', [dress_type], 'relative', preprocess)
285 |
286 | return compute_fiq_val_metrics(relative_val_dataset, clip_model, index_features, index_names,
287 | combining_function)
288 |
289 | def artemis_compute_cirr_val_metrics(relative_val_dataset: CIRRDataset, blip_textual, blip_multimodal, index_features: torch.tensor,
290 | index_names: List[str], artemis) -> Tuple[
291 | float, float, float, float, float, float, float]:
292 | """
293 | Compute validation metrics on CIRR dataset
294 | :param relative_val_dataset: CIRR validation dataset in relative mode
295 | :param clip_model: CLIP model
296 | :param index_features: validation index features
297 | :param index_names: validation index names
298 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined
299 | features
300 | :return: the computed validation metrics
301 | """
302 | # Generate predictions
303 |
304 | print("Compute CIRR validation metrics")
305 |
306 | # Normalize the index features
307 | index_features = F.normalize(index_features, dim=-1).float()
308 |
309 | print("Compute CIRR validation predictions")
310 | relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32, num_workers=8,
311 | pin_memory=True, collate_fn=collate_fn)
312 | feature_dim = 768
313 | # Initialize predicted features, target_names, group_members and reference_names
314 | artemis.eval()
315 |
316 | artemis_scores = torch.empty((0, len(index_names))).to(device, non_blocking=True)
317 | target_names = []
318 | group_members = []
319 | reference_names = []
320 |
321 | for reference_images, batch_reference_names, batch_target_names, captions, batch_group_members in tqdm(
322 | relative_val_loader): # Load data
323 | reference_images = reference_images.to(device)
324 | batch_group_members = np.array(batch_group_members).T.tolist()
325 | # Compute the predicted features
326 | with torch.no_grad():
327 | text_feats = blip_textual.extract_features({"text_input":captions},
328 | mode="text").text_embeds[:,0,:]
329 | reference_feats = blip_multimodal.extract_features({"image":reference_images,
330 | "text_input":captions}).multimodal_embeds[:,0,:]
331 | batch_artemis_score = torch.empty((0, len(index_names))).to(device, non_blocking=True)
332 | for i in range(reference_feats.shape[0]):
333 | one_artemis_score = artemis.compute_score_artemis(reference_feats[i].unsqueeze(0), text_feats[i].unsqueeze(0), index_features)
334 | batch_artemis_score = torch.vstack((batch_artemis_score, one_artemis_score))
335 |
336 | artemis_scores = torch.vstack((artemis_scores, F.normalize(batch_artemis_score, dim=-1)))
337 | target_names.extend(batch_target_names)
338 | group_members.extend(batch_group_members)
339 | reference_names.extend(batch_reference_names)
340 |
341 |
342 | # Compute the distances and sort the results
343 | distances = 1 - artemis_scores
344 | sorted_indices = torch.argsort(distances, dim=-1).cpu()
345 | sorted_index_names = np.array(index_names)[sorted_indices]
346 |
347 | # Delete the reference image from the results
348 | reference_mask = torch.tensor(
349 | sorted_index_names != np.repeat(np.array(reference_names), len(index_names)).reshape(len(target_names), -1))
350 | sorted_index_names = sorted_index_names[reference_mask].reshape(sorted_index_names.shape[0],
351 | sorted_index_names.shape[1] - 1)
352 | # Compute the ground-truth labels wrt the predictions
353 | labels = torch.tensor(
354 | sorted_index_names == np.repeat(np.array(target_names), len(index_names) - 1).reshape(len(target_names), -1))
355 |
356 | # Compute the subset predictions and ground-truth labels
357 | group_members = np.array(group_members)
358 | group_mask = (sorted_index_names[..., None] == group_members[:, None, :]).sum(-1).astype(bool)
359 | group_labels = labels[group_mask].reshape(labels.shape[0], -1)
360 |
361 | assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int())
362 | assert torch.equal(torch.sum(group_labels, dim=-1).int(), torch.ones(len(target_names)).int())
363 |
364 | # Compute the metrics
365 | recall_at1 = (torch.sum(labels[:, :1]) / len(labels)).item() * 100
366 | recall_at5 = (torch.sum(labels[:, :5]) / len(labels)).item() * 100
367 | recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100
368 | recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100
369 | group_recall_at1 = (torch.sum(group_labels[:, :1]) / len(group_labels)).item() * 100
370 | group_recall_at2 = (torch.sum(group_labels[:, :2]) / len(group_labels)).item() * 100
371 | group_recall_at3 = (torch.sum(group_labels[:, :3]) / len(group_labels)).item() * 100
372 |
373 | return group_recall_at1, group_recall_at2, group_recall_at3, recall_at1, recall_at5, recall_at10, recall_at50
374 |
375 | def compute_cirr_val_metrics_clip(relative_val_dataset: CIRRDataset, clip_model: CLIP, index_features: torch.tensor,
376 | index_names: List[str], combining_function: callable) -> Tuple[
377 | float, float, float, float, float, float, float]:
378 | """
379 | Compute validation metrics on CIRR dataset
380 | :param relative_val_dataset: CIRR validation dataset in relative mode
381 | :param clip_model: CLIP model
382 | :param index_features: validation index features
383 | :param index_names: validation index names
384 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined
385 | features
386 | :return: the computed validation metrics
387 | """
388 | # Generate predictions
389 | predicted_features, reference_names, target_names, group_members = \
390 | generate_cirr_val_predictions_clip(clip_model, relative_val_dataset, combining_function, index_names, index_features)
391 |
392 | print("Compute CIRR validation metrics")
393 |
394 | # Normalize the index features
395 | index_features = F.normalize(index_features, dim=-1).float()
396 |
397 | # Compute the distances and sort the results
398 | distances = 1 - predicted_features @ index_features.T
399 | sorted_indices = torch.argsort(distances, dim=-1).cpu()
400 | sorted_index_names = np.array(index_names)[sorted_indices]
401 |
402 | # Delete the reference image from the results
403 | reference_mask = torch.tensor(
404 | sorted_index_names != np.repeat(np.array(reference_names), len(index_names)).reshape(len(target_names), -1))
405 | sorted_index_names = sorted_index_names[reference_mask].reshape(sorted_index_names.shape[0],
406 | sorted_index_names.shape[1] - 1)
407 | # Compute the ground-truth labels wrt the predictions
408 | labels = torch.tensor(
409 | sorted_index_names == np.repeat(np.array(target_names), len(index_names) - 1).reshape(len(target_names), -1))
410 |
411 | # Compute the subset predictions and ground-truth labels
412 | group_members = np.array(group_members)
413 | group_mask = (sorted_index_names[..., None] == group_members[:, None, :]).sum(-1).astype(bool)
414 | group_labels = labels[group_mask].reshape(labels.shape[0], -1)
415 |
416 | assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int())
417 | assert torch.equal(torch.sum(group_labels, dim=-1).int(), torch.ones(len(target_names)).int())
418 |
419 | # Compute the metrics
420 | recall_at1 = (torch.sum(labels[:, :1]) / len(labels)).item() * 100
421 | recall_at5 = (torch.sum(labels[:, :5]) / len(labels)).item() * 100
422 | recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100
423 | recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100
424 | group_recall_at1 = (torch.sum(group_labels[:, :1]) / len(group_labels)).item() * 100
425 | group_recall_at2 = (torch.sum(group_labels[:, :2]) / len(group_labels)).item() * 100
426 | group_recall_at3 = (torch.sum(group_labels[:, :3]) / len(group_labels)).item() * 100
427 |
428 | return group_recall_at1, group_recall_at2, group_recall_at3, recall_at1, recall_at5, recall_at10, recall_at50
429 |
430 | def generate_cirr_val_predictions_clip(clip_model: CLIP, relative_val_dataset: CIRRDataset,
431 | combining_function: callable, index_names: List[str], index_features: torch.tensor) -> \
432 | Tuple[torch.tensor, List[str], List[str], List[List[str]]]:
433 | """
434 | Compute CIRR predictions on the validation set
435 | :param clip_model: CLIP model
436 | :param relative_val_dataset: CIRR validation dataset in relative mode
437 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined
438 | features
439 | :param index_features: validation index features
440 | :param index_names: validation index names
441 | :return: predicted features, reference names, target names and group members
442 | """
443 | print("Compute CIRR validation predictions")
444 | relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32, num_workers=8,
445 | pin_memory=True, collate_fn=collate_fn)
446 |
447 | # Get a mapping from index names to index features
448 | name_to_feat = dict(zip(index_names, index_features))
449 |
450 | # Initialize predicted features, target_names, group_members and reference_names
451 | predicted_features = torch.empty((0, clip_model.visual.output_dim)).to(device, non_blocking=True)
452 | target_names = []
453 | group_members = []
454 | reference_names = []
455 |
456 | for batch_reference_names, batch_target_names, captions, batch_group_members in tqdm(
457 | relative_val_loader): # Load data
458 | text_inputs = clip.tokenize(captions).to(device, non_blocking=True)
459 | batch_group_members = np.array(batch_group_members).T.tolist()
460 |
461 | # Compute the predicted features
462 | with torch.no_grad():
463 | text_features = clip_model.encode_text(text_inputs)[0]
464 | # Check whether a single element is in the batch due to the exception raised by torch.stack when used with
465 | # a single tensor
466 | if text_features.shape[0] == 1:
467 | reference_image_features = itemgetter(*batch_reference_names)(name_to_feat).unsqueeze(0)
468 | else:
469 | reference_image_features = torch.stack(itemgetter(*batch_reference_names)(
470 | name_to_feat)) # To avoid unnecessary computation retrieve the reference image features directly from the index features
471 | batch_predicted_features = combining_function(reference_image_features, text_features)
472 |
473 | predicted_features = torch.vstack((predicted_features, F.normalize(batch_predicted_features, dim=-1)))
474 | target_names.extend(batch_target_names)
475 | group_members.extend(batch_group_members)
476 | reference_names.extend(batch_reference_names)
477 |
478 | return predicted_features, reference_names, target_names, group_members
479 |
480 |
481 | def compute_cirr_val_metrics_blip2(relative_val_dataset: CIRRDataset, blip_textual, blip_multimodal, index_features: torch.tensor,
482 | index_names: List[str], combining_function: callable) -> Tuple[
483 | float, float, float, float, float, float, float]:
484 | """
485 | Compute validation metrics on CIRR dataset
486 | :param relative_val_dataset: CIRR validation dataset in relative mode
487 | :param clip_model: CLIP model
488 | :param index_features: validation index features
489 | :param index_names: validation index names
490 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined
491 | features
492 | :return: the computed validation metrics
493 | """
494 | # Generate predictions
495 | predicted_features, reference_names, target_names, group_members = \
496 | generate_cirr_val_predictions_blip2(blip_textual, blip_multimodal, relative_val_dataset, combining_function, index_names, index_features)
497 |
498 | print("Compute CIRR validation metrics")
499 |
500 | # Normalize the index features
501 | index_features = F.normalize(index_features, dim=-1).float()
502 |
503 | # Compute the distances and sort the results
504 | distances = 1 - predicted_features @ index_features.T
505 | sorted_indices = torch.argsort(distances, dim=-1).cpu()
506 | sorted_index_names = np.array(index_names)[sorted_indices]
507 |
508 | # Delete the reference image from the results
509 | reference_mask = torch.tensor(
510 | sorted_index_names != np.repeat(np.array(reference_names), len(index_names)).reshape(len(target_names), -1))
511 | sorted_index_names = sorted_index_names[reference_mask].reshape(sorted_index_names.shape[0],
512 | sorted_index_names.shape[1] - 1)
513 | # Compute the ground-truth labels wrt the predictions
514 | labels = torch.tensor(
515 | sorted_index_names == np.repeat(np.array(target_names), len(index_names) - 1).reshape(len(target_names), -1))
516 |
517 | # Compute the subset predictions and ground-truth labels
518 | group_members = np.array(group_members)
519 | group_mask = (sorted_index_names[..., None] == group_members[:, None, :]).sum(-1).astype(bool)
520 | group_labels = labels[group_mask].reshape(labels.shape[0], -1)
521 |
522 | assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int())
523 | assert torch.equal(torch.sum(group_labels, dim=-1).int(), torch.ones(len(target_names)).int())
524 |
525 | # Compute the metrics
526 | recall_at1 = (torch.sum(labels[:, :1]) / len(labels)).item() * 100
527 | recall_at5 = (torch.sum(labels[:, :5]) / len(labels)).item() * 100
528 | recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100
529 | recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100
530 | group_recall_at1 = (torch.sum(group_labels[:, :1]) / len(group_labels)).item() * 100
531 | group_recall_at2 = (torch.sum(group_labels[:, :2]) / len(group_labels)).item() * 100
532 | group_recall_at3 = (torch.sum(group_labels[:, :3]) / len(group_labels)).item() * 100
533 |
534 | return group_recall_at1, group_recall_at2, group_recall_at3, recall_at1, recall_at5, recall_at10, recall_at50
535 |
536 |
537 | def generate_cirr_val_predictions_blip2(blip_textual, blip_multimodal, relative_val_dataset: CIRRDataset,
538 | combining_function: callable, index_names: List[str], index_features: torch.tensor) -> \
539 | Tuple[torch.tensor, List[str], List[str], List[List[str]]]:
540 | """
541 | Compute CIRR predictions on the validation set
542 | :param clip_model: CLIP model
543 | :param relative_val_dataset: CIRR validation dataset in relative mode
544 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined
545 | features
546 | :param index_features: validation index features
547 | :param index_names: validation index names
548 | :return: predicted features, reference names, target names and group members
549 | """
550 | print("Compute CIRR validation predictions")
551 | relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32, num_workers=8,
552 | pin_memory=True, collate_fn=collate_fn)
553 |
554 | # Get a mapping from index names to index features
555 | # name_to_feat = dict(zip(index_names, index_features))
556 |
557 | feature_dim = 768
558 | # Initialize predicted features, target_names, group_members and reference_names
559 | predicted_features = torch.empty((0, feature_dim)).to(device, non_blocking=True)
560 | target_names = []
561 | group_members = []
562 | reference_names = []
563 |
564 | for reference_images, batch_reference_names, batch_target_names, captions, batch_group_members in tqdm(
565 | relative_val_loader): # Load data
566 | reference_images = reference_images.to(device)
567 | batch_group_members = np.array(batch_group_members).T.tolist()
568 | # Compute the predicted features
569 | with torch.no_grad():
570 | text_feats = blip_textual.extract_features({"text_input":captions},
571 | mode="text").text_embeds[:,0,:]
572 | reference_feats = blip_multimodal.extract_features({"image":reference_images,
573 | "text_input":captions}).multimodal_embeds[:,0,:]
574 |
575 | batch_predicted_features = combining_function(text_feats, reference_feats)
576 | predicted_features = torch.vstack((predicted_features, F.normalize(batch_predicted_features, dim=-1)))
577 | target_names.extend(batch_target_names)
578 | group_members.extend(batch_group_members)
579 | reference_names.extend(batch_reference_names)
580 |
581 | return predicted_features, reference_names, target_names, group_members
582 |
583 |
584 | def compute_cirr_val_metrics_blip1(relative_val_dataset: CIRRDataset, blip_textual, index_features: torch.tensor,
585 | index_names: List[str], combining_function: callable) -> Tuple[
586 | float, float, float, float, float, float, float]:
587 | """
588 | Compute validation metrics on CIRR dataset
589 | :param relative_val_dataset: CIRR validation dataset in relative mode
590 | :param clip_model: CLIP model
591 | :param index_features: validation index features
592 | :param index_names: validation index names
593 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined
594 | features
595 | :return: the computed validation metrics
596 | """
597 | # Generate predictions
598 | predicted_features, reference_names, target_names, group_members = \
599 | generate_cirr_val_predictions_blip1(blip_textual, relative_val_dataset, combining_function, index_names, index_features)
600 |
601 | print("Compute CIRR validation metrics")
602 |
603 | # Normalize the index features
604 | index_features = F.normalize(index_features, dim=-1).float()
605 |
606 | # Compute the distances and sort the results
607 | distances = 1 - predicted_features @ index_features.T
608 | sorted_indices = torch.argsort(distances, dim=-1).cpu()
609 | sorted_index_names = np.array(index_names)[sorted_indices]
610 |
611 | # Delete the reference image from the results
612 | reference_mask = torch.tensor(
613 | sorted_index_names != np.repeat(np.array(reference_names), len(index_names)).reshape(len(target_names), -1))
614 | sorted_index_names = sorted_index_names[reference_mask].reshape(sorted_index_names.shape[0],
615 | sorted_index_names.shape[1] - 1)
616 | # Compute the ground-truth labels wrt the predictions
617 | labels = torch.tensor(
618 | sorted_index_names == np.repeat(np.array(target_names), len(index_names) - 1).reshape(len(target_names), -1))
619 |
620 | # Compute the subset predictions and ground-truth labels
621 | group_members = np.array(group_members)
622 | group_mask = (sorted_index_names[..., None] == group_members[:, None, :]).sum(-1).astype(bool)
623 | group_labels = labels[group_mask].reshape(labels.shape[0], -1)
624 |
625 | assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int())
626 | assert torch.equal(torch.sum(group_labels, dim=-1).int(), torch.ones(len(target_names)).int())
627 |
628 | # Compute the metrics
629 | recall_at1 = (torch.sum(labels[:, :1]) / len(labels)).item() * 100
630 | recall_at5 = (torch.sum(labels[:, :5]) / len(labels)).item() * 100
631 | recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100
632 | recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100
633 | group_recall_at1 = (torch.sum(group_labels[:, :1]) / len(group_labels)).item() * 100
634 | group_recall_at2 = (torch.sum(group_labels[:, :2]) / len(group_labels)).item() * 100
635 | group_recall_at3 = (torch.sum(group_labels[:, :3]) / len(group_labels)).item() * 100
636 |
637 | return group_recall_at1, group_recall_at2, group_recall_at3, recall_at1, recall_at5, recall_at10, recall_at50
638 |
639 |
640 | def generate_cirr_val_predictions_blip1(blip_textual, relative_val_dataset: CIRRDataset,
641 | combining_function: callable, index_names: List[str], index_features: torch.tensor) -> \
642 | Tuple[torch.tensor, List[str], List[str], List[List[str]]]:
643 | """
644 | Compute CIRR predictions on the validation set
645 | :param clip_model: CLIP model
646 | :param relative_val_dataset: CIRR validation dataset in relative mode
647 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined
648 | features
649 | :param index_features: validation index features
650 | :param index_names: validation index names
651 | :return: predicted features, reference names, target names and group members
652 | """
653 | print("Compute CIRR validation predictions")
654 | relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32, num_workers=8,
655 | pin_memory=True, collate_fn=collate_fn)
656 |
657 | # Get a mapping from index names to index features
658 | name_to_feat = dict(zip(index_names, index_features))
659 |
660 | feature_dim = 256
661 | # Initialize predicted features, target_names, group_members and reference_names
662 | predicted_features = torch.empty((0, feature_dim)).to(device, non_blocking=True)
663 | target_names = []
664 | group_members = []
665 | reference_names = []
666 |
667 | for batch_reference_images, batch_reference_names, batch_target_names, captions, batch_group_members in tqdm(
668 | relative_val_loader): # Load data
669 |
670 | batch_group_members = np.array(batch_group_members).T.tolist()
671 | # Compute the predicted features
672 | with torch.no_grad():
673 | text_features = blip_textual(captions, max_length=77, device=device)[:,0,:]
674 | if text_features.shape[0] == 1:
675 | reference_image_features = itemgetter(*batch_reference_names)(name_to_feat).unsqueeze(0)
676 | else:
677 | reference_image_features = torch.stack(itemgetter(*batch_reference_names)(
678 | name_to_feat)) # To avoid unnecessary computation retrieve the reference image features directly from the index features
679 | batch_predicted_features = combining_function(reference_image_features, text_features)
680 |
681 | predicted_features = torch.vstack((predicted_features, F.normalize(batch_predicted_features, dim=-1)))
682 | target_names.extend(batch_target_names)
683 | group_members.extend(batch_group_members)
684 | reference_names.extend(batch_reference_names)
685 |
686 | return predicted_features, reference_names, target_names, group_members
687 |
688 |
689 |
690 |
691 | def cirr_val_retrieval(combining_function: callable, blip_visual, blip_textual, blip_multimodal, preprocess):
692 | """
693 | Perform retrieval on CIRR validation set computing the metrics. To combine the features the `combining_function`
694 | is used
695 | :param combining_function: function which takes as input (image_features, text_features) and outputs the combined
696 | features
697 | :param clip_model: CLIP model
698 | :param preprocess: preprocess pipeline
699 | """
700 |
701 | blip_visual = blip_visual.float().eval()
702 | blip_textual = blip_textual.float().eval()
703 | blip_multimodal = blip_multimodal.float().eval()
704 | # Define the validation datasets and extract the index features
705 | classic_val_dataset = CIRRDataset('val', 'classic', preprocess)
706 | index_features, index_names = extract_index_features_blip2(classic_val_dataset, blip_visual)
707 | relative_val_dataset = CIRRDataset('val', 'relative', preprocess)
708 |
709 | return compute_cirr_val_metrics_blip2(relative_val_dataset, blip_textual, blip_multimodal, index_features, index_names,
710 | combining_function)
711 |
712 |
713 | def main():
714 | parser = ArgumentParser()
715 | parser.add_argument("--dataset", type=str, required=True, help="should be either 'CIRR' or 'fashionIQ'")
716 | parser.add_argument("--combining-function", type=str, required=True,
717 | help="Which combining function use, should be in ['combiner', 'sum']")
718 | # parser.add_argument("--combiner-path", type=Path, help="path to trained Combiner")
719 | # parser.add_argument("--projection-dim", default=640 * 4, type=int, help='Combiner projection dim')
720 | # parser.add_argument("--hidden-dim", default=640 * 8, type=int, help="Combiner hidden dim")
721 | # parser.add_argument("--clip-model-name", default="RN50x4", type=str, help="CLIP model to use, e.g 'RN50', 'RN50x4'")
722 | # parser.add_argument("--clip-model-path", type=Path, help="Path to the fine-tuned CLIP model")
723 |
724 | parser.add_argument("--blip2-textual-path", type=str, help="Path to the fine-tuned BLIP2 model")
725 | parser.add_argument("--blip2-visual-path", type=str, help="Path to the fine-tuned BLIP2 model")
726 | parser.add_argument("--blip2-multimodal-path", type=str, help="Path to the fine-tuned BLIP2 model")
727 |
728 | parser.add_argument("--target-ratio", default=1.25, type=float, help="TargetPad target ratio")
729 | parser.add_argument("--transform", default="targetpad", type=str,
730 | help="Preprocess pipeline, should be in ['clip', 'squarepad', 'targetpad'] ")
731 |
732 |
733 | args = parser.parse_args()
734 |
735 | # clip_model, clip_preprocess = clip.load(args.clip_model_name, device=device, jit=False)
736 | # input_dim = clip_model.visual.input_resolution
737 | # feature_dim = clip_model.visual.output_dim
738 |
739 | blip_textual_path = args.blip2_textual_path
740 | blip_visual_path = args.blip2_visual_path
741 | blip_multimodal_path = args.blip2_multimodal_path
742 |
743 | blip_textual = load_model(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device)
744 | x_saved_state_dict = torch.load(blip_textual_path, map_location=device)
745 | # print(x_saved_state_dict.keys())
746 | blip_textual.load_state_dict(x_saved_state_dict["Blip2Qformer"])
747 |
748 | blip_multimodal = load_model(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device)
749 | r_saved_state_dict = torch.load(blip_multimodal_path, map_location=device)
750 | blip_multimodal.load_state_dict(r_saved_state_dict["Blip2Qformer"])
751 |
752 | blip_visual = load_model(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device)
753 | t_saved_state_dict = torch.load(blip_visual_path, map_location=device)
754 | blip_visual.load_state_dict(t_saved_state_dict["Blip2Qformer"])
755 |
756 |
757 | # blip_textual = load_model(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device)
758 |
759 | # blip_multimodal = load_model(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device)
760 |
761 | # blip_visual = load_model(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device)
762 |
763 | # if args.clip_model_path:
764 | # print('Trying to load the CLIP model')
765 | # saved_state_dict = torch.load(args.clip_model_path, map_location=device)
766 | # clip_model.load_state_dict(saved_state_dict["CLIP"])
767 | # print('CLIP model loaded successfully')
768 | input_dim = 224
769 | if args.transform == 'targetpad':
770 | print('Target pad preprocess pipeline is used')
771 | preprocess = targetpad_transform(args.target_ratio, input_dim)
772 | elif args.transform == 'squarepad':
773 | print('Square pad preprocess pipeline is used')
774 | preprocess = squarepad_transform(input_dim)
775 | # else:
776 | # print('CLIP default preprocess pipeline is used')
777 | # preprocess = clip_preprocess
778 |
779 | if args.combining_function.lower() == 'sum':
780 | combining_function = element_wise_sum
781 | # elif args.combining_function.lower() == 'combiner':
782 | # combiner = Combiner(feature_dim, args.projection_dim, args.hidden_dim).to(device, non_blocking=True)
783 | # state_dict = torch.load(args.combiner_path, map_location=device)
784 | # combiner.load_state_dict(state_dict["Combiner"])
785 | # combiner.eval()
786 | # combining_function = combiner.combine_features
787 | else:
788 | raise ValueError("combiner_path should be in ['sum', 'combiner']")
789 |
790 | if args.dataset.lower() == 'cirr':
791 | group_recall_at1, group_recall_at2, group_recall_at3, recall_at1, recall_at5, recall_at10, recall_at50 = \
792 | cirr_val_retrieval(combining_function, blip_visual, blip_textual, blip_multimodal , preprocess)
793 |
794 | print(f"{group_recall_at1 = }")
795 | print(f"{group_recall_at2 = }")
796 | print(f"{group_recall_at3 = }")
797 | print(f"{recall_at1 = }")
798 | print(f"{recall_at5 = }")
799 | print(f"{recall_at10 = }")
800 | print(f"{recall_at50 = }")
801 |
802 | elif args.dataset.lower() == 'fashioniq':
803 | average_recall10_list = []
804 | average_recall50_list = []
805 |
806 | shirt_recallat10, shirt_recallat50 = fashioniq_val_retrieval('shirt', combining_function, clip_model,
807 | preprocess)
808 | average_recall10_list.append(shirt_recallat10)
809 | average_recall50_list.append(shirt_recallat50)
810 |
811 | dress_recallat10, dress_recallat50 = fashioniq_val_retrieval('dress', combining_function, clip_model,
812 | preprocess)
813 | average_recall10_list.append(dress_recallat10)
814 | average_recall50_list.append(dress_recallat50)
815 |
816 | toptee_recallat10, toptee_recallat50 = fashioniq_val_retrieval('toptee', combining_function, clip_model,
817 | preprocess)
818 | average_recall10_list.append(toptee_recallat10)
819 | average_recall50_list.append(toptee_recallat50)
820 |
821 | print(f"\n{shirt_recallat10 = }")
822 | print(f"{shirt_recallat50 = }")
823 |
824 | print(f"{dress_recallat10 = }")
825 | print(f"{dress_recallat50 = }")
826 |
827 | print(f"{toptee_recallat10 = }")
828 | print(f"{toptee_recallat50 = }")
829 |
830 | print(f"average recall10 = {mean(average_recall10_list)}")
831 | print(f"average recall50 = {mean(average_recall50_list)}")
832 | else:
833 | raise ValueError("Dataset should be either 'CIRR' or 'FashionIQ")
834 |
835 |
836 | if __name__ == '__main__':
837 | main()
838 |
--------------------------------------------------------------------------------
/src/blip_fine_tune.py:
--------------------------------------------------------------------------------
1 | from comet_ml import Experiment
2 | import json
3 | import multiprocessing
4 | from argparse import ArgumentParser
5 | from datetime import datetime
6 | from pathlib import Path
7 | from statistics import mean, geometric_mean, harmonic_mean
8 | from typing import List
9 |
10 | import numpy as np
11 | import pandas as pd
12 | import torch
13 | import torch.nn.functional as F
14 | from torch import optim, nn
15 | from torch.utils.data import DataLoader
16 | from tqdm import tqdm
17 | from data_utils import base_path, squarepad_transform, targetpad_transform, CIRRDataset, FashionIQDataset
18 | import clip
19 | from lavis.models import load_model
20 | from utils import collate_fn, update_train_running_results, set_train_bar_description, save_model, \
21 | extract_index_features_blip2, generate_randomized_fiq_caption, extract_index_features_clip, \
22 | device, element_wise_sum, cosine_lr_schedule
23 | from validate import compute_cirr_val_metrics_blip2, compute_fiq_val_metrics_blip2, compute_fiq_val_metrics_clip, \
24 | artemis_compute_cirr_val_metrics ,compute_cirr_val_metrics_clip, artemis_compute_fiq_val_metrics
25 | from twin_attention_compositor_blip2 import TwinAttentionCompositorBLIP2
26 | from hinge_based_cross_attention_blip2 import HingebasedCrossAttentionBLIP2
27 | from twin_attention_compositor_clip import TwinAttentionCompositorCLIP
28 | from hinge_based_cross_attention_clip import HingebasedCrossAttentionCLIP
29 | import random
30 | from artemis import Artemis
31 | import ssl
32 |
33 | base_path = Path(__file__).absolute().parents[1].absolute()
34 |
35 | def blip_finetune_fiq(train_dress_types: List[str], val_dress_types: List[str],
36 | num_epochs: int, batch_size: int,
37 | validation_frequency: int, transform: str, save_training: bool, save_best: bool,
38 | **kwargs):
39 | """
40 | Fine-tune blip on the FashionIQ dataset using as combining function the image-text element-wise sum
41 | :param train_dress_types: FashionIQ categories to train on
42 | :param val_dress_types: FashionIQ categories to validate on
43 | :param num_epochs: number of epochs
44 | :param blip_model_name: blip model you want to use: "RN50", "RN101", "RN50x4"...
45 | :param learning_rate: fine-tuning leanring rate
46 | :param batch_size: batch size
47 | :param validation_frequency: validation frequency expressed in epoch
48 | :param transform: preprocess transform you want to use. Should be in ['blip', 'squarepad', 'targetpad']. When
49 | targetpad is also required to provide `target_ratio` kwarg.
50 | :param save_training: when True save the weights of the fine-tuned blip model
51 | :param encoder: which blip encoder to fine-tune, should be in ['both', 'text', 'image']
52 | :param save_best: when True save only the weights of the best blip model wrt the average_recall metric
53 | :param kwargs: if you use the `targetpad` transform you should prove `target_ratio` as kwarg
54 | """
55 |
56 | experiment_name = kwargs["experiment_name"]
57 | # training_start = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
58 | training_path: Path = Path(
59 | base_path / f"models/blip_cirr_{experiment_name}")
60 | training_path.mkdir(exist_ok=False, parents=True)
61 |
62 | # Save all the hyperparameters on a file
63 | with open(training_path / "training_hyperparameters.json", 'w+') as file:
64 | json.dump(training_hyper_params, file, sort_keys=True, indent=4)
65 |
66 | # initialize encoders
67 | encoder_arch = kwargs["encoder_arch"]
68 | model_name = kwargs["model_name"]
69 |
70 | if encoder_arch == "blip2":
71 | blip_textual = load_model(name=model_name, model_type="pretrain_vitL", is_eval=True, device=device)
72 | blip_multimodal = load_model(name=model_name, model_type="pretrain_vitL", is_eval=True, device=device)
73 | blip_visual = load_model(name=model_name, model_type="pretrain_vitL", is_eval=True, device=device)
74 |
75 | elif encoder_arch == "clip":
76 | clip_model, clip_preprocess = clip.load(model_name, device=device, jit=False)
77 | clip_model.eval().float()
78 |
79 | # initialize support modules
80 | embeds_dim = kwargs["embeds_dim"]
81 | if encoder_arch == "blip2":
82 | tac = TwinAttentionCompositorBLIP2(embeds_dim).to(device)
83 | hca = HingebasedCrossAttentionBLIP2(embeds_dim).to(device)
84 |
85 | elif encoder_arch == "clip":
86 | tac = TwinAttentionCompositorCLIP().to(device)
87 | hca = HingebasedCrossAttentionCLIP(embeds_dim).to(device)
88 |
89 | cir_frame = kwargs["cir_frame"]
90 | artemis = Artemis(embeds_dim).to(device)
91 |
92 |
93 | # define the combining func
94 | combining_function = element_wise_sum
95 | # preprocess
96 | if encoder_arch == "blip2":
97 | input_dim = 224
98 | elif encoder_arch == "clip":
99 | input_dim = clip_model.visual.input_resolution
100 |
101 |
102 | if transform == "clip":
103 | preprocess = clip_preprocess
104 | print('CLIP default preprocess pipeline is used')
105 | elif transform == "squarepad":
106 | preprocess = squarepad_transform(input_dim)
107 | print('Square pad preprocess pipeline is used')
108 | elif transform == "targetpad":
109 | target_ratio = kwargs['target_ratio']
110 | preprocess = targetpad_transform(target_ratio, input_dim)
111 | print(f'Target pad with {target_ratio = } preprocess pipeline is used')
112 | else:
113 | raise ValueError("Preprocess transform should be in ['blip', 'squarepad', 'targetpad']")
114 |
115 | idx_to_dress_mapping = {}
116 | relative_val_datasets = []
117 | classic_val_datasets = []
118 |
119 | # When fine-tuning only the text encoder we can precompute the index features since they do not change over
120 | # the epochs
121 | encoder = kwargs["encoder"]
122 | if encoder == 'text':
123 | index_features_list = []
124 | index_names_list = []
125 |
126 | # Define the validation datasets
127 | for idx, dress_type in enumerate(val_dress_types):
128 | idx_to_dress_mapping[idx] = dress_type
129 | relative_val_dataset = FashionIQDataset('val', [dress_type], 'relative', preprocess, )
130 | relative_val_datasets.append(relative_val_dataset)
131 | classic_val_dataset = FashionIQDataset('val', [dress_type], 'classic', preprocess, )
132 | classic_val_datasets.append(classic_val_dataset)
133 | if encoder == 'text':
134 | if encoder_arch == "blip2":
135 | index_features_and_names = extract_index_features_blip2(classic_val_dataset, clip_model)
136 | index_features_list.append(index_features_and_names[0])
137 | index_names_list.append(index_features_and_names[1])
138 |
139 | if encoder_arch == "clip":
140 | index_features_and_names = extract_index_features_clip(classic_val_dataset, clip_model)
141 | index_features_list.append(index_features_and_names[0])
142 | index_names_list.append(index_features_and_names[1])
143 |
144 |
145 | # Define the train datasets and the combining function
146 | relative_train_dataset = FashionIQDataset('train', train_dress_types, 'relative', preprocess)
147 | relative_train_loader = DataLoader(dataset=relative_train_dataset, batch_size=batch_size,
148 | num_workers=0, pin_memory=False, collate_fn=collate_fn,
149 | drop_last=True, shuffle=True)
150 |
151 | # Define the optimizer, the loss and the grad scaler
152 | learning_rate = kwargs["learning_rate"]
153 | min_lr = kwargs["min_lr"]
154 | max_epoch = kwargs["max_epoch"]
155 |
156 |
157 | # Define the optimizer, the loss and the grad scaler
158 | if encoder_arch == "blip2":
159 | # # blip2_encoder
160 | if encoder == 'multi': # only finetuning text_encoder
161 | optimizer = optim.AdamW([ # param in blip_multimodal
162 | {'params': [param for param in blip_multimodal.Qformer.bert.parameters()],
163 | 'lr': learning_rate,
164 | 'weight_decay': 0.05},
165 | {'params': blip_multimodal.query_tokens,
166 | 'lr': learning_rate,
167 | 'weight_decay': 0.05},
168 | # param in blip_textual
169 | {'params': [param for param in blip_textual.Qformer.bert.parameters()],
170 | 'lr': learning_rate,
171 | 'weight_decay': 0.05},
172 |
173 | # params in support modules
174 | {'params': [param for param in tac.parameters()],
175 | 'lr': learning_rate * 2,
176 | 'weight_decay': 0.05},
177 |
178 | {'params': [param for param in hca.parameters()],
179 | 'lr': learning_rate * 2,
180 | 'weight_decay': 0.05},
181 |
182 | {'params': [param for param in artemis.parameters()],
183 | 'lr': learning_rate * 10,
184 | 'weight_decay': 0.05},
185 | ])
186 |
187 | elif encoder == 'both': # finetuning textual and visual concurrently
188 | optimizer = optim.AdamW([
189 | # param in blip_multimodal
190 | {'params': [param for param in blip_multimodal.Qformer.bert.parameters()],
191 | 'lr': learning_rate,
192 | 'weight_decay': 0.05},
193 | {'params': blip_multimodal.query_tokens,
194 | 'lr': learning_rate,
195 | 'weight_decay': 0.05},
196 |
197 | # param in blip_textual
198 | {'params': [param for param in blip_textual.Qformer.bert.parameters()],
199 | 'lr': learning_rate,
200 | 'weight_decay': 0.05},
201 |
202 | # param in blip_visual
203 | {'params': [param for param in blip_visual.Qformer.bert.parameters()],
204 | 'lr': learning_rate,
205 | 'weight_decay': 0.05},
206 | {'params': blip_visual.query_tokens,
207 | 'lr': learning_rate,
208 | 'weight_decay': 0.05},
209 |
210 | # params in support modules
211 |
212 | {'params': [param for param in tac.parameters()],
213 | 'lr': learning_rate * 2,
214 | 'weight_decay': 0.05},
215 |
216 | {'params': [param for param in hca.parameters()],
217 | 'lr': learning_rate * 2,
218 | 'weight_decay': 0.05},
219 |
220 | {'params': [param for param in artemis.parameters()],
221 | 'lr': learning_rate * 10,
222 | 'weight_decay': 0.05},
223 | ])
224 |
225 | else:
226 | raise ValueError("encoders to finetune must be 'multi' or 'both'")
227 | elif encoder_arch == "clip":
228 | # clip_encoder
229 | optimizer = optim.AdamW([ # param in blip_multimodal
230 | {'params': [param for name, param in clip_model.named_parameters()
231 | if 'visual' not in name],
232 | 'lr': learning_rate,
233 | 'weight_decay': 0.05},
234 |
235 | # params in support modules
236 | {'params': [param for param in tac.parameters()],
237 | 'lr': learning_rate * 2,
238 | 'weight_decay': 0.05},
239 |
240 | {'params': [param for param in hca.parameters()],
241 | 'lr': learning_rate * 2,
242 | 'weight_decay': 0.05},
243 | ])
244 |
245 | crossentropy_criterion = nn.CrossEntropyLoss()
246 | scaler = torch.cuda.amp.GradScaler()
247 |
248 | # When save_best == True initialize the best result to zero
249 | if save_best:
250 | best_avg_recall = 0
251 |
252 | # Define dataframes for CSV logging
253 | training_log_frame = pd.DataFrame()
254 | validation_log_frame = pd.DataFrame()
255 |
256 | # define weights for different modules
257 | tac_weight = kwargs["tac_weight"]
258 | hca_weight = kwargs["hca_weight"]
259 |
260 | # Start with the training loop
261 | print('Training loop started')
262 | for epoch in range(num_epochs):
263 | with experiment.train():
264 | # encoder = "text" or "both"
265 | # set models to train mode
266 | if encoder_arch == "blip2":
267 | # # blip2_encoder
268 | blip_multimodal.Qformer.bert.train()
269 | # blip_multimodal.vision_proj.train()
270 | blip_multimodal.query_tokens.requires_grad = True
271 | blip_textual.Qformer.bert.train()
272 | # blip_textual.text_proj.train()
273 |
274 | # both adds param in visual_encoder
275 | # blip2_encoder
276 | if encoder == "both":
277 | blip_visual.Qformer.bert.train()
278 | # blip_visual.vision_proj.train()
279 | blip_visual.query_tokens.requires_grad = True
280 |
281 | elif encoder_arch == "clip":
282 | # clip_encoder
283 | clip_model.train()
284 |
285 | # support modules
286 | if tac_weight > 0:
287 | tac.train()
288 | if hca_weight > 0:
289 | hca.train()
290 | if cir_frame == "artemis":
291 | artemis.train()
292 |
293 |
294 | train_running_results = {'images_in_epoch': 0, 'accumulated_train_loss': 0}
295 | train_bar = tqdm(relative_train_loader, ncols=150)
296 |
297 | # adjust learning rate
298 | cosine_lr_schedule(optimizer, epoch, max_epoch, learning_rate, min_lr, onlyGroup0=True)
299 |
300 | for idx, (reference_images, target_images, captions) in enumerate(train_bar):
301 | images_in_batch = reference_images.size(0)
302 | step = len(train_bar) * epoch + idx
303 | optimizer.zero_grad()
304 |
305 | # move ref and tar img to device
306 | reference_images = reference_images.to(device, non_blocking=True)
307 | target_images = target_images.to(device, non_blocking=True)
308 |
309 | # Randomize the training caption in four way: (a) cap1 and cap2 (b) cap2 and cap1 (c) cap1 (d) cap2
310 | flattened_captions: list = np.array(captions).T.flatten().tolist()
311 | input_captions = generate_randomized_fiq_caption(flattened_captions)
312 |
313 | # Extract the features, compute the logits and the loss
314 |
315 | with torch.cuda.amp.autocast():
316 | if encoder_arch == "blip2":
317 | # text
318 | text_embeds = blip_textual.extract_features({"text_input":input_captions},
319 | mode="text").text_embeds
320 | text_feats = text_embeds[:,0,:]
321 |
322 | # target
323 | target_embeds = blip_visual.extract_features({"image":target_images},
324 | mode="image").image_embeds
325 | target_feats = F.normalize(target_embeds[:,0,:], dim=-1)
326 |
327 | # reference
328 | reference_embeds = blip_multimodal.extract_features({"image":reference_images,
329 | "text_input":input_captions}).multimodal_embeds
330 | reference_feats = reference_embeds[:,0,:]
331 |
332 | # embeds encoded with visual_encoder
333 | reference_embeds_for_tac = blip_visual.extract_features({"image":reference_images},
334 | mode="image").image_embeds
335 | elif encoder_arch == "clip":
336 | # reference
337 | reference_feats, reference_embeds = clip_model.encode_image(reference_images)
338 | reference_embeds_for_tac = reference_embeds
339 |
340 | # text
341 | text_inputs = clip.tokenize(input_captions, context_length=77, truncate=True).to(device, non_blocking=True)
342 | text_feats, text_embeds = clip_model.encode_text(text_inputs)
343 |
344 | # target
345 | target_feats, target_embeds = clip_model.encode_image(target_images)
346 | target_feats = F.normalize(target_feats)
347 |
348 | # ============ Query-Target Contrastive ===========
349 |
350 | if cir_frame == "artemis":
351 | # artemis
352 | artemis_scores = artemis.compute_score_broadcast_artemis(reference_feats, text_feats, target_feats)
353 | # artemis_logits = artemis.temperature.exp() * artemis_scores
354 | elif cir_frame == "sum":
355 | # sum_predicted
356 | predicted_feats = combining_function(reference_feats, text_feats)
357 | matching_logits = 100 * predicted_feats @ target_feats.T
358 | # ============ Query-Target Align ===========
359 |
360 | # align(tac)
361 | if tac_weight > 0:
362 | visual_gap_feats = tac(reference_embeds_for_tac, target_embeds)
363 | aligning_logits = 10 * text_feats @ visual_gap_feats.T
364 |
365 | # ============ Reference-Caption-Target Contrastive ===========
366 | if hca_weight > 0:
367 | psudo_T = hca(reference_embeds = reference_embeds,
368 | caption_embeds = text_embeds,
369 | target_embeds = target_embeds)
370 |
371 | reasoning_logits = 10 * psudo_T @ reference_feats.T
372 |
373 |
374 | # ============ LOSS ===========
375 | # align_loss / tac
376 | # align_loss = crossentropy_criterion(align_logits, ground_truth)
377 |
378 | # hca_loss
379 | # hca_tcr_loss = crossentropy_criterion(hca_logits, ground_truth)
380 |
381 | if cir_frame == "artemis":
382 | # artemis_loss
383 | # artemis_loss = crossentropy_criterion(artemis_logits, ground_truth)
384 | if tac_weight > 0:
385 | contrastive_logits = tac_weight * aligning_logits + \
386 | (1 - tac_weight) * artemis_scores
387 | else:
388 | contrastive_logits = artemis_scores
389 | elif cir_frame == "sum":
390 | # contrastive loss
391 | # contrastive_loss = crossentropy_criterion(contrast_logits, ground_truth)
392 | if tac_weight > 0:
393 | contrastive_logits = tac_weight * aligning_logits + \
394 | (1 - tac_weight) * matching_logits
395 | else:
396 | contrastive_logits = matching_logits
397 |
398 | # ========== Sum_Loss ===============
399 | # hca_loss
400 | ground_truth = torch.arange(images_in_batch, dtype=torch.long, device=device)
401 | contrastive_loss = crossentropy_criterion(contrastive_logits, ground_truth)
402 | if hca_weight > 0:
403 | reasoning_loss = crossentropy_criterion(reasoning_logits, ground_truth)
404 | loss = hca_weight * reasoning_loss + (1 - hca_weight) * contrastive_loss
405 | else:
406 | loss = contrastive_loss
407 |
408 | # Backpropagate and update the weights
409 | scaler.scale(loss).backward()
410 | scaler.step(optimizer)
411 | scaler.update()
412 |
413 | experiment.log_metric('step_loss', loss.detach().cpu().item(), step=step)
414 | update_train_running_results(train_running_results, loss, images_in_batch)
415 | set_train_bar_description(train_bar, epoch, num_epochs, train_running_results)
416 |
417 | train_epoch_loss = float(
418 | train_running_results['accumulated_train_loss'] / train_running_results['images_in_epoch'])
419 | experiment.log_metric('epoch_loss', train_epoch_loss, epoch=epoch)
420 |
421 | # Training CSV logging
422 | training_log_frame = pd.concat(
423 | [training_log_frame,
424 | pd.DataFrame(data={'epoch': epoch, 'train_epoch_loss': train_epoch_loss}, index=[0])])
425 | training_log_frame.to_csv(str(training_path / 'train_metrics.csv'), index=False)
426 |
427 | if epoch % validation_frequency == 0:
428 | with experiment.validate():
429 | recalls_at10 = []
430 | recalls_at50 = []
431 |
432 | # Compute and log validation metrics for each validation dataset (which corresponds to a different
433 | # FashionIQ category)
434 |
435 | for relative_val_dataset, classic_val_dataset, idx in zip(relative_val_datasets, classic_val_datasets,
436 | idx_to_dress_mapping):
437 | if encoder == 'text':
438 | index_features, index_names = index_features_list[idx], index_names_list[idx]
439 | else:
440 | if encoder_arch == "blip2":
441 | index_features, index_names = extract_index_features_blip2(classic_val_dataset, blip_visual)
442 | else:
443 | index_features, index_names = extract_index_features_clip(classic_val_dataset, clip_model)
444 |
445 | if cir_frame == "sum":
446 | if encoder_arch == "blip2":
447 |
448 | recall_at10, recall_at50 = compute_fiq_val_metrics_blip2(relative_val_dataset,
449 | blip_textual,
450 | blip_multimodal,
451 | index_features,
452 | index_names,
453 | combining_function)
454 | else:
455 | recall_at10, recall_at50 = compute_fiq_val_metrics_clip(relative_val_dataset,
456 | blip_textual,
457 | blip_multimodal,
458 | index_features,
459 | index_names,
460 | combining_function)
461 | elif cir_frame == "artemis":
462 | recall_at10, recall_at50 = artemis_compute_fiq_val_metrics(relative_val_dataset,
463 | blip_textual,
464 | blip_multimodal,
465 | index_features,
466 | index_names,
467 | artemis)
468 |
469 | recalls_at10.append(recall_at10)
470 | recalls_at50.append(recall_at50)
471 |
472 | results_dict = {}
473 | for i in range(len(recalls_at10)):
474 | results_dict[f'{idx_to_dress_mapping[i]}_recall_at10'] = recalls_at10[i]
475 | results_dict[f'{idx_to_dress_mapping[i]}_recall_at50'] = recalls_at50[i]
476 | results_dict.update({
477 | f'average_recall_at10': mean(recalls_at10),
478 | f'average_recall_at50': mean(recalls_at50),
479 | f'average_recall': (mean(recalls_at50) + mean(recalls_at10)) / 2
480 | })
481 |
482 | print(json.dumps(results_dict, indent=4))
483 | experiment.log_metrics(
484 | results_dict,
485 | epoch=epoch
486 | )
487 |
488 | # Validation CSV logging
489 | log_dict = {'epoch': epoch}
490 | log_dict.update(results_dict)
491 | validation_log_frame = pd.concat([validation_log_frame, pd.DataFrame(data=log_dict, index=[0])])
492 | validation_log_frame.to_csv(str(training_path / 'validation_metrics.csv'), index=False)
493 |
494 | if save_training:
495 | if save_best and results_dict['average_recall'] > best_avg_recall:
496 | best_avg_recall = results_dict['average_recall']
497 | if encoder_arch == "blip2":
498 | save_model('tuned_blip_text_arithmetic', epoch, blip_textual, training_path)
499 | save_model('tuned_blip_multi_arithmetic', epoch, blip_multimodal, training_path)
500 | save_model('tuned_blip_visual_arithmetic', epoch, blip_visual, training_path)
501 | elif encoder_arch == "clip":
502 | save_model('tuned_clip_arithmetic', epoch, clip_model, training_path)
503 |
504 | save_model('tuned_tac', epoch, tac, training_path)
505 | save_model('tuned_hca', epoch, hca, training_path)
506 | if not save_best:
507 | print("Warning!!!! Now you don't save any models, please set save_best==True")
508 |
509 |
510 | def blip_finetune_cirr(num_epochs: int, batch_size: int,
511 | validation_frequency: int, transform: str, save_training: bool, save_best: bool,
512 | **kwargs):
513 | """
514 | Fine-tune blip on the CIRR dataset using as combining function the image-text element-wise sum
515 | :param num_epochs: number of epochs
516 | :param blip_model_name: blip model you want to use: "RN50", "RN101", "RN50x4"...
517 | :param learning_rate: fine-tuning learning rate
518 | :param batch_size: batch size
519 | :param validation_frequency: validation frequency expressed in epoch
520 | :param transform: preprocess transform you want to use. Should be in ['blip', 'squarepad', 'targetpad']. When
521 | targetpad is also required to provide `target_ratio` kwarg.
522 | :param save_training: when True save the weights of the Combiner network
523 | :param encoder: which blip encoder to fine-tune, should be in ['both', 'text', 'image']
524 | :param save_best: when True save only the weights of the best Combiner wrt three different averages of the metrics
525 | :param kwargs: if you use the `targetpad` transform you should prove `target_ratio` :return:
526 | """
527 | experiment_name = kwargs["experiment_name"]
528 | # training_start = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
529 | training_path: Path = Path(
530 | base_path / f"models/blip_cirr_{experiment_name}")
531 | training_path.mkdir(exist_ok=False, parents=True)
532 |
533 | # Save all the hyperparameters on a file
534 | with open(training_path / "training_hyperparameters.json", 'w+') as file:
535 | json.dump(training_hyper_params, file, sort_keys=True, indent=4)
536 |
537 | # initialize encoders
538 | encoder_arch = kwargs["encoder_arch"]
539 |
540 | # initialize the encoders with different arch
541 | model_name = kwargs["model_name"]
542 | if encoder_arch == "blip2":
543 | # blip2_encoders
544 | blip_textual = load_model(name=model_name, model_type="pretrain_vitL", is_eval=True, device=device)
545 | blip_multimodal = load_model(name=model_name, model_type="pretrain_vitL", is_eval=True, device=device)
546 | blip_visual = load_model(name=model_name, model_type="pretrain_vitL", is_eval=True, device=device)
547 |
548 | elif encoder_arch == "clip":
549 | clip_model, clip_preprocess = clip.load(model_name, device=device, jit=False)
550 | clip_model.eval().float()
551 | # initialize support modules
552 | embeds_dim = kwargs["embeds_dim"]
553 | if encoder_arch == "blip2":
554 | tac = TwinAttentionCompositorBLIP2(embeds_dim).to(device)
555 | hca = HingebasedCrossAttentionBLIP2(embeds_dim).to(device)
556 |
557 | elif encoder_arch == "clip":
558 | tac = TwinAttentionCompositorCLIP().to(device)
559 | hca = HingebasedCrossAttentionCLIP(embeds_dim).to(device)
560 |
561 | cir_frame = kwargs["cir_frame"]
562 | artemis = Artemis(embeds_dim).to(device)
563 |
564 | # defined the combining func
565 | combining_function = element_wise_sum
566 |
567 | # preprocess
568 | if encoder_arch == "blip2":
569 | input_dim = 224
570 | elif encoder_arch == "clip":
571 | input_dim = clip_model.visual.input_resolution
572 |
573 | if transform == "clip":
574 | preprocess = clip_preprocess
575 | print('CLIP default preprocess pipeline is used')
576 | elif transform == "squarepad":
577 | preprocess = squarepad_transform(input_dim)
578 | print('Square pad preprocess pipeline is used')
579 | elif transform == "targetpad":
580 | target_ratio = kwargs['target_ratio']
581 | preprocess = targetpad_transform(target_ratio, input_dim)
582 | print(f'Target pad with {target_ratio = } preprocess pipeline is used')
583 | else:
584 | raise ValueError("Preprocess transform should be in ['squarepad', 'targetpad']")
585 |
586 | # Define the validation datasets
587 | relative_val_dataset = CIRRDataset('val', 'relative', preprocess)
588 | classic_val_dataset = CIRRDataset('val', 'classic', preprocess)
589 |
590 | # When fine-tuning only the text encoder we can precompute the index features since they do not change over
591 | # the epochs
592 | encoder = kwargs["encoder"]
593 |
594 | if encoder_arch == "blip2":
595 | val_index_features, val_index_names = extract_index_features_blip2(classic_val_dataset, blip_visual)
596 | if encoder_arch == "clip":
597 | val_index_features, val_index_names = extract_index_features_clip(classic_val_dataset, clip_model)
598 | if encoder == 'text':
599 | if encoder_arch == "blip2":
600 | val_index_features, val_index_names = extract_index_features_blip2(classic_val_dataset, blip_visual)
601 | if encoder_arch == "clip":
602 | val_index_features, val_index_names = extract_index_features_clip(classic_val_dataset, clip_model)
603 |
604 | # debug for validation
605 | if encoder_arch == "blip2":
606 | results = artemis_compute_cirr_val_metrics(relative_val_dataset,
607 | blip_textual,
608 | blip_multimodal,
609 | val_index_features,
610 | val_index_names,
611 | artemis)
612 | else:
613 | results = compute_cirr_val_metrics_clip(relative_val_dataset, clip_model, val_index_features,
614 | val_index_names, combining_function)
615 |
616 |
617 | # Define the train dataset and the combining function
618 | relative_train_dataset = CIRRDataset('train', 'relative', preprocess)
619 | relative_train_loader = DataLoader(dataset=relative_train_dataset, batch_size=batch_size,
620 | num_workers=0, pin_memory=False, collate_fn=collate_fn,
621 | drop_last=True, shuffle=True)
622 |
623 | # Define the optimizer, the loss and the grad scaler
624 | learning_rate = kwargs["learning_rate"]
625 | min_lr = kwargs["min_lr"]
626 | max_epoch = kwargs["max_epoch"]
627 |
628 | # Define the optimizer, the loss and the grad scaler
629 | if encoder_arch == "blip2":
630 | # # blip2_encoder
631 | if encoder == 'multi': # only finetuning text_encoder
632 | optimizer = optim.AdamW([ # param in blip_multimodal
633 | {'params': [param for param in blip_multimodal.Qformer.bert.parameters()],
634 | 'lr': learning_rate,
635 | 'weight_decay': 0.05},
636 | {'params': blip_multimodal.query_tokens,
637 | 'lr': learning_rate,
638 | 'weight_decay': 0.05},
639 | # param in blip_textual
640 | {'params': [param for param in blip_textual.Qformer.bert.parameters()],
641 | 'lr': learning_rate,
642 | 'weight_decay': 0.05},
643 |
644 | # params in support modules
645 | {'params': [param for param in tac.parameters()],
646 | 'lr': learning_rate * 2,
647 | 'weight_decay': 0.05},
648 |
649 | {'params': [param for param in hca.parameters()],
650 | 'lr': learning_rate * 2,
651 | 'weight_decay': 0.05},
652 |
653 | {'params': [param for param in artemis.parameters()],
654 | 'lr': learning_rate * 10,
655 | 'weight_decay': 0.05},
656 | ])
657 |
658 | elif encoder == 'both': # finetuning textual and visual concurrently
659 | optimizer = optim.AdamW([
660 | # param in blip_multimodal
661 | {'params': [param for param in blip_multimodal.Qformer.bert.parameters()],
662 | 'lr': learning_rate,
663 | 'weight_decay': 0.05},
664 | {'params': blip_multimodal.query_tokens,
665 | 'lr': learning_rate,
666 | 'weight_decay': 0.05},
667 |
668 | # param in blip_textual
669 | {'params': [param for param in blip_textual.Qformer.bert.parameters()],
670 | 'lr': learning_rate,
671 | 'weight_decay': 0.05},
672 |
673 | # param in blip_visual
674 | {'params': [param for param in blip_visual.Qformer.bert.parameters()],
675 | 'lr': learning_rate,
676 | 'weight_decay': 0.05},
677 | {'params': blip_visual.query_tokens,
678 | 'lr': learning_rate,
679 | 'weight_decay': 0.05},
680 |
681 | # params in support modules
682 |
683 | {'params': [param for param in tac.parameters()],
684 | 'lr': learning_rate * 2,
685 | 'weight_decay': 0.05},
686 |
687 | {'params': [param for param in hca.parameters()],
688 | 'lr': learning_rate * 2,
689 | 'weight_decay': 0.05},
690 |
691 | {'params': [param for param in artemis.parameters()],
692 | 'lr': learning_rate * 10,
693 | 'weight_decay': 0.05},
694 | ])
695 |
696 | else:
697 | raise ValueError("encoders to finetune must be 'multi' or 'both'")
698 | elif encoder_arch == "clip":
699 | # clip_encoder
700 | optimizer = optim.AdamW([ # param in blip_multimodal
701 | {'params': [param for name, param in clip_model.named_parameters()
702 | if 'visual' not in name],
703 | 'lr': learning_rate,
704 | 'weight_decay': 0.05},
705 |
706 | # params in support modules
707 | {'params': [param for param in tac.parameters()],
708 | 'lr': learning_rate * 2,
709 | 'weight_decay': 0.05},
710 |
711 | {'params': [param for param in hca.parameters()],
712 | 'lr': learning_rate * 2,
713 | 'weight_decay': 0.05},
714 | ])
715 |
716 | # define loss function and scaler
717 | crossentropy_criterion = nn.CrossEntropyLoss()
718 | scaler = torch.cuda.amp.GradScaler()
719 |
720 | # When save_best == True initialize the best results to zero
721 | if save_best:
722 | # best_harmonic = 0
723 | # best_geometric = 0
724 | best_arithmetic = 0
725 |
726 | # Define dataframes for CSV logging
727 | training_log_frame = pd.DataFrame()
728 | validation_log_frame = pd.DataFrame()
729 |
730 | # define weights for different modules
731 | tac_weight = kwargs["tac_weight"]
732 | hca_weight = kwargs["hca_weight"]
733 |
734 | # epoch loop
735 | for epoch in range(num_epochs):
736 | with experiment.train():
737 | # encoder = "text" or "both"
738 | # set models to train mode
739 | if encoder_arch == "blip2":
740 | # # blip2_encoder
741 | blip_multimodal.Qformer.bert.train()
742 | # blip_multimodal.vision_proj.train()
743 | blip_multimodal.query_tokens.requires_grad = True
744 | blip_textual.Qformer.bert.train()
745 | # blip_textual.text_proj.train()
746 |
747 | # both adds param in visual_encoder
748 | # blip2_encoder
749 | if encoder == "both":
750 | blip_visual.Qformer.bert.train()
751 | # blip_visual.vision_proj.train()
752 | blip_visual.query_tokens.requires_grad = True
753 |
754 | elif encoder_arch == "clip":
755 | # clip_encoder
756 | clip_model.train()
757 |
758 | # support modules
759 | if tac_weight > 0:
760 | tac.train()
761 | if hca_weight > 0:
762 | hca.train()
763 | if cir_frame == "artemis":
764 | artemis.train()
765 |
766 | train_running_results = {'images_in_epoch': 0, 'accumulated_train_loss': 0}
767 | train_bar = tqdm(relative_train_loader, ncols=150)
768 |
769 | # adjust learning rate in every epoch
770 | cosine_lr_schedule(optimizer, epoch, max_epoch, learning_rate, min_lr, onlyGroup0=True)
771 |
772 | # iteration loop
773 | for idx, (reference_images, target_images, captions) in enumerate(train_bar):
774 | images_in_batch = reference_images.size(0)
775 | step = len(train_bar) * epoch + idx
776 | optimizer.zero_grad()
777 |
778 | # move ref and tar img to device
779 | reference_images = reference_images.to(device, non_blocking=True)
780 | target_images = target_images.to(device, non_blocking=True)
781 |
782 | # Extract the features, compute the logits and the loss
783 |
784 | with torch.cuda.amp.autocast():
785 | if encoder_arch == "blip2":
786 | # text
787 | text_embeds = blip_textual.extract_features({"text_input":captions},
788 | mode="text").text_embeds
789 | text_feats = text_embeds[:,0,:]
790 |
791 | # target
792 | target_embeds = blip_visual.extract_features({"image":target_images},
793 | mode="image").image_embeds
794 | target_feats = F.normalize(target_embeds[:,0,:], dim=-1)
795 |
796 | # reference
797 | reference_embeds = blip_multimodal.extract_features({"image":reference_images,
798 | "text_input":captions}).multimodal_embeds
799 | reference_feats = reference_embeds[:,0,:]
800 |
801 | # embeds encoded with visual_encoder
802 | reference_embeds_for_tac = blip_visual.extract_features({"image":reference_images},
803 | mode="image").image_embeds
804 | elif encoder_arch == "clip":
805 | # reference
806 | reference_feats, reference_embeds = clip_model.encode_image(reference_images)
807 | reference_embeds_for_tac = reference_embeds
808 |
809 | # text
810 | text_inputs = clip.tokenize(captions, context_length=77, truncate=True).to(device,non_blocking=True)
811 | text_feats, text_embeds = clip_model.encode_text(text_inputs)
812 |
813 | # target
814 | target_feats, target_embeds = clip_model.encode_image(target_images)
815 | target_feats = F.normalize(target_feats)
816 |
817 |
818 | # ============ Query-Target Contrastive ===========
819 |
820 | if cir_frame == "artemis":
821 | # artemis
822 | artemis_scores = artemis.compute_score_broadcast_artemis(reference_feats, text_feats, target_feats)
823 | # artemis_logits = artemis.temperature.exp() * artemis_scores
824 | elif cir_frame == "sum":
825 | # sum_predicted
826 | predicted_feats = combining_function(reference_feats, text_feats)
827 | matching_logits = 100 * predicted_feats @ target_feats.T
828 |
829 | # ============ Query-Target Align ===========
830 |
831 | # align(tac)
832 | if tac_weight > 0:
833 | visual_gap_feats = tac(reference_embeds_for_tac, target_embeds)
834 | aligning_logits = 10 * text_feats @ visual_gap_feats.T
835 |
836 | # ============ Reference-Caption-Target Contrastive ===========
837 | if hca_weight > 0:
838 | psudo_T = hca(reference_embeds = reference_embeds,
839 | caption_embeds = text_embeds,
840 | target_embeds = target_embeds)
841 |
842 | reasoning_logits = 10 * psudo_T @ reference_feats.T
843 |
844 | # ============ LOSS ===========
845 | # align_loss / tac
846 | # align_loss = crossentropy_criterion(align_logits, ground_truth)
847 |
848 | # hca_loss
849 | # hca_tcr_loss = crossentropy_criterion(hca_logits, ground_truth)
850 |
851 | if cir_frame == "artemis":
852 | # artemis_loss
853 | # artemis_loss = crossentropy_criterion(artemis_logits, ground_truth)
854 | if tac_weight > 0:
855 | contrastive_logits = tac_weight * aligning_logits + \
856 | (1 - tac_weight) * artemis_scores
857 | else:
858 | contrastive_logits = artemis_scores
859 | elif cir_frame == "sum":
860 | # contrastive loss
861 | # contrastive_loss = crossentropy_criterion(contrast_logits, ground_truth)
862 | if tac_weight > 0:
863 | contrastive_logits = tac_weight * aligning_logits + \
864 | (1 - tac_weight) * matching_logits
865 | else:
866 | contrastive_logits = matching_logits
867 |
868 | # ========== Sum_Loss ===============
869 | # hca_loss
870 | ground_truth = torch.arange(images_in_batch, dtype=torch.long, device=device)
871 | contrastive_loss = crossentropy_criterion(contrastive_logits, ground_truth)
872 | if hca_weight > 0:
873 | reasoning_loss = crossentropy_criterion(reasoning_logits, ground_truth)
874 | loss = hca_weight * reasoning_loss + (1 - hca_weight) * contrastive_loss
875 | else:
876 | loss = contrastive_loss
877 |
878 | # Backpropagate and update the weights
879 | scaler.scale(loss).backward()
880 | scaler.step(optimizer)
881 | scaler.update()
882 |
883 | experiment.log_metric('step_loss', loss.detach().cpu().item(), step=step)
884 | update_train_running_results(train_running_results, loss, images_in_batch)
885 | set_train_bar_description(train_bar, epoch, num_epochs, train_running_results)
886 |
887 | train_epoch_loss = float(
888 | train_running_results['accumulated_train_loss'] / train_running_results['images_in_epoch'])
889 | experiment.log_metric('epoch_loss', train_epoch_loss, epoch=epoch)
890 |
891 | # Training CSV logging
892 | training_log_frame = pd.concat(
893 | [training_log_frame,
894 | pd.DataFrame(data={'epoch': epoch, 'train_epoch_loss': train_epoch_loss}, index=[0])])
895 | training_log_frame.to_csv(str(training_path / 'train_metrics.csv'), index=False)
896 |
897 | if epoch % validation_frequency == 0:
898 | with experiment.validate():
899 | if encoder_arch == "blip2":
900 | val_index_features, val_index_names = extract_index_features_blip2(classic_val_dataset, blip_visual)
901 | if cir_frame == "sum":
902 | results = compute_cirr_val_metrics_blip2(relative_val_dataset,
903 | blip_textual,
904 | blip_multimodal,
905 | val_index_features,
906 | val_index_names,
907 | combining_function)
908 | elif cir_frame == "artemis":
909 | results = artemis_compute_cirr_val_metrics(relative_val_dataset,
910 | blip_textual,
911 | blip_multimodal,
912 | val_index_features,
913 | val_index_names,
914 | artemis)
915 | elif encoder_arch == "clip":
916 | results = compute_cirr_val_metrics_clip(relative_val_dataset, clip_model, val_index_features,
917 | val_index_names, combining_function)
918 |
919 |
920 | group_recall_at1, group_recall_at2, group_recall_at3, recall_at1, recall_at5, recall_at10, recall_at50 = results
921 |
922 | results_dict = {
923 | 'group_recall_at1': group_recall_at1,
924 | 'group_recall_at2': group_recall_at2,
925 | 'group_recall_at3': group_recall_at3,
926 | 'recall_at1': recall_at1,
927 | 'recall_at5': recall_at5,
928 | 'recall_at10': recall_at10,
929 | 'recall_at50': recall_at50,
930 | 'mean(R@5+R_s@1)': (group_recall_at1 + recall_at5) / 2,
931 | 'arithmetic_mean': mean(results),
932 | 'harmonic_mean': harmonic_mean(results),
933 | 'geometric_mean': geometric_mean(results)
934 | }
935 | print(json.dumps(results_dict, indent=4))
936 |
937 | experiment.log_metrics(
938 | results_dict,
939 | epoch=epoch
940 | )
941 |
942 | # Validation CSV logging
943 | log_dict = {'epoch': epoch}
944 | log_dict.update(results_dict)
945 | validation_log_frame = pd.concat([validation_log_frame, pd.DataFrame(data=log_dict, index=[0])])
946 | validation_log_frame.to_csv(str(training_path / 'validation_metrics.csv'), index=False)
947 |
948 | if save_training:
949 | if save_best and results_dict['arithmetic_mean'] > best_arithmetic:
950 | best_arithmetic = results_dict['arithmetic_mean']
951 | # save encoders
952 | if encoder_arch == "blip2":
953 | save_model('tuned_blip_text_arithmetic', epoch, blip_textual, training_path)
954 | save_model('tuned_blip_multi_arithmetic', epoch, blip_multimodal, training_path)
955 | save_model('tuned_blip_visual_arithmetic', epoch, blip_visual, training_path)
956 | elif encoder_arch == "clip":
957 | save_model('tuned_clip_arithmetic', epoch, clip_model, training_path)
958 | # save support modules anyway
959 | save_model('tuned_tac_arithmetic', epoch, tac, training_path)
960 | save_model('tuned_hca_arithmetic', epoch, hca, training_path)
961 | # save artemis modules
962 | if cir_frame == "artemis":
963 | save_model('tuned_artemis_arithmetic', epoch, artemis, training_path)
964 | if not save_best:
965 | print("Warning!!!! Now you don't save any models, please set save_best==True")
966 |
967 |
968 | if __name__ == '__main__':
969 | parser = ArgumentParser()
970 | # dataset
971 | parser.add_argument("--dataset", type=str, required=True, help="should be either 'CIRR' or 'fashionIQ'")
972 | # comet enviroment
973 | parser.add_argument("--api-key", type=str, help="api for Comet logging")
974 | parser.add_argument("--workspace", type=str, help="workspace of Comet logging")
975 | parser.add_argument("--experiment-name", type=str, help="name of the experiment on Comet")
976 | # fine_tune_encoder_modal
977 | parser.add_argument("--encoder", type=str, default="text", help="the encoder that needs to be finetuned")
978 | parser.add_argument("--encoder-arch", type=str, default="blip2", help="the encoder architecture")
979 | parser.add_argument("--model-name", type=str, default="blip2_feature_extractor", help="the model used for encoder blip2_feature_extractor for blip2 or RN50x4 for clip")
980 | # training args
981 | parser.add_argument("--num-epochs", default=300, type=int, help="number training epochs")
982 | parser.add_argument("--learning-rate", default=1e-5, type=float, help="Learning rate")
983 | parser.add_argument("--batch-size", default=512, type=int, help="Batch size")
984 | # cosin learning rate scheduler
985 | parser.add_argument("--min-lr", default=0, type=float, help="Cos Learning Rate Scheduler min learning rate")
986 | parser.add_argument("--max-epoch", default=10, type=int, help="Cos Learning Rate Scheduler max epoch")
987 | #i mage preprocessing
988 | parser.add_argument("--target-ratio", default=1.25, type=float, help="TargetPad target ratio")
989 | parser.add_argument("--transform", default="targetpad", type=str,
990 | help="Preprocess pipeline, should be in ['blip', 'squarepad', 'targetpad'] ")
991 | # training settings
992 | parser.add_argument("--validation-frequency", default=1, type=int, help="Validation frequency expressed in epochs")
993 | parser.add_argument("--save-training", dest="save_training", action='store_true',
994 | help="Whether save the training model")
995 | parser.add_argument("--save-best", dest="save_best", action='store_true',
996 | help="Save only the best model during training")
997 | parser.add_argument("--cir-frame", default="sum", type=str, help="frame loss")
998 | parser.add_argument("--tac-weight", default=0.1, type=float, help="tac_loss weight")
999 | parser.add_argument("--hca-weight", default=0.1, type=float, help="hca-loss weight")
1000 | parser.add_argument("--hca-temperature", default=0.92, type=float, help="hca_temperature")
1001 | parser.add_argument("--tac-temperature", default=2.3, type=float, help="tac_temperature")
1002 | parser.add_argument("--embeds-dim", default=768, type=int, help="")
1003 |
1004 | # fix seed for stable results
1005 | seed = 42
1006 | random.seed(seed)
1007 | np.random.seed(seed)
1008 | torch.manual_seed(seed)
1009 | torch.cuda.manual_seed(seed)
1010 | torch.backends.cudnn.deterministic=True
1011 | torch.backends.cudnn.benchmark=True
1012 |
1013 | args = parser.parse_args()
1014 | if args.dataset.lower() not in ['fashioniq', 'cirr']:
1015 | raise ValueError("Dataset should be either 'CIRR' or 'FashionIQ")
1016 |
1017 | training_hyper_params = {
1018 | "num_epochs": args.num_epochs,
1019 | "learning_rate": args.learning_rate,
1020 | "max_epoch": args.max_epoch,
1021 | "min_lr": args.min_lr,
1022 | "batch_size": args.batch_size,
1023 | "validation_frequency": args.validation_frequency,
1024 | "transform": args.transform,
1025 | "target_ratio": args.target_ratio,
1026 | "save_training": args.save_training,
1027 | "save_best": args.save_best,
1028 | "experiment_name":args.experiment_name,
1029 | "encoder":args.encoder,
1030 | "encoder_arch":args.encoder_arch,
1031 | "model_name":args.model_name,
1032 | "cir_frame":args.cir_frame,
1033 | "tac_weight": args.tac_weight,
1034 | "hca_weight": args.hca_weight,
1035 | "hca_temperature": args.hca_temperature,
1036 | "tac_temperature": args.tac_temperature,
1037 | "embeds_dim": args.embeds_dim
1038 | }
1039 | if args.api_key and args.workspace:
1040 | print("Comet logging ENABLED")
1041 | experiment = Experiment(
1042 | api_key=args.api_key,
1043 | project_name=f"{args.dataset} blip fine-tuning",
1044 | workspace=args.workspace,
1045 | disabled=False
1046 | )
1047 | if args.experiment_name:
1048 | experiment.set_name(args.experiment_name)
1049 | else:
1050 | print("Comet loging DISABLED, in order to enable it you need to provide an api key and a workspace")
1051 | experiment = Experiment(
1052 | api_key="",
1053 | project_name="",
1054 | workspace="",
1055 | disabled=True
1056 | )
1057 |
1058 | experiment.log_code(folder=str(base_path / 'src'))
1059 | experiment.log_parameters(training_hyper_params)
1060 |
1061 | ssl._create_default_https_context = ssl._create_unverified_context
1062 |
1063 | if args.dataset.lower() == 'cirr':
1064 | blip_finetune_cirr(**training_hyper_params)
1065 | elif args.dataset.lower() == 'fashioniq':
1066 | training_hyper_params.update(
1067 | {'train_dress_types': ['dress', 'toptee', 'shirt'], 'val_dress_types': ['dress', 'toptee', 'shirt']})
1068 | blip_finetune_fiq(**training_hyper_params)
1069 |
--------------------------------------------------------------------------------