├── MSCI └── code │ ├── model │ ├── __init__.py │ ├── discriminator.py │ ├── model_factory_final.py │ ├── common.py │ └── Mutifuse_new.py │ ├── tools │ ├── __init__.py │ ├── mixup.py │ └── optimization.py │ ├── clip_modules │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── text_encoder.py │ ├── interface.py │ ├── tokenization_clip.py │ └── clip_model.py │ ├── data │ ├── feasibility_cgqa.pt │ ├── feasibility_mit-states.pt │ ├── feasibility_ut-zappos.pt │ └── view.py │ ├── check.py │ ├── requirements.txt │ ├── download_data │ ├── reorganize_utzap.py │ ├── download_data.sh │ └── feasibility.py │ ├── LICENSE │ ├── config │ └── msci │ │ ├── mit-states-ow.yml │ │ ├── ut-zappos-ow.yml │ │ ├── cgqa-ow.yml │ │ ├── ut-zappos.yml │ │ ├── cgqa.yml │ │ └── mit-states.yml │ ├── .gitignore │ ├── utils.py │ ├── parameters.py │ ├── README.md │ ├── train_base.py │ ├── dataset.py │ └── test_base.py ├── github_structure.jpg └── README.md /MSCI/code/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MSCI/code/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MSCI/code/clip_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .interface import CLIPInterface -------------------------------------------------------------------------------- /github_structure.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ltpwy/MSCI/HEAD/github_structure.jpg -------------------------------------------------------------------------------- /MSCI/code/data/feasibility_cgqa.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ltpwy/MSCI/HEAD/MSCI/code/data/feasibility_cgqa.pt -------------------------------------------------------------------------------- /MSCI/code/data/feasibility_mit-states.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ltpwy/MSCI/HEAD/MSCI/code/data/feasibility_mit-states.pt -------------------------------------------------------------------------------- /MSCI/code/data/feasibility_ut-zappos.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ltpwy/MSCI/HEAD/MSCI/code/data/feasibility_ut-zappos.pt -------------------------------------------------------------------------------- /MSCI/code/clip_modules/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ltpwy/MSCI/HEAD/MSCI/code/clip_modules/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /MSCI/code/check.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # 加载模型 4 | model = torch.load('./data/feasibility_mit-states.pt') 5 | feasibility = model['feasibility'] 6 | 7 | # 打印张量的形状 8 | print(feasibility.shape) 9 | # 打印模型结构 10 | 11 | -------------------------------------------------------------------------------- /MSCI/code/data/view.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # 加载模型的 state_dict 4 | state_dict = torch.load("feasibility_mit-states.pt") # 替换为你的 .pt 文件路径 5 | 6 | # 打印 keys 和对应的 shapes 7 | for key, value in state_dict.items(): 8 | print(f"{key}: {value}") 9 | -------------------------------------------------------------------------------- /MSCI/code/requirements.txt: -------------------------------------------------------------------------------- 1 | bypy==1.8 2 | certifi==2022.9.24 3 | charset-normalizer==2.1.1 4 | clip==0.1.0 5 | dill==0.3.6 6 | einops==0.5.0 7 | ftfy==6.1.1 8 | idna==3.4 9 | multiprocess==0.70.14 10 | numpy==1.23.3 11 | opencv-python==4.6.0.66 12 | pandas==1.5.0 13 | Pillow==9.2.0 14 | python-dateutil==2.8.2 15 | pytz==2022.4 16 | PyYAML==6.0 17 | regex==2022.9.13 18 | requests==2.28.1 19 | requests-toolbelt==0.10.1 20 | scipy==1.9.2 21 | six==1.16.0 22 | tqdm==4.64.1 23 | typing_extensions==4.4.0 24 | urllib3==1.26.12 25 | wcwidth==0.2.5 26 | -------------------------------------------------------------------------------- /MSCI/code/tools/mixup.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def mixup_data(x, y_comp, y_attr, y_obj, alpha=1.0): 6 | lam = np.random.beta(alpha, alpha) if alpha > 0 else 1 7 | batch_size = x.shape[0] 8 | index = torch.randperm(batch_size).cuda() 9 | 10 | mixed_x = lam * x + (1 - lam) * x[index] 11 | y_comp_a, y_comp_b = y_comp, y_comp[index] 12 | y_attr_a, y_attr_b = y_attr, y_attr[index] 13 | y_obj_a, y_obj_b = y_obj, y_obj[index] 14 | return mixed_x, y_comp_a, y_comp_b, y_attr_a, y_attr_b, y_obj_a, y_obj_b, lam 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /MSCI/code/model/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import math 6 | #判别器输入为属性特征对象特征 ,输出为是否为原始组合的概率 7 | class CombinationDiscriminator(nn.Module): 8 | def __init__(self, feature_dim): 9 | super(CombinationDiscriminator, self).__init__() 10 | self.fc = nn.Sequential( 11 | nn.Linear(feature_dim * 2, 128), 12 | nn.ReLU(), 13 | nn.Linear(128, 64), 14 | nn.ReLU(), 15 | nn.Linear(64, 1), # 判别是原始组合 16 | nn.Sigmoid() 17 | ) 18 | 19 | def forward(self, attr_features, obj_features): 20 | combined_features = torch.cat([attr_features, obj_features], dim=-1) 21 | return self.fc(combined_features) 22 | -------------------------------------------------------------------------------- /MSCI/code/model/model_factory_final.py: -------------------------------------------------------------------------------- 1 | # multi-path paradigm 2 | # from model.clip_multi_path import CLIP_Multi_Path 3 | # from model.coop_multi_path import COOP_Multi_Path 4 | from model.Mutifuse_new import MSCI 5 | 6 | def get_model(config, attributes, classes, offset): 7 | if config.model_name == 'MSCI': 8 | model = MSCI(config, attributes=attributes, classes=classes, offset=offset) 9 | # elif config.model_name == 'clip_multi_path': 10 | # model = CLIP_Multi_Path(config, attributes=attributes, classes=classes, offset=offset) 11 | # elif config.model_name == 'coop_multi_path': 12 | # model = COOP_Multi_Path(config, attributes=attributes, classes=classes, offset=offset) 13 | else: 14 | raise NotImplementedError( 15 | "Error: Unrecognized Model Name {:s}.".format( 16 | config.model_name 17 | ) 18 | ) 19 | 20 | 21 | return model 22 | -------------------------------------------------------------------------------- /MSCI/code/download_data/reorganize_utzap.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | """ 8 | Reorganize the UT-Zappos dataset to resemble the MIT-States dataset 9 | root/attr_obj/img1.jpg 10 | root/attr_obj/img2.jpg 11 | root/attr_obj/img3.jpg 12 | ... 13 | """ 14 | 15 | import os 16 | import torch 17 | import shutil 18 | import tqdm 19 | 20 | DATA_FOLDER= "/data/jyy/lll/dataset" 21 | 22 | root = DATA_FOLDER+'/ut-zap50k/' 23 | os.makedirs(root+'/images',exist_ok=True) 24 | 25 | print(root) 26 | data = torch.load(root+'/metadata_compositional-split-natural.t7') 27 | # print(data) 28 | for instance in tqdm.tqdm(data): 29 | image, attr, obj = instance['_image'], instance['attr'], instance['obj'] 30 | old_file = '%s/_images/%s'%(root, image) 31 | new_dir = '%s/images/%s_%s/'%(root, attr, obj) 32 | os.makedirs(new_dir, exist_ok=True) 33 | shutil.copy(old_file, new_dir) -------------------------------------------------------------------------------- /MSCI/code/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Carman Lu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MSCI/code/download_data/download_data.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | CURRENT_DIR=$(pwd) 9 | 10 | mkdir data 11 | cd data 12 | 13 | # download datasets and splits 14 | wget -c http://wednesday.csail.mit.edu/joseph_result/state_and_transformation/release_dataset.zip -O mitstates.zip 15 | wget -c http://vision.cs.utexas.edu/projects/finegrained/utzap50k/ut-zap50k-images.zip -O utzap.zip 16 | wget -c http://www.cs.cmu.edu/~spurushw/publication/compositional/compositional_split_natural.tar.gz -O compositional_split_natural.tar.gz 17 | wget -c https://s3.mlcloud.uni-tuebingen.de/czsl/cgqa-updated.zip -O cgqa.zip 18 | 19 | 20 | # MIT-States 21 | unzip mitstates.zip 'release_dataset/images/*' -d mit-states/ 22 | mv mit-states/release_dataset/images mit-states/images/ 23 | rm -r mit-states/release_dataset 24 | rename "s/ /_/g" mit-states/images/* 25 | 26 | # UT-Zappos50k 27 | unzip utzap.zip -d ut-zap50k/ 28 | mv ut-zap50k/ut-zap50k-images ut-zap50k/_images/ 29 | 30 | # C-GQA 31 | unzip cgqa.zip -d cgqa/ 32 | 33 | # Download new splits for Purushwalkam et. al 34 | tar -zxvf splits.tar.gz 35 | 36 | cd $CURRENT_DIR 37 | python download_data/reorganize_utzap.py 38 | 39 | mv data/ut-zap50k data/ut-zappos 40 | -------------------------------------------------------------------------------- /MSCI/code/config/msci/mit-states-ow.yml: -------------------------------------------------------------------------------- 1 | model: 2 | model_name: MSCI 3 | prompt_template: ["a photo of x x", "a photo of x", "a photo of x"] 4 | ctx_init: ["a photo of ", "a photo of ", "a photo of "] 5 | clip_model: "ViT-L/14" 6 | # clip_arch: 7 | cmt_layers: 3 8 | init_lamda: 0.1 9 | cross_attn_dropout: 0.1 10 | adapter_dim: 64 11 | adapter_dropout: 0.1 12 | # branch 13 | pair_loss_weight: 1.0 14 | pair_inference_weight: 0.3 15 | attr_loss_weight: 1.0 16 | attr_inference_weight: 0.7 17 | obj_loss_weight: 1.0 18 | obj_inference_weight: 0.7 19 | 20 | train: 21 | dataset: mit-states 22 | # dataset_path: 23 | optimizer: Adam 24 | scheduler: StepLR 25 | step_size: 5 26 | gamma: 0.5 27 | lr: 0.0001 28 | attr_dropout: 0.3 29 | weight_decay: 0.00001 30 | context_length: 8 31 | train_batch_size: 64 32 | gradient_accumulation_steps: 1 33 | # seed: 34 | epochs: 20 35 | epoch_start: 0 36 | # save_path: 37 | val_metric: best_AUC 38 | save_final_model: True 39 | adversarial_loss_weight: 0.4 40 | selected_low_layers: 4 41 | selected_high_layers: 4 42 | stage_1_dropout: 0.1 43 | stage_2_dropout: 0.1 44 | fusion_dropout: 0.1 45 | stage_1_num_heads: 16 46 | stage_2_num_heads: 16 47 | stage_1_num_cmt_layers: 1 48 | stage_2_num_cmt_layers: 1 49 | 50 | test: 51 | eval_batch_size: 48 52 | open_world: True 53 | # load_model: 54 | topk: 1 55 | text_encoder_batch_size: 1024 56 | threshold: 0.36183673469387756 57 | threshold_trials: 50 58 | bias: 0 59 | text_first: True -------------------------------------------------------------------------------- /MSCI/code/config/msci/ut-zappos-ow.yml: -------------------------------------------------------------------------------- 1 | model: 2 | model_name: MSCI 3 | prompt_template: ["a photo of x x", "a photo of x", "a photo of x"] 4 | ctx_init: ["a photo of ", "a photo of ", "a photo of "] 5 | clip_model: "ViT-L/14" 6 | # clip_arch: 7 | cmt_layers: 2 8 | init_lamda: 0.1 9 | cross_attn_dropout: 0 10 | adapter_dim: 64 11 | adapter_dropout: 0.1 12 | # branch 13 | pair_loss_weight: 1.0 14 | pair_inference_weight: 1.0 15 | attr_loss_weight: 1.0 16 | attr_inference_weight: 1.0 17 | obj_loss_weight: 1.0 18 | obj_inference_weight: 1.0 19 | 20 | train: 21 | dataset: ut-zappos 22 | # dataset_path: 23 | optimizer: AdamW 24 | scheduler: StepLR 25 | step_size: 5 26 | gamma: 0.5 27 | lr: 0.00025 #0.0005 # 0.00025 28 | attr_dropout: 0.3 29 | weight_decay: 0.00001 #0.00001 30 | context_length: 8 31 | train_batch_size: 48 # 64 32 | gradient_accumulation_steps: 1 33 | # seed: 34 | epochs: 15 35 | epoch_start: 0 36 | # save_path: 37 | val_metric: best_loss 38 | save_final_model: True 39 | selected_low_layers: 3 40 | selected_high_layers: 3 41 | stage_1_dropout: 0.1 42 | stage_2_dropout: 0.1 43 | fusion_dropout: 0.2 44 | stage_1_num_heads: 12 45 | stage_2_num_heads: 12 46 | stage_1_num_cmt_layers: 1 47 | stage_2_num_cmt_layers: 1 48 | 49 | test: 50 | eval_batch_size: 32 51 | open_world: True 52 | # load_model: 53 | topk: 1 54 | text_encoder_batch_size: 1024 55 | threshold: 0.4631293780949651 56 | threshold_trials: 50 57 | bias: 0.001 58 | text_first: True -------------------------------------------------------------------------------- /MSCI/code/config/msci/cgqa-ow.yml: -------------------------------------------------------------------------------- 1 | model: 2 | model_name: MSCI 3 | prompt_template: ["a photo of x x", "a photo of x", "a photo of x"] 4 | ctx_init: ["a photo of ", "a photo of ", "a photo of "] 5 | clip_model: "ViT-L/14" 6 | # clip_arch: 7 | cmt_layers: 2 8 | init_lamda: 0.1 9 | cross_attn_dropout: 0 10 | adapter_dim: 64 11 | adapter_dropout: 0.1 12 | # branch 13 | pair_loss_weight: 1.0 14 | pair_inference_weight: 0.3 15 | attr_loss_weight: 1.0 16 | attr_inference_weight: 0.7 17 | obj_loss_weight: 1.0 18 | obj_inference_weight: 0.7 19 | covariance_loss_weight: 0.00001 20 | 21 | train: 22 | dataset: cgqa 23 | # dataset_path: 24 | optimizer: Adam 25 | scheduler: StepLR 26 | step_size: 5 27 | gamma: 0.5 28 | lr: 0.0000125 29 | attr_dropout: 0 30 | weight_decay: 0.00001 31 | context_length: 8 32 | train_batch_size: 48 33 | gradient_accumulation_steps: 1 34 | # seed: 35 | epochs: 15 36 | epoch_start: 0 37 | # save_path: 38 | val_metric: AUC 39 | save_final_model: True 40 | selected_low_layers: 4 41 | selected_high_layers: 4 42 | stage_1_dropout: 0.1 43 | stage_2_dropout: 0.1 44 | fusion_dropout: 0.2 45 | stage_1_num_heads: 12 46 | stage_2_num_heads: 12 47 | stage_1_num_cmt_layers: 1 48 | stage_2_num_cmt_layers: 1 49 | # load_model: False # False or model path 50 | 51 | test: 52 | eval_batch_size: 6 53 | open_world: True 54 | # load_model: 55 | topk: 1 56 | text_encoder_batch_size: 1024 57 | threshold: 0.4061361697255348 58 | threshold_trials: 50 59 | bias: 0.001 60 | text_first: True -------------------------------------------------------------------------------- /MSCI/code/config/msci/ut-zappos.yml: -------------------------------------------------------------------------------- 1 | model: 2 | model_name: MSCI 3 | prompt_template: ["a photo of x x", "a photo of x", "a photo of x"] 4 | ctx_init: ["a photo of ", "a photo of ", "a photo of "] 5 | clip_model: "ViT-L/14" 6 | # clip_arch: 7 | cmt_layers: 2 8 | init_lamda: 0.1 9 | cross_attn_dropout: 0.1 10 | adapter_dim: 64 11 | adapter_dropout: 0.1 12 | # branch 13 | pair_loss_weight: 1.0 14 | pair_inference_weight: 1.0 15 | attr_loss_weight: 1.0 16 | attr_inference_weight: 1.0 17 | obj_loss_weight: 1.0 18 | obj_inference_weight: 1.0 19 | covariance_loss_weight: 0 #0.00001 20 | 21 | 22 | 23 | 24 | train: 25 | dataset: ut-zappos 26 | # dataset_path: 27 | optimizer: AdamW 28 | scheduler: StepLR 29 | step_size: 5 30 | gamma: 0.5 31 | lr: 0.00025 #0.0005 # 0.00025 32 | attr_dropout: 0.3 33 | weight_decay: 0.00001 #0.00001 34 | context_length: 8 35 | train_batch_size: 48 # 64 36 | gradient_accumulation_steps: 1 37 | # seed: 38 | epochs: 15 39 | epoch_start: 0 40 | # save_path: 41 | val_metric: best_AUC 42 | save_final_model: True 43 | selected_low_layers: 3 44 | selected_high_layers: 3 45 | stage_1_dropout: 0.1 46 | stage_2_dropout: 0.1 47 | fusion_dropout: 0.2 48 | stage_1_num_heads: 12 49 | stage_2_num_heads: 12 50 | stage_1_num_cmt_layers: 1 51 | stage_2_num_cmt_layers: 1 52 | # load_model: False # False or model path 53 | 54 | test: 55 | eval_batch_size: 64 56 | open_world: False 57 | # load_model: 58 | topk: 1 59 | text_encoder_batch_size: 1024 60 | threshold_trials: 50 61 | bias: 0.001 62 | text_first: True -------------------------------------------------------------------------------- /MSCI/code/config/msci/cgqa.yml: -------------------------------------------------------------------------------- 1 | model: 2 | model_name: MSCI 3 | prompt_template: ["a photo of x x", "a photo of x", "a photo of x"] 4 | ctx_init: ["a photo of ", "a photo of ", "a photo of "] 5 | clip_model: "ViT-L/14" 6 | # clip_arch: 7 | cmt_layers: 2 8 | init_lamda: 0.1 9 | cross_attn_dropout: 0 10 | adapter_dim: 64 11 | adapter_dropout: 0.1 12 | # branch 13 | pair_loss_weight: 1.0 14 | pair_inference_weight: 1.0 15 | attr_loss_weight: 0.1 16 | attr_inference_weight: 1.0 17 | obj_loss_weight: 0.1 18 | obj_inference_weight: 1.0 19 | covariance_loss_weight: 0.00001 20 | #inference weight set as 0.1,0.9,0.9 for testing in close-world setting 21 | 22 | train: 23 | dataset: cgqa 24 | # dataset_path: 25 | optimizer: Adam 26 | scheduler: StepLR 27 | step_size: 5 28 | gamma: 0.5 29 | lr: 0.0000125 30 | attr_dropout: 0 31 | weight_decay: 0.00001 32 | context_length: 8 33 | train_batch_size: 48 34 | gradient_accumulation_steps: 1 35 | # seed: 36 | epochs: 15 37 | epoch_start: 0 38 | # save_path: 39 | val_metric: AUC 40 | save_final_model: True 41 | selected_low_layers: 4 42 | selected_high_layers: 4 43 | stage_1_dropout: 0.1 44 | stage_2_dropout: 0.1 45 | fusion_dropout: 0.2 46 | stage_1_num_heads: 12 47 | stage_2_num_heads: 12 48 | stage_1_num_cmt_layers: 1 49 | stage_2_num_cmt_layers: 1 50 | # load_model: False # False or model path 51 | 52 | test: 53 | eval_batch_size: 32 54 | open_world: False 55 | # load_model: 56 | topk: 1 57 | text_encoder_batch_size: 1024 58 | # threshold: 0.4 59 | threshold_trials: 50 60 | bias: 0.001 61 | text_first: True 62 | -------------------------------------------------------------------------------- /MSCI/code/config/msci/mit-states.yml: -------------------------------------------------------------------------------- 1 | model: 2 | model_name: MSCI 3 | prompt_template: ["a photo of x x", "a photo of x", "a photo of x"] 4 | ctx_init: ["a photo of ", "a photo of ", "a photo of "] 5 | clip_model: "ViT-L/14" 6 | # clip_arch: 7 | cmt_layers: 3 8 | init_lamda: 0.1 9 | cross_attn_dropout: 0.1 10 | adapter_dim: 64 11 | adapter_dropout: 0.1 12 | # branch 13 | pair_loss_weight: 1.0 14 | pair_inference_weight: 1.0 15 | attr_loss_weight: 1.0 16 | attr_inference_weight: 1.0 17 | obj_loss_weight: 1.0 18 | obj_inference_weight: 1.0 19 | #inference_weight set as 0.1,0.9,0.9 for testing in close-world setting 20 | 21 | 22 | train: 23 | dataset: mit-states 24 | # dataset_path: 25 | optimizer: Adam 26 | scheduler: StepLR 27 | step_size: 5 28 | gamma: 0.5 29 | lr: 0.0001 30 | attr_dropout: 0.3 31 | weight_decay: 0.00001 32 | context_length: 8 33 | train_batch_size: 64 34 | gradient_accumulation_steps: 1 35 | # seed: 36 | epochs: 20 37 | epoch_start: 0 38 | # save_path: 39 | val_metric: best_AUC 40 | save_final_model: True 41 | adversarial_loss_weight: 0.4 42 | selected_low_layers: 4 43 | selected_high_layers: 4 44 | stage_1_dropout: 0.1 45 | stage_2_dropout: 0.1 46 | fusion_dropout: 0.1 47 | stage_1_num_heads: 12 48 | stage_2_num_heads: 12 49 | stage_1_num_cmt_layers: 1 50 | stage_2_num_cmt_layers: 1 51 | 52 | # load_model: False # False or model path 53 | 54 | test: 55 | eval_batch_size: 32 56 | open_world: False 57 | # load_model: 58 | topk: 1 59 | text_encoder_batch_size: 1024 60 | # threshold: 0.4 61 | threshold_trials: 50 62 | bias: 0.001 63 | text_first: True 64 | -------------------------------------------------------------------------------- /MSCI/code/clip_modules/text_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class CustomTextEncoder(torch.nn.Module): 4 | def __init__(self, clip_model, tokenizer, dtype=torch.float16): 5 | super().__init__() 6 | self.dtype = dtype 7 | 8 | self.transformer = clip_model.transformer 9 | self.positional_embedding = clip_model.positional_embedding 10 | self.ln_final = clip_model.ln_final 11 | self.text_projection = clip_model.text_projection 12 | self.token_embedding = clip_model.token_embedding 13 | 14 | self.tokenizer = tokenizer 15 | 16 | def tokenize(self, text): 17 | return torch.cat([self.tokenizer(tok) for tok in text]) 18 | 19 | def encode_text(self, text, enable_pos_emb=True): 20 | token_ids = self.tokenize(text) 21 | text_features = self.forward(token_ids, None, enable_pos_emb) 22 | return text_features 23 | 24 | def forward(self, token_ids, token_tensors=None, enable_pos_emb=False): 25 | """The forward function to compute representations for the prompts. 26 | 27 | Args: 28 | token_ids (torch.tensor): the token ids, which 29 | contains the token. 30 | token_tensors (torch.Tensor, optional): the tensor 31 | embeddings for the token ids. Defaults to None. 32 | enable_pos_emb (bool, optional): adds the learned 33 | positional embeddings if true. Defaults to False. 34 | 35 | Returns: 36 | torch.Tensor: the vector representation of the prompt. 37 | """ 38 | if token_tensors is not None: 39 | text_features = token_tensors 40 | else: 41 | text_features = self.token_embedding(token_ids) 42 | 43 | text_features = text_features.type(self.dtype) 44 | x = ( 45 | text_features + self.positional_embedding.type(self.dtype) 46 | if enable_pos_emb 47 | else text_features 48 | ) 49 | x = x.permute(1, 0, 2) 50 | x = self.transformer(x) 51 | x = x.permute(1, 0, 2) 52 | x = self.ln_final(x) 53 | tf = ( 54 | x[ 55 | torch.arange(x.shape[0]), token_ids.argmax(dim=-1) 56 | ] # POS of 57 | @ self.text_projection 58 | ) 59 | return tf 60 | -------------------------------------------------------------------------------- /MSCI/code/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | # data/ 131 | glove* 132 | .DS_Store 133 | .idea 134 | 135 | wandb/ 136 | 137 | # for Troika 138 | /save_models 139 | ignored_file.py -------------------------------------------------------------------------------- /MSCI/code/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import random 5 | import os 6 | import yaml 7 | import json 8 | 9 | from tools.optimization import AdamW, get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup 10 | 11 | DIR_PATH = os.path.dirname(os.path.realpath(__file__)) 12 | 13 | 14 | def set_seed(seed): 15 | np.random.seed(seed) 16 | random.seed(seed) 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) 20 | 21 | 22 | def load_args(filename, args): 23 | with open(filename, 'r') as stream: 24 | data_loaded = yaml.safe_load(stream) 25 | for key, group in data_loaded.items(): 26 | for key, val in group.items(): 27 | setattr(args, key, val) 28 | 29 | 30 | def write_json(filename, content): 31 | with open(filename, 'w') as f: 32 | json.dump(content, f) 33 | 34 | 35 | def load_json(filename): 36 | with open(filename, "r") as f: 37 | return json.load(f) 38 | 39 | 40 | def get_optimizer(model, config): 41 | if config.optimizer == 'Adam': 42 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay) 43 | elif config.optimizer == 'SGD': 44 | optimizer = torch.optim.SGD(model.parameters(), lr=config.lr, weight_decay=config.weight_decay) 45 | elif config.optimizer == 'AdamW': 46 | optimizer = AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay) 47 | return optimizer 48 | 49 | 50 | def get_scheduler(optimizer, config, num_batches=-1): 51 | if not hasattr(config, 'scheduler'): 52 | return None 53 | if config.scheduler == 'StepLR': 54 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.step_size, gamma=config.gamma) 55 | elif config.scheduler == 'linear_w_warmup' or config.scheduler == 'cosine_w_warmup': 56 | assert num_batches != -1 57 | num_training_steps = num_batches * config.epochs 58 | num_warmup_steps = int(config.warmup_proportion * num_training_steps) 59 | if config.scheduler == 'linear_w_warmup': 60 | scheduler = get_linear_schedule_with_warmup(optimizer, 61 | num_warmup_steps=num_warmup_steps, 62 | num_training_steps=num_training_steps) 63 | if config.scheduler == 'cosine_w_warmup': 64 | scheduler = get_cosine_schedule_with_warmup(optimizer, 65 | num_warmup_steps=num_warmup_steps, 66 | num_training_steps=num_training_steps) 67 | return scheduler 68 | 69 | 70 | def step_scheduler(scheduler, config, bid, num_batches): 71 | if config.scheduler in ['StepLR']: 72 | if bid + 1 == num_batches: # end of the epoch 73 | scheduler.step() 74 | elif config.scheduler in ['linear_w_warmup', 'cosine_w_warmup']: 75 | scheduler.step() 76 | 77 | return scheduler 78 | -------------------------------------------------------------------------------- /MSCI/code/parameters.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | 5 | 6 | # model config 7 | parser.add_argument("--model_name", help="model name", type=str) 8 | parser.add_argument("--lr", help="learning rate", type=float, default=5e-05) 9 | parser.add_argument("--dataset", help="name of the dataset", type=str, default='mit-states') 10 | parser.add_argument("--weight_decay", help="weight decay", type=float, default=1e-05) 11 | parser.add_argument("--clip_model", help="clip model type", type=str, default="ViT-L/14") 12 | parser.add_argument("--epochs", help="number of epochs", default=20, type=int) 13 | parser.add_argument("--epoch_start", help="start epoch", default=0, type=int) 14 | parser.add_argument("--train_batch_size", help="train batch size", default=48, type=int) 15 | parser.add_argument("--eval_batch_size", help="eval batch size", default=16, type=int) 16 | parser.add_argument("--num_workers", help="number of workers", default=4, type=int) 17 | parser.add_argument("--context_length", help="sets the context length of the clip model", default=8, type=int) 18 | parser.add_argument("--attr_dropout", help="add dropout to attributes", type=float, default=0.3) 19 | parser.add_argument("--yml_path", help="yml path", type=str) 20 | parser.add_argument("--clip_arch", help="clip path", type=str) 21 | parser.add_argument("--dataset_path", help="dataset path", type=str) 22 | parser.add_argument("--save_path", help="save path", type=str) 23 | parser.add_argument("--save_every_n", default=5, type=int, help="saves the model every n epochs") 24 | parser.add_argument("--save_final_model", help="indicate if you want to save the model state dict()", action="store_true") 25 | parser.add_argument("--load_model", default=None, help="load the trained model") 26 | parser.add_argument("--seed", help="seed value", default=0, type=int) 27 | parser.add_argument("--gradient_accumulation_steps", help="number of gradient accumulation steps", default=1, type=int) 28 | parser.add_argument("--same_prim_sample", help="if sample same prim samples", action="store_true") 29 | 30 | parser.add_argument("--open_world", help="evaluate on open world setup", default=False) 31 | parser.add_argument("--bias", help="eval bias", type=float, default=1e3) 32 | parser.add_argument("--topk", help="eval topk", type=int, default=1) 33 | parser.add_argument("--text_encoder_batch_size", help="batch size of the text encoder", default=16, type=int) 34 | parser.add_argument('--threshold', type=float, default=None, help="optional threshold") 35 | parser.add_argument('--threshold_low', type=float, default=None, help="optional threshold") 36 | parser.add_argument('--threshold_high', type=float, default=None, help="optional threshold") 37 | parser.add_argument('--threshold_trials', type=int, default=50, help="how many threshold values to try") 38 | 39 | parser.add_argument("--adapter_dim", help="middle dimension of Adapter", type=int, default=64) 40 | parser.add_argument("--init_lamda", help="lamda initialization value", type=float, default=0.1) 41 | parser.add_argument("--cmt_layers", help="Number of layers in cross-attention", type=int, default=2) 42 | parser.add_argument("--sampling_rate", help="Frequency of layer feature extraction", type=int, default=12) 43 | parser.add_argument("--adversarial_loss_weight", help="weight of the confusion_loss", type=int, default=0.1) 44 | parser.add_argument("--selected_low_layers", help="number of selected layers", type=int, default=3) 45 | parser.add_argument("--selected_high_layers", help="number of selected layers", type=int, default=3) 46 | parser.add_argument("--stage_1_dropout", help="number of selected layers", type=float, default=0.2) 47 | parser.add_argument("--stage_2_dropout", help="number of selected layers", type=float, default=0.2) 48 | parser.add_argument("--fusion_dropout", help="number of selected layers", type=float, default=0.1) 49 | parser.add_argument("--stage_1_num_heads", help="number of selected layers", type=int, default=12) 50 | parser.add_argument("--stage_2_num_heads", help="number of selected layers", type=int, default=12) 51 | parser.add_argument("--stage_1_num_cmt_layers", help="number of selected layers", type=int, default=1) 52 | parser.add_argument("--stage_2_num_cmt_layers", help="number of selected layers", type=int, default=1) 53 | 54 | -------------------------------------------------------------------------------- /MSCI/code/README.md: -------------------------------------------------------------------------------- 1 | # MSCI: Addressing CLIP's Inherent Limitations for Compositional Zero-Shot Learning 2 | 3 | ## Project Setup and Requirements 4 | 5 | To run the project, follow the steps below. 6 | 7 | ### Install Required Environment 8 | 9 | First, install the necessary environment by running the following command: 10 | 11 | ```bash 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | ### Download Backbone Model (ViT-L14) 16 | 17 | Next, you need to download the backbone (ViT-L14) model using `wget`. Use the following command: 18 | 19 | ```bash 20 | cd 21 | wget https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt 22 | ``` 23 | 24 | ### Dataset Download 25 | 26 | We conduct experiments on three datasets: Mit-states, Ut-zappos, and C-GQA. Please download these datasets and place them in the `MSCI/code/download_data` directory. Use the links below to download them: 27 | 28 | - **Mit-states**: [Mit-states dataset](https://web.mit.edu/phillipi/Public/states_and_transformations/index.html) 29 | - **Ut-zappos**: [Ut-zappos dataset](https://vision.cs.utexas.edu/projects/finegrained/utzap50k/) 30 | - **C-GQA**: [C-GQA dataset](https://github.com/ExplainableML/czsl) 31 | 32 | Once downloaded, run the following command to set up the datasets: 33 | 34 | ```bash 35 | sh download_data.sh 36 | ``` 37 | 38 | ## Model Training 39 | 40 | ### Training in Closed-World Setting 41 | 42 | To train the model in the closed-world setting, use the following command: 43 | 44 | ```bash 45 | python -u train_base.py \ 46 | --clip_arch /ViT-L-14.pt \ 47 | --dataset_path / \ 48 | --save_path / \ 49 | --yml_path ./config/msci/.yml \ 50 | --num_workers 10 \ 51 | --seed 0 52 | ``` 53 | 54 | ### Evaluating in Closed-World Setting 55 | 56 | To evaluate the model's performance in the closed-world setting, run the following command: 57 | 58 | ```bash 59 | python -u test_base.py \ 60 | --clip_arch /ViT-L-14.pt \ 61 | --dataset_path / \ 62 | --save_path / \ 63 | --yml_path ./config/msci/.yml \ 64 | --num_workers 10 \ 65 | --seed 0 \ 66 | --load_model //val_best.pt 67 | ``` 68 | 69 | ### Evaluating in Open-World Setting 70 | 71 | To evaluate the model's performance in the open-world setting, we need to compute feasibility scores for all candidate combinations and filter based on these scores. The configuration files for the feasibility scores of each dataset are embedded in the code, allowing you to directly evaluate the model’s performance in the open-world setting with the following command: 72 | 73 | ```bash 74 | python -u test_base.py \ 75 | --clip_arch /ViT-L-14.pt \ 76 | --dataset_path / \ 77 | --save_path / \ 78 | --yml_path ./config/msci/-ow.yml \ 79 | --num_workers 10 \ 80 | --seed 0 \ 81 | --load_model //val_best.pt 82 | ``` 83 | 84 | ## Notes 85 | 86 | 1. **Ensure Directories Are Correct**: Before running the commands, verify that the paths to the model files, datasets, and save directories are correctly specified. Replace placeholders like ``, ``, and `` with the actual paths. 87 | 88 | 2. **Check for Dependencies**: Make sure all required libraries and dependencies are correctly installed using the provided `requirements.txt` file. This ensures that the environment is set up for running the experiments smoothly. 89 | 90 | 3. **Evaluation Configurations**: Make sure to select the correct configuration file based on your dataset (`` in the commands above), whether for the closed-world or open-world setting. 91 | 92 | 4. **Model Optimization**: The performance of the model may vary across datasets. Make sure to monitor the results and adjust hyperparameters like the learning rate, batch size, and epoch count if necessary. 93 | 94 | 5. **Troubleshooting**: In case of issues with downloading datasets or model weights, check your internet connection or the validity of the provided download links. 95 | 96 | 97 | ## Acknowledgement 98 | 99 | Our code references the following projects: 100 | 101 | * [DFSP](https://github.com/Forest-art/DFSP) 102 | * [AdaptFormer](https://github.com/ShoufaChen/AdaptFormer) 103 | * [Troika](https://github.com/bighuang624/Troika) 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | -------------------------------------------------------------------------------- /MSCI/code/clip_modules/interface.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from clip_modules.clip_model import CLIP 5 | from clip_modules.tokenization_clip import SimpleTokenizer 6 | 7 | from .text_encoder import CustomTextEncoder 8 | 9 | 10 | class CLIPInterface(torch.nn.Module): 11 | def __init__( 12 | self, 13 | clip_model: CLIP, 14 | tokenizer: SimpleTokenizer, 15 | config: argparse.ArgumentParser, 16 | token_ids: torch.tensor, 17 | soft_embeddings: torch.nn.Parameter = None, 18 | dtype: torch.dtype = None, 19 | device: torch.device = "cuda:0", 20 | enable_pos_emb: bool = False, 21 | ): 22 | """CLIP interface for our custom modules. 23 | 24 | Args: 25 | clip_model (CLIP): the clip model 26 | config (argparse.ArgumentParser): arguments used for 27 | training 28 | token_ids (torch.tensor): the input token ids to the text 29 | encoder 30 | soft_embeddings (torch.nn.Parameter, optional): the only 31 | parameter that we finetune in the experiment. 32 | Defaults to None. 33 | dtype (torch.dtype, optional): torch dtype for the 34 | transformer. This allows the half precision option. 35 | Defaults to None. 36 | device (torch.device, optional): the device where the model 37 | should be loaded. Defaults to "cuda:0". 38 | enable_pos_emb (bool, optional): if true, adds the learned 39 | positional embeddings. Defaults to False. 40 | """ 41 | super().__init__() 42 | 43 | self.config = config 44 | 45 | self.clip_model = clip_model 46 | 47 | if dtype is None and device == "cpu": 48 | self.dtype = torch.float32 49 | elif dtype is None: 50 | self.dtype = torch.float16 51 | else: 52 | self.dtype = dtype 53 | 54 | self.device = device 55 | 56 | self.enable_pos_emb = enable_pos_emb 57 | 58 | self.text_encoder = CustomTextEncoder(clip_model, tokenizer, self.dtype) 59 | # for params in self.text_encoder.parameters(): 60 | # params.requires_grad = False 61 | # self.clip_model.text_projection.requires_grad = False 62 | 63 | self.token_ids = token_ids 64 | self.soft_embeddings = soft_embeddings 65 | 66 | def encode_image(self, imgs): 67 | return self.clip_model.encode_image(imgs) 68 | 69 | def encode_text(self, text, enable_pos_emb=True): 70 | return self.text_encoder.encode_text( 71 | text, enable_pos_emb=enable_pos_emb 72 | ) 73 | 74 | def tokenize(self, text): 75 | return self.text_encoder.tokenize(text) 76 | 77 | def set_soft_embeddings(self, se): 78 | if se.shape == self.soft_embeddings.shape: 79 | self.state_dict()['soft_embeddings'].copy_(se) 80 | else: 81 | raise RuntimeError(f"Error: Incorrect Soft Embedding Shape {se.shape}, Expecting {self.soft_embeddings.shape}!") 82 | 83 | def set_frozen_embeddings(self, se): 84 | if se.shape == self.frozen_embeddings.shape: 85 | self.state_dict()['frozen_embeddings'].copy_(se) 86 | else: 87 | raise RuntimeError(f"Error: Incorrect Frozen Embedding Shape {se.shape}, Expecting {self.frozen_embeddings.shape}!") 88 | 89 | def construct_token_tensors(self, idx): 90 | """The function is used to generate token tokens. These 91 | token tensors can be None or custom. For custom token_tensors 92 | the class needs to be inherited and the function should be 93 | replaced. 94 | 95 | Raises: 96 | NotImplementedError: raises error if the model contains 97 | soft embeddings but does not make custom modifications. 98 | 99 | Returns: 100 | torch.Tensor: returns torch.Tensor or None 101 | """ 102 | if self.soft_embeddings is None: 103 | return None 104 | else: 105 | # Implement a custom version 106 | raise NotImplementedError 107 | 108 | def forward(self, batch_img, idx): 109 | batch_img = batch_img.to(self.device) 110 | 111 | token_tensors = self.construct_token_tensors(idx) 112 | 113 | text_features = self.text_encoder( 114 | self.token_ids, 115 | token_tensors, 116 | enable_pos_emb=self.enable_pos_emb, 117 | ) 118 | 119 | #_text_features = text_features[idx, :] 120 | _text_features = text_features 121 | 122 | idx_text_features = _text_features / _text_features.norm( 123 | dim=-1, keepdim=True 124 | ) 125 | normalized_img = batch_img / batch_img.norm(dim=-1, keepdim=True) 126 | logits = ( 127 | self.clip_model.logit_scale.exp() 128 | * normalized_img 129 | @ idx_text_features.t() 130 | ) 131 | 132 | return logits 133 | -------------------------------------------------------------------------------- /MSCI/code/train_base.py: -------------------------------------------------------------------------------- 1 | # 2 | import argparse 3 | import os 4 | import pickle 5 | import pprint 6 | 7 | import numpy as np 8 | import torch 9 | import tqdm 10 | from torch.nn.modules.loss import CrossEntropyLoss 11 | from torch.utils.data.dataloader import DataLoader 12 | import torch.nn.functional as F 13 | from model.model_factory_final import get_model 14 | from parameters import parser 15 | 16 | # from test import * 17 | import test_base as test 18 | from dataset import CompositionDataset 19 | from utils import * 20 | 21 | 22 | def train_model(model, optimizer, config, train_dataset, val_dataset, test_dataset): 23 | train_dataloader = DataLoader( 24 | train_dataset, 25 | batch_size=config.train_batch_size, 26 | shuffle=True, 27 | num_workers=config.num_workers 28 | ) 29 | 30 | model.train() 31 | best_metric = 0 32 | best_loss = 1e5 33 | best_epoch = 0 34 | final_model_state = None 35 | 36 | val_results = [] 37 | 38 | scheduler = get_scheduler(optimizer, config, len(train_dataloader)) 39 | attr2idx = train_dataset.attr2idx 40 | obj2idx = train_dataset.obj2idx 41 | 42 | train_pairs = torch.tensor([(attr2idx[attr], obj2idx[obj]) 43 | for attr, obj in train_dataset.train_pairs]).cuda() 44 | 45 | train_losses = [] 46 | 47 | for i in range(config.epoch_start, config.epochs): 48 | progress_bar = tqdm.tqdm( 49 | total=len(train_dataloader), desc="epoch % 3d" % (i + 1) 50 | ) 51 | 52 | epoch_train_losses = [] 53 | for bid, batch in enumerate(train_dataloader): 54 | 55 | logits = model(batch, train_pairs, return_features=True) # 训练时添加解耦损失 56 | 57 | loss = model.loss_calu(logits,batch) 58 | 59 | # normalize loss to account for batch accumulation 60 | loss = loss / config.gradient_accumulation_steps 61 | 62 | # backward pass 63 | loss.backward() 64 | 65 | # weights update 66 | if ((bid + 1) % config.gradient_accumulation_steps == 0) or (bid + 1 == len(train_dataloader)): 67 | optimizer.step() 68 | optimizer.zero_grad() 69 | scheduler = step_scheduler(scheduler, config, bid, len(train_dataloader)) 70 | 71 | epoch_train_losses.append(loss.item()) 72 | progress_bar.set_postfix({"train loss": np.mean(epoch_train_losses[-50:])}) 73 | progress_bar.update() 74 | 75 | progress_bar.close() 76 | progress_bar.write(f"epoch {i + 1} train loss {np.mean(epoch_train_losses)}") 77 | train_losses.append(np.mean(epoch_train_losses)) 78 | 79 | if (i + 1) % 1 == 0: 80 | # 加载前打印键 81 | #print("Model keys before loading:", model.state_dict().keys()) 82 | torch.save(model.state_dict(), os.path.join(config.save_path, f"epoch_{i + 1}.pt")) 83 | 84 | 85 | print("Evaluating val dataset:") 86 | val_result = evaluate(model, val_dataset, config) 87 | val_results.append(val_result) 88 | 89 | if config.val_metric == 'best_loss' and val_result["loss"] < best_loss: 90 | # print('val_loss:', val_result["loss"]) 91 | # print('best_loss:', best_loss) 92 | best_loss = val_result["loss"] 93 | best_epoch = i + 1 94 | torch.save(model.state_dict(), os.path.join( 95 | config.save_path, "best_model.pt")) 96 | this_epoch_result= val_result["AUC"] 97 | best_epoch_result= best_metric 98 | if config.val_metric != 'best_loss' and val_result["AUC"] > best_metric : 99 | best_metric = val_result["AUC"] 100 | best_epoch = i + 1 101 | torch.save(model.state_dict(), os.path.join( 102 | config.save_path, "best_model.pt")) 103 | print('best_epoch:', best_epoch) 104 | print('loss:',val_result["loss"]) 105 | final_model_state = model.state_dict() 106 | if this_epoch_result > best_epoch_result: 107 | print("--- Evaluating test dataset on Closed World ---") 108 | model.load_state_dict(torch.load(os.path.join( 109 | config.save_path, "best_model.pt" 110 | ))) 111 | evaluate(model, test_dataset, config) 112 | 113 | if config.save_final_model: 114 | torch.save(final_model_state, os.path.join(config.save_path, f'final_model_wy.pt')) 115 | 116 | 117 | def evaluate(model, dataset, config): 118 | model.eval() 119 | evaluator = test.Evaluator(dataset, model=None) 120 | all_logits, all_attr_gt, all_obj_gt, all_pair_gt, loss_avg = test.predict_logits( 121 | model, dataset, config) 122 | test_stats = test.test( 123 | dataset, 124 | evaluator, 125 | all_logits, 126 | all_attr_gt, 127 | all_obj_gt, 128 | all_pair_gt, 129 | config 130 | ) 131 | test_saved_results = dict() 132 | result = "" 133 | key_set = ["best_seen", "best_unseen", "best_hm", "AUC", "attr_acc", "obj_acc"] 134 | for key in key_set: 135 | result = result + key + " " + str(round(test_stats[key], 4)) + "| " 136 | test_saved_results[key] = round(test_stats[key], 4) 137 | print(result) 138 | test_saved_results['loss'] = loss_avg 139 | return test_saved_results 140 | 141 | 142 | if __name__ == "__main__": 143 | config = parser.parse_args() 144 | if config.yml_path: 145 | load_args(config.yml_path, config) 146 | print(config) 147 | # set the seed value 148 | set_seed(config.seed) 149 | 150 | dataset_path = config.dataset_path 151 | 152 | train_dataset = CompositionDataset(dataset_path, 153 | phase='train', 154 | split='compositional-split-natural', 155 | same_prim_sample=config.same_prim_sample) 156 | 157 | val_dataset = CompositionDataset(dataset_path, 158 | phase='val', 159 | split='compositional-split-natural') 160 | 161 | test_dataset = CompositionDataset(dataset_path, 162 | phase='test', 163 | split='compositional-split-natural') 164 | 165 | allattrs = train_dataset.attrs 166 | allobj = train_dataset.objs 167 | classes = [cla.replace(".", " ").lower() for cla in allobj] 168 | attributes = [attr.replace(".", " ").lower() for attr in allattrs] 169 | offset = len(attributes) 170 | 171 | model = get_model(config, attributes=attributes, classes=classes, offset=offset).cuda() 172 | optimizer = get_optimizer(model, config) 173 | 174 | os.makedirs(config.save_path, exist_ok=True) 175 | 176 | train_model(model, optimizer, config, train_dataset, val_dataset, test_dataset) 177 | 178 | with open(os.path.join(config.save_path, "config.pkl"), "wb") as fp: 179 | pickle.dump(config, fp) 180 | write_json(os.path.join(config.save_path, "config.json"), vars(config)) 181 | print("done!") 182 | -------------------------------------------------------------------------------- /MSCI/code/download_data/feasibility.py: -------------------------------------------------------------------------------- 1 | ''' 2 | python -u download_data/feasibility.py --dataset mit-states --dataset_root /mnt/nas-zhangjiakou/siteng/CZSL_data --data_root ./data 3 | ''' 4 | import argparse 5 | import os 6 | from itertools import product 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | from dataset import CompositionDataset 12 | 13 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 14 | 15 | def compute_cosine_similarity(names, weights, return_dict=True): 16 | pairing_names = list(product(names, names)) 17 | normed_weights = F.normalize(weights,dim=1) 18 | similarity = torch.mm(normed_weights, normed_weights.t()) 19 | if return_dict: 20 | dict_sim = {} 21 | for i,n in enumerate(names): 22 | for j,m in enumerate(names): 23 | dict_sim[(n,m)]=similarity[i,j].item() 24 | return dict_sim 25 | return pairing_names, similarity.to(device) 26 | 27 | 28 | def load_glove_embeddings(vocab): 29 | ''' 30 | Inputs 31 | emb_file: Text file with word embedding pairs e.g. Glove, Processed in lower case. 32 | vocab: List of words 33 | Returns 34 | Embedding Matrix 35 | ''' 36 | vocab = [v.lower() for v in vocab] 37 | emb_file = os.path.join(config.data_root, 'data/glove.6B.300d.txt') 38 | model = {} # populating a dictionary of word and embeddings 39 | for line in open(emb_file, 'r'): 40 | line = line.strip().split(' ') # Word-embedding 41 | wvec = torch.FloatTensor(list(map(float, line[1:]))) 42 | model[line[0]] = wvec 43 | 44 | # Adding some vectors for UT Zappos 45 | custom_map = { 46 | 'faux.fur': 'fake_fur', 47 | 'faux.leather': 'fake_leather', 48 | 'full.grain.leather': 'thick_leather', 49 | 'hair.calf': 'hair_leather', 50 | 'patent.leather': 'shiny_leather', 51 | 'boots.ankle': 'ankle_boots', 52 | 'boots.knee.high': 'knee_high_boots', 53 | 'boots.mid-calf': 'midcalf_boots', 54 | 'shoes.boat.shoes': 'boat_shoes', 55 | 'shoes.clogs.and.mules': 'clogs_shoes', 56 | 'shoes.flats': 'flats_shoes', 57 | 'shoes.heels': 'heels', 58 | 'shoes.loafers': 'loafers', 59 | 'shoes.oxfords': 'oxford_shoes', 60 | 'shoes.sneakers.and.athletic.shoes': 'sneakers', 61 | 'traffic_light': 'traffic_light', 62 | 'trash_can': 'trashcan', 63 | 'dry-erase_board' : 'dry_erase_board', 64 | 'black_and_white' : 'black_white', 65 | 'eiffel_tower' : 'tower', 66 | 'nubuck' : 'grainy_leather', 67 | } 68 | 69 | embeds = [] 70 | for k in vocab: 71 | if k in custom_map: 72 | k = custom_map[k] 73 | if '_' in k: 74 | ks = k.split('_') 75 | emb = torch.stack([model[it] for it in ks]).mean(dim=0) 76 | else: 77 | emb = model[k] 78 | embeds.append(emb) 79 | embeds = torch.stack(embeds) 80 | print('Glove Embeddings loaded, total embeddings: {}'.format(embeds.size())) 81 | return embeds 82 | 83 | 84 | def clip_embeddings(model, tokenizer, words_list): 85 | words_list = [word.replace(".", " ").lower() for word in words_list] 86 | prompts = [f"a photo of {word}" for word in words_list] 87 | 88 | tokenized_prompts = tokenizer(prompts) 89 | with torch.no_grad(): 90 | _text_features = model.text_encoder(tokenized_prompts, enable_pos_emb=True) 91 | text_features = _text_features / _text_features.norm( 92 | dim=-1, keepdim=True 93 | ) 94 | return text_features 95 | 96 | def get_pair_scores_objs(attr, obj, all_objs, attrs_by_obj_train, obj_embedding_sim): 97 | score = -1. 98 | for o in all_objs: 99 | if o!=obj and attr in attrs_by_obj_train[o]: 100 | temp_score = obj_embedding_sim[(obj,o)] 101 | if temp_score>score: 102 | score=temp_score 103 | return score 104 | 105 | def get_pair_scores_attrs(attr, obj, all_attrs, obj_by_attrs_train, attr_embedding_sim): 106 | score = -1. 107 | for a in all_attrs: 108 | if a != attr and obj in obj_by_attrs_train[a]: 109 | temp_score = attr_embedding_sim[(attr, a)] 110 | if temp_score > score: 111 | score = temp_score 112 | return score 113 | 114 | def compute_feasibility(dataset): 115 | objs = dataset.objs 116 | attrs = dataset.attrs 117 | 118 | print('computing the obj embeddings') 119 | obj_embeddings = load_glove_embeddings(objs).to(device) 120 | obj_embedding_sim = compute_cosine_similarity(objs, obj_embeddings, 121 | return_dict=True) 122 | 123 | print('computing the attr embeddings') 124 | attr_embeddings = load_glove_embeddings(attrs).to(device) 125 | attr_embedding_sim = compute_cosine_similarity(attrs, attr_embeddings, 126 | return_dict=True) 127 | 128 | print('computing the feasibilty score') 129 | feasibility_scores = dataset.seen_mask.clone().float() 130 | for a in attrs: 131 | print('Attribute', a) 132 | for o in objs: 133 | if (a, o) not in dataset.train_pairs: 134 | idx = dataset.pair2idx[(a, o)] 135 | score_obj = get_pair_scores_objs( 136 | a, 137 | o, 138 | dataset.objs, 139 | dataset.attrs_by_obj_train, 140 | obj_embedding_sim 141 | ) 142 | score_attr = get_pair_scores_attrs( 143 | a, 144 | o, 145 | dataset.attrs, 146 | dataset.obj_by_attrs_train, 147 | attr_embedding_sim 148 | ) 149 | score = (score_obj + score_attr) / 2 150 | feasibility_scores[idx] = score 151 | 152 | # feasibility_scores = feasibility_scores 153 | 154 | return feasibility_scores * (1 - dataset.seen_mask.float()) 155 | 156 | if __name__ == "__main__": 157 | parser = argparse.ArgumentParser() 158 | parser.add_argument("--dataset", help="name of the dataset", type=str) 159 | parser.add_argument( 160 | "--clip_model", help="clip model type", type=str, default="ViT-B/32" 161 | ) 162 | parser.add_argument( 163 | "--dataset_root", 164 | help="root of the dataset", 165 | type=str, 166 | ) 167 | parser.add_argument( 168 | "--data_root", 169 | help="root of the data", 170 | type=str, 171 | ) 172 | config = parser.parse_args() 173 | 174 | dataset_path = os.path.join(config.dataset_root, config.dataset) 175 | dataset = CompositionDataset(dataset_path, 176 | phase='test', 177 | split='compositional-split-natural', 178 | open_world=True) 179 | 180 | feasibility = compute_feasibility(dataset) 181 | print('feasibility:',feasibility) 182 | save_path = os.path.join(config.data_root, f'data/feasibility_{config.dataset}.pt') 183 | torch.save({ 184 | 'feasibility': feasibility, 185 | }, save_path) 186 | 187 | print('done!') 188 | -------------------------------------------------------------------------------- /MSCI/code/clip_modules/tokenization_clip.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import torch 3 | import html 4 | import os 5 | from functools import lru_cache 6 | 7 | import ftfy 8 | import regex as re 9 | 10 | 11 | @lru_cache() 12 | def default_bpe(): 13 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 14 | 15 | 16 | @lru_cache() 17 | def bytes_to_unicode(): 18 | """ 19 | Returns list of utf-8 byte and a corresponding list of unicode strings. 20 | The reversible bpe codes work on unicode strings. 21 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 22 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 23 | This is a signficant percentage of your normal, say, 32K bpe vocab. 24 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 25 | And avoids mapping to whitespace/control characters the bpe code barfs on. 26 | """ 27 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 28 | cs = bs[:] 29 | n = 0 30 | for b in range(2**8): 31 | if b not in bs: 32 | bs.append(b) 33 | cs.append(2**8+n) 34 | n += 1 35 | cs = [chr(n) for n in cs] 36 | return dict(zip(bs, cs)) 37 | 38 | 39 | def get_pairs(word): 40 | """Return set of symbol pairs in a word. 41 | Word is represented as tuple of symbols (symbols being variable-length strings). 42 | """ 43 | pairs = set() 44 | prev_char = word[0] 45 | for char in word[1:]: 46 | pairs.add((prev_char, char)) 47 | prev_char = char 48 | return pairs 49 | 50 | 51 | def basic_clean(text): 52 | text = ftfy.fix_text(text) 53 | text = html.unescape(html.unescape(text)) 54 | return text.strip() 55 | 56 | 57 | def whitespace_clean(text): 58 | text = re.sub(r'\s+', ' ', text) 59 | text = text.strip() 60 | return text 61 | 62 | 63 | class SimpleTokenizer(object): 64 | def __init__(self, bpe_path: str = default_bpe()): 65 | self.byte_encoder = bytes_to_unicode() 66 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 67 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 68 | merges = merges[1:49152-256-2+1] 69 | merges = [tuple(merge.split()) for merge in merges] 70 | vocab = list(bytes_to_unicode().values()) 71 | vocab = vocab + [v+'' for v in vocab] 72 | for merge in merges: 73 | vocab.append(''.join(merge)) 74 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 75 | self.encoder = dict(zip(vocab, range(len(vocab)))) 76 | self.decoder = {v: k for k, v in self.encoder.items()} 77 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 78 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 79 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 80 | 81 | self.vocab = self.encoder 82 | 83 | def bpe(self, token): 84 | if token in self.cache: 85 | return self.cache[token] 86 | word = tuple(token[:-1]) + ( token[-1] + '',) 87 | pairs = get_pairs(word) 88 | 89 | if not pairs: 90 | return token+'' 91 | 92 | while True: 93 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 94 | if bigram not in self.bpe_ranks: 95 | break 96 | first, second = bigram 97 | new_word = [] 98 | i = 0 99 | while i < len(word): 100 | try: 101 | j = word.index(first, i) 102 | new_word.extend(word[i:j]) 103 | i = j 104 | except: 105 | new_word.extend(word[i:]) 106 | break 107 | 108 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 109 | new_word.append(first+second) 110 | i += 2 111 | else: 112 | new_word.append(word[i]) 113 | i += 1 114 | new_word = tuple(new_word) 115 | word = new_word 116 | if len(word) == 1: 117 | break 118 | else: 119 | pairs = get_pairs(word) 120 | word = ' '.join(word) 121 | self.cache[token] = word 122 | return word 123 | 124 | def encode(self, text): 125 | bpe_tokens = [] 126 | text = whitespace_clean(basic_clean(text)).lower() 127 | for token in re.findall(self.pat, text): 128 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 129 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 130 | return bpe_tokens 131 | 132 | def decode(self, tokens): 133 | text = ''.join([self.decoder[token] for token in tokens]) 134 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 135 | return text 136 | 137 | def tokenize(self, text): 138 | tokens = [] 139 | text = whitespace_clean(basic_clean(text)).lower() 140 | for token in re.findall(self.pat, text): 141 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 142 | tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) 143 | return tokens 144 | 145 | def convert_tokens_to_ids(self, tokens): 146 | return [self.encoder[bpe_token] for bpe_token in tokens] 147 | 148 | def __call__(self, texts, context_length=77, return_tensors='pt', padding=True, truncation=True): 149 | """ 150 | Returns the tokenized representation of given input string(s) 151 | Parameters 152 | ---------- 153 | texts : Union[str, List[str]] 154 | An input string or a list of input strings to tokenize 155 | context_length : int 156 | The context length to use; all CLIP models use 77 as the context length 157 | 158 | remaining params are just to have same interface with huggingface tokenizer. 159 | They don't do much. 160 | Returns 161 | ------- 162 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 163 | """ 164 | # context_length = 77 165 | if isinstance(texts, str): 166 | texts = [texts] 167 | 168 | sot_token = self.encoder["<|startoftext|>"] 169 | eot_token = self.encoder["<|endoftext|>"] 170 | all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] 171 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 172 | 173 | for i, tokens in enumerate(all_tokens): 174 | if len(tokens) > context_length: 175 | # raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 176 | new_tokens = [sot_token] + tokens[1:context_length-1] + [eot_token] 177 | result[i, :len(tokens)] = torch.tensor(new_tokens) 178 | else: 179 | result[i, :len(tokens)] = torch.tensor(tokens) 180 | 181 | return result 182 | -------------------------------------------------------------------------------- /MSCI/code/tools/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | from typing import Callable, Iterable, Optional, Tuple, Union 19 | 20 | import torch 21 | from torch import nn 22 | from torch.optim import Optimizer 23 | from torch.optim.lr_scheduler import LambdaLR 24 | 25 | 26 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 27 | """ 28 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after 29 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. 30 | 31 | Args: 32 | optimizer (:class:`~torch.optim.Optimizer`): 33 | The optimizer for which to schedule the learning rate. 34 | num_warmup_steps (:obj:`int`): 35 | The number of steps for the warmup phase. 36 | num_training_steps (:obj:`int`): 37 | The total number of training steps. 38 | last_epoch (:obj:`int`, `optional`, defaults to -1): 39 | The index of the last epoch when resuming training. 40 | 41 | Return: 42 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 43 | """ 44 | 45 | def lr_lambda(current_step: int): 46 | if current_step < num_warmup_steps: 47 | return float(current_step) / float(max(1, num_warmup_steps)) 48 | return max( 49 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) 50 | ) 51 | 52 | return LambdaLR(optimizer, lr_lambda, last_epoch) 53 | 54 | 55 | 56 | def get_cosine_schedule_with_warmup( 57 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1 58 | ): 59 | """ 60 | Create a schedule with a learning rate that decreases following the values of the cosine function between the 61 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the 62 | initial lr set in the optimizer. 63 | 64 | Args: 65 | optimizer (:class:`~torch.optim.Optimizer`): 66 | The optimizer for which to schedule the learning rate. 67 | num_warmup_steps (:obj:`int`): 68 | The number of steps for the warmup phase. 69 | num_training_steps (:obj:`int`): 70 | The total number of training steps. 71 | num_cycles (:obj:`float`, `optional`, defaults to 0.5): 72 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 73 | following a half-cosine). 74 | last_epoch (:obj:`int`, `optional`, defaults to -1): 75 | The index of the last epoch when resuming training. 76 | 77 | Return: 78 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 79 | """ 80 | 81 | def lr_lambda(current_step): 82 | if current_step < num_warmup_steps: 83 | return float(current_step) / float(max(1, num_warmup_steps)) 84 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 85 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 86 | 87 | return LambdaLR(optimizer, lr_lambda, last_epoch) 88 | 89 | 90 | class AdamW(Optimizer): 91 | """ 92 | Implements Adam algorithm with weight decay fix as introduced in `Decoupled Weight Decay Regularization 93 | `__. 94 | 95 | Parameters: 96 | params (:obj:`Iterable[nn.parameter.Parameter]`): 97 | Iterable of parameters to optimize or dictionaries defining parameter groups. 98 | lr (:obj:`float`, `optional`, defaults to 1e-3): 99 | The learning rate to use. 100 | betas (:obj:`Tuple[float,float]`, `optional`, defaults to (0.9, 0.999)): 101 | Adam's betas parameters (b1, b2). 102 | eps (:obj:`float`, `optional`, defaults to 1e-6): 103 | Adam's epsilon for numerical stability. 104 | weight_decay (:obj:`float`, `optional`, defaults to 0): 105 | Decoupled weight decay to apply. 106 | correct_bias (:obj:`bool`, `optional`, defaults to `True`): 107 | Whether or not to correct bias in Adam (for instance, in Bert TF repository they use :obj:`False`). 108 | """ 109 | 110 | def __init__( 111 | self, 112 | params: Iterable[nn.parameter.Parameter], 113 | lr: float = 1e-3, 114 | betas: Tuple[float, float] = (0.9, 0.999), 115 | eps: float = 1e-6, 116 | weight_decay: float = 0.0, 117 | correct_bias: bool = True, 118 | ): 119 | if lr < 0.0: 120 | raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") 121 | if not 0.0 <= betas[0] < 1.0: 122 | raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0[") 123 | if not 0.0 <= betas[1] < 1.0: 124 | raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0[") 125 | if not 0.0 <= eps: 126 | raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") 127 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias) 128 | super().__init__(params, defaults) 129 | 130 | def step(self, closure: Callable = None): 131 | """ 132 | Performs a single optimization step. 133 | 134 | Arguments: 135 | closure (:obj:`Callable`, `optional`): A closure that reevaluates the model and returns the loss. 136 | """ 137 | loss = None 138 | if closure is not None: 139 | loss = closure() 140 | 141 | for group in self.param_groups: 142 | for p in group["params"]: 143 | if p.grad is None: 144 | continue 145 | grad = p.grad.data 146 | if grad.is_sparse: 147 | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") 148 | 149 | state = self.state[p] 150 | 151 | # State initialization 152 | if len(state) == 0: 153 | state["step"] = 0 154 | # Exponential moving average of gradient values 155 | state["exp_avg"] = torch.zeros_like(p.data) 156 | # Exponential moving average of squared gradient values 157 | state["exp_avg_sq"] = torch.zeros_like(p.data) 158 | 159 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 160 | beta1, beta2 = group["betas"] 161 | 162 | state["step"] += 1 163 | 164 | # Decay the first and second moment running average coefficient 165 | # In-place operations to update the averages at the same time 166 | exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) 167 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) 168 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 169 | 170 | step_size = group["lr"] 171 | if group["correct_bias"]: # No bias correction for Bert 172 | bias_correction1 = 1.0 - beta1 ** state["step"] 173 | bias_correction2 = 1.0 - beta2 ** state["step"] 174 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 175 | 176 | p.data.addcdiv_(exp_avg, denom, value=-step_size) 177 | 178 | # Just adding the square of the weights to the loss function is *not* 179 | # the correct way of using L2 regularization/weight decay with Adam, 180 | # since that will interact with the m and v parameters in strange ways. 181 | # 182 | # Instead we want to decay the weights in a manner that doesn't interact 183 | # with the m/v parameters. This is equivalent to adding the square 184 | # of the weights to the loss with plain (non-momentum) SGD. 185 | # Add weight decay at the end (fixed version) 186 | if group["weight_decay"] > 0.0: 187 | p.data.add_(p.data, alpha=(-group["lr"] * group["weight_decay"])) 188 | 189 | return loss 190 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MSCI: Addressing CLIP's Inherent Limitations for Compositional Zero-Shot Learning 2 | 3 | ## 🧠Model Structure 4 | ![项目结构图](./github_structure.jpg) 5 | 6 | 7 | 8 | 9 | ## ⚙️Project Setup and Requirements 10 | 11 | To run the project, follow the steps below. 12 | 13 | ### Install Required Environment 14 | 15 | First, install the necessary environment by running the following command: 16 | 17 | ```bash 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | ### Download Backbone Model (ViT-L14) 22 | 23 | Next, you need to download the backbone (ViT-L14) model using `wget`. Use the following command: 24 | 25 | ```bash 26 | cd 27 | wget https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt 28 | ``` 29 | 30 | ### Dataset Download 31 | 32 | We conduct experiments on three datasets: Mit-states, Ut-zappos, and C-GQA. Please download these datasets and place them in the `MSCI/code/download_data` directory. Use the links below to download them: 33 | 34 | - **Mit-states**: [Mit-states dataset](https://web.mit.edu/phillipi/Public/states_and_transformations/index.html) 35 | - **Ut-zappos**: [Ut-zappos dataset](https://vision.cs.utexas.edu/projects/finegrained/utzap50k/) 36 | - **C-GQA**: [C-GQA dataset](https://github.com/ExplainableML/czsl) 37 | 38 | Once downloaded, run the following command to set up the datasets: 39 | 40 | ```bash 41 | sh download_data.sh 42 | ``` 43 | 44 | ## 🏋️Model Training 45 | 46 | ### Training in Closed-World Setting 47 | 48 | To train the model in the closed-world setting, use the following command: 49 | 50 | ```bash 51 | python -u train_base.py \ 52 | --clip_arch /ViT-L-14.pt \ 53 | --dataset_path / \ 54 | --save_path / \ 55 | --yml_path ./config/msci/.yml \ 56 | --num_workers 10 \ 57 | --seed 0 58 | ``` 59 | 60 | ### Evaluating in Closed-World Setting 61 | 62 | To evaluate the model's performance in the closed-world setting, run the following command: 63 | 64 | ```bash 65 | python -u test_base.py \ 66 | --clip_arch /ViT-L-14.pt \ 67 | --dataset_path / \ 68 | --save_path / \ 69 | --yml_path ./config/msci/.yml \ 70 | --num_workers 10 \ 71 | --seed 0 \ 72 | --load_model //val_best.pt 73 | ``` 74 | 75 | ### Evaluating in Open-World Setting 76 | 77 | To evaluate the model's performance in the open-world setting, we need to compute feasibility scores for all candidate combinations and filter based on these scores. The configuration files for the feasibility scores of each dataset are embedded in the code, allowing you to directly evaluate the model’s performance in the open-world setting with the following command: 78 | 79 | ```bash 80 | python -u test_base.py \ 81 | --clip_arch /ViT-L-14.pt \ 82 | --dataset_path / \ 83 | --save_path / \ 84 | --yml_path ./config/msci/-ow.yml \ 85 | --num_workers 10 \ 86 | --seed 0 \ 87 | --load_model //val_best.pt 88 | ``` 89 | 90 | 91 | ## 📊Model Performance Comparison 92 | 93 |

Performance in Closed-World Setting

94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 |
ModelVenueMIT-StatesUT-ZapposC-GQA
SUHAUCSUHAUCSUHAUC
CSPICLR46.649.936.319.464.266.246.633.028.826.820.56.2
DFSPCVPR46.952.037.320.666.771.747.236.938.232.927.110.5
HPLIJCAI47.550.637.320.263.068.848.235.030.828.422.47.2
GIPCOLWACV48.549.636.619.965.068.548.836.231.928.422.57.1
TroikaCVPR49.053.039.322.166.873.854.641.741.035.729.412.4
CDS-CZSLCVPR50.352.939.222.463.974.852.739.538.334.228.111.1
PLIDECCV49.752.439.022.167.368.852.438.738.833.027.911.0
MSCIIJCAI50.253.439.922.867.475.559.245.842.438.231.714.2
121 | 122 | 123 | 124 |

Performance in Open-World Setting

125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 |
ModelVenueMIT-StatesUT-ZapposC-GQA
SUHAUCSUHAUCSUHAUC
CSPICLR46.315.717.45.764.144.138.922.728.75.26.91.2
DFSPCVPR47.518.519.36.866.860.044.030.338.37.210.42.4
HPLIJCAI46.418.919.86.963.448.140.224.630.15.87.51.4
GIPCOLWACV48.516.017.96.365.045.040.123.531.65.57.31.3
TroikaCVPR48.818.720.17.266.461.247.833.040.87.910.92.7
CDS-CZSLCVPR49.421.822.18.564.761.348.232.337.68.211.62.7
PLIDECCV49.118.720.47.367.655.546.630.839.17.510.62.5
MSCIIJCAI49.220.621.27.967.463.053.237.342.010.613.73.8
152 | 153 | - **S / U / H**: Seen / Unseen / Harmonic Mean 154 | - **AUC**: Area Under Curve 155 | - **Bold**: Best result 156 | - *Italic*: Second-best result 157 | 158 | 159 | 160 | 161 | 162 | ## 📝Notes 163 | 164 | 1. **Ensure Directories Are Correct**: Before running the commands, verify that the paths to the model files, datasets, and save directories are correctly specified. Replace placeholders like ``, ``, and `` with the actual paths. 165 | 166 | 2. **Check for Dependencies**: Make sure all required libraries and dependencies are correctly installed using the provided `requirements.txt` file. This ensures that the environment is set up for running the experiments smoothly. 167 | 168 | 3. **Evaluation Configurations**: Make sure to select the correct configuration file based on your dataset (`` in the commands above), whether for the closed-world or open-world setting. 169 | 170 | 4. **Troubleshooting**: In case of issues with downloading datasets or model weights, check your internet connection or the validity of the provided download links. 171 | 172 | 173 | ## 🙏Acknowledgement 174 | 175 | Our code references the following projects: 176 | 177 | * [DFSP](https://github.com/Forest-art/DFSP) 178 | * [AdaptFormer](https://github.com/ShoufaChen/AdaptFormer) 179 | * [Troika](https://github.com/bighuang624/Troika) 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | -------------------------------------------------------------------------------- /MSCI/code/dataset.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | from random import choice 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | from torchvision.transforms import (CenterCrop, Compose, InterpolationMode, 9 | Normalize, RandomHorizontalFlip, 10 | RandomPerspective, RandomRotation, Resize, 11 | ToTensor) 12 | from torchvision.transforms.transforms import RandomResizedCrop 13 | 14 | BICUBIC = InterpolationMode.BICUBIC 15 | n_px = 224 16 | 17 | 18 | def transform_image(split="train", imagenet=False): 19 | if imagenet: 20 | # from czsl repo. 21 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 22 | transform = Compose( 23 | [ 24 | RandomResizedCrop(n_px), 25 | RandomHorizontalFlip(), 26 | ToTensor(), 27 | Normalize( 28 | mean, 29 | std, 30 | ), 31 | ] 32 | ) 33 | return transform 34 | 35 | if split == "test" or split == "val": 36 | transform = Compose( 37 | [ 38 | Resize(n_px, interpolation=BICUBIC), 39 | CenterCrop(n_px), 40 | lambda image: image.convert("RGB"), 41 | ToTensor(), 42 | Normalize( 43 | (0.48145466, 0.4578275, 0.40821073), 44 | (0.26862954, 0.26130258, 0.27577711), 45 | ), 46 | ] 47 | ) 48 | else: 49 | transform = Compose( 50 | [ 51 | # RandomResizedCrop(n_px, interpolation=BICUBIC), 52 | Resize(n_px, interpolation=BICUBIC), 53 | CenterCrop(n_px), 54 | RandomHorizontalFlip(), 55 | RandomPerspective(), 56 | RandomRotation(degrees=5), 57 | lambda image: image.convert("RGB"), 58 | ToTensor(), 59 | Normalize( 60 | (0.48145466, 0.4578275, 0.40821073), 61 | (0.26862954, 0.26130258, 0.27577711), 62 | ), 63 | ] 64 | ) 65 | 66 | return transform 67 | 68 | class ImageLoader: 69 | def __init__(self, root): 70 | self.img_dir = root 71 | 72 | def __call__(self, img): 73 | file = '%s/%s' % (self.img_dir, img) 74 | img = Image.open(file).convert('RGB') 75 | return img 76 | 77 | 78 | class CompositionDataset(Dataset): 79 | def __init__( 80 | self, 81 | root, 82 | phase, 83 | split='compositional-split-natural', 84 | open_world=False, 85 | imagenet=False, 86 | same_prim_sample=False 87 | ): 88 | self.root = root 89 | self.phase = phase 90 | self.split = split 91 | self.open_world = open_world 92 | self.same_prim_sample = same_prim_sample 93 | 94 | self.feat_dim = None 95 | self.transform = transform_image(phase, imagenet=imagenet) 96 | self.loader = ImageLoader(self.root + '/images/') 97 | 98 | self.attrs, self.objs, self.pairs, \ 99 | self.train_pairs, self.val_pairs, \ 100 | self.test_pairs = self.parse_split() 101 | 102 | if self.open_world: 103 | self.pairs = list(product(self.attrs, self.objs)) 104 | 105 | self.train_data, self.val_data, self.test_data = self.get_split_info() 106 | if self.phase == 'train': 107 | self.data = self.train_data 108 | elif self.phase == 'val': 109 | self.data = self.val_data 110 | else: 111 | self.data = self.test_data 112 | 113 | self.obj2idx = {obj: idx for idx, obj in enumerate(self.objs)} 114 | self.attr2idx = {attr: idx for idx, attr in enumerate(self.attrs)} 115 | self.pair2idx = {pair: idx for idx, pair in enumerate(self.pairs)} 116 | 117 | print('# train pairs: %d | # val pairs: %d | # test pairs: %d' % (len( 118 | self.train_pairs), len(self.val_pairs), len(self.test_pairs))) 119 | print('# train images: %d | # val images: %d | # test images: %d' % 120 | (len(self.train_data), len(self.val_data), len(self.test_data))) 121 | 122 | self.train_pair_to_idx = dict( 123 | [(pair, idx) for idx, pair in enumerate(self.train_pairs)] 124 | ) 125 | 126 | if self.open_world: 127 | mask = [1 if pair in set(self.train_pairs) else 0 for pair in self.pairs] 128 | self.seen_mask = torch.BoolTensor(mask) * 1. 129 | 130 | self.obj_by_attrs_train = {k: [] for k in self.attrs} 131 | for (a, o) in self.train_pairs: 132 | self.obj_by_attrs_train[a].append(o) 133 | 134 | # Intantiate attribut-object relations, needed just to evaluate mined pairs 135 | self.attrs_by_obj_train = {k: [] for k in self.objs} 136 | for (a, o) in self.train_pairs: 137 | self.attrs_by_obj_train[o].append(a) 138 | 139 | if self.phase == 'train' and self.same_prim_sample: 140 | self.same_attr_diff_obj_dict = {pair: list() for pair in self.train_pairs} 141 | self.same_obj_diff_attr_dict = {pair: list() for pair in self.train_pairs} 142 | for i_sample, sample in enumerate(self.train_data): 143 | sample_attr, sample_obj = sample[1], sample[2] 144 | for pair_key in self.same_attr_diff_obj_dict.keys(): 145 | if (pair_key[1] == sample_obj) and (pair_key[0] != sample_attr): 146 | self.same_obj_diff_attr_dict[pair_key].append(i_sample) 147 | elif (pair_key[1] != sample_obj) and (pair_key[0] == sample_attr): 148 | self.same_attr_diff_obj_dict[pair_key].append(i_sample) 149 | 150 | 151 | def get_split_info(self): 152 | data = torch.load(self.root + '/metadata_{}.t7'.format(self.split)) 153 | train_data, val_data, test_data = [], [], [] 154 | for instance in data: 155 | image, attr, obj, settype = instance['image'], instance[ 156 | 'attr'], instance['obj'], instance['set'] 157 | 158 | if attr == 'NA' or (attr, 159 | obj) not in self.pairs or settype == 'NA': 160 | # ignore instances with unlabeled attributes 161 | # ignore instances that are not in current split 162 | continue 163 | 164 | data_i = [image, attr, obj] 165 | if settype == 'train': 166 | train_data.append(data_i) 167 | elif settype == 'val': 168 | val_data.append(data_i) 169 | else: 170 | test_data.append(data_i) 171 | 172 | return train_data, val_data, test_data 173 | 174 | def parse_split(self): 175 | def parse_pairs(pair_list): 176 | with open(pair_list, 'r') as f: 177 | pairs = f.read().strip().split('\n') 178 | # pairs = [t.split() if not '_' in t else t.split('_') for t in pairs] 179 | pairs = [t.split() for t in pairs] 180 | pairs = list(map(tuple, pairs)) 181 | attrs, objs = zip(*pairs) 182 | return attrs, objs, pairs 183 | 184 | tr_attrs, tr_objs, tr_pairs = parse_pairs( 185 | '%s/%s/train_pairs.txt' % (self.root, self.split)) 186 | vl_attrs, vl_objs, vl_pairs = parse_pairs( 187 | '%s/%s/val_pairs.txt' % (self.root, self.split)) 188 | ts_attrs, ts_objs, ts_pairs = parse_pairs( 189 | '%s/%s/test_pairs.txt' % (self.root, self.split)) 190 | 191 | all_attrs, all_objs = sorted( 192 | list(set(tr_attrs + vl_attrs + ts_attrs))), sorted( 193 | list(set(tr_objs + vl_objs + ts_objs))) 194 | all_pairs = sorted(list(set(tr_pairs + vl_pairs + ts_pairs))) 195 | 196 | return all_attrs, all_objs, all_pairs, tr_pairs, vl_pairs, ts_pairs 197 | 198 | def __getitem__(self, index): 199 | image, attr, obj = self.data[index] 200 | img = self.loader(image) 201 | img = self.transform(img) 202 | 203 | if self.phase == 'train': 204 | data = [ 205 | img, self.attr2idx[attr], self.obj2idx[obj], self.train_pair_to_idx[(attr, obj)] 206 | ] 207 | else: 208 | data = [ 209 | img, self.attr2idx[attr], self.obj2idx[obj], self.pair2idx[(attr, obj)] 210 | ] 211 | 212 | if self.phase == 'train' and self.same_prim_sample: 213 | [same_attr_image, same_attr, diff_obj], same_attr_mask = self.same_A_diff_B(label_A=attr, label_B=obj, phase='attr') 214 | [same_obj_image, diff_attr, same_obj], same_obj_mask = self.same_A_diff_B(label_A=obj, label_B=attr, phase='obj') 215 | same_attr_img = self.transform(self.loader(same_attr_image)) 216 | same_obj_img = self.transform(self.loader(same_obj_image)) 217 | data += [same_attr_img, self.attr2idx[same_attr], self.obj2idx[diff_obj], 218 | self.train_pair_to_idx[(same_attr, diff_obj)], same_attr_mask, 219 | same_obj_img, self.attr2idx[diff_attr], self.obj2idx[same_obj], 220 | self.train_pair_to_idx[(diff_attr, same_obj)], same_obj_mask] 221 | 222 | return data 223 | 224 | def same_A_diff_B(self, label_A, label_B, phase='attr'): 225 | if phase=='attr': 226 | candidate_list = self.same_attr_diff_obj_dict[(label_A, label_B)] 227 | else: 228 | candidate_list = self.same_obj_diff_attr_dict[(label_B, label_A)] 229 | if len(candidate_list) != 0: 230 | idx = choice(candidate_list) 231 | mask = 1 232 | else: 233 | idx = choice(list(range(len(self.data)))) 234 | mask = 0 235 | return self.data[idx], mask 236 | 237 | def __len__(self): 238 | return len(self.data) 239 | -------------------------------------------------------------------------------- /MSCI/code/model/common.py: -------------------------------------------------------------------------------- 1 | from stringprep import b1_set 2 | from turtle import shape 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torchvision import models 7 | import argparse 8 | import numpy as np 9 | from collections import OrderedDict 10 | from einops import rearrange, reduce, repeat 11 | from einops.layers.torch import Rearrange, Reduce 12 | 13 | class CustomTextEncoder(torch.nn.Module): 14 | def __init__(self, clip_model, tokenizer, dtype=torch.float16): 15 | super().__init__() 16 | self.dtype = dtype 17 | 18 | self.transformer = clip_model.transformer 19 | self.positional_embedding = clip_model.positional_embedding 20 | self.ln_final = clip_model.ln_final 21 | self.text_projection = clip_model.text_projection 22 | self.token_embedding = clip_model.token_embedding 23 | 24 | self.tokenizer = tokenizer 25 | 26 | def tokenize(self, text): 27 | return torch.cat([self.tokenizer(tok) for tok in text]) 28 | 29 | def encode_text(self, text, enable_pos_emb=True): 30 | token_ids = self.tokenize(text) 31 | text_features = self.forward(token_ids, None, enable_pos_emb) 32 | return text_features 33 | 34 | def forward(self, token_ids, token_tensors, enable_pos_emb): 35 | """The forward function to compute representations for the prompts. 36 | 37 | Args: 38 | token_ids (torch.tensor): the token ids, which 39 | contains the token. 40 | token_tensors (torch.Tensor, optional): the tensor 41 | embeddings for the token ids. Defaults to None. 42 | enable_pos_emb (bool, optional): adds the learned 43 | positional embeddigngs if true. Defaults to False. 44 | 45 | Returns: 46 | torch.Tensor: the vector representation of the prompt. 47 | """ 48 | if token_tensors is not None: 49 | text_features = token_tensors 50 | else: 51 | text_features = self.token_embedding(token_ids) 52 | 53 | text_features = text_features.type(self.dtype) 54 | x = ( 55 | text_features + self.positional_embedding.type(self.dtype) 56 | if enable_pos_emb 57 | else text_features 58 | ) 59 | x = x.permute(1, 0, 2) 60 | text_feature = self.transformer(x) 61 | 62 | x = text_feature.permute(1, 0, 2) 63 | x = self.ln_final(x) 64 | tf = ( 65 | x[ 66 | torch.arange(x.shape[0]), token_ids.argmax(dim=-1) 67 | ] # POS of 68 | @ self.text_projection 69 | ) 70 | return tf, text_feature 71 | 72 | 73 | class MLP(nn.Module): 74 | ''' 75 | Baseclass to create a simple MLP 76 | Inputs 77 | inp_dim: Int, Input dimension 78 | out-dim: Int, Output dimension 79 | num_layer: Number of hidden layers 80 | relu: Bool, Use non linear function at output 81 | bias: Bool, Use bias 82 | ''' 83 | def __init__(self, inp_dim, out_dim, num_layers = 1, relu = True, bias = True, dropout = False, norm = False, layers = []): 84 | super(MLP, self).__init__() 85 | mod = [] 86 | incoming = inp_dim 87 | for layer in range(num_layers - 1): 88 | if len(layers) == 0: 89 | outgoing = incoming 90 | else: 91 | outgoing = layers.pop(0) 92 | mod.append(nn.Linear(incoming, outgoing, bias = bias)) 93 | 94 | incoming = outgoing 95 | if norm: 96 | mod.append(nn.LayerNorm(outgoing)) 97 | # mod.append(nn.BatchNorm1d(outgoing)) 98 | mod.append(nn.ReLU(inplace = True)) 99 | # mod.append(nn.LeakyReLU(inplace=True, negative_slope=0.2)) 100 | if dropout: 101 | mod.append(nn.Dropout(p = 0.3)) 102 | 103 | mod.append(nn.Linear(incoming, out_dim, bias = bias)) 104 | 105 | if relu: 106 | mod.append(nn.ReLU(inplace = True)) 107 | # mod.append(nn.LeakyReLU(inplace=True, negative_slope=0.2)) 108 | self.mod = nn.Sequential(*mod) 109 | 110 | def forward(self, x): 111 | return self.mod(x) 112 | 113 | class LayerNorm(nn.LayerNorm): 114 | """Subclass torch's LayerNorm to handle fp16.""" 115 | 116 | def forward(self, x: torch.Tensor): 117 | orig_type = x.dtype 118 | ret = super().forward(x.type(torch.float32)) 119 | return ret.type(orig_type) 120 | 121 | 122 | class QuickGELU(nn.Module): 123 | def forward(self, x: torch.Tensor): 124 | return x * torch.sigmoid(1.702 * x) 125 | 126 | 127 | class ResidualAttentionBlock(nn.Module): 128 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 129 | super().__init__() 130 | 131 | self.attn = nn.MultiheadAttention(d_model, n_head) 132 | self.ln_1 = LayerNorm(d_model) 133 | self.mlp = nn.Sequential(OrderedDict([ 134 | ("c_fc", nn.Linear(d_model, d_model * 4)), 135 | ("gelu", QuickGELU()), 136 | ("drop", nn.Dropout(0.3)), 137 | ("c_proj", nn.Linear(d_model * 4, d_model)) 138 | ])) 139 | self.ln_2 = LayerNorm(d_model) 140 | self.attn_mask = attn_mask 141 | 142 | def attention(self, x: torch.Tensor): 143 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 144 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 145 | 146 | def forward(self, x: torch.Tensor): 147 | x = x + self.attention(self.ln_1(x)) 148 | x = x + self.mlp(self.ln_2(x)) 149 | return x 150 | 151 | 152 | class CrossResidualAttentionBlock(nn.Module): 153 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 154 | super().__init__() 155 | 156 | self.attn = nn.MultiheadAttention(d_model, n_head) 157 | self.ln_x = LayerNorm(d_model) 158 | self.ln_y = LayerNorm(d_model) 159 | self.mlp = nn.Sequential(OrderedDict([ 160 | ("c_fc", nn.Linear(d_model, d_model * 4)), 161 | ("gelu", QuickGELU()), 162 | ("c_proj", nn.Linear(d_model * 4, d_model)) 163 | ])) 164 | self.ln_2 = LayerNorm(d_model) 165 | self.attn_mask = attn_mask 166 | 167 | def attention(self, x: torch.Tensor, y: torch.Tensor): 168 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 169 | return self.attn(x, y, y, need_weights=False, attn_mask=self.attn_mask)[0] 170 | 171 | def forward(self, x: torch.Tensor, y: torch.Tensor): 172 | x = x + self.attention(self.ln_x(x), self.ln_y(y)) 173 | x = x + self.mlp(self.ln_2(x)) 174 | return x 175 | 176 | 177 | 178 | class FusionTextImageBlock(nn.Module): 179 | def __init__(self, 180 | width_img: int, 181 | width_txt: int, 182 | attributes: int, 183 | classes: int, 184 | layers: int, 185 | attn_mask: torch.Tensor = None, 186 | context_length: int = 8, 187 | fusion: str = "BiFusion"): 188 | super().__init__() 189 | self.fusion = fusion 190 | self.width_img = width_img 191 | self.width_txt = width_txt 192 | self.layers = layers 193 | self.context_length = context_length 194 | self.attributes = attributes 195 | self.classes = classes 196 | self.img2txt_transform_layer1 = nn.Linear(width_img, width_txt) 197 | self.img2txt_transform_layer2 = nn.Linear(257, context_length * (attributes + classes)) 198 | self.txt2img_transform_layer1 = nn.Linear(width_txt, width_img) 199 | self.txt2img_transform_layer2 = nn.Linear(context_length * (attributes + classes), 257) 200 | self.dropout = nn.Dropout(0.3) 201 | self.crossblock_img = CrossResidualAttentionBlock(width_img, width_img//64, attn_mask) 202 | self.crossblock_txt = CrossResidualAttentionBlock(width_txt, width_txt//64, attn_mask) 203 | self.resblocks_img = nn.Sequential(*[ResidualAttentionBlock(width_img, width_img//64, attn_mask) for _ in range(layers)]) 204 | self.resblocks_txt = nn.Sequential(*[ResidualAttentionBlock(width_txt, width_txt//64, attn_mask) for _ in range(layers)]) 205 | self.txt_fine_tune = nn.Linear(self.width_txt, self.width_txt) 206 | 207 | 208 | def decompose(self, text_feature, idx): 209 | t, l, c = text_feature.shape 210 | att_idx, obj_idx = idx[:, 0].cpu().numpy(), idx[:, 1].cpu().numpy() 211 | text_att = torch.zeros(t, self.attributes, c).cuda() 212 | text_obj = torch.zeros(t, self.classes, c).cuda() 213 | for i in range(self.attributes): 214 | text_att[:, i, :] = text_feature[:, np.where(att_idx==i)[0], :].mean(-2) 215 | for i in range(self.classes): 216 | text_obj[:, i, :] = text_feature[:, np.where(obj_idx==i)[0], :].mean(-2) 217 | text_decom_feature = torch.cat([text_att, text_obj], dim=1) 218 | return text_decom_feature 219 | 220 | 221 | def compose(self, text_feature, idx): 222 | t, l, c = text_feature.shape 223 | att_idx, obj_idx = idx[:, 0].cpu().numpy(), idx[:, 1].cpu().numpy() 224 | text_com_feature = torch.zeros(t, len(idx), c).cuda() 225 | text_com_feature = text_feature[:, att_idx, :] * text_feature[:, obj_idx + self.attributes, :] 226 | text_com_feature = self.txt_fine_tune(text_com_feature) 227 | return text_com_feature 228 | 229 | 230 | 231 | def img2txt(self, x: torch.Tensor): 232 | x = self.img2txt_transform_layer1(x) 233 | x = x.permute(2,1,0) 234 | x = self.img2txt_transform_layer2(x) 235 | x = x.permute(2,1,0).reshape(-1, (self.attributes + self.classes), self.width_txt) 236 | x = self.dropout(x) 237 | return x 238 | 239 | def txt2img(self, x:torch.Tensor, idx, b: int): 240 | x = self.decompose(x, idx) 241 | x = self.txt2img_transform_layer1(x) 242 | x = rearrange(x, 't l c -> c (t l)') 243 | x = self.txt2img_transform_layer2(x) 244 | x = self.dropout(x) 245 | x = x.permute(1,0).unsqueeze(1).repeat(1,b,1) 246 | return x 247 | 248 | 249 | def forward(self, x_image: torch.Tensor, x_text: torch.Tensor, idx, b: int): 250 | if self.fusion == "BiFusion": 251 | x_img = self.crossblock_img(x_image, self.txt2img(x_text, idx, b)) 252 | x_txt = self.img2txt(x_image) 253 | x_text = self.decompose(x_text, idx) 254 | x_txt = self.crossblock_txt(x_text.repeat(b, 1, 1), x_txt) 255 | x_txt = self.resblocks_txt(x_txt) 256 | x_txt = self.compose(x_txt, idx) 257 | x_txt = x_txt.reshape(b, self.context_length, -1, self.width_txt) 258 | x_img = self.resblocks_img(x_img) 259 | return x_img, x_txt 260 | elif self.fusion == "img2txt": 261 | x_txt = self.img2txt(x_image) 262 | x_text = self.decompose(x_text, idx) 263 | x_txt = self.crossblock_txt(x_text.repeat(b, 1, 1), x_txt) 264 | x_txt = self.resblocks_txt(x_txt) 265 | x_txt = self.compose(x_txt, idx) 266 | x_txt = x_txt.reshape(b, self.context_length, -1, self.width_txt) 267 | x_img = self.resblocks_img(x_image) 268 | return x_img, x_txt 269 | elif self.fusion == "txt2img": 270 | x_img = self.crossblock_img(x_image, self.txt2img(x_text, idx, b)) 271 | x_img = self.resblocks_img(x_img) 272 | x_txt = self.resblocks_txt(x_text) 273 | return x_img, x_txt 274 | elif self.fusion == "OnlySPM": 275 | return x_image, x_text -------------------------------------------------------------------------------- /MSCI/code/clip_modules/clip_model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import os 5 | import hashlib 6 | import warnings 7 | import numpy as np 8 | import urllib 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | from tqdm import tqdm 13 | 14 | 15 | class Bottleneck(nn.Module): 16 | expansion = 4 17 | 18 | def __init__(self, inplanes, planes, stride=1): 19 | super().__init__() 20 | 21 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 22 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | 25 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | 28 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 29 | 30 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 31 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 32 | 33 | self.relu = nn.ReLU(inplace=True) 34 | self.downsample = None 35 | self.stride = stride 36 | 37 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 38 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 39 | self.downsample = nn.Sequential(OrderedDict([ 40 | ("-1", nn.AvgPool2d(stride)), 41 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 42 | ("1", nn.BatchNorm2d(planes * self.expansion)) 43 | ])) 44 | 45 | def forward(self, x: torch.Tensor): 46 | identity = x 47 | 48 | out = self.relu(self.bn1(self.conv1(x))) 49 | out = self.relu(self.bn2(self.conv2(out))) 50 | out = self.avgpool(out) 51 | out = self.bn3(self.conv3(out)) 52 | 53 | if self.downsample is not None: 54 | identity = self.downsample(x) 55 | 56 | out += identity 57 | out = self.relu(out) 58 | return out 59 | 60 | 61 | class AttentionPool2d(nn.Module): 62 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 63 | super().__init__() 64 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 65 | self.k_proj = nn.Linear(embed_dim, embed_dim) 66 | self.q_proj = nn.Linear(embed_dim, embed_dim) 67 | self.v_proj = nn.Linear(embed_dim, embed_dim) 68 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 69 | self.num_heads = num_heads 70 | 71 | def forward(self, x): 72 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 73 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 74 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 75 | x, _ = F.multi_head_attention_forward( 76 | query=x, key=x, value=x, 77 | embed_dim_to_check=x.shape[-1], 78 | num_heads=self.num_heads, 79 | q_proj_weight=self.q_proj.weight, 80 | k_proj_weight=self.k_proj.weight, 81 | v_proj_weight=self.v_proj.weight, 82 | in_proj_weight=None, 83 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 84 | bias_k=None, 85 | bias_v=None, 86 | add_zero_attn=False, 87 | dropout_p=0, 88 | out_proj_weight=self.c_proj.weight, 89 | out_proj_bias=self.c_proj.bias, 90 | use_separate_proj_weight=True, 91 | training=self.training, 92 | need_weights=False 93 | ) 94 | 95 | return x[0] 96 | 97 | 98 | class ModifiedResNet(nn.Module): 99 | """ 100 | A ResNet class that is similar to torchvision's but contains the following changes: 101 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 102 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 103 | - The final pooling layer is a QKV attention instead of an average pool 104 | """ 105 | 106 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 107 | super().__init__() 108 | self.output_dim = output_dim 109 | self.input_resolution = input_resolution 110 | 111 | # the 3-layer stem 112 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 113 | self.bn1 = nn.BatchNorm2d(width // 2) 114 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 115 | self.bn2 = nn.BatchNorm2d(width // 2) 116 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 117 | self.bn3 = nn.BatchNorm2d(width) 118 | self.avgpool = nn.AvgPool2d(2) 119 | self.relu = nn.ReLU(inplace=True) 120 | 121 | # residual layers 122 | self._inplanes = width # this is a *mutable* variable used during construction 123 | self.layer1 = self._make_layer(width, layers[0]) 124 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 125 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 126 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 127 | 128 | embed_dim = width * 32 # the ResNet feature dimension 129 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 130 | 131 | def _make_layer(self, planes, blocks, stride=1): 132 | layers = [Bottleneck(self._inplanes, planes, stride)] 133 | 134 | self._inplanes = planes * Bottleneck.expansion 135 | for _ in range(1, blocks): 136 | layers.append(Bottleneck(self._inplanes, planes)) 137 | 138 | return nn.Sequential(*layers) 139 | 140 | def forward(self, x): 141 | def stem(x): 142 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 143 | x = self.relu(bn(conv(x))) 144 | x = self.avgpool(x) 145 | return x 146 | 147 | x = x.type(self.conv1.weight.dtype) 148 | x = stem(x) 149 | x = self.layer1(x) 150 | x = self.layer2(x) 151 | x = self.layer3(x) 152 | x = self.layer4(x) 153 | x = self.attnpool(x) 154 | 155 | return x 156 | 157 | 158 | class LayerNorm(nn.LayerNorm): 159 | """Subclass torch's LayerNorm to handle fp16.""" 160 | 161 | def forward(self, x: torch.Tensor): 162 | orig_type = x.dtype 163 | ret = super().forward(x.type(torch.float32)) 164 | return ret.type(orig_type) 165 | 166 | 167 | class QuickGELU(nn.Module): 168 | def forward(self, x: torch.Tensor): 169 | return x * torch.sigmoid(1.702 * x) 170 | 171 | 172 | class ResidualAttentionBlock(nn.Module): 173 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 174 | super().__init__() 175 | 176 | self.attn = nn.MultiheadAttention(d_model, n_head) 177 | self.ln_1 = LayerNorm(d_model) 178 | self.mlp = nn.Sequential(OrderedDict([ 179 | ("c_fc", nn.Linear(d_model, d_model * 4)), 180 | ("gelu", QuickGELU()), 181 | ("c_proj", nn.Linear(d_model * 4, d_model)) 182 | ])) 183 | self.ln_2 = LayerNorm(d_model) 184 | self.attn_mask = attn_mask 185 | 186 | def attention(self, x: torch.Tensor): 187 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 188 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 189 | 190 | def forward(self, x: torch.Tensor): 191 | x = x + self.attention(self.ln_1(x)) 192 | x = x + self.mlp(self.ln_2(x)) 193 | return x 194 | 195 | 196 | class Transformer(nn.Module): 197 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 198 | super().__init__() 199 | self.width = width 200 | self.layers = layers 201 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 202 | 203 | def forward(self, x: torch.Tensor): 204 | return self.resblocks(x) 205 | 206 | 207 | class VisualTransformer(nn.Module): 208 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 209 | super().__init__() 210 | self.input_resolution = input_resolution 211 | self.output_dim = output_dim 212 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 213 | 214 | scale = width ** -0.5 215 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 216 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 217 | self.ln_pre = LayerNorm(width) 218 | 219 | self.transformer = Transformer(width, layers, heads) 220 | 221 | self.ln_post = LayerNorm(width) 222 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 223 | 224 | def forward(self, x: torch.Tensor): 225 | x = self.conv1(x) # shape = [*, width, grid, grid] 226 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 227 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 228 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 229 | x = x + self.positional_embedding.to(x.dtype) 230 | x = self.ln_pre(x) 231 | 232 | x = x.permute(1, 0, 2) # NLD -> LND 233 | x = self.transformer(x) 234 | x = x.permute(1, 0, 2) # LND -> NLD 235 | 236 | x = self.ln_post(x[:, 0, :]) 237 | 238 | if self.proj is not None: 239 | x = x @ self.proj 240 | 241 | return x 242 | 243 | 244 | class CLIP(nn.Module): 245 | def __init__(self, 246 | embed_dim: int, 247 | # vision 248 | image_resolution: int, 249 | vision_layers: Union[Tuple[int, int, int, int], int], 250 | vision_width: int, 251 | vision_patch_size: int, 252 | # text 253 | context_length: int, 254 | vocab_size: int, 255 | transformer_width: int, 256 | transformer_heads: int, 257 | transformer_layers: int 258 | ): 259 | super().__init__() 260 | 261 | self.context_length = context_length 262 | 263 | if isinstance(vision_layers, (tuple, list)): 264 | vision_heads = vision_width * 32 // 64 265 | self.visual = ModifiedResNet( 266 | layers=vision_layers, 267 | output_dim=embed_dim, 268 | heads=vision_heads, 269 | input_resolution=image_resolution, 270 | width=vision_width 271 | ) 272 | else: 273 | vision_heads = vision_width // 64 274 | self.visual = VisualTransformer( 275 | input_resolution=image_resolution, 276 | patch_size=vision_patch_size, 277 | width=vision_width, 278 | layers=vision_layers, 279 | heads=vision_heads, 280 | output_dim=embed_dim 281 | ) 282 | 283 | self.transformer = Transformer( 284 | width=transformer_width, 285 | layers=transformer_layers, 286 | heads=transformer_heads, 287 | attn_mask=self.build_attention_mask() 288 | ) 289 | 290 | self.vocab_size = vocab_size 291 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 292 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 293 | self.ln_final = LayerNorm(transformer_width) 294 | 295 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 296 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 297 | 298 | self.initialize_parameters() 299 | 300 | def initialize_parameters(self): 301 | nn.init.normal_(self.token_embedding.weight, std=0.02) 302 | nn.init.normal_(self.positional_embedding, std=0.01) 303 | 304 | if isinstance(self.visual, ModifiedResNet): 305 | if self.visual.attnpool is not None: 306 | std = self.visual.attnpool.c_proj.in_features ** -0.5 307 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 308 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 309 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 310 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 311 | 312 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 313 | for name, param in resnet_block.named_parameters(): 314 | if name.endswith("bn3.weight"): 315 | nn.init.zeros_(param) 316 | 317 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 318 | attn_std = self.transformer.width ** -0.5 319 | fc_std = (2 * self.transformer.width) ** -0.5 320 | for block in self.transformer.resblocks: 321 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 322 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 323 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 324 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 325 | 326 | if self.text_projection is not None: 327 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 328 | 329 | def build_attention_mask(self): 330 | # lazily create causal attention mask, with full attention between the vision tokens 331 | # pytorch uses additive attention mask; fill with -inf 332 | mask = torch.empty(self.context_length, self.context_length) 333 | mask.fill_(float("-inf")) 334 | mask.triu_(1) # zero out the lower diagonal 335 | return mask 336 | 337 | @property 338 | def dtype(self): 339 | return self.visual.conv1.weight.dtype 340 | 341 | def encode_image(self, image): 342 | return self.visual(image.type(self.dtype)) 343 | 344 | def encode_text(self, text, return_all_tokens=False): 345 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 346 | 347 | x = x + self.positional_embedding.type(self.dtype) 348 | x = x.permute(1, 0, 2) # NLD -> LND 349 | x = self.transformer(x) 350 | x = x.permute(1, 0, 2) # LND -> NLD 351 | x = self.ln_final(x).type(self.dtype) 352 | 353 | if return_all_tokens: 354 | return x @ self.text_projection 355 | 356 | # x.shape = [batch_size, n_ctx, transformer.width] 357 | # take features from the eot embedding (eot_token is the highest number in each sequence) 358 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 359 | 360 | return x 361 | 362 | def forward(self, image, text): 363 | image_features = self.encode_image(image) 364 | text_features = self.encode_text(text) 365 | 366 | # normalized features 367 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 368 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 369 | 370 | # cosine similarity as logits 371 | logit_scale = self.logit_scale.exp() 372 | logits_per_image = logit_scale * image_features @ text_features.t() 373 | logits_per_text = logit_scale * text_features @ image_features.t() 374 | 375 | # shape = [global_batch_size, global_batch_size] 376 | return logits_per_image, logits_per_text 377 | 378 | 379 | def convert_weights(model: nn.Module): 380 | """Convert applicable model parameters to fp16""" 381 | 382 | def _convert_weights_to_fp16(l): 383 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 384 | l.weight.data = l.weight.data.half() 385 | if l.bias is not None: 386 | l.bias.data = l.bias.data.half() 387 | 388 | if isinstance(l, nn.MultiheadAttention): 389 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 390 | tensor = getattr(l, attr) 391 | if tensor is not None: 392 | tensor.data = tensor.data.half() 393 | 394 | for name in ["text_projection", "proj"]: 395 | if hasattr(l, name): 396 | attr = getattr(l, name) 397 | if attr is not None: 398 | attr.data = attr.data.half() 399 | 400 | model.apply(_convert_weights_to_fp16) 401 | 402 | 403 | def build_model(state_dict: dict, context_length: int): 404 | vit = "visual.proj" in state_dict 405 | 406 | if vit: 407 | vision_width = state_dict["visual.conv1.weight"].shape[0] 408 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 409 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 410 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 411 | image_resolution = vision_patch_size * grid_size 412 | else: 413 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 414 | vision_layers = tuple(counts) 415 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 416 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 417 | vision_patch_size = None 418 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 419 | image_resolution = output_width * 32 420 | 421 | embed_dim = state_dict["text_projection"].shape[1] 422 | # context_length = state_dict["positional_embedding"].shape[0] 423 | 424 | if context_length != 77: 425 | if context_length > 77: 426 | warnings.warn( 427 | f"context length is set to {context_length}. " 428 | f"Positional embeddings may not work " 429 | ) 430 | zeros = torch.zeros((context_length - 77, embed_dim)) 431 | state_dict["positional_embedding"] = torch.cat( 432 | (state_dict["positional_embedding"], zeros), dim=0 433 | ) 434 | 435 | else: 436 | state_dict["positional_embedding"] = state_dict[ 437 | "positional_embedding" 438 | ][:context_length, :] 439 | 440 | vocab_size = state_dict["token_embedding.weight"].shape[0] 441 | transformer_width = state_dict["ln_final.weight"].shape[0] 442 | transformer_heads = transformer_width // 64 443 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 444 | 445 | model = CLIP( 446 | embed_dim, 447 | image_resolution, vision_layers, vision_width, vision_patch_size, 448 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 449 | ) 450 | 451 | for key in ["input_resolution", "context_length", "vocab_size"]: 452 | if key in state_dict: 453 | del state_dict[key] 454 | 455 | # convert_weights(model) 456 | model.load_state_dict(state_dict) 457 | return model.eval() 458 | 459 | 460 | _MODELS = { 461 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 462 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 463 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 464 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 465 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 466 | "ViT-L-14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 467 | } 468 | 469 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 470 | os.makedirs(root, exist_ok=True) 471 | filename = os.path.basename(url) 472 | 473 | expected_sha256 = url.split("/")[-2] 474 | download_target = os.path.join(root, filename) 475 | 476 | if os.path.exists(download_target) and not os.path.isfile(download_target): 477 | raise RuntimeError(f"{download_target} exists and is not a regular file") 478 | 479 | if os.path.isfile(download_target): 480 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 481 | return download_target 482 | else: 483 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 484 | 485 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 486 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 487 | while True: 488 | buffer = source.read(8192) 489 | if not buffer: 490 | break 491 | 492 | output.write(buffer) 493 | loop.update(len(buffer)) 494 | 495 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 496 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 497 | 498 | return download_target 499 | 500 | def available_models(): 501 | """Returns the names of available CLIP models""" 502 | return list(_MODELS.keys()) 503 | 504 | 505 | def load_clip(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True, context_length=77): 506 | """Load a CLIP model 507 | Parameters 508 | ---------- 509 | name : str 510 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 511 | device : Union[str, torch.device] 512 | The device to put the loaded model 513 | Returns 514 | ------- 515 | model : torch.nn.Module 516 | The CLIP model 517 | """ 518 | 519 | jit = False 520 | if name in _MODELS: 521 | model_path = _download(_MODELS[name]) 522 | elif os.path.isfile(name): 523 | model_path = name 524 | else: 525 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 526 | 527 | try: 528 | # loading JIT archive 529 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 530 | state_dict = None 531 | except RuntimeError: 532 | # loading saved state dict 533 | if jit: 534 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 535 | jit = False 536 | state_dict = torch.load(model_path, map_location="cpu") 537 | 538 | if not jit: 539 | model = build_model(state_dict or model.state_dict(), context_length).to(device) 540 | if str(device) == "cpu": 541 | model.float() 542 | return model 543 | 544 | # patch the device names 545 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 546 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 547 | 548 | def patch_device(module): 549 | graphs = [module.graph] if hasattr(module, "graph") else [] 550 | if hasattr(module, "forward1"): 551 | graphs.append(module.forward1.graph) 552 | 553 | for graph in graphs: 554 | for node in graph.findAllNodes("prim::Constant"): 555 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 556 | node.copyAttributes(device_node) 557 | 558 | model.apply(patch_device) 559 | patch_device(model.encode_image) 560 | patch_device(model.encode_text) 561 | 562 | # patch dtype to float32 on CPU 563 | if str(device) == "cpu": 564 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 565 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 566 | float_node = float_input.node() 567 | 568 | def patch_float(module): 569 | graphs = [module.graph] if hasattr(module, "graph") else [] 570 | if hasattr(module, "forward1"): 571 | graphs.append(module.forward1.graph) 572 | 573 | for graph in graphs: 574 | for node in graph.findAllNodes("aten::to"): 575 | inputs = list(node.inputs()) 576 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 577 | if inputs[i].node()["value"] == 5: 578 | inputs[i].node().copyAttributes(float_node) 579 | 580 | model.apply(patch_float) 581 | patch_float(model.encode_image) 582 | patch_float(model.encode_text) 583 | 584 | model.float() 585 | 586 | return model -------------------------------------------------------------------------------- /MSCI/code/test_base.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import copy 4 | import json 5 | import os 6 | from itertools import product 7 | 8 | import numpy as np 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | from scipy.stats import hmean 12 | from torch.utils.data.dataloader import DataLoader 13 | from tqdm import tqdm 14 | import cv2 15 | 16 | from utils import * 17 | from parameters import parser 18 | from dataset import CompositionDataset 19 | from model.model_factory_final import get_model 20 | import pandas as pd 21 | 22 | 23 | cudnn.benchmark = True 24 | 25 | device = "cuda" if torch.cuda.is_available() else "cpu" 26 | 27 | 28 | class Evaluator: 29 | """ 30 | Evaluator class, adapted from: 31 | https://github.com/Tushar-N/attributes-as-operators 32 | 33 | With modifications from: 34 | https://github.com/ExplainableML/czsl 35 | """ 36 | 37 | def __init__(self, dset, model): 38 | 39 | self.dset = dset 40 | 41 | # Convert text pairs to idx tensors: [('sliced', 'apple'), ('ripe', 42 | # 'apple'), ...] --> torch.LongTensor([[0,1],[1,1], ...]) 43 | pairs = [(dset.attr2idx[attr], dset.obj2idx[obj]) 44 | for attr, obj in dset.pairs] 45 | self.train_pairs = [(dset.attr2idx[attr], dset.obj2idx[obj]) 46 | for attr, obj in dset.train_pairs] 47 | self.pairs = torch.LongTensor(pairs) 48 | 49 | # Mask over pairs that occur in closed world 50 | # Select set based on phase 51 | if dset.phase == 'train': 52 | print('Evaluating with train pairs') 53 | test_pair_set = set(dset.train_pairs) 54 | test_pair_gt = set(dset.train_pairs) 55 | elif dset.phase == 'val': 56 | print('Evaluating with validation pairs') 57 | test_pair_set = set(dset.val_pairs + dset.train_pairs) 58 | test_pair_gt = set(dset.val_pairs) 59 | else: 60 | print('Evaluating with test pairs') 61 | test_pair_set = set(dset.test_pairs + dset.train_pairs) 62 | test_pair_gt = set(dset.test_pairs) 63 | 64 | self.test_pair_dict = [ 65 | (dset.attr2idx[attr], 66 | dset.obj2idx[obj]) for attr, 67 | obj in test_pair_gt] 68 | self.test_pair_dict = dict.fromkeys(self.test_pair_dict, 0) 69 | 70 | # dict values are pair val, score, total 71 | for attr, obj in test_pair_gt: 72 | pair_val = dset.pair2idx[(attr, obj)] 73 | key = (dset.attr2idx[attr], dset.obj2idx[obj]) 74 | self.test_pair_dict[key] = [pair_val, 0, 0] 75 | 76 | # open world 77 | if dset.open_world: 78 | masks = [1 for _ in dset.pairs] 79 | else: 80 | masks = [1 if pair in test_pair_set else 0 for pair in dset.pairs] 81 | 82 | # masks = [1 if pair in test_pair_set else 0 for pair in dset.pairs] 83 | 84 | self.closed_mask = torch.BoolTensor(masks) 85 | # Mask of seen concepts 86 | seen_pair_set = set(dset.train_pairs) 87 | mask = [1 if pair in seen_pair_set else 0 for pair in dset.pairs] 88 | self.seen_mask = torch.BoolTensor(mask) 89 | 90 | # Object specific mask over which pairs occur in the object oracle 91 | # setting 92 | oracle_obj_mask = [] 93 | for _obj in dset.objs: 94 | mask = [1 if _obj == obj else 0 for attr, obj in dset.pairs] 95 | oracle_obj_mask.append(torch.BoolTensor(mask)) 96 | self.oracle_obj_mask = torch.stack(oracle_obj_mask, 0) 97 | 98 | # Decide if the model under evaluation is a manifold model or not 99 | self.score_model = self.score_manifold_model 100 | 101 | # Generate mask for each settings, mask scores, and get prediction labels 102 | def generate_predictions(self, scores, obj_truth, bias=0.0, topk=1): # (Batch, #pairs) 103 | ''' 104 | Inputs 105 | scores: Output scores 106 | obj_truth: Ground truth object 107 | Returns 108 | results: dict of results in 3 settings 109 | ''' 110 | 111 | def get_pred_from_scores(_scores, topk): 112 | """ 113 | Given list of scores, returns top 10 attr and obj predictions 114 | Check later 115 | """ 116 | _, pair_pred = _scores.topk( 117 | topk, dim=1) # sort returns indices of k largest values 118 | pair_pred = pair_pred.contiguous().view(-1) 119 | attr_pred, obj_pred = self.pairs[pair_pred][:, 0].view( 120 | -1, topk 121 | ), self.pairs[pair_pred][:, 1].view(-1, topk) 122 | return (attr_pred, obj_pred) 123 | 124 | results = {} 125 | orig_scores = scores.clone() 126 | mask = self.seen_mask.repeat( 127 | scores.shape[0], 1 128 | ) # Repeat mask along pairs dimension 129 | scores[~mask] += bias # Add bias to test pairs 130 | 131 | # Unbiased setting 132 | 133 | # Open world setting --no mask, all pairs of the dataset 134 | results.update({"open": get_pred_from_scores(scores, topk)}) 135 | results.update( 136 | {"unbiased_open": get_pred_from_scores(orig_scores, topk)} 137 | ) 138 | # Closed world setting - set the score for all Non test pairs to -1e10, 139 | # this excludes the pairs from set not in evaluation 140 | mask = self.closed_mask.repeat(scores.shape[0], 1) 141 | closed_scores = scores.clone() 142 | closed_scores[~mask] = -1e10 143 | closed_orig_scores = orig_scores.clone() 144 | closed_orig_scores[~mask] = -1e10 145 | results.update({"closed": get_pred_from_scores(closed_scores, topk)}) 146 | results.update( 147 | {"unbiased_closed": get_pred_from_scores(closed_orig_scores, topk)} 148 | ) 149 | 150 | return results 151 | 152 | def score_clf_model(self, scores, obj_truth, topk=1): 153 | ''' 154 | Wrapper function to call generate_predictions for CLF models 155 | ''' 156 | attr_pred, obj_pred = scores 157 | 158 | # Go to CPU 159 | attr_pred, obj_pred, obj_truth = attr_pred.to( 160 | 'cpu'), obj_pred.to('cpu'), obj_truth.to('cpu') 161 | 162 | # Gather scores (P(a), P(o)) for all relevant (a,o) pairs 163 | # Multiply P(a) * P(o) to get P(pair) 164 | # Return only attributes that are in our pairs 165 | attr_subset = attr_pred.index_select(1, self.pairs[:, 0]) 166 | obj_subset = obj_pred.index_select(1, self.pairs[:, 1]) 167 | scores = (attr_subset * obj_subset) # (Batch, #pairs) 168 | 169 | results = self.generate_predictions(scores, obj_truth) 170 | results['biased_scores'] = scores 171 | 172 | return results 173 | 174 | def score_manifold_model(self, scores, obj_truth, bias=0.0, topk=1): 175 | ''' 176 | Wrapper function to call generate_predictions for manifold models 177 | ''' 178 | # Go to CPU 179 | scores = {k: v.to('cpu') for k, v in scores.items()} 180 | obj_truth = obj_truth.to(device) 181 | 182 | # Gather scores for all relevant (a,o) pairs 183 | scores = torch.stack( 184 | [scores[(attr, obj)] for attr, obj in self.dset.pairs], 1 185 | ) # (Batch, #pairs) 186 | orig_scores = scores.clone() 187 | results = self.generate_predictions(scores, obj_truth, bias, topk) 188 | results['scores'] = orig_scores 189 | return results 190 | 191 | def score_fast_model(self, scores, obj_truth, bias=0.0, topk=1): 192 | ''' 193 | Wrapper function to call generate_predictions for manifold models 194 | ''' 195 | 196 | results = {} 197 | # Repeat mask along pairs dimension 198 | mask = self.seen_mask.repeat(scores.shape[0], 1) 199 | scores[~mask] += bias # Add bias to test pairs 200 | 201 | mask = self.closed_mask.repeat(scores.shape[0], 1) 202 | closed_scores = scores.clone() 203 | closed_scores[~mask] = -1e10 204 | 205 | # sort returns indices of k largest values 206 | _, pair_pred = closed_scores.topk(topk, dim=1) 207 | # _, pair_pred = scores.topk(topk, dim=1) # sort returns indices of k 208 | # largest values 209 | pair_pred = pair_pred.contiguous().view(-1) 210 | attr_pred, obj_pred = self.pairs[pair_pred][:, 0].view(-1, topk), \ 211 | self.pairs[pair_pred][:, 1].view(-1, topk) 212 | 213 | results.update({'closed': (attr_pred, obj_pred)}) 214 | return results 215 | 216 | def evaluate_predictions( 217 | self, 218 | predictions, 219 | attr_truth, 220 | obj_truth, 221 | pair_truth, 222 | allpred, 223 | topk=1): 224 | # Go to CPU 225 | attr_truth, obj_truth, pair_truth = ( 226 | attr_truth.to("cpu"), 227 | obj_truth.to("cpu"), 228 | pair_truth.to("cpu"), 229 | ) 230 | 231 | pairs = list(zip(list(attr_truth.numpy()), list(obj_truth.numpy()))) 232 | 233 | seen_ind, unseen_ind = [], [] 234 | for i in range(len(attr_truth)): 235 | if pairs[i] in self.train_pairs: 236 | seen_ind.append(i) 237 | else: 238 | unseen_ind.append(i) 239 | 240 | seen_ind, unseen_ind = torch.LongTensor(seen_ind), torch.LongTensor( 241 | unseen_ind 242 | ) 243 | 244 | def _process(_scores): 245 | # Top k pair accuracy 246 | # Attribute, object and pair 247 | attr_match = ( 248 | attr_truth.unsqueeze(1).repeat(1, topk) == _scores[0][:, :topk] 249 | ) 250 | obj_match = ( 251 | obj_truth.unsqueeze(1).repeat(1, topk) == _scores[1][:, :topk] 252 | ) 253 | 254 | # Match of object pair 255 | match = (attr_match * obj_match).any(1).float() 256 | attr_match = attr_match.any(1).float() 257 | obj_match = obj_match.any(1).float() 258 | # Match of seen and unseen pairs 259 | seen_match = match[seen_ind] 260 | unseen_match = match[unseen_ind] 261 | # Calculating class average accuracy 262 | 263 | seen_score, unseen_score = torch.ones(512, 5), torch.ones(512, 5) 264 | 265 | return attr_match, obj_match, match, seen_match, unseen_match, torch.Tensor( 266 | seen_score + unseen_score), torch.Tensor(seen_score), torch.Tensor(unseen_score) 267 | 268 | def _add_to_dict(_scores, type_name, stats): 269 | base = [ 270 | "_attr_match", 271 | "_obj_match", 272 | "_match", 273 | "_seen_match", 274 | "_unseen_match", 275 | "_ca", 276 | "_seen_ca", 277 | "_unseen_ca", 278 | ] 279 | for val, name in zip(_scores, base): 280 | stats[type_name + name] = val 281 | 282 | stats = dict() 283 | 284 | # Closed world 285 | closed_scores = _process(predictions["closed"]) 286 | unbiased_closed = _process(predictions["unbiased_closed"]) 287 | _add_to_dict(closed_scores, "closed", stats) 288 | _add_to_dict(unbiased_closed, "closed_ub", stats) 289 | 290 | # Calculating AUC 291 | scores = predictions["scores"] 292 | # getting score for each ground truth class 293 | correct_scores = scores[torch.arange(scores.shape[0]), pair_truth][ 294 | unseen_ind 295 | ] 296 | 297 | # Getting top predicted score for these unseen classes 298 | max_seen_scores = predictions['scores'][unseen_ind][:, self.seen_mask].topk(topk, dim=1)[ 299 | 0][:, topk - 1] 300 | 301 | # Getting difference between these scores 302 | unseen_score_diff = max_seen_scores - correct_scores 303 | 304 | # Getting matched classes at max bias for diff 305 | unseen_matches = stats["closed_unseen_match"].bool() 306 | correct_unseen_score_diff = unseen_score_diff[unseen_matches] - 1e-4 307 | 308 | # sorting these diffs 309 | correct_unseen_score_diff = torch.sort(correct_unseen_score_diff)[0] 310 | magic_binsize = 20 311 | # getting step size for these bias values 312 | bias_skip = max(len(correct_unseen_score_diff) // magic_binsize, 1) 313 | # Getting list 314 | biaslist = correct_unseen_score_diff[::bias_skip] 315 | 316 | seen_match_max = float(stats["closed_seen_match"].mean()) 317 | unseen_match_max = float(stats["closed_unseen_match"].mean()) 318 | seen_accuracy, unseen_accuracy = [], [] 319 | 320 | # Go to CPU 321 | base_scores = {k: v.to("cpu") for k, v in allpred.items()} 322 | obj_truth = obj_truth.to("cpu") 323 | 324 | # Gather scores for all relevant (a,o) pairs 325 | base_scores = torch.stack( 326 | [allpred[(attr, obj)] for attr, obj in self.dset.pairs], 1 327 | ) # (Batch, #pairs) 328 | 329 | for bias in biaslist: 330 | scores = base_scores.clone() 331 | results = self.score_fast_model( 332 | scores, obj_truth, bias=bias, topk=topk) 333 | results = results['closed'] # we only need biased 334 | results = _process(results) 335 | seen_match = float(results[3].mean()) 336 | unseen_match = float(results[4].mean()) 337 | seen_accuracy.append(seen_match) 338 | unseen_accuracy.append(unseen_match) 339 | 340 | seen_accuracy.append(seen_match_max) 341 | unseen_accuracy.append(unseen_match_max) 342 | seen_accuracy, unseen_accuracy = np.array(seen_accuracy), np.array( 343 | unseen_accuracy 344 | ) 345 | area = np.trapz(seen_accuracy, unseen_accuracy) 346 | 347 | for key in stats: 348 | stats[key] = float(stats[key].mean()) 349 | 350 | try: 351 | harmonic_mean = hmean([seen_accuracy, unseen_accuracy], axis=0) 352 | except BaseException: 353 | harmonic_mean = 0 354 | 355 | max_hm = np.max(harmonic_mean) 356 | idx = np.argmax(harmonic_mean) 357 | if idx == len(biaslist): 358 | bias_term = 1e3 359 | else: 360 | bias_term = biaslist[idx] 361 | stats["biasterm"] = float(bias_term) 362 | stats["best_unseen"] = np.max(unseen_accuracy) 363 | stats["best_seen"] = np.max(seen_accuracy) 364 | stats["AUC"] = area 365 | stats["hm_unseen"] = unseen_accuracy[idx] 366 | stats["hm_seen"] = seen_accuracy[idx] 367 | stats["best_hm"] = max_hm 368 | return stats 369 | 370 | 371 | def predict_logits(model, dataset, config): 372 | """Function to predict the cosine similarities between the 373 | images and the attribute-object representations. The function 374 | also returns the ground truth for attributes, objects, and pair 375 | of attribute-objects. 376 | 377 | Args: 378 | model (nn.Module): the model 379 | text_rep (nn.Tensor): the attribute-object representations. 380 | dataset (CompositionDataset): the composition dataset (validation/test) 381 | device (str): the device (either cpu/cuda:0) 382 | config (argparse.ArgumentParser): config/args 383 | 384 | Returns: 385 | tuple: the logits, attribute labels, object labels, 386 | pair attribute-object labels 387 | """ 388 | model.eval() 389 | all_attr_gt, all_obj_gt, all_pair_gt = ( 390 | [], 391 | [], 392 | [], 393 | ) 394 | attr2idx = dataset.attr2idx 395 | obj2idx = dataset.obj2idx 396 | # print(text_rep.shape) 397 | pairs_dataset = dataset.pairs 398 | pairs = torch.tensor([(attr2idx[attr], obj2idx[obj]) 399 | for attr, obj in pairs_dataset]).cuda() 400 | dataloader = DataLoader( 401 | dataset, 402 | batch_size=config.eval_batch_size, 403 | shuffle=False, 404 | num_workers=config.num_workers) 405 | all_logits = torch.Tensor() 406 | loss = 0 407 | with torch.no_grad(): 408 | for idx, data in tqdm( 409 | enumerate(dataloader), total=len(dataloader), desc="Testing" 410 | ): 411 | # batch_img = data[0].cuda() 412 | predict = model(data, pairs) 413 | logits = model.logit_infer(predict, pairs) 414 | loss += model.loss_calu(predict, data).item() 415 | attr_truth, obj_truth, pair_truth = data[1], data[2], data[3] 416 | logits = logits.cpu() 417 | all_logits = torch.cat([all_logits, logits], dim=0) 418 | all_attr_gt.append(attr_truth) 419 | all_obj_gt.append(obj_truth) 420 | all_pair_gt.append(pair_truth) 421 | 422 | all_attr_gt, all_obj_gt, all_pair_gt = ( 423 | torch.cat(all_attr_gt).to("cpu"), 424 | torch.cat(all_obj_gt).to("cpu"), 425 | torch.cat(all_pair_gt).to("cpu"), 426 | ) 427 | 428 | return all_logits, all_attr_gt, all_obj_gt, all_pair_gt, loss / len(dataloader) 429 | 430 | 431 | def predict_logits_text_first(model, dataset, config): 432 | """Function to predict the cosine similarities between the 433 | images and the attribute-object representations. The function 434 | also returns the ground truth for attributes, objects, and pair 435 | of attribute-objects. 436 | 437 | Args: 438 | model (nn.Module): the model 439 | text_rep (nn.Tensor): the attribute-object representations. 440 | dataset (CompositionDataset): the composition dataset (validation/test) 441 | device (str): the device (either cpu/cuda:0) 442 | config (argparse.ArgumentParser): config/args 443 | 444 | Returns: 445 | tuple: the logits, attribute labels, object labels, 446 | pair attribute-object labels 447 | """ 448 | model.eval() 449 | all_attr_gt, all_obj_gt, all_pair_gt = ( 450 | [], 451 | [], 452 | [], 453 | ) 454 | attr2idx = dataset.attr2idx 455 | obj2idx = dataset.obj2idx 456 | # print(text_rep.shape) 457 | print(attr2idx) 458 | print(obj2idx) 459 | pairs_dataset = dataset.pairs 460 | pairs = torch.tensor([(attr2idx[attr], obj2idx[obj]) 461 | for attr, obj in pairs_dataset]).cuda() 462 | print(pairs) 463 | dataloader = DataLoader( 464 | dataset, 465 | batch_size=config.eval_batch_size, 466 | shuffle=False, 467 | num_workers=config.num_workers) 468 | all_logits = torch.Tensor() 469 | loss = 0 470 | with torch.no_grad(): 471 | text_feats = [[], [], []] 472 | num_text_batch = pairs.shape[0] // config.text_encoder_batch_size 473 | for i_text_batch in range(num_text_batch): 474 | cur_pair = pairs[i_text_batch*config.text_encoder_batch_size:(i_text_batch+1)*config.text_encoder_batch_size, :] 475 | cur_text_feats = model.encode_text_for_open(cur_pair) 476 | for i_item in range(len(text_feats)): 477 | text_feats[i_item].append(cur_text_feats[i_item]) 478 | if pairs.shape[0] % config.text_encoder_batch_size != 0: 479 | cur_pair = pairs[num_text_batch*config.text_encoder_batch_size:, :] 480 | cur_text_feats = model.encode_text_for_open(cur_pair) 481 | for i_item in range(len(text_feats)): 482 | text_feats[i_item].append(cur_text_feats[i_item]) 483 | for i_item in range(len(text_feats)): 484 | text_feats[i_item] = torch.cat(text_feats[i_item], dim=0) 485 | for idx, data in tqdm( 486 | enumerate(dataloader), total=len(dataloader), desc="Testing" 487 | ): 488 | # batch_img = data[0].cuda() 489 | predict = model.forward_for_open(data, text_feats) 490 | logits = model.logit_infer(predict, pairs) 491 | loss += model.loss_calu(predict, data).item() 492 | attr_truth, obj_truth, pair_truth = data[1], data[2], data[3] 493 | logits = logits.cpu() 494 | all_logits = torch.cat([all_logits, logits], dim=0) 495 | all_attr_gt.append(attr_truth) 496 | all_obj_gt.append(obj_truth) 497 | all_pair_gt.append(pair_truth) 498 | 499 | all_attr_gt, all_obj_gt, all_pair_gt = ( 500 | torch.cat(all_attr_gt).to("cpu"), 501 | torch.cat(all_obj_gt).to("cpu"), 502 | torch.cat(all_pair_gt).to("cpu"), 503 | ) 504 | 505 | # ? delete the text encoder to save CUDA memory 506 | # del model.transformer 507 | # torch.cuda.empty_cache() 508 | 509 | return all_logits, all_attr_gt, all_obj_gt, all_pair_gt, loss / len(dataloader) 510 | 511 | 512 | def threshold_with_feasibility( 513 | logits, 514 | seen_mask, 515 | threshold=None, 516 | feasiblity=None): 517 | """Function to remove infeasible compositions. 518 | 519 | Args: 520 | logits (torch.Tensor): the cosine similarities between 521 | the images and the attribute-object pairs. 522 | seen_mask (torch.tensor): the seen mask with binary 523 | threshold (float, optional): the threshold value. 524 | Defaults to None. 525 | feasiblity (torch.Tensor, optional): the feasibility. 526 | Defaults to None. 527 | 528 | Returns: 529 | torch.Tensor: the logits after filtering out the 530 | infeasible compositions. 531 | """ 532 | score = copy.deepcopy(logits) 533 | # Note: Pairs are already aligned here 534 | mask = (feasiblity >= threshold).float() 535 | # score = score*mask + (1.-mask)*(-1.) 536 | score = score * (mask + seen_mask) 537 | 538 | return score 539 | 540 | 541 | def test( 542 | test_dataset, 543 | evaluator, 544 | all_logits, 545 | all_attr_gt, 546 | all_obj_gt, 547 | all_pair_gt, 548 | config): 549 | """Function computes accuracy on the validation and 550 | test dataset. 551 | 552 | Args: 553 | test_dataset (CompositionDataset): the validation/test 554 | dataset 555 | evaluator (Evaluator): the evaluator object 556 | all_logits (torch.Tensor): the cosine similarities between 557 | the images and the attribute-object pairs. 558 | all_attr_gt (torch.tensor): the attribute ground truth 559 | all_obj_gt (torch.tensor): the object ground truth 560 | all_pair_gt (torch.tensor): the attribute-object pair ground 561 | truth 562 | config (argparse.ArgumentParser): the config 563 | 564 | Returns: 565 | dict: the result with all the metrics 566 | """ 567 | """此功能在验证和测试数据集上计算准确性。 568 | 569 | 参数: 570 | test_dataset (CompositionDataset): 验证/测试数据集 571 | evaluator (Evaluator): 评估器对象 572 | all_logits (torch.Tensor): 图像与属性-对象对之间的余弦相似度。 573 | all_attr_gt (torch.tensor): 属性的真实值 574 | all_obj_gt (torch.tensor): 对象的真实值 575 | all_pair_gt (torch.tensor): 属性-对象对的真实值 576 | config (argparse.ArgumentParser): 配置参数 577 | 578 | 返回: 579 | dict: 所有度量的结果 580 | """ 581 | 582 | predictions = { 583 | pair_name: all_logits[:, i] 584 | for i, pair_name in enumerate(test_dataset.pairs) 585 | } 586 | all_pred = [predictions] 587 | 588 | all_pred_dict = {} 589 | for k in all_pred[0].keys(): 590 | all_pred_dict[k] = torch.cat( 591 | [all_pred[i][k] for i in range(len(all_pred))] 592 | ).float() 593 | 594 | results = evaluator.score_model( 595 | all_pred_dict, all_obj_gt, bias=1e3, topk=1 596 | ) 597 | 598 | results['predicted_attributes'] = results['unbiased_closed'][0].squeeze(-1).tolist() 599 | results['predicted_objects'] = results['unbiased_closed'][1].squeeze(-1).tolist() 600 | results['true_attributes'] = all_attr_gt.tolist() 601 | results['true_objects'] = all_obj_gt.tolist() 602 | 603 | results['true_pairs'] = all_pair_gt.tolist() 604 | 605 | # Create a DataFrame to store results 606 | df = pd.DataFrame({ 607 | 'Predicted Attributes': results['predicted_attributes'], 608 | 'Predicted Objects': results['predicted_objects'], 609 | 'True Attributes': results['true_attributes'], 610 | 'True Objects': results['true_objects'], 611 | 'True Pairs': results['true_pairs'] 612 | }) 613 | 614 | # Save the DataFrame to a CSV file 615 | df.to_csv('test_results.csv', index=False) 616 | 617 | attr_acc = float(torch.mean( 618 | (results['unbiased_closed'][0].squeeze(-1) == all_attr_gt).float())) 619 | obj_acc = float(torch.mean( 620 | (results['unbiased_closed'][1].squeeze(-1) == all_obj_gt).float())) 621 | 622 | stats = evaluator.evaluate_predictions( 623 | results, 624 | all_attr_gt, 625 | all_obj_gt, 626 | all_pair_gt, 627 | all_pred_dict, 628 | topk=1, 629 | ) 630 | 631 | stats['attr_acc'] = attr_acc 632 | stats['obj_acc'] = obj_acc 633 | 634 | return stats 635 | 636 | 637 | if __name__ == "__main__": 638 | config = parser.parse_args() 639 | if config.yml_path: 640 | load_args(config.yml_path, config) 641 | 642 | # set the seed value 643 | print("----") 644 | test_type = 'OPEN WORLD' if config.open_world else 'CLOSED WORLD' 645 | print(f"{test_type} evaluation details") 646 | print("----") 647 | print(f"dataset: {config.dataset}") 648 | 649 | 650 | dataset_path = config.dataset_path 651 | 652 | print('loading validation dataset') 653 | val_dataset = CompositionDataset(dataset_path, 654 | phase='val', 655 | split='compositional-split-natural', 656 | open_world=config.open_world) 657 | 658 | allattrs = val_dataset.attrs 659 | allobj = val_dataset.objs 660 | classes = [cla.replace(".", " ").lower() for cla in allobj] 661 | attributes = [attr.replace(".", " ").lower() for attr in allattrs] 662 | offset = len(attributes) 663 | 664 | model = get_model(config, attributes=attributes, classes=classes, offset=offset).cuda() 665 | if config.load_model: 666 | model.load_state_dict(torch.load(config.load_model)) 667 | predict_logits_func = predict_logits 668 | # ? can be deleted if not needed 669 | if (hasattr(config, 'text_first') and config.text_first): 670 | print('text first') 671 | predict_logits_func = predict_logits_text_first 672 | 673 | print('evaluating on the validation set') 674 | if config.open_world and config.threshold is None: 675 | evaluator = Evaluator(val_dataset, model=None) 676 | feasibility_path = os.path.join( 677 | DIR_PATH, f'data/feasibility_{config.dataset}.pt') 678 | unseen_scores = torch.load( 679 | feasibility_path, 680 | map_location='cpu')['feasibility'] 681 | seen_mask = val_dataset.seen_mask.to('cpu') 682 | min_feasibility = (unseen_scores + seen_mask * 10.).min() 683 | max_feasibility = (unseen_scores - seen_mask * 10.).max() 684 | thresholds = np.linspace( 685 | min_feasibility, 686 | max_feasibility, 687 | num=config.threshold_trials) 688 | best_auc = 0. 689 | best_th = -10 690 | val_stats = None 691 | with torch.no_grad(): 692 | all_logits, all_attr_gt, all_obj_gt, all_pair_gt, loss_avg = predict_logits_func( 693 | model, val_dataset, config) 694 | for th in thresholds: 695 | temp_logits = threshold_with_feasibility( 696 | all_logits, val_dataset.seen_mask, threshold=th, feasiblity=unseen_scores) 697 | results = test( 698 | val_dataset, 699 | evaluator, 700 | temp_logits, 701 | all_attr_gt, 702 | all_obj_gt, 703 | all_pair_gt, 704 | config 705 | ) 706 | auc = results['AUC'] 707 | if auc > best_auc: 708 | best_auc = auc 709 | best_th = th 710 | print('New best AUC', best_auc) 711 | print('Threshold', best_th) 712 | val_stats = copy.deepcopy(results) 713 | else: 714 | best_th = config.threshold 715 | evaluator = Evaluator(val_dataset, model=None) 716 | feasibility_path = os.path.join( 717 | DIR_PATH, f'data/feasibility_{config.dataset}.pt') 718 | unseen_scores = torch.load( 719 | feasibility_path, 720 | map_location='cpu')['feasibility'] 721 | with torch.no_grad(): 722 | all_logits, all_attr_gt, all_obj_gt, all_pair_gt, loss_avg = predict_logits_func( 723 | model, val_dataset, config) 724 | if config.open_world: 725 | print('using threshold: ', best_th) 726 | all_logits = threshold_with_feasibility( 727 | all_logits, val_dataset.seen_mask, threshold=best_th, feasiblity=unseen_scores) 728 | results = test( 729 | val_dataset, 730 | evaluator, 731 | all_logits, 732 | all_attr_gt, 733 | all_obj_gt, 734 | all_pair_gt, 735 | config 736 | ) 737 | val_stats = copy.deepcopy(results) 738 | result = "" 739 | for key in val_stats: 740 | result = result + key + " " + str(round(val_stats[key], 4)) + "| " 741 | print(result) 742 | 743 | # del val_dataset 744 | # torch.cuda.empty_cache() 745 | print('loading test dataset') 746 | test_dataset = CompositionDataset(dataset_path, 747 | phase='test', 748 | split='compositional-split-natural', 749 | open_world=config.open_world) 750 | print('evaluating on the test set') 751 | with torch.no_grad(): 752 | evaluator = Evaluator(test_dataset, model=None) 753 | all_logits, all_attr_gt, all_obj_gt, all_pair_gt, loss_avg = predict_logits_func( 754 | model, test_dataset, config) 755 | if config.open_world and best_th is not None: 756 | print('using threshold: ', best_th) 757 | all_logits = threshold_with_feasibility( 758 | all_logits, 759 | test_dataset.seen_mask, 760 | threshold=best_th, 761 | feasiblity=unseen_scores) 762 | test_stats = test( 763 | test_dataset, 764 | evaluator, 765 | all_logits, 766 | all_attr_gt, 767 | all_obj_gt, 768 | all_pair_gt, 769 | config 770 | ) 771 | 772 | result = "" 773 | for key in test_stats: 774 | result = result + key + " " + \ 775 | str(round(test_stats[key], 4)) + "| " 776 | print(result) 777 | 778 | results = { 779 | 'val': val_stats, 780 | 'test': test_stats, 781 | } 782 | 783 | if best_th is not None: 784 | results['best_threshold'] = best_th 785 | 786 | if config.load_model: 787 | title = config.load_model[:-2] 788 | else: 789 | os.makedirs(config.save_path, exist_ok=True) 790 | title = config.save_path + '/' 791 | if config.open_world: 792 | result_path = title + "open.calibrated.json" 793 | else: 794 | result_path = title + "closed.json" 795 | 796 | with open(result_path, 'w+') as fp: 797 | json.dump(results, fp) 798 | 799 | print("done!") -------------------------------------------------------------------------------- /MSCI/code/model/Mutifuse_new.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import math 6 | from functools import reduce 7 | from operator import mul 8 | from copy import deepcopy 9 | from torch.nn.modules.utils import _pair 10 | from torch.nn.modules.loss import CrossEntropyLoss 11 | from clip_modules.clip_model import load_clip, QuickGELU 12 | from clip_modules.tokenization_clip import SimpleTokenizer 13 | from model.common import * 14 | 15 | 16 | class Adapter(nn.Module): 17 | # Referece: https://github.com/ShoufaChen/AdaptFormer 18 | def __init__(self, 19 | d_model=None, 20 | bottleneck=None, 21 | dropout=0.0, 22 | init_option="lora", 23 | adapter_scalar="0.1", 24 | adapter_layernorm_option="none"): 25 | super().__init__() 26 | self.n_embd = d_model 27 | self.down_size = bottleneck 28 | 29 | # _before 30 | self.adapter_layernorm_option = adapter_layernorm_option 31 | 32 | self.adapter_layer_norm_before = None 33 | if adapter_layernorm_option == "in" or adapter_layernorm_option == "out": 34 | self.adapter_layer_norm_before = nn.LayerNorm(self.n_embd) 35 | 36 | if adapter_scalar == "learnable_scalar": 37 | self.scale = nn.Parameter(torch.ones(1)) 38 | else: 39 | self.scale = float(adapter_scalar) 40 | 41 | self.down_proj = nn.Linear(self.n_embd, self.down_size) 42 | self.non_linear_func = nn.ReLU() 43 | self.up_proj = nn.Linear(self.down_size, self.n_embd) 44 | 45 | self.dropout = dropout 46 | self.init_option = init_option 47 | 48 | self._reset_parameters() 49 | 50 | def _reset_parameters(self): 51 | if self.init_option == "bert": 52 | raise NotImplementedError 53 | elif self.init_option == "lora": 54 | with torch.no_grad(): 55 | nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5)) 56 | nn.init.zeros_(self.up_proj.weight) 57 | nn.init.zeros_(self.down_proj.bias) 58 | nn.init.zeros_(self.up_proj.bias) 59 | 60 | def forward(self, x, add_residual=True, residual=None): 61 | residual = x if residual is None else residual 62 | if self.adapter_layernorm_option == 'in': 63 | x = self.adapter_layer_norm_before(x) 64 | 65 | down = self.down_proj(x) 66 | down = self.non_linear_func(down) 67 | down = nn.functional.dropout(down, p=self.dropout, training=self.training) 68 | up = self.up_proj(down) 69 | 70 | up = up * self.scale 71 | 72 | if self.adapter_layernorm_option == 'out': 73 | up = self.adapter_layer_norm_before(up) 74 | 75 | if add_residual: 76 | output = up + residual 77 | else: 78 | output = up 79 | 80 | return output 81 | 82 | 83 | 84 | 85 | class Disentangler(nn.Module): 86 | def __init__(self, emb_dim): 87 | super(Disentangler, self).__init__() 88 | self.fc1 = nn.Linear(emb_dim, emb_dim) 89 | self.bn1_fc = nn.BatchNorm1d(emb_dim) 90 | 91 | def forward(self, x): 92 | x = F.relu(self.bn1_fc(self.fc1(x))) 93 | x = F.dropout(x, training=self.training) 94 | return x 95 | 96 | 97 | class MultiHeadAttention(nn.Module): 98 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 99 | super().__init__() 100 | self.num_heads = num_heads 101 | head_dim = dim // num_heads 102 | 103 | self.scale = qk_scale or head_dim ** -0.5 104 | 105 | self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) 106 | self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) 107 | self.v_proj = nn.Linear(dim, dim, bias=qkv_bias) 108 | 109 | self.attn_drop = nn.Dropout(attn_drop) 110 | self.proj = nn.Linear(dim, dim) 111 | self.proj_drop = nn.Dropout(proj_drop) 112 | 113 | def forward(self, q, k, v): 114 | B, N, C = q.shape 115 | B, M, C = k.shape 116 | q = self.q_proj(q).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 117 | k = self.k_proj(k).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 118 | v = self.v_proj(v).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 119 | 120 | attn = (q @ k.transpose(-2, -1)) * self.scale 121 | attn = attn.softmax(dim=-1) 122 | attn = self.attn_drop(attn) 123 | 124 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 125 | x = self.proj(x) 126 | x = self.proj_drop(x) 127 | return x 128 | 129 | 130 | class CrossAttentionLayer(nn.Module): # 交叉注意力+残差 131 | def __init__(self, d_model, nhead, dropout=0.1, ): 132 | super().__init__() 133 | self.cross_attn = MultiHeadAttention(d_model, nhead, proj_drop=dropout) 134 | self.norm = nn.LayerNorm(d_model) 135 | 136 | self.dropout = nn.Dropout(dropout) 137 | 138 | self.mlp = nn.Sequential( 139 | nn.Linear(d_model, d_model * 4), 140 | QuickGELU(), 141 | nn.Dropout(dropout), 142 | nn.Linear(d_model * 4, d_model) 143 | ) 144 | 145 | def forward(self, q, kv): 146 | q = q + self.cross_attn(q, kv, kv) 147 | q = q + self.dropout(self.mlp(self.norm(q))) 148 | return q 149 | 150 | 151 | 152 | 153 | class FeatureVariance(nn.Module): 154 | def __init__(self, feature_dim): 155 | super().__init__() 156 | self.feature_dim = feature_dim 157 | 158 | def forward(self, features, text_features): 159 | # features: [batch_size, seq_len, feature_dim] 160 | # text_features: [num_text, feature_dim] 161 | # Reshape features to [batch_size * seq_len, feature_dim] for batch matrix multiplication 162 | #features_norm = F.normalize(features, p=2, dim=-1) 163 | #text_features_norm = F.normalize(text_features, p=2, dim=-1) 164 | 165 | features_flat = features.view(-1, self.feature_dim) 166 | # Calculate dot product similarity 167 | similarity_scores = torch.matmul(features_flat, text_features.t()) # [batch_size * seq_len, num_text] 168 | # Calculate variance of similarity scores along text feature dimension 169 | variance = torch.var(similarity_scores, dim=1, unbiased=False) # [batch_size * seq_len] 170 | # Reshape back to [batch_size, seq_len] 171 | return variance.view(features.size(0), features.size(1)) 172 | 173 | class LocalFeatureAdd(nn.Module): 174 | def __init__(self, feature_dim): 175 | super().__init__() 176 | self.variance_calc = FeatureVariance(feature_dim) 177 | 178 | def forward(self, low_features, high_features, text_features): 179 | low_variances = self.variance_calc(low_features, text_features) 180 | global_low_variance=low_variances[:,0].unsqueeze(1) 181 | low_variances_ratio=low_variances/global_low_variance 182 | high_variances = self.variance_calc(high_features, text_features) 183 | global_high_variance=high_variances[:,0].unsqueeze(1) 184 | high_variances_ratio=high_variances/global_high_variance 185 | ratio=low_variances_ratio/high_variances_ratio 186 | enhance_ratio=torch.clamp(ratio, min=1) 187 | #print(enhance_ratio.shape) 188 | enhanced_high_feature=high_features*enhance_ratio.unsqueeze(2) 189 | return enhanced_high_feature 190 | 191 | 192 | 193 | 194 | 195 | 196 | class Stage1LocalAlignment(nn.Module): 197 | def __init__(self, d_model, num_heads=12, dropout=0.1, num_cmt_layers_first=1): 198 | super().__init__() 199 | self.norm = nn.LayerNorm(d_model) 200 | #self.spatial_attn = SpatialAttention(d_model) 201 | self.cross_attn_layers = nn.ModuleList([ 202 | CrossAttentionLayer(d_model, num_heads, dropout) for _ in range(num_cmt_layers_first) 203 | ]) 204 | self.dropout = nn.Dropout(dropout) 205 | #self.se_attn_layers = SelfAttentionLayer(d_model, num_heads, dropout=dropout) 206 | 207 | def forward(self, q, kv_low_level): 208 | kv_low_level=self.norm(kv_low_level) 209 | # kv_low_level=self.se_attn_layers(kv_low_level) 210 | for cross_attn in self.cross_attn_layers: 211 | q = cross_attn(q, kv_low_level) 212 | return q 213 | 214 | 215 | 216 | 217 | 218 | class Stage2GlobalFusion(nn.Module): 219 | def __init__(self, d_model, num_heads=12, dropout=0.1, num_cmt_layers_second=3): 220 | super().__init__() 221 | self.cross_attn_layers = nn.ModuleList([ 222 | CrossAttentionLayer(d_model, num_heads, dropout) for _ in range(num_cmt_layers_second) 223 | ]) 224 | self.norm = nn.LayerNorm(d_model) 225 | self.dropout = nn.Dropout(dropout) 226 | 227 | def forward(self, q, kv_high_level): 228 | kv_high_level = self.norm(kv_high_level) 229 | for cross_attn in self.cross_attn_layers: 230 | q = cross_attn(q, kv_high_level) 231 | return q 232 | 233 | 234 | 235 | 236 | class GatedFusion(nn.Module): 237 | def __init__(self, d_model, dropout=0.1): 238 | super().__init__() 239 | self.attention_gate = nn.Sequential( 240 | nn.Linear(d_model * 2, d_model), 241 | nn.LayerNorm(d_model), 242 | nn.ReLU(), 243 | nn.Dropout(dropout) # Changing activation to ReLU for experimentation 244 | ) 245 | 246 | self.norm = nn.LayerNorm(d_model) 247 | # self.dropout = nn.Dropout(dropout) 248 | 249 | def forward(self, feature1, feature2): 250 | combined_features = torch.cat([feature1, feature2], dim=-1) 251 | gate_feature = self.attention_gate(combined_features) 252 | #gated_features = gate_values * feature1 + (1 - gate_values) * feature2 253 | # gated_features = self.dropout(gated_features) 254 | # output = self.norm(gated_features) 255 | return gate_feature 256 | 257 | 258 | 259 | ''' 260 | class FeatureAggregator(nn.Module): 261 | def __init__(self, d_model=768, text_dim=768, hidden_dim=256, num_heads=8, dropout=0.1): 262 | super(FeatureAggregator, self).__init__() 263 | # 确保 hidden_dim 是 num_heads 的整数倍 264 | if hidden_dim % num_heads != 0: 265 | raise ValueError("hidden_dim must be divisible by num_heads") 266 | 267 | self.query_projection = nn.Linear(d_model, hidden_dim) 268 | self.text_projection = nn.Linear(text_dim, hidden_dim) 269 | self.attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, dropout=dropout) 270 | self.enhance_projection = nn.Linear(hidden_dim, d_model) 271 | self.norm = nn.LayerNorm(d_model) 272 | 273 | def forward(self, high_features, low_features, text_features): 274 | # 计算低层和高层特征之间的差异 275 | diff_features = F.relu(low_features - high_features) # 只保留正差异 276 | 277 | # 投影低层差异特征和文本特征到共同空间 278 | projected_diff = self.query_projection(diff_features) # [batch_size, seq_len, hidden_dim] 279 | projected_text = self.text_projection(text_features) # [num_texts, hidden_dim] 280 | 281 | # 注意力机制输入需要是 [seq_len, batch_size, hidden_dim] 282 | projected_diff = projected_diff.transpose(0, 1) # [seq_len, batch_size, hidden_dim] 283 | projected_text = projected_text.unsqueeze(1) # [num_texts, 1, hidden_dim] 284 | key_value = projected_text.expand(-1, projected_diff.size(0), -1) # [num_texts, seq_len, hidden_dim] 285 | 286 | attn_output, _ = self.attention(projected_diff, key_value, key_value) 287 | attn_output = attn_output.transpose(0, 1) # 转换回 [batch_size, seq_len, hidden_dim] 288 | 289 | # 使用注意力输出加强原始低层特征 290 | enhanced_diff = self.enhance_projection(attn_output) 291 | enhanced_features = low_features + enhanced_diff 292 | enhanced_features = self.norm(enhanced_features) 293 | 294 | return enhanced_features 295 | 296 | ''' 297 | 298 | 299 | 300 | class MSCI(nn.Module): 301 | 302 | def __init__(self, config, attributes, classes, offset): 303 | super().__init__() 304 | self.clip = load_clip(name=config.clip_arch, context_length=config.context_length) 305 | self.tokenizer = SimpleTokenizer() 306 | self.config = config 307 | self.attributes = attributes 308 | self.classes = classes 309 | self.attr_dropout = nn.Dropout(config.attr_dropout) 310 | self.cross_attn_dropout = config.cross_attn_dropout if hasattr(config, 'cross_attn_dropout') else 0.1 311 | self.prim_loss_weight = config.prim_loss_weight if hasattr(config, 'prim_loss_weight') else 1 312 | 313 | self.token_ids, self.soft_att_obj, comp_ctx_vectors, attr_ctx_vectors, obj_ctx_vectors = self.construct_soft_prompt() 314 | self.offset = offset 315 | self.num_layers = self.clip.visual.transformer.layers # 总层数 316 | self.enable_pos_emb = True 317 | self.selected_low_layers = config.selected_low_layers # 默认为3 318 | self.selected_high_layers = config.selected_high_layers 319 | dtype = self.clip.dtype 320 | if dtype is None: 321 | self.dtype = torch.float16 322 | else: 323 | self.dtype = dtype 324 | self.text_encoder = CustomTextEncoder(self.clip, self.tokenizer, self.dtype) 325 | 326 | # freeze CLIP's parameters 327 | for p in self.parameters(): 328 | p.requires_grad = False 329 | 330 | # only consider ViT as visual encoder 331 | assert 'ViT' in config.clip_model 332 | 333 | self.additional_visual_params = self.add_visual_tunable_params() 334 | 335 | output_dim = self.clip.visual.output_dim 336 | self.local_feature_add = LocalFeatureAdd(output_dim) 337 | self.soft_att_obj = nn.Parameter(self.soft_att_obj) 338 | self.comp_ctx_vectors = nn.Parameter(comp_ctx_vectors).cuda() 339 | self.attr_ctx_vectors = nn.Parameter(attr_ctx_vectors).cuda() 340 | self.obj_ctx_vectors = nn.Parameter(obj_ctx_vectors).cuda() 341 | 342 | self.attr_disentangler = Disentangler(output_dim) 343 | self.obj_disentangler = Disentangler(output_dim) 344 | 345 | self.cmt = nn.ModuleList([CrossAttentionLayer(output_dim, output_dim // 64, self.cross_attn_dropout) for _ in 346 | range(config.cmt_layers)]) 347 | self.lamda = nn.Parameter(torch.ones(output_dim) * config.init_lamda) 348 | self.lamda_2 = nn.Parameter(torch.ones(output_dim) * config.init_lamda) 349 | self.patch_norm = nn.LayerNorm(output_dim) 350 | 351 | # 定义拼接后的线性变换层 352 | self.concat_projection_low = nn.Sequential( 353 | nn.Linear(self.selected_low_layers * output_dim, output_dim), 354 | nn.LayerNorm(output_dim), 355 | nn.ReLU(), 356 | nn.Dropout(0.1) 357 | ) 358 | 359 | self.concat_projection_high = nn.Sequential( 360 | nn.Linear(self.selected_high_layers * output_dim, output_dim), 361 | nn.LayerNorm(output_dim), 362 | nn.ReLU(), 363 | nn.Dropout(0.1) 364 | 365 | ) 366 | self.stage1_local_alignment = Stage1LocalAlignment(d_model=self.clip.visual.output_dim, 367 | num_heads=config.stage_1_num_heads, 368 | dropout=self.config.stage_1_dropout, 369 | num_cmt_layers_first=config.stage_1_num_cmt_layers) 370 | self.stage2_global_fusion = Stage2GlobalFusion(d_model=self.clip.visual.output_dim, 371 | num_heads=config.stage_2_num_heads, 372 | dropout=config.stage_2_dropout, 373 | num_cmt_layers_second=config.stage_2_num_cmt_layers) 374 | # self.stage3_fine_grained_enhancement = Stage3FineGrainedEnhancement(d_model=self.clip.visual.output_dim, 375 | # num_heads=8, dropout=0.2) 376 | 377 | self.fusion_module = GatedFusion(output_dim, dropout=self.config.fusion_dropout) 378 | 379 | # self.fusion_module_2=GatedFusion(output_dim,dropout=self.config.fusion_dropout) 380 | #self.lamda_2 = nn.Parameter(torch.ones(output_dim) * config.init_lamda) 381 | #self.feature_aggregator = FeatureAggregator(output_dim,output_dim) 382 | 383 | 384 | 385 | 386 | def multi_stage_cross_attention(self, q, low_level_features, high_level_features): 387 | # 第一阶段:局部特征对齐 388 | 389 | 390 | 391 | q_1 = self.stage1_local_alignment(q, low_level_features) 392 | q_2 = self.stage2_global_fusion(q_1, high_level_features) 393 | 394 | #q_fused = self.fusion_module(q, q_1) 395 | # print('q_1',q) 396 | # 第二阶段:全局特征融合 397 | 398 | 399 | # print('q_2', q) 400 | return q_1, q_2 401 | # return q_1,q_2 402 | 403 | def aggregate_features_low(self, visual_features): 404 | """ 405 | 使用拼接和线性变换来聚合不定数量的视觉特征层。 406 | :param visual_features: List[Tensor],每个 tensor 大小为 [batch_size, seq_len, feature_dim] 407 | :return: Tensor, 大小为 [batch_size, seq_len, feature_dim] 408 | """ 409 | # 检查输入是否为非空列表 410 | assert isinstance(visual_features, list) and len( 411 | visual_features) > 0, "Input should be a non-empty list of tensors." 412 | 413 | # 获取输入特征的基本维度信息 414 | batch_size, seq_len, feature_dim = visual_features[0].shape 415 | num_selected_layers = len(visual_features) 416 | 417 | # 将所有选择的层特征拼接在一起 418 | concat_features = torch.cat(visual_features, dim=-1) # [batch_size, seq_len, num_selected_layers * feature_dim] 419 | 420 | # 通过线性层映射回 feature_dim 维度 421 | aggregated_features = self.concat_projection_low(concat_features) # [batch_size, seq_len, feature_dim] 422 | 423 | return aggregated_features 424 | 425 | def aggregate_features_high(self, visual_features): 426 | """ 427 | 使用拼接和线性变换来聚合不定数量的视觉特征层。 428 | :param visual_features: List[Tensor],每个 tensor 大小为 [batch_size, seq_len, feature_dim] 429 | :return: Tensor, 大小为 [batch_size, seq_len, feature_dim] 430 | """ 431 | # 检查输入是否为非空列表 432 | assert isinstance(visual_features, list) and len( 433 | visual_features) > 0, "Input should be a non-empty list of tensors." 434 | 435 | # 获取输入特征的基本维度信息 436 | batch_size, seq_len, feature_dim = visual_features[0].shape 437 | num_selected_layers = len(visual_features) 438 | 439 | # 将所有选择的层特征拼接在一起 440 | concat_features = torch.cat(visual_features, dim=-1) # [batch_size, seq_len, num_selected_layers * feature_dim] 441 | 442 | # 通过线性层映射回 feature_dim 维度 443 | aggregated_features = self.concat_projection_high(concat_features) # [batch_size, seq_len, feature_dim] 444 | 445 | return aggregated_features 446 | 447 | def add_visual_tunable_params(self): 448 | adapter_num = 2 * self.clip.visual.transformer.layers 449 | params = nn.ModuleList([Adapter(d_model=self.clip.visual.transformer.width, 450 | bottleneck=self.config.adapter_dim, 451 | dropout=self.config.adapter_dropout 452 | ) for _ in range(adapter_num)]) 453 | 454 | 455 | return params 456 | 457 | ''' 458 | x_first torch.Size([16, 3, 224, 224]) 459 | x: torch.Size([16, 3, 224, 224]) 460 | x: torch.Size([16, 1024, 256]) 461 | x: torch.Size([16, 256, 1024]) 462 | x: torch.Size([16, 257, 1024]) 463 | x: torch.Size([16, 257, 1024]) 464 | x: torch.Size([257, 16, 1024]) 465 | img_feature torch.Size([16, 257, 1024]) 466 | 467 | 468 | ''' 469 | 470 | def encode_image(self, x: torch.Tensor): 471 | # print('x_first',x.shape) 472 | 473 | return self.encode_image_with_adapter(x) 474 | 475 | def encode_image_with_adapter(self, x: torch.Tensor): 476 | x = self.clip.visual.conv1(x) 477 | x = x.reshape(x.shape[0], x.shape[1], -1) 478 | x = x.permute(0, 2, 1) 479 | x = torch.cat([self.clip.visual.class_embedding.to(x.dtype) + 480 | torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) 481 | x = x + self.clip.visual.positional_embedding.to(x.dtype) 482 | x = self.clip.visual.ln_pre(x) 483 | x = x.permute(1, 0, 2) 484 | 485 | low_level_features = [] 486 | mid_level_features = [] 487 | high_level_features = [] 488 | 489 | for i_block in range(self.clip.visual.transformer.layers): 490 | # MHA 491 | adapt_x = self.additional_visual_params[i_block](x, add_residual=False) 492 | residual = x 493 | x = self.clip.visual.transformer.resblocks[i_block].attention( 494 | self.clip.visual.transformer.resblocks[i_block].ln_1(x) 495 | ) 496 | x = x + adapt_x + residual 497 | 498 | # FFN 499 | i_adapter = i_block + self.clip.visual.transformer.layers 500 | adapt_x = self.additional_visual_params[i_adapter](x, add_residual=False) 501 | residual = x 502 | x = self.clip.visual.transformer.resblocks[i_block].mlp( 503 | self.clip.visual.transformer.resblocks[i_block].ln_2(x) 504 | ) 505 | x = x + adapt_x + residual 506 | 507 | # 保存每一层的特征 508 | x_feature = x.permute(1, 0, 2) # LND -> NLD 509 | x_feature = self.clip.visual.ln_post(x_feature) 510 | if self.clip.visual.proj is not None: 511 | x_feature = x_feature @ self.clip.visual.proj # [batch_size, seq_len, feature_dim] 512 | 513 | # 添加位置编码 514 | #x_feature = self.add_positional_encoding(x_feature, x_feature.size(1), x_feature.size(-1)) 515 | 516 | # 提取特定层的特征 517 | if i_block < self.config.selected_low_layers: # Low-level features (前8个Transformer块) 518 | low_level_features.append(x_feature) 519 | # high_level_features.append(x_feature) 520 | 521 | elif self.config.selected_low_layers <= i_block < 24 - self.config.selected_high_layers: # Mid-level features (中间8个Transformer块) 522 | mid_level_features.append(x_feature) 523 | else: # High-level features (最后8个Transformer块) 524 | high_level_features.append(x_feature) 525 | 526 | img_feature = high_level_features[-1] 527 | # print(high_level_features.shape) 528 | # stacked_low_level_features = torch.stack(low_level_features, dim=0) 529 | # stacked_high_level_features = torch.stack(high_level_features, dim=0) 530 | #max_pooled_low_level_features = torch.max(stacked_low_level_features, dim=0)[0] # 通过 [0] 来提取第一个返回值 values 531 | #mean_pooled_high_level_features = torch.mean(stacked_high_level_features, dim=0) 532 | # low_level_features=self.aggregate_features_low(low_level_features) 533 | #high_level_features=self.aggregate_features_high(high_level_features) 534 | 535 | ''' 536 | stacked_low_level_features = torch.stack(low_level_features, dim=0) 537 | stacked_high_level_features = torch.stack(high_level_features, dim=0) 538 | 539 | # 应用最大池化于底层特征 540 | max_pooled_low_level_features = torch.max(stacked_low_level_features, dim=0)[0]#通过 [0] 来提取第一个返回值 values 541 | mean_pooled_high_level_features = torch.mean(stacked_high_level_features, dim=0) 542 | 543 | #print(len(low_level_features)) 544 | #print(len(high_level_features)) 545 | #print(len(low_level_features)) 546 | #print(len(high_level_features)) 547 | ''' 548 | 549 | #low_level_features = self.aggregate_features_low(low_level_features) 550 | 551 | 552 | if self.config.selected_high_layers==1: 553 | high_level_features = high_level_features[-1] 554 | 555 | else: 556 | high_level_features=self.aggregate_features_high(high_level_features) 557 | 558 | if self.config.selected_low_layers == 1: 559 | low_level_features = low_level_features[0] 560 | 561 | else: 562 | low_level_features = self.aggregate_features_low(low_level_features) 563 | 564 | return img_feature[:, 0, :],low_level_features, high_level_features 565 | 566 | def encode_text(self, token_ids, token_tensors=None, enable_pos_emb=False): 567 | return self.text_encoder(token_ids, token_tensors, enable_pos_emb) 568 | 569 | def construct_soft_prompt(self): 570 | # token_ids indicates the position of [EOS] 571 | token_ids = self.tokenizer(self.config.prompt_template, 572 | context_length=self.config.context_length).cuda() 573 | 574 | tokenized = torch.cat( 575 | [ 576 | self.tokenizer(tok, context_length=self.config.context_length) 577 | for tok in self.attributes + self.classes 578 | ] 579 | ) 580 | orig_token_embedding = self.clip.token_embedding(tokenized.cuda()) 581 | soft_att_obj = torch.zeros( 582 | (len(self.attributes) + len(self.classes), orig_token_embedding.size(-1)), 583 | ) 584 | for idx, rep in enumerate(orig_token_embedding): 585 | eos_idx = tokenized[idx].argmax() 586 | soft_att_obj[idx, :] = torch.mean(rep[1:eos_idx, :], axis=0) 587 | 588 | ctx_init = self.config.ctx_init 589 | assert isinstance(ctx_init, list) 590 | n_ctx = [len(ctx.split()) for ctx in ctx_init] 591 | prompt = self.tokenizer(ctx_init, 592 | context_length=self.config.context_length).cuda() 593 | with torch.no_grad(): 594 | embedding = self.clip.token_embedding(prompt) 595 | 596 | comp_ctx_vectors = embedding[0, 1: 1 + n_ctx[0], :].to(self.clip.dtype) # 组合上下文前缀放入 597 | attr_ctx_vectors = embedding[1, 1: 1 + n_ctx[1], :].to(self.clip.dtype) # 属性上下文前缀放入 598 | obj_ctx_vectors = embedding[2, 1: 1 + n_ctx[2], :].to(self.clip.dtype) # 物体上下文前缀放入 599 | 600 | return token_ids, soft_att_obj, comp_ctx_vectors, attr_ctx_vectors, obj_ctx_vectors 601 | 602 | def construct_token_tensors(self, pair_idx): 603 | attr_idx, obj_idx = pair_idx[:, 0], pair_idx[:, 1] 604 | token_tensor, num_elements = list(), [len(pair_idx), self.offset, len(self.classes)] 605 | for i_element in range(self.token_ids.shape[0]): 606 | class_token_ids = self.token_ids[i_element].repeat(num_elements[i_element], 1) 607 | token_tensor.append(self.clip.token_embedding( 608 | class_token_ids.cuda() 609 | ).type(self.clip.dtype)) 610 | 611 | eos_idx = [int(self.token_ids[i_element].argmax()) for i_element in range(self.token_ids.shape[0])] 612 | soft_att_obj = self.attr_dropout(self.soft_att_obj) 613 | 614 | # 组合替换 615 | token_tensor[0][:, eos_idx[0] - 2, :] = soft_att_obj[ 616 | attr_idx 617 | ].type(self.clip.dtype) 618 | token_tensor[0][:, eos_idx[0] - 1, :] = soft_att_obj[ 619 | obj_idx + self.offset 620 | ].type(self.clip.dtype) 621 | token_tensor[0][ 622 | :, 1: len(self.comp_ctx_vectors) + 1, : 623 | ] = self.comp_ctx_vectors.type(self.clip.dtype) 624 | 625 | # 属性替换 626 | 627 | token_tensor[1][:, eos_idx[1] - 1, :] = soft_att_obj[ 628 | :self.offset 629 | ].type(self.clip.dtype) 630 | token_tensor[1][ 631 | :, 1: len(self.attr_ctx_vectors) + 1, : 632 | ] = self.attr_ctx_vectors.type(self.clip.dtype) 633 | 634 | # 物体替换 635 | token_tensor[2][:, eos_idx[2] - 1, :] = soft_att_obj[ 636 | self.offset: 637 | ].type(self.clip.dtype) 638 | token_tensor[2][ 639 | :, 1: len(self.obj_ctx_vectors) + 1, : 640 | ] = self.obj_ctx_vectors.type(self.clip.dtype) 641 | 642 | return token_tensor 643 | 644 | def loss_calu(self, predict, target): 645 | loss_fn = CrossEntropyLoss() 646 | _, batch_attr, batch_obj, batch_target = target 647 | 648 | # 检查 predict 的类型 649 | if isinstance(predict, tuple) and len(predict) == 2: 650 | # 训练阶段,predict 包含 logits 和特征 651 | logits, (high_features, low_features,text_features) = predict 652 | comp_logits, attr_logits, obj_logits = logits 653 | else: 654 | # 评估/测试阶段,predict 只包含 logits 655 | logits = predict 656 | comp_logits, attr_logits, obj_logits = logits 657 | high_features = None 658 | low_features = None 659 | text_features = None 660 | 661 | batch_attr = batch_attr.cuda() 662 | batch_obj = batch_obj.cuda() 663 | batch_target = batch_target.cuda() 664 | 665 | # 计算分类损失 666 | loss_comp = loss_fn(comp_logits, batch_target) 667 | loss_attr = loss_fn(attr_logits, batch_attr) 668 | loss_obj = loss_fn(obj_logits, batch_obj) 669 | loss = loss_comp * self.config.pair_loss_weight + \ 670 | loss_attr * self.config.attr_loss_weight + \ 671 | loss_obj * self.config.obj_loss_weight 672 | 673 | 674 | ''' 675 | # 仅在训练阶段计算协方差损失 676 | if high_features is not None and low_features is not None and text_features is not None: 677 | loss_cov= self.local_feature_add(high_features, low_features,text_features) 678 | #loss += loss_cov * self.config.covariance_loss_weight 679 | #print(loss_cov) 680 | 681 | ''' 682 | 683 | 684 | #print("损失不加") 685 | 686 | return loss 687 | 688 | def logit_infer(self, predict, pairs): 689 | comp_logits, attr_logits, obj_logits = predict 690 | attr_pred = F.softmax(attr_logits, dim=-1) 691 | obj_pred = F.softmax(obj_logits, dim=-1) 692 | for i_comp in range(comp_logits.shape[-1]): 693 | weighted_attr_pred = 1 if self.config.attr_inference_weight == 0 else attr_pred[:, pairs[i_comp][ 694 | 0]] * self.config.attr_inference_weight 695 | weighted_obj_pred = 1 if self.config.obj_inference_weight == 0 else obj_pred[:, pairs[i_comp][ 696 | 1]] * self.config.obj_inference_weight 697 | comp_logits[:, i_comp] = comp_logits[:, 698 | i_comp] * self.config.pair_inference_weight + weighted_attr_pred * weighted_obj_pred 699 | return comp_logits 700 | 701 | def encode_text_for_open(self, idx): 702 | token_tensors = self.construct_token_tensors(idx) 703 | text_features = [] 704 | for i_element in range(self.token_ids.shape[0]): 705 | _text_features, _ = self.encode_text( 706 | self.token_ids[i_element], 707 | token_tensors[i_element], 708 | enable_pos_emb=self.enable_pos_emb, 709 | ) 710 | 711 | idx_text_features = _text_features / _text_features.norm( 712 | dim=-1, keepdim=True 713 | ) 714 | text_features.append(idx_text_features) 715 | return text_features 716 | 717 | def forward_for_open(self, batch, text_feats): 718 | batch_img = batch[0].cuda() 719 | b = batch_img.shape[0] 720 | # l, _ = idx.shape 721 | batch_img, low_level_features, high_level_features = self.encode_image(batch_img.type(self.clip.dtype)) 722 | att_feature = self.attr_disentangler(batch_img) 723 | obj_feature = self.obj_disentangler(batch_img) 724 | 725 | batch_img_features = [batch_img, att_feature, obj_feature] 726 | normalized_img_features = [feats / feats.norm(dim=-1, keepdim=True) for feats in batch_img_features] 727 | 728 | logits = list() 729 | for i_element in range(self.token_ids.shape[0]): 730 | idx_text_features = text_feats[i_element] 731 | idx_text_features = idx_text_features / idx_text_features.norm( 732 | dim=-1, keepdim=True 733 | ) 734 | # CMT 735 | cmt_text_features = idx_text_features.unsqueeze(0).expand(b, -1, -1) 736 | 737 | # 多阶段交叉注意力 738 | cmt_1, cmt_2 = self.multi_stage_cross_attention( 739 | cmt_text_features, 740 | low_level_features, 741 | high_level_features 742 | 743 | ) 744 | # cmt1是全局特征交互形成的 745 | # cmt2是低层局部特征补齐形成的 746 | 747 | # print('weight:',self.stage1_local_alignment.spatial_attn.conv.weight) 748 | #cmt=self.fusion_module(cmt_1, cmt_2) 749 | cmt_text_features = idx_text_features + self.lamda * cmt_1+self.lamda_2 * cmt_2 750 | cmt_text_features = cmt_text_features / cmt_text_features.norm(dim=-1, keepdim=True) 751 | 752 | logits.append( 753 | torch.einsum( 754 | "bd, bkd->bk", 755 | normalized_img_features[i_element], 756 | cmt_text_features * self.clip.logit_scale.exp() 757 | ) 758 | ) 759 | return logits 760 | 761 | def forward(self, batch, idx, return_features=False): 762 | batch_img = batch[0].cuda() 763 | b = batch_img.shape[0] 764 | l, _ = idx.shape 765 | 766 | # 编码图像并提取不同层次的特征 767 | batch_img, low_level_features, high_level_features = self.encode_image(batch_img.type(self.clip.dtype)) 768 | 769 | att_feature = self.attr_disentangler(batch_img) 770 | obj_feature = self.obj_disentangler(batch_img) 771 | 772 | batch_img_features = [batch_img, att_feature, obj_feature] 773 | normalized_img_features = [feats / feats.norm(dim=-1, keepdim=True) for feats in batch_img_features] 774 | 775 | token_tensors = self.construct_token_tensors(idx) 776 | 777 | logits = [] 778 | for i_element in range(self.token_ids.shape[0]): 779 | _text_features, _ = self.encode_text( 780 | self.token_ids[i_element], 781 | token_tensors[i_element], 782 | enable_pos_emb=self.enable_pos_emb, 783 | ) 784 | 785 | idx_text_features = _text_features / _text_features.norm( 786 | dim=-1, keepdim=True 787 | ) 788 | 789 | cmt_text_features = idx_text_features.unsqueeze(0).expand(b, -1, -1) 790 | 791 | # print(low_level_features.shape) 792 | # print(high_level_features.shape) 793 | # print(fine_grained_features.shape) 794 | 795 | # 多阶段交叉注意力 796 | #print(idx_text_features.shape) 797 | 798 | #low_level_features=self.local_feature_add(high_level_features,low_level_features,idx_text_features) 799 | 800 | 801 | 802 | #enhanced_high_features=self.local_feature_add(low_level_features,high_level_features,idx_text_features) 803 | cmt_1, cmt_2 = self.multi_stage_cross_attention( 804 | cmt_text_features, 805 | low_level_features, 806 | high_level_features 807 | 808 | ) 809 | 810 | #cmt_text_features = self.fusion_module(cmt_1, cmt_2) 811 | # Cmt_text_features = cmt 812 | # print('weight:',self.stage1_local_alignment.spatial_attn.conv.weight) 813 | 814 | # fused_cmt = self.fusion_module(cmt_1, cmt_2) # 确保维度匹配 815 | 816 | #cmt_text_features = idx_text_features + self.lamda * cmt_2.squeeze(1) 817 | cmt_text_features = idx_text_features + self.lamda * cmt_1.squeeze(1)+self.lamda_2 * cmt_2.squeeze(1) 818 | # cmt_text_features= cmt_text_features.squeeze(1) 819 | cmt_text_features = cmt_text_features / cmt_text_features.norm(dim=-1, keepdim=True) 820 | 821 | logits.append( 822 | torch.einsum( 823 | "bd, bkd->bk", 824 | normalized_img_features[i_element], 825 | cmt_text_features * self.clip.logit_scale.exp() 826 | ) 827 | ) 828 | 829 | if return_features: 830 | return logits, (high_level_features, low_level_features, idx_text_features) 831 | else: 832 | return logits 833 | --------------------------------------------------------------------------------