├── AutoTuning.py ├── Create_MultiModal_Dataset.py ├── LICENSE ├── MBF-SUES ├── Create_MultiModal_Dataset.py ├── Multi_HBP.py ├── Preprocessing.py ├── SUES_bert.py ├── helpers.py ├── model_.py ├── resnetv2.py ├── settings.yaml ├── test_and_evaluate.py ├── train.py ├── utils.py ├── vision_transformer.py └── vision_transformer_hybrid.py ├── Multi_HBP.py ├── Preprocessing.py ├── README.md ├── Shifited_test_and_evaluate.py ├── U1652_bert.py ├── U1652_test_and_evaluate.py ├── Visualization.py ├── activation.py ├── draw_cam_ViT.py ├── helpers.py ├── multi_test_and_evaluate.py ├── resnetv2.py ├── settings.yaml ├── train.py ├── utils.py ├── vision_transformer.py └── vision_transformer_hybrid.py /AutoTuning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import yaml 4 | from utils import parameter 5 | from train import train 6 | from U1652_test_and_evaluate import eval_and_test 7 | 8 | 9 | def Auto_tune(drop_rate, learning_rate): 10 | # for model in model_list: 11 | # parameter("model", model) 12 | for dr in drop_rate: 13 | parameter("drop_rate", dr) 14 | for lr in learning_rate: 15 | parameter("lr", lr) 16 | # for wd in weight_decay: 17 | # parameter("weight_decay", wd) 18 | with open("settings.yaml", "r", encoding="utf-8") as f: 19 | setting_dict = yaml.load(f, Loader=yaml.FullLoader) 20 | print(setting_dict) 21 | f.close() 22 | train() 23 | try: 24 | eval_and_test(384) 25 | except: 26 | print("error") 27 | continue 28 | 29 | 30 | # height_list = [150, 200, 250, 300] 31 | learning_rate = [0.008, 0.009, 0.01] 32 | drop_rate = [0.2, 0.25] 33 | 34 | # model_list = ["LPN"] 35 | Auto_tune(drop_rate, learning_rate) 36 | -------------------------------------------------------------------------------- /Create_MultiModal_Dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import torch 6 | from torchvision import datasets, transforms 7 | # from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM 8 | 9 | from PIL import Image 10 | import json 11 | 12 | 13 | class Multimodel_Dateset(torch.utils.data.Dataset): 14 | def __init__(self, data_path, transforms): 15 | self.transforms = transforms 16 | self.img_data_path = data_path 17 | 18 | if "drone" in os.path.basename(self.img_data_path): 19 | self.text_path = os.path.join(os.path.dirname(data_path), "text_drone") 20 | elif "satellite" in os.path.basename(self.img_data_path): 21 | self.text_path = os.path.join(os.path.dirname(data_path), "text_satellite") 22 | self.tensor = torch.load(os.path.join(self.text_path, "satellite.pth")) 23 | 24 | img_list = glob.glob(os.path.join(data_path, "*")) 25 | self.classes = os.listdir(data_path) 26 | self.img_names = [] 27 | for imgs in img_list: 28 | self.img_names += glob.glob(os.path.join(imgs, '*')) 29 | len_img = len(glob.glob(os.path.join(imgs, '*'))) 30 | self.labels = range(len(img_list)) 31 | img_arr = np.array(self.labels).reshape(1, -1) 32 | img_arr = np.repeat(img_arr, len_img).tolist() 33 | 34 | self.imgs = list(zip(self.img_names, img_arr)) 35 | # print(imgs[:10]) 36 | # for img_dir in img_list: 37 | # for img_file in glob.glob(os.path.join(img_dir, "*")): 38 | 39 | def __len__(self): 40 | return len(self.img_names) 41 | 42 | def __getitem__(self, item): 43 | img = self.img_names[item] 44 | # text = self.text[os.path.basename(self.img_names[item])] 45 | if "drone" in os.path.basename(self.img_data_path): 46 | name = os.path.basename(img).split('.')[0] + '.pth' 47 | text = torch.load(os.path.join(self.text_path, name)).cpu() 48 | elif "satellite" in os.path.basename(self.img_data_path): 49 | text = self.tensor.cpu() 50 | # print(text.device) 51 | label = self.labels[self.classes.index(os.path.basename(os.path.dirname(img)))] 52 | # print(img, label) 53 | img = Image.open(img).convert('RGB') 54 | img = self.transforms(img) 55 | return img, text, label 56 | 57 | class Multimodel_Dateset_flip(torch.utils.data.Dataset): 58 | def __init__(self, data_path, transforms, gap): 59 | self.transforms = transforms 60 | self.img_data_path = data_path 61 | self.gap = gap 62 | if "drone" in os.path.basename(self.img_data_path): 63 | self.text_path = os.path.join(os.path.dirname(data_path), "text_drone") 64 | elif "satellite" in os.path.basename(self.img_data_path): 65 | self.text_path = os.path.join(os.path.dirname(data_path), "text_satellite") 66 | self.tensor = torch.load(os.path.join(self.text_path, "satellite.pth")) 67 | 68 | img_list = glob.glob(os.path.join(data_path, "*")) 69 | self.classes = os.listdir(data_path) 70 | self.img_names = [] 71 | for imgs in img_list: 72 | self.img_names += glob.glob(os.path.join(imgs, '*')) 73 | len_img = len(glob.glob(os.path.join(imgs, '*'))) 74 | self.labels = range(len(img_list)) 75 | img_arr = np.array(self.labels).reshape(1, -1) 76 | img_arr = np.repeat(img_arr, len_img).tolist() 77 | 78 | self.imgs = list(zip(self.img_names, img_arr)) 79 | # print(imgs[:10]) 80 | # for img_dir in img_list: 81 | # for img_file in glob.glob(os.path.join(img_dir, "*")): 82 | 83 | def __len__(self): 84 | return len(self.img_names) 85 | 86 | def __getitem__(self, item): 87 | img = self.img_names[item] 88 | # text = self.text[os.path.basename(self.img_names[item])] 89 | if "drone" in os.path.basename(self.img_data_path): 90 | name = os.path.basename(img).split('.')[0] + '.pth' 91 | text = torch.load(os.path.join(self.text_path, name)).cpu() 92 | elif "satellite" in os.path.basename(self.img_data_path): 93 | text = self.tensor.cpu() 94 | # print(text.device) 95 | label = self.labels[self.classes.index(os.path.basename(os.path.dirname(img)))] 96 | # print(img, label) 97 | img = Image.open(img).convert('RGB') 98 | height = img.height 99 | flip = img.crop((0, 0, self.gap, height)) 100 | img = img.crop((0, 0, height - self.gap, height)) 101 | flip = flip.transpose(Image.FLIP_LEFT_RIGHT) 102 | joint = Image.new("RGB", (height, height)) 103 | joint.paste(flip, (0, 0, self.gap, height)) 104 | joint.paste(img, (self.gap, 0, height, height)) 105 | img = self.transforms(joint) 106 | # plt.figure("black") 107 | # plt.imshow(joint) 108 | # plt.show() 109 | return img, text, label 110 | 111 | 112 | if __name__ == "__main__": 113 | path = "/home/sues/media/disk2/University-Release-MultiModel/University-Release/test/gallery_satellite" 114 | print(os.path.basename(path)) 115 | 116 | transforms = transforms.Compose([ 117 | transforms.Resize((384, 384), interpolation=3), 118 | transforms.ToTensor(), 119 | ]) 120 | dataset = Multimodel_Dateset(path, transforms=transforms) 121 | loader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=False) 122 | for img, text, label in loader: 123 | print(img.shape) 124 | print(text.shape) 125 | print(label) 126 | break 127 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MBF-SUES/Create_MultiModal_Dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | import numpy as np 5 | import torch 6 | from torchvision import datasets, transforms 7 | 8 | 9 | from PIL import Image 10 | import json 11 | 12 | 13 | class Multimodel_Dateset(torch.utils.data.Dataset): 14 | def __init__(self, data_path, transform): 15 | self.transforms = transform 16 | self.img_data_path = data_path 17 | 18 | if "drone" in os.path.basename(self.img_data_path): 19 | self.text_path = os.path.join(os.path.dirname(data_path), "text_drone") 20 | self.drone_tensor = torch.load(os.path.join(self.text_path, "drone.pth")) 21 | elif "satellite" in os.path.basename(self.img_data_path): 22 | self.text_path = os.path.join(os.path.dirname(data_path), "text_satellite") 23 | self.satellite_tensor = torch.load(os.path.join(self.text_path, "satellite.pth")) 24 | 25 | img_list = glob.glob(os.path.join(data_path, "*")) 26 | # print(img_list) 27 | self.classes = os.listdir(data_path) 28 | self.img_names = [] 29 | for imgs in img_list: 30 | self.img_names += glob.glob(os.path.join(imgs, '*')) 31 | len_img = len(glob.glob(os.path.join(imgs, '*'))) 32 | self.labels = range(len(img_list)) 33 | img_arr = np.array(self.labels).reshape(1, -1) 34 | img_arr = np.repeat(img_arr, len_img).tolist() 35 | 36 | self.imgs = list(zip(self.img_names, img_arr)) 37 | # print(imgs[:10]) 38 | # for img_dir in img_list: 39 | # for img_file in glob.glob(os.path.join(img_dir, "*")): 40 | 41 | def __len__(self): 42 | return len(self.img_names) 43 | 44 | def __getitem__(self, item): 45 | img = self.img_names[item] 46 | # text = self.text[os.path.basename(self.img_names[item])] 47 | if "drone" in os.path.basename(self.img_data_path): 48 | # name = os.path.basename(img).split('.')[0] + '.pth' 49 | text = self.drone_tensor.cpu() 50 | elif "satellite" in os.path.basename(self.img_data_path): 51 | text = self.satellite_tensor.cpu() 52 | # print(text.device) 53 | label = self.labels[self.classes.index(os.path.basename(os.path.dirname(img)))] 54 | # print(img, label) 55 | img = Image.open(img).convert('RGB') 56 | img = self.transforms(img) 57 | return img, text, label 58 | 59 | 60 | if __name__ == "__main__": 61 | path = "/Users/reza/Documents/SUES-200-512x512/Training/150/drone" 62 | print(os.path.basename(path)) 63 | 64 | transforms = transforms.Compose([ 65 | transforms.ToTensor(), 66 | ]) 67 | dataset = Multimodel_Dateset(path, transform=transforms) 68 | loader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=False) 69 | for img, text, label in loader: 70 | print(img.shape) 71 | print(text.shape) 72 | print(label) 73 | break 74 | -------------------------------------------------------------------------------- /MBF-SUES/Multi_HBP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from vision_transformer_hybrid import _create_vision_transformer_hybrid 4 | from torch import nn 5 | from torch.nn import init, functional 6 | import torch.nn.functional as F 7 | 8 | # from utils import get_yaml_value 9 | from resnetv2 import ResNetV2 10 | from timm.models.layers import StdConv2dSame, StdConv2d, to_2tuple 11 | from vision_transformer import VisionTransformer, checkpoint_filter_fn, _create_vision_transformer, Block 12 | # from timm.models.vision_transformer_hybrid import _create_vision_transformer_hybrid 13 | from functools import partial 14 | from model_ import weights_init_kaiming, weights_init_classifier 15 | from einops import rearrange 16 | 17 | 18 | class GeM(nn.Module): 19 | def __init__(self, dim=2048, p=3, eps=1e-6): 20 | super(GeM, self).__init__() 21 | self.p = p 22 | self.eps = eps 23 | self.dim = dim 24 | def forward(self, x): 25 | return self.gem(x, p=self.p, eps=self.eps) 26 | 27 | def gem(self, x, p=3, eps=1e-6): 28 | x = torch.transpose(x, 1, -1) 29 | x = x.clamp(min=eps).pow(p) 30 | x = torch.transpose(x, 1, -1) 31 | x = F.avg_pool2d(x, (x.size(-2), x.size(-1))) 32 | x = x.view(x.size(0), x.size(1)) 33 | x = x.pow(1./p) 34 | return x 35 | 36 | def __repr__(self): 37 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ',' + 'dim='+str(self.dim)+')' 38 | 39 | 40 | class ClassBlock(nn.Module): 41 | 42 | def __init__(self, input_dim, class_num, drop_rate, num_bottleneck=512): 43 | super(ClassBlock, self).__init__() 44 | add_block = [] 45 | add_block += [ 46 | nn.Linear(input_dim, num_bottleneck), 47 | nn.GELU(), 48 | nn.BatchNorm1d(num_bottleneck), 49 | nn.Dropout(p=drop_rate) 50 | ] 51 | 52 | add_block = nn.Sequential(*add_block) 53 | add_block.apply(weights_init_kaiming) 54 | 55 | classifier = [] 56 | classifier += [nn.Linear(num_bottleneck, class_num)] 57 | classifier = nn.Sequential(*classifier) 58 | classifier.apply(weights_init_classifier) 59 | 60 | self.add_block = add_block 61 | self.classifier = classifier 62 | 63 | def forward(self, x): 64 | x = self.add_block(x) 65 | feature = x 66 | x = self.classifier(x) 67 | return x, feature 68 | 69 | 70 | class Hybird_ViT(nn.Module): 71 | def __init__(self, classes, drop_rate, block, share_weight=True): 72 | super(Hybird_ViT, self).__init__() 73 | self.block = block 74 | conv_layer = partial(StdConv2dSame, eps=1e-8) 75 | backbone = ResNetV2( 76 | layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=3, 77 | preact=False, stem_type="same", conv_layer=conv_layer, act_layer=nn.ReLU) 78 | model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, num_classes=0) 79 | model = _create_vision_transformer_hybrid( 80 | 'vit_base_r50_s16_384', backbone=backbone, pretrained=True, **model_kwargs) 81 | self.model_1 = model 82 | if share_weight: 83 | self.model_2 = self.model_1 84 | # else: 85 | # self.model_2 = hybrid_model(layers=(3, 4, 9), img_size=24, patch_size=1, num_classes=1000, depth=12) 86 | self.classifier_hbp = ClassBlock(2048*3, classes, drop_rate) 87 | self.classifier_multi = ClassBlock(768*2, classes, drop_rate) 88 | self.classifier = ClassBlock(768, classes, drop_rate) 89 | 90 | self.proj = nn.Conv2d(768, 1024, kernel_size=1, stride=1) 91 | self.bilinear_proj = torch.nn.Sequential(torch.nn.Conv2d(1024, 2048, kernel_size=1, bias=False), 92 | torch.nn.BatchNorm2d(2048), 93 | torch.nn.ReLU()) 94 | 95 | self.bilinear_proj_lpn = torch.nn.Sequential(torch.nn.Conv2d(1024, 2048, kernel_size=1, bias=False), 96 | torch.nn.BatchNorm2d(2048), 97 | torch.nn.ReLU()) 98 | self.Vit_block = Block(dim=768, num_heads=12, mlp_ratio=4.0, qkv_bias=True, init_values=None, 99 | drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=partial(nn.LayerNorm, eps=1e-6), 100 | act_layer=nn.GELU) 101 | 102 | self.p = nn.Parameter(torch.ones(1024)*3, requires_grad=True) 103 | self.gem = GeM(dim=1024, p=self.p) 104 | for m in self.bilinear_proj.modules(): 105 | if isinstance(m, torch.nn.Conv2d): 106 | torch.nn.init.xavier_normal_(m.weight) 107 | if m.bias is not None: 108 | torch.nn.init.constant_(m.bias, 0) 109 | elif isinstance(m, torch.nn.BatchNorm2d): 110 | torch.nn.init.constant_(m.weight, 1) 111 | torch.nn.init.constant_(m.bias, 0) 112 | elif isinstance(m, torch.nn.Linear): 113 | torch.nn.init.xavier_normal_(m.weight) 114 | torch.nn.init.constant_(m.bias, 0) 115 | 116 | LPN = 1 117 | if LPN: 118 | for i in range(self.block): 119 | # before lpn 120 | # name = 'classifier' + str(i + 1) 121 | # after lpn 122 | name = 'classifier' + str(i) 123 | setattr(self, name, ClassBlock(1024, classes, drop_rate)) 124 | # print(name) 125 | 126 | def hbp(self, conv1, conv2): 127 | N = conv1.size()[0] 128 | proj_1 = self.bilinear_proj(conv1) 129 | proj_2 = self.bilinear_proj(conv2) 130 | 131 | X = proj_1 * proj_2 132 | # print(X.shape) 133 | X = torch.sum(X.view(X.size()[0], X.size()[1], -1), dim=2) 134 | # print(X.shape) 135 | X = X.view(N, 2048) 136 | X = torch.sqrt(X + 1e-5) 137 | X = torch.nn.functional.normalize(X) 138 | return X 139 | 140 | def restore_vit_feature(self, x): 141 | x = x[:, 1:, :] 142 | x = rearrange(x, "b (h w) y -> b y h w", h=24, w=24) 143 | x = self.proj(x) 144 | return x 145 | 146 | def fusion_features(self, x, t, model): 147 | # print(len(x), len(t)) 148 | y = [] 149 | x, p_f, v_f, l_f = model(x, t) 150 | l_f = self.Vit_block(l_f) 151 | 152 | # direct softmax 153 | 154 | y0, f = self.classifier(x) 155 | 156 | # multi modal softmax 157 | # t = t.view(t.size(0), -1) 158 | # x = torch.concat([x, t], dim=1) 159 | # y1, multi_f = self.classifier_multi(x) 160 | 161 | v_f = self.restore_vit_feature(v_f) 162 | l_f = self.restore_vit_feature(l_f) 163 | 164 | # HBP softmax 3 layer feature X multiply 165 | x1 = self.hbp(p_f, v_f) 166 | x2 = self.hbp(p_f, l_f) 167 | x3 = self.hbp(v_f, l_f) 168 | x = torch.concat([x1, x2, x3], dim=1) 169 | # print(x1.shape) 170 | 171 | y2, hbp_f = self.classifier_hbp(x) 172 | 173 | result = self.get_part_pool(v_f) 174 | 175 | # result = map(self.gem, result) 176 | # lpn_f = torch.concat(result, dim=2) 177 | # print(result[0].shape) 178 | if self.training: 179 | y3, lpn_f = self.part_classifier(result) 180 | else: 181 | lpn_f = self.part_classifier(result) 182 | y3 = [None, None] 183 | # print(len(y3), lpn_f.shape) 184 | 185 | # after lpn HBP softmax 186 | # lpn = self.hbp(result[0], result[1]) 187 | # y4 = self.classifier_hbp(lpn) 188 | 189 | y.append(y0) 190 | # y.append(y1) 191 | y.append(y2) 192 | y.append(y3[0]) 193 | y.append(y3[1]) 194 | # y.append(y4) 195 | if self.training: 196 | f_all = torch.concat([f, hbp_f, lpn_f], dim=1) 197 | else: 198 | f = f.view(f.size()[0], f.size()[1], 1) 199 | hbp_f = hbp_f.view(hbp_f.size()[0], hbp_f.size()[1], 1) 200 | # print(f.shape, hbp_f.shape, lpn_f.shape) 201 | f_all = torch.concat([f, hbp_f, lpn_f], dim=2) 202 | # print("fall", f_all.shape) 203 | return y, f_all 204 | 205 | def forward(self, x1, x2, t1, t2): 206 | if x1 is None: 207 | y1 = None 208 | f1 = None 209 | t1 = None 210 | output1 = None 211 | else: 212 | y1, f1 = self.fusion_features(x1, t1, self.model_1) 213 | 214 | if x2 is None: 215 | y2 = None 216 | f2 = None 217 | t2 = None 218 | output2 = None 219 | else: 220 | y2, f2 = self.fusion_features(x2, t2, self.model_2) 221 | 222 | if self.training: 223 | return y1, y2, f1, f2 224 | # output1, output2 225 | else: 226 | return f1, f2 227 | 228 | def get_part_pool(self, x, pool='max', no_overlap=True): 229 | result = [] 230 | if pool == 'avg': 231 | pooling = torch.nn.AdaptiveAvgPool2d((1, 1)) 232 | elif pool == 'max': 233 | pooling = torch.nn.AdaptiveMaxPool2d((1, 1)) 234 | H, W = x.size(2), x.size(3) 235 | c_h, c_w = int(H / 2), int(W / 2) 236 | per_h, per_w = H / (2 * self.block), W / (2 * self.block) 237 | if per_h < 1 and per_w < 1: 238 | new_H, new_W = H + (self.block - c_h) * 2, W + (self.block - c_w) * 2 239 | x = nn.functional.interpolate(x, size=[new_H, new_W], mode='bilinear', align_corners=True) 240 | H, W = x.size(2), x.size(3) 241 | c_h, c_w = int(H / 2), int(W / 2) 242 | per_h, per_w = H / (2 * self.block), W / (2 * self.block) 243 | per_h, per_w = math.floor(per_h), math.floor(per_w) # 向下取整 244 | for i in range(self.block): 245 | i = i + 1 246 | if i < self.block: 247 | # print("x", x.shape) 248 | x_curr = x[:, :, (c_h - i * per_h):(c_h + i * per_h), (c_w - i * per_w):(c_w + i * per_w)] 249 | # print("x_curr", x_curr.shape) 250 | if no_overlap and i > 1: 251 | x_pre = x[:, :, (c_h - (i - 1) * per_h):(c_h + (i - 1) * per_h), 252 | (c_w - (i - 1) * per_w):(c_w + (i - 1) * per_w)] 253 | x_pad = functional.pad(x_pre, (per_h, per_h, per_w, per_w), "constant", 0) 254 | x_curr = x_curr - x_pad 255 | # print("x_curr", x_curr.shape) 256 | avgpool = pooling(x_curr) 257 | # print("pool", avgpool.shape) 258 | result.append(avgpool) 259 | # print(x_curr.shape) 260 | else: 261 | if no_overlap and i > 1: 262 | x_pre = x[:, :, (c_h - (i - 1) * per_h):(c_h + (i - 1) * per_h), 263 | (c_w - (i - 1) * per_w):(c_w + (i - 1) * per_w)] 264 | pad_h = c_h - (i - 1) * per_h 265 | pad_w = c_w - (i - 1) * per_w 266 | # x_pad = F.pad(x_pre,(pad_h,pad_h,pad_w,pad_w),"constant",0) 267 | if x_pre.size(2) + 2 * pad_h == H: 268 | x_pad = functional.pad(x_pre, (pad_h, pad_h, pad_w, pad_w), "constant", 0) 269 | else: 270 | ep = H - (x_pre.size(2) + 2 * pad_h) 271 | x_pad = functional.pad(x_pre, (pad_h + ep, pad_h, pad_w + ep, pad_w), "constant", 0) 272 | x = x - x_pad 273 | avgpool = pooling(x) 274 | result.append(avgpool) 275 | # print(x.shape) 276 | return torch.concat(result, dim=2) 277 | 278 | def part_classifier(self, x): 279 | part = {} 280 | predict = {} 281 | features = [] 282 | for i in range(self.block): 283 | part[i] = x[:, :, i].view(x.size(0), -1) 284 | 285 | name = 'classifier' + str(i) 286 | c = getattr(self, name) 287 | # print(c) 288 | predict[i], feature = c(part[i]) 289 | features.append(feature) 290 | 291 | # print(predict[i][0].shape) 292 | # print(predict) 293 | y = [] 294 | for i in range(self.block): 295 | y.append(predict[i]) 296 | if not self.training: 297 | return torch.stack(y, dim=2) 298 | return y, torch.concat(features, dim=1) 299 | 300 | 301 | if __name__ == '__main__': 302 | # create_model() 303 | model = Hybird_ViT(classes=120, drop_rate=0.3, block=2).cuda() 304 | # print(model) 305 | # print(model.model_1.patch_embed.backbone.stages[-1]) 306 | feature = torch.randn(8, 3, 384, 384).cuda() 307 | text = torch.rand(8, 1, 768).cuda() 308 | output = model(feature, feature, text, text) 309 | print(output[0]) 310 | # print(f1.shape) 311 | # -------------------------------------------------------------------------------- /MBF-SUES/Preprocessing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from utils import get_yaml_value 4 | from torchvision import datasets, transforms 5 | from Create_MultiModal_Dataset import Multimodel_Dateset 6 | 7 | 8 | def Create_Training_Datasets(train_data_path, batch_size, image_size): 9 | training_data_loader = {} 10 | transform_drone_list = [ 11 | transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC), 12 | transforms.RandomCrop((image_size, image_size)), 13 | transforms.RandomHorizontalFlip(), 14 | transforms.ToTensor(), 15 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 16 | ] 17 | 18 | transforms_satellite_list = [ 19 | transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC), 20 | transforms.RandomCrop((image_size, image_size)), 21 | transforms.RandomHorizontalFlip(), 22 | transforms.ToTensor(), 23 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 24 | ] 25 | image_dataset = {} 26 | image_dataset["drone"] = Multimodel_Dateset(os.path.join(train_data_path, "drone"), 27 | transform=transforms.Compose(transform_drone_list)) 28 | image_dataset["satellite"] = Multimodel_Dateset(os.path.join(train_data_path, "satellite"), 29 | transform=transforms.Compose(transforms_satellite_list)) 30 | 31 | training_data_loader["drone"] = torch.utils.data.DataLoader(image_dataset["drone"], 32 | batch_size=batch_size, 33 | shuffle=True, 34 | # num_workers=4, # 多进程 35 | pin_memory=True) # 锁页内存 36 | 37 | training_data_loader["satellite"] = torch.utils.data.DataLoader(image_dataset["satellite"], 38 | batch_size=batch_size, 39 | shuffle=True, 40 | # num_workers=4, # 多进程 41 | pin_memory=True) # 锁页内存 42 | 43 | return training_data_loader, image_dataset 44 | 45 | 46 | def Create_Testing_Datasets(test_data_path, batch_size, image_size): 47 | print(test_data_path) 48 | testing_data_loader = {} 49 | image_datasets = {} 50 | transforms_test_list = [ 51 | transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC), 52 | transforms.ToTensor(), 53 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 54 | ] 55 | 56 | image_datasets['query_drone'] = Multimodel_Dateset(os.path.join(test_data_path, "query_drone"), 57 | transform=transforms.Compose(transforms_test_list)) 58 | 59 | image_datasets['query_satellite'] = Multimodel_Dateset(os.path.join(test_data_path, "query_satellite"), 60 | transform=transforms.Compose(transforms_test_list)) 61 | 62 | image_datasets['gallery_drone'] = Multimodel_Dateset(os.path.join(test_data_path, "gallery_drone"), 63 | transform=transforms.Compose(transforms_test_list)) 64 | 65 | image_datasets['gallery_satellite'] = Multimodel_Dateset(os.path.join(test_data_path, "gallery_satellite"), 66 | transform=transforms.Compose(transforms_test_list)) 67 | 68 | testing_data_loader["query_drone"] = torch.utils.data.DataLoader(image_datasets['query_drone'], 69 | batch_size=batch_size, 70 | shuffle=False, 71 | # num_workers=4, # 多进程 72 | pin_memory=True) 73 | 74 | testing_data_loader["query_satellite"] = torch.utils.data.DataLoader(image_datasets['query_satellite'], 75 | batch_size=batch_size, 76 | shuffle=False, 77 | # num_workers=4, # 多进程 78 | pin_memory=True) # 锁页内存 79 | 80 | testing_data_loader["gallery_drone"] = torch.utils.data.DataLoader(image_datasets['gallery_drone'], 81 | batch_size=batch_size, 82 | shuffle=False, 83 | # num_workers=4, # 多进程 84 | pin_memory=True) # 锁页内存 85 | 86 | testing_data_loader["gallery_satellite"] = torch.utils.data.DataLoader(image_datasets['gallery_satellite'], 87 | batch_size=batch_size, 88 | shuffle=False, 89 | # num_workers=4, # 多进程 90 | pin_memory=True) # 锁页内存 91 | 92 | return testing_data_loader, image_datasets 93 | 94 | 95 | if __name__ == "__main__": 96 | # Cross_Dataset("../Datasets/SUES-200/Training/150", 224) 97 | dataloaders, img_dataset = Create_Training_Datasets(train_data_path="../SUES-200-512x512/Training/150", 98 | batch_size=4, 99 | image_size=384) 100 | # print(image_datasets['drone_train'].classes) 101 | # print(img_dataset['drone'].imgs) 102 | for img, text, label in dataloaders["drone"]: 103 | print(text.shape) 104 | print(img.shape, label.shape) 105 | break 106 | 107 | -------------------------------------------------------------------------------- /MBF-SUES/SUES_bert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from utils import create_dir, get_yaml_value 4 | from pytorch_pretrained_bert import BertTokenizer, BertModel 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | class Word_Embeding: 8 | def __init__(self): 9 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased').to(device) 10 | self.model = BertModel.from_pretrained('bert-base-uncased') 11 | self.model.eval() 12 | 13 | def word_embedding(self, text): 14 | 15 | marked_text = text 16 | tokenized_text = self.tokenizer.tokenize(marked_text).to(device) 17 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text) 18 | tokens_tensor = torch.tensor([indexed_tokens]).to(device) 19 | segments_ids = [1] * len(tokenized_text) 20 | segments_tensors = torch.tensor([segments_ids]).to(device) 21 | 22 | with torch.no_grad(): 23 | encoded_layers, _ = self.model(tokens_tensor, segments_tensors) 24 | 25 | sentence_embedding = torch.mean(encoded_layers[11], 1) 26 | 27 | return sentence_embedding 28 | 29 | 30 | 31 | wd = Word_Embeding() 32 | 33 | heights = [150, 200, 250, 300] 34 | angles = [45, 50, 60, 70] 35 | 36 | 37 | 38 | param = get_yaml_value("settings.yaml") 39 | train_path = os.path.join(param["dataset_path"], "Training") 40 | test_path = os.path.join(param["dataset_path"], "Testing") 41 | 42 | for i in range(len(heights)): 43 | 44 | train_height_path = os.path.join(train_path, str(heights[i])) 45 | test_height_path = os.path.join(test_path, str(heights[i])) 46 | 47 | train_drone_text_path = os.path.join(train_height_path, "text_drone") 48 | train_satellite_text_path = os.path.join(test_height_path, "text_satellite") 49 | 50 | if not os.path.exists(train_drone_text_path): 51 | os.mkdir(train_drone_text_path) 52 | if not os.path.exists(train_satellite_text_path): 53 | os.mkdir(train_satellite_text_path) 54 | 55 | test_drone_text_path = os.path.join(train_height_path, "text_drone") 56 | test_satellite_text_path = os.path.join(test_height_path, "text_satellite") 57 | 58 | if not os.path.exists(test_drone_text_path): 59 | os.mkdir(test_drone_text_path) 60 | if not os.path.exists(test_satellite_text_path): 61 | os.mkdir(test_satellite_text_path) 62 | 63 | drone = "The altitude of the drone is %d meters, the angle of camera is %d degree" % (heights[i], angles[i]) 64 | drone_tensor = wd.word_embedding(drone) 65 | torch.save(drone_tensor, os.path.join(train_drone_text_path, "drone.pth")) 66 | torch.save(drone_tensor, os.path.join(test_drone_text_path, "drone.pth")) 67 | # print(os.path.join(train_path, "text_drone", "image-%02d.pth" % (i + 1))) 68 | 69 | # satellite 70 | satellite = "The altitude of the satellite is 1000 kilometers" 71 | satellite_tensor = wd.word_embedding(satellite) 72 | torch.save(satellite_tensor, os.path.join(train_satellite_text_path, "satellite.pth")) 73 | torch.save(satellite_tensor, os.path.join(test_satellite_text_path, "satellite.pth")) 74 | 75 | 76 | print("Done") 77 | 78 | -------------------------------------------------------------------------------- /MBF-SUES/model_.py: -------------------------------------------------------------------------------- 1 | import os 2 | import timm 3 | import time 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import init, functional 8 | from torchvision import models 9 | 10 | 11 | class GeM(nn.Module): 12 | # GeM zhedong zheng 13 | def __init__(self, dim=2048, p=3, eps=1e-6): 14 | super(GeM, self).__init__() 15 | self.p = nn.Parameter(torch.ones(dim)*p, requires_grad=True).cuda() 16 | self.eps = eps 17 | self.dim = dim 18 | def forward(self, x): 19 | return self.gem(x, p=self.p, eps=self.eps) 20 | 21 | def gem(self, x, p=3, eps=1e-6): 22 | x = x.cuda() 23 | x = torch.transpose(x, 1, -1) 24 | x = x.clamp(min=eps).pow(p) 25 | x = torch.transpose(x, 1, -1) 26 | x = F.avg_pool2d(x, (x.size(-2), x.size(-1))) 27 | x = x.view(x.size(0), x.size(1)) 28 | x = x.pow(1./p) 29 | return x 30 | 31 | def __repr__(self): 32 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ',' + 'dim='+str(self.dim)+')' 33 | 34 | 35 | 36 | class ClassBlock(nn.Module): 37 | 38 | def __init__(self, input_dim, class_num, drop_rate, num_bottleneck=512): 39 | super(ClassBlock, self).__init__() 40 | add_block = [] 41 | add_block += [ 42 | nn.Linear(input_dim, num_bottleneck), 43 | nn.GELU(), 44 | nn.BatchNorm1d(num_bottleneck), 45 | nn.Dropout(p=drop_rate) 46 | ] 47 | 48 | add_block = nn.Sequential(*add_block) 49 | add_block.apply(weights_init_kaiming) 50 | 51 | classifier = [] 52 | classifier += [nn.Linear(num_bottleneck, class_num)] 53 | classifier = nn.Sequential(*classifier) 54 | classifier.apply(weights_init_classifier) 55 | 56 | self.add_block = add_block 57 | self.classifier = classifier 58 | 59 | def forward(self, x): 60 | 61 | x = self.add_block(x) 62 | feature = x 63 | x = self.classifier(x) 64 | return x, feature 65 | 66 | class ResNet(nn.Module): 67 | def __init__(self, class_num, drop_rate, share_weight=False): 68 | super(ResNet, self).__init__() 69 | self.model_1 = timm.create_model("resnet50", pretrained=True, num_classes=0) 70 | 71 | if share_weight: 72 | self.model_2 = self.model_1 73 | else: 74 | self.model_2 = timm.create_model("resnet50", pretrained=True, num_classes=0) 75 | 76 | self.classifier = ClassBlock(2048, class_num, drop_rate) 77 | 78 | def forward(self, x1, x2): 79 | if x1 is None: 80 | y1 = None 81 | else: 82 | x1 = self.model_1(x1) 83 | y1 = self.classifier(x1) 84 | 85 | if x2 is None: 86 | y2 = None 87 | else: 88 | x2 = self.model_2(x2) 89 | y2 = self.classifier(x2) 90 | 91 | return y1, y2 92 | 93 | 94 | class SEResNet_50(nn.Module): 95 | def __init__(self, classes, drop_rate, share_weight = False): 96 | super(SEResNet_50, self).__init__() 97 | self.model_1 = timm.create_model("seresnet50", pretrained=True, num_classes=0) 98 | if share_weight: 99 | self.model_2 = self.model_1 100 | else: 101 | self.model_2 = timm.create_model("seresnet50", pretrained=True, num_classes=0) 102 | self.classifier = ClassBlock(2048, classes, drop_rate) 103 | 104 | def forward(self, x1, x2): 105 | if x1 is None: 106 | y1 = None 107 | else: 108 | x1 = self.model_1(x1) 109 | y1 = self.classifier(x1) 110 | 111 | if x2 is None: 112 | y2 = None 113 | else: 114 | x2 = self.model_2(x2) 115 | y2 = self.classifier(x2) 116 | return y1, y2 117 | 118 | 119 | class DenseNet(nn.Module): 120 | def __init__(self, class_num, drop_rate, share_weight=False): 121 | super(DenseNet, self).__init__() 122 | self.model_1 = timm.create_model("densenet201", pretrained=True, num_classes=0) 123 | if share_weight: 124 | self.model_2 = self.model_1 125 | else: 126 | self.model_2 = timm.create_model("densenet201", pretrained=True, num_classes=0) 127 | self.classifier = ClassBlock(1920, class_num, drop_rate) 128 | 129 | def forward(self, x1, x2): 130 | if x1 is None: 131 | y1 = None 132 | else: 133 | x1 = self.model_1(x1) 134 | y1 = self.classifier(x1) 135 | 136 | if x2 is None: 137 | y2 = None 138 | else: 139 | x2 = self.model_2(x2) 140 | y2 = self.classifier(x2) 141 | return y1, y2 142 | 143 | 144 | class Hybird_ViT(nn.Module): 145 | def __init__(self, classes, drop_rate, share_weight=True): 146 | super(Hybird_ViT, self).__init__() 147 | self.model_1 = timm.create_model("vit_base_r50_s16_384", pretrained=True, num_classes=0) 148 | if share_weight: 149 | self.model_2 = self.model_1 150 | else: 151 | self.model_2 = timm.create_model("vit_base_r50_s16_384", pretrained=True, num_classes=0) 152 | self.classifier = ClassBlock(768, classes, drop_rate) 153 | 154 | def forward(self, x1, x2): 155 | if x1 is None: 156 | y1 = None 157 | f1 = None 158 | else: 159 | x1 = self.model_1(x1) 160 | y1, f1 = self.classifier(x1) 161 | 162 | if x2 is None: 163 | y2 = None 164 | f2 = None 165 | else: 166 | x2 = self.model_2(x2) 167 | y2, f2 = self.classifier(x2) 168 | if self.training: 169 | return y1, y2, f1, f2 170 | else: 171 | return f1, f2 172 | 173 | 174 | class ViT(nn.Module): 175 | def __init__(self, classes, drop_rate, share_weight): 176 | super(ViT, self).__init__() 177 | # checkpoint = torch.load(os.path.join("checkpoint", "drone_checkpoint.pth")) 178 | 179 | self.model_1 = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=0) 180 | # self.model_1.load_state_dict(checkpoint) 181 | # model.model_2.load_state_dict(checkpoint["model"]) 182 | 183 | if share_weight: 184 | self.model_2 = self.model_1 185 | else: 186 | self.model_2 = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=0) 187 | 188 | # self.model_2 = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=0) 189 | # self.model_1.load_state_dict(torch.load("../SS-Study/checkpoint/satellite_checkpoint.pth")) 190 | # self.model_2.load_state_dict(torch.load("../SS-Study/checkpoint/drone_checkpoint.pth")) 191 | # self.model_1 = self.model_2 192 | 193 | self.bn = torch.nn.BatchNorm2d(3) 194 | # self.model_1 = timm.create_model("vit_base_patch16_384", pretrained=True, num_classes=0) 195 | # self.model_2 = timm.create_model("vit_base_patch16_384", pretrained=True, num_classes=0) 196 | self.classifier = ClassBlock(768, classes, drop_rate, num_bottleneck=768) 197 | 198 | def forward(self, x1, x2): 199 | if x1 is None: 200 | y1 = None 201 | else: 202 | # x1 = self.bn(x1) 203 | # print(x1.shape) 204 | x1 = self.model_1(x1) 205 | y1 = self.classifier(x1) 206 | 207 | if x2 is None: 208 | y2 = None 209 | else: 210 | # x2 = self.bn(x2) 211 | x2 = self.model_2(x2) 212 | y2 = self.classifier(x2) 213 | return y1, y2 214 | 215 | 216 | class Swin(nn.Module): 217 | def __init__(self, classes, drop_rate, share_weight=True): 218 | super(Swin, self).__init__() 219 | self.model_1 = timm.create_model("swin_base_patch4_window12_384", pretrained=True, num_classes=0) 220 | if share_weight: 221 | self.model_2 = self.model_1 222 | else: 223 | self.model_2 = timm.create_model("swin_base_patch4_window12_384", pretrained=True, num_classes=0) 224 | self.classifier = ClassBlock(1024, classes, drop_rate) 225 | 226 | def forward(self, x1, x2): 227 | if x1 is None: 228 | y1 = None 229 | else: 230 | x1 = self.model_1(x1) 231 | # print(x1.shape) 232 | y1 = self.classifier(x1) 233 | 234 | if x2 is None: 235 | y2 = None 236 | else: 237 | x2 = self.model_2(x2) 238 | y2 = self.classifier(x2) 239 | return y1, y2 240 | 241 | 242 | class ft_net_LPN(nn.Module): 243 | def __init__(self, stride=1, init_model=None, pool='avg', block=4): 244 | super(ft_net_LPN, self).__init__() 245 | # model_ft = timm.create_model("resnet50", pretrained=True, num_classes=0) 246 | model_ft = models.resnet50(pretrained=True) 247 | # avg pooling to global pooling 248 | if stride == 1: 249 | model_ft.layer4[0].downsample[0].stride = (1, 1) 250 | model_ft.layer4[0].conv2.stride = (1, 1) 251 | 252 | self.pool = pool 253 | self.model = model_ft 254 | self.model.relu = nn.ReLU(inplace=True) 255 | self.block = block 256 | if init_model != None: 257 | self.model = init_model.model 258 | self.pool = init_model.pool 259 | #self.classifier.add_block = init_model.classifier.add_block 260 | 261 | def forward(self, x): 262 | # x = self.model(x) 263 | x = self.model.conv1(x) 264 | x = self.model.bn1(x) 265 | x = self.model.relu(x) 266 | x = self.model.maxpool(x) 267 | x = self.model.layer1(x) 268 | x = self.model.layer2(x) 269 | x = self.model.layer3(x) 270 | x = self.model.layer4(x) 271 | # print(x.shape) 272 | # print(x.shape) 273 | 274 | if self.pool == 'avg+max': 275 | x1 = self.get_part_pool(x, pool='avg') 276 | x2 = self.get_part_pool(x, pool='max') 277 | x = torch.cat((x1, x2), dim=1) 278 | x = x.view(x.size(0), x.size(1), -1) 279 | elif self.pool == 'avg': 280 | x = self.get_part_pool(x) 281 | x = x.view(x.size(0), x.size(1), -1) 282 | elif self.pool == 'max': 283 | x = self.get_part_pool(x, pool='max') 284 | x = x.view(x.size(0), x.size(1), -1) 285 | 286 | return x 287 | 288 | def get_part_pool(self, x, pool='avg', no_overlap=True): 289 | result = [] 290 | if pool == 'avg': 291 | pooling = torch.nn.AdaptiveAvgPool2d((1, 1)) 292 | elif pool == 'max': 293 | pooling = torch.nn.AdaptiveMaxPool2d((1, 1)) 294 | H, W = x.size(2), x.size(3) 295 | c_h, c_w = int(H/2), int(W/2) 296 | per_h, per_w = H/(2*self.block), W/(2*self.block) 297 | if per_h < 1 and per_w < 1: 298 | new_H, new_W = H+(self.block-c_h)*2, W+(self.block-c_w)*2 299 | x = nn.functional.interpolate(x, size=[new_H, new_W], mode='bilinear', align_corners=True) 300 | H, W = x.size(2), x.size(3) 301 | c_h, c_w = int(H/2), int(W/2) 302 | per_h, per_w = H/(2*self.block), W/(2*self.block) 303 | per_h, per_w = math.floor(per_h), math.floor(per_w) # 向下取整 304 | for i in range(self.block): 305 | i = i + 1 306 | if i < self.block: 307 | # print("x", x.shape) 308 | x_curr = x[:,:,(c_h-i*per_h):(c_h+i*per_h),(c_w-i*per_w):(c_w+i*per_w)] 309 | # print("x_curr", x_curr.shape) 310 | if no_overlap and i > 1: 311 | x_pre = x[:,:,(c_h-(i-1)*per_h):(c_h+(i-1)*per_h),(c_w-(i-1)*per_w):(c_w+(i-1)*per_w)] 312 | x_pad = functional.pad(x_pre,(per_h,per_h,per_w,per_w),"constant",0) 313 | x_curr = x_curr - x_pad 314 | # print("x_curr", x_curr.shape) 315 | avgpool = pooling(x_curr) 316 | # print("pool", avgpool.shape) 317 | result.append(avgpool) 318 | else: 319 | if no_overlap and i > 1: 320 | x_pre = x[:,:,(c_h-(i-1)*per_h):(c_h+(i-1)*per_h),(c_w-(i-1)*per_w):(c_w+(i-1)*per_w)] 321 | pad_h = c_h-(i-1)*per_h 322 | pad_w = c_w-(i-1)*per_w 323 | # x_pad = F.pad(x_pre,(pad_h,pad_h,pad_w,pad_w),"constant",0) 324 | if x_pre.size(2)+2*pad_h == H: 325 | x_pad = functional.pad(x_pre,(pad_h,pad_h,pad_w,pad_w),"constant",0) 326 | else: 327 | ep = H - (x_pre.size(2)+2*pad_h) 328 | x_pad = functional.pad(x_pre,(pad_h+ep,pad_h,pad_w+ep,pad_w),"constant",0) 329 | x = x - x_pad 330 | avgpool = pooling(x) 331 | result.append(avgpool) 332 | return torch.cat(result, dim=2) 333 | 334 | 335 | class two_view_net(nn.Module): 336 | def __init__(self, class_num, droprate, pool='avg', share_weight=False, VGG16=False, LPN=False, block=4): 337 | super(two_view_net, self).__init__() 338 | self.LPN = LPN 339 | self.block = block 340 | self.model_1 = ft_net_LPN(pool=pool, block=block) 341 | # self.model_2 = ft_net_LPN(class_num, stride=stride, pool=pool, block=block) 342 | 343 | if share_weight: 344 | self.model_2 = self.model_1 345 | else: 346 | self.model_2 = ft_net_LPN(pool=pool, block=block) 347 | 348 | if pool == 'avg+max': 349 | for i in range(self.block): 350 | name = 'classifier'+str(i) 351 | setattr(self, name, ClassBlock(4096, class_num, droprate)) 352 | else: 353 | for i in range(self.block): 354 | name = 'classifier'+str(i) 355 | setattr(self, name, ClassBlock(2048, class_num, droprate)) 356 | 357 | def forward(self, x1, x2): # x4 is extra data 358 | 359 | if x1 is None: 360 | y1 = None 361 | else: 362 | x1 = self.model_1(x1) 363 | y1 = self.part_classifier(x1) 364 | 365 | if x2 is None: 366 | y2 = None 367 | else: 368 | x2 = self.model_2(x2) 369 | y2 = self.part_classifier(x2) 370 | 371 | return y1, y2 372 | 373 | def part_classifier(self, x): 374 | part = {} 375 | predict = {} 376 | for i in range(self.block): 377 | part[i] = x[:, :, i].view(x.size(0), -1) 378 | # part[i] = torch.squeeze(x[:,:,i]) 379 | name = 'classifier'+str(i) 380 | c = getattr(self, name) 381 | # print(c) 382 | predict[i] = c(part[i]) 383 | # print(predict[i].shape) 384 | # print(predict) 385 | y = [] 386 | for i in range(self.block): 387 | y.append(predict[i]) 388 | if not self.training: 389 | return torch.stack(y, dim=2) 390 | return y 391 | 392 | 393 | def weights_init_kaiming(m): 394 | classname = m.__class__.__name__ 395 | # print(classname) 396 | if classname.find('Conv') != -1: 397 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') # For old pytorch, you may use kaiming_normal. 398 | elif classname.find('Linear') != -1: 399 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 400 | init.constant_(m.bias.data, 0.0) 401 | elif classname.find('BatchNorm1d') != -1: 402 | init.normal_(m.weight.data, 1.0, 0.02) 403 | init.constant_(m.bias.data, 0.0) 404 | 405 | 406 | def weights_init_classifier(m): 407 | classname = m.__class__.__name__ 408 | if classname.find('Linear') != -1: 409 | init.normal_(m.weight.data, std=0.001) 410 | init.constant_(m.bias.data, 0.0) 411 | 412 | 413 | if __name__ == '__main__': 414 | # import ssl 415 | 416 | # ssl._create_default_https_context = ssl._create_unverified_context 417 | # model = ViT_two_view_LPN(100, 0.1).cuda() 418 | # model = Hybird_ViT(100, 0.1).cuda() 419 | # model = ViT_two_view_LPN(100, 0.1).cuda() 420 | model = Hybird_ViT(100, 0.1, True).cuda() 421 | # print(model) 422 | # model = EfficientNet_b() 423 | # print(model.device) 424 | # print(model.extract_features) 425 | # Here I left a simple forward function. 426 | # Test the model, before you train it. 427 | input = torch.randn(1, 3, 384, 384).cuda() 428 | output1, output2 = model(input, input) 429 | print(output1.size()) 430 | # print(output) 431 | 432 | model_dict = { 433 | "LPN": two_view_net, 434 | "resnet": ResNet, 435 | "seresnet": SEResNet_50, 436 | "dense": DenseNet, 437 | "vit": ViT, 438 | "swin": Swin, 439 | "hybrid": Hybird_ViT 440 | } 441 | -------------------------------------------------------------------------------- /MBF-SUES/settings.yaml: -------------------------------------------------------------------------------- 1 | 2 | # dateset path 3 | dataset_path: /home/LVM_date/zhurz/dataset/SUES-200-512x512 4 | weight_save_path: /home/LVM_date/zhurz/dataset/save_model_weight 5 | 6 | # apply LPN and set block number 7 | LPN : 1 8 | block : 2 9 | 10 | # super parameters 11 | batch_size : 8 12 | num_epochs : 40 13 | drop_rate : 0.35 14 | weight_decay : 0.0001 15 | lr : 0.01 16 | 17 | #intial parameters 18 | height : 150 19 | query : drone 20 | image_size: 384 21 | fp16 : 0 22 | classes : 120 23 | 24 | model : MBF 25 | name: MBF 26 | -------------------------------------------------------------------------------- /MBF-SUES/test_and_evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import glob 3 | import os 4 | import time 5 | import timm 6 | import torch 7 | import shutil 8 | import argparse 9 | 10 | import numpy as np 11 | import pandas as pd 12 | from torch import nn 13 | 14 | from utils import fliplr, load_network, which_view, get_id, get_yaml_value 15 | 16 | from Preprocessing import Create_Testing_Datasets 17 | 18 | if torch.cuda.is_available(): 19 | device = torch.device("cuda:0") 20 | 21 | def evaluate(qf, ql, gf, gl): 22 | 23 | query = qf.view(-1, 1) 24 | 25 | score = torch.mm(gf, query) 26 | # score.shape = (51355,1) 27 | score = score.squeeze(1).cpu() 28 | # score.shape = (51355,) 29 | score = score.numpy() 30 | # print(score) 31 | # print(score.shape) 32 | 33 | # predict index 34 | index = np.argsort(score) # from small to large 35 | # 从小到大的索引排列 36 | # print("index before", index) 37 | index = index[::-1] 38 | # print("index after", index) 39 | # 从大到小的索引排列 40 | 41 | # index = index[0:2000] 42 | # good index 43 | query_index = np.argwhere(gl == ql) 44 | # print(query_index.shape) (54, 1) 45 | # gl = ql 返回标签值相同的索引矩阵 46 | # 得到 ql:卫星图标签,gl:无人机图标签 47 | # 即 卫星图标签在 gl中的索引位置 组成的矩阵 48 | good_index = query_index 49 | 50 | # print(good_index) 51 | # print(index[0:10]) 52 | junk_index = np.argwhere(gl == -1) 53 | # print(junk_index) = [] 54 | 55 | CMC_tmp = compute_mAP(index, good_index, junk_index) 56 | return CMC_tmp 57 | 58 | 59 | def compute_mAP(index, good_index, junk_index): 60 | # CMC就是recall的,只要前K里面有一个正确答案就算recall成功是1否则是0 61 | # mAP是传统retrieval的指标,算的是 recall和precision曲线,这个曲线和x轴的面积。 62 | # 你可以自己搜索一下mAP 63 | 64 | ap = 0 65 | cmc = torch.IntTensor(len(index)).zero_() 66 | # print(cmc.shape) torch.Size([51355]) 67 | if good_index.size == 0: # if empty 68 | cmc[0] = -1 69 | return ap, cmc 70 | 71 | # remove junk_index 72 | mask = np.in1d(index, junk_index, invert=True) 73 | index = index[mask] 74 | # print(index.shape) (51355,) 75 | # if junk_index == [] 76 | # return index fully 77 | 78 | # find good_index index 79 | ngood = len(good_index) 80 | # print("good_index", good_index) (54, 1) 81 | # print(index) 82 | # print(good_index) 83 | mask = np.in1d(index, good_index) 84 | # print(mask) 85 | # print(mask.shape) (51355,) 86 | # 51355 中 54 个对应元素变为了True 87 | 88 | rows_good = np.argwhere(mask == True) 89 | # print(rows_good.shape) (54, 1) 90 | # rows_good 得到这 54 个为 True 元素的索引位置 91 | 92 | rows_good = rows_good.flatten() 93 | # print(rows_good.shape) (54,) 94 | # print(rows_good[0]) 95 | 96 | cmc[rows_good[0]:] = 1 97 | # print(cmc) 98 | # print(cmc.shape) torch.Size([51355]) 99 | 100 | # print(cmc) 101 | for i in range(ngood): 102 | d_recall = 1.0 / ngood 103 | # d_racall = 1/54 104 | precision = (i + 1) * 1.0 / (rows_good[i] + 1) 105 | # n/sum 106 | # print("row_good[]", i, rows_good[i]) 107 | # print(precision) 108 | if rows_good[i] != 0: 109 | old_precision = i * 1.0 / rows_good[i] 110 | else: 111 | old_precision = 1.0 112 | ap = ap + d_recall * (old_precision + precision) / 2 113 | 114 | return ap, cmc 115 | 116 | 117 | def extract_feature(model, dataloaders, block, LPN, view_index=1): 118 | features = torch.FloatTensor() 119 | count = 0 120 | for data in dataloaders: 121 | img, text, label = data 122 | n, c, h, w = img.size() 123 | text = text.to(device) 124 | count += n 125 | 126 | if LPN: 127 | ff = torch.FloatTensor(n, 512, block+2).zero_().cuda() 128 | else: 129 | ff = torch.FloatTensor(n, 512).zero_().cuda() 130 | 131 | # why for in range(2): 132 | # 1. for flip img 133 | # 2. for normal img 134 | 135 | for i in range(1): 136 | if i == 1: 137 | img = fliplr(img) 138 | 139 | input_img = img.to(device) 140 | outputs = None 141 | since = time.time() 142 | 143 | if view_index == 1: 144 | outputs, _ = model(input_img, None, text, None) 145 | elif view_index == 2: 146 | _, outputs = model(None, input_img, None, text) 147 | # print(outputs.shape) 148 | ff += outputs 149 | time_elapsed = time.time() - since 150 | # print(time_elapsed) 151 | # ff.shape = [16, 512, 4] 152 | 153 | if LPN: 154 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * np.sqrt(block) 155 | # print("fnorm", fnorm.shape) 156 | ff = ff.div(fnorm.expand_as(ff)) 157 | # print("ff", ff.shape) 158 | ff = ff.view(ff.size(0), -1) 159 | # print("ff", ff.shape) 160 | else: 161 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) 162 | # print("fnorm", fnorm.shape) 163 | ff = ff.div(fnorm.expand_as(ff)) 164 | # print("ff", ff.shape) 165 | 166 | features = torch.cat((features, ff.data.cpu()), 0) # 在维度0上拼接 167 | return features 168 | 169 | 170 | ############################### main function ####################################### 171 | def eval_and_test(cfg_path, name, seqs): 172 | param_dict = get_yaml_value(cfg_path) 173 | if name == "": 174 | name = param_dict["name"] 175 | block = param_dict["block"] 176 | LPN = param_dict["LPN"] 177 | data_path = param_dict["dataset_path"] 178 | batch_size = param_dict["batch_size"] 179 | image_size = param_dict["image_size"] 180 | height = param_dict["height"] 181 | 182 | all_block = block 183 | 184 | dataloaders, image_datasets = Create_Testing_Datasets(test_data_path=data_path + "/Testing/{}".format(height), 185 | batch_size=batch_size, 186 | image_size=image_size) 187 | 188 | # print("Testing Start >>>>>>>>") 189 | table_path = os.path.join(param_dict["weight_save_path"], 190 | name + ".csv") 191 | save_model_list = glob.glob(os.path.join(param_dict["weight_save_path"], 192 | name, "*.pth")) 193 | # print(param_dict("name")) 194 | if os.path.exists(os.path.join(param_dict["weight_save_path"], 195 | name)) and len(save_model_list) >= 1: 196 | if not os.path.exists(table_path): 197 | evaluate_csv = pd.DataFrame(index=["recall@1", "recall@5", "recall@10", "recall@1p", "AP", "time"]) 198 | else: 199 | evaluate_csv = pd.read_csv(table_path) 200 | evaluate_csv.index = evaluate_csv["index"] 201 | for query in ['satellite', 'drone']: 202 | for seq in range(-seqs, 0): 203 | # net_name = "mae_pretrained" 204 | # model, net_name = load_network(seq=seq) 205 | # model = model_.two_view_net(701, 0.3) 206 | # model.load_state_dict(torch.load("/home/sues/Reza/SS-Study/LPN/save_model_weight/three_view_lr_0.01_dr_0.5_2/net_119.pth")) 207 | # print(model) 208 | # LPN 209 | model, net_name = load_network(seq) 210 | # model = Hybird_ViT(120, 0.1) 211 | # model.load_state_dict(torch.load(net_path)) 212 | if LPN: 213 | for i in range(all_block): 214 | cls_name = 'classifier' + str(i) 215 | c = getattr(model, cls_name) 216 | c.classifier = nn.Sequential() 217 | else: 218 | model.classifier.classifier = nn.Sequential() 219 | # print(net_name) 220 | 221 | model = model.eval() 222 | model = model.cuda() 223 | # print(model) 224 | query_name = "" 225 | gallery_name = "" 226 | 227 | if query == "satellite": 228 | query_name = 'query_satellite' 229 | gallery_name = 'gallery_drone' 230 | elif query == "drone": 231 | query_name = 'query_drone' 232 | gallery_name = 'gallery_satellite' 233 | 234 | which_query = which_view(query_name) 235 | which_gallery = which_view(gallery_name) 236 | 237 | print('%s -> %s:' % (query_name, gallery_name)) 238 | 239 | # image_datasets, data_loader = Create_Testing_Datasets(test_data_path=data_path) 240 | 241 | gallery_path = image_datasets[gallery_name].imgs 242 | query_path = image_datasets[query_name].imgs 243 | 244 | gallery_label, gallery_path = get_id(gallery_path) 245 | query_label, query_path = get_id(query_path) 246 | 247 | with torch.no_grad(): 248 | since = time.time() 249 | query_feature = extract_feature(model, dataloaders[query_name], all_block, LPN, which_query) 250 | gallery_feature = extract_feature(model, dataloaders[gallery_name], all_block, LPN, which_gallery) 251 | 252 | time_elapsed = time.time() - since 253 | print('Testing complete in {:.0f}m {:.0f}s'.format( 254 | time_elapsed // 60, time_elapsed % 60)) 255 | 256 | # result = {'gallery_f': gallery_feature.numpy(), 'gallery_label': gallery_label, 257 | # 'gallery_path': gallery_path, 258 | # 'query_f': query_feature.numpy(), 'query_label': query_label, 'query_path': query_path} 259 | # 260 | # scipy.io.savemat('U1652_pytorch_result.mat', result) 261 | # 262 | # print(">>>>>>>> Testing END") 263 | # 264 | # print("Evaluating Start >>>>>>>>") 265 | # 266 | # result = scipy.io.loadmat("U1652_pytorch_result.mat") 267 | # 268 | # # initialize query feature data 269 | # query_feature = torch.FloatTensor(result['query_f']) 270 | # query_label = result['query_label'][0] 271 | # 272 | # # initialize all(gallery) feature data 273 | # gallery_feature = torch.FloatTensor(result['gallery_f']) 274 | # gallery_label = result['gallery_label'][0] 275 | query_feature = query_feature.cuda() 276 | gallery_feature = gallery_feature.cuda() 277 | query_label = np.array(query_label) 278 | gallery_label = np.array(gallery_label) 279 | 280 | # fed tensor to GPU 281 | query_feature = query_feature.cuda() 282 | gallery_feature = gallery_feature.cuda() 283 | 284 | # CMC = recall 285 | CMC = torch.IntTensor(len(gallery_label)).zero_() 286 | # ap = average precision 287 | ap = 0.0 288 | 289 | for i in range(len(query_label)): 290 | ap_tmp, CMC_tmp = evaluate(query_feature[i], query_label[i], gallery_feature, gallery_label) 291 | if CMC_tmp[0] == -1: 292 | continue 293 | CMC += CMC_tmp 294 | ap += ap_tmp 295 | 296 | # average CMC 297 | 298 | CMC = CMC.float() 299 | CMC = CMC / len(query_label) 300 | # print(len(query_label)) 301 | recall_1 = CMC[0] * 100 302 | recall_5 = CMC[4] * 100 303 | recall_10 = CMC[9] * 100 304 | recall_1p = CMC[round(len(gallery_label) * 0.01)] * 100 305 | AP = ap / len(query_label) * 100 306 | 307 | evaluate_csv[query_name+"_"+net_name] = [float(recall_1), float(recall_5), 308 | float(recall_10), float(recall_1p), 309 | float(AP), 310 | float(time_elapsed) 311 | ] 312 | evaluate_result = 'Recall@1:%.2f Recall@5:%.2f Recall@10:%.2f Recall@top1:%.2f AP:%.2f Time::%.2f' % ( 313 | recall_1, recall_5, recall_10, recall_1p, AP, time_elapsed 314 | ) 315 | 316 | # show result and save 317 | save_path = os.path.join(param_dict["weight_save_path"], name) 318 | save_txt_path = os.path.join(save_path, '%s_to_%s_%s_%.2f_%.2f.txt' % (query_name[6:], gallery_name[8:], net_name[:7], recall_1, AP)) 319 | # print(save_txt_path) 320 | 321 | with open(save_txt_path, 'w') as f: 322 | f.write(evaluate_result) 323 | f.close() 324 | 325 | shutil.copy('settings.yaml', os.path.join(save_path, "settings_saved.yaml")) 326 | shutil.copy('train.py', os.path.join(save_path, "train.py")) 327 | shutil.copy('Multi.py', os.path.join(save_path, "model.py")) 328 | 329 | # print(round(len(gallery_label)*0.01)) 330 | print(evaluate_result) 331 | # evaluate_csv["max"] = 332 | drone_max = [] 333 | satellite_max = [] 334 | 335 | for index in evaluate_csv.index: 336 | drone_max.append(evaluate_csv.loc[index].iloc[:5].max()) 337 | satellite_max.append(evaluate_csv.loc[index].iloc[5:].max()) 338 | 339 | evaluate_csv['drone_max'] = drone_max 340 | evaluate_csv['satellite_max'] = satellite_max 341 | evaluate_csv.columns.name = "net" 342 | evaluate_csv.index.name = "index" 343 | evaluate_csv.to_csv(table_path) 344 | else: 345 | print("Don't have enough weights to evaluate!") 346 | 347 | 348 | def parse_opt(known=False): 349 | parser = argparse.ArgumentParser() 350 | parser.add_argument('--cfg', type=str, default='settings.yaml', help='config file XXX.yaml path') 351 | parser.add_argument('--name', type=str, default='', help='evaluate which weight,dir name') 352 | parser.add_argument('--seq', type=int, default=1, help='evaluate how many weights from loss value(small -> big)') 353 | 354 | opt = parser.parse_known_args()[0] if known else parser.parse_args() 355 | 356 | return opt 357 | 358 | 359 | if __name__ == '__main__': 360 | opt = parse_opt(True) 361 | 362 | eval_and_test(opt.cfg, opt.name, opt.seq) -------------------------------------------------------------------------------- /MBF-SUES/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import time 4 | import torch 5 | import argparse 6 | 7 | import numpy as np 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.optim import lr_scheduler 11 | import torch.backends.cudnn as cudnn 12 | from pytorch_metric_learning import losses, miners 13 | 14 | 15 | from Multi_HBP import Hybird_ViT 16 | from utils import get_yaml_value, parameter, create_dir, save_feature_network, setup_seed 17 | from Preprocessing import Create_Training_Datasets 18 | import random 19 | import os 20 | 21 | if torch.cuda.is_available(): 22 | device = torch.device("cuda:1") 23 | cudnn.benchmark = True 24 | 25 | 26 | # torch.cuda.manual_seed(random.randint(1, 100)) 27 | # setup_seed() 28 | 29 | def one_LPN_output(outputs, labels, criterion, block): 30 | # part = {} 31 | # print(len(outputs)) 32 | sm = nn.Softmax(dim=1) 33 | num_part = block 34 | score = 0 35 | loss = 0 36 | # print(len(outputs)) 37 | for i in range(num_part): 38 | part = outputs[i] 39 | score += sm(part) 40 | loss += criterion(part, labels) 41 | _, preds = torch.max(score.data, 1) 42 | 43 | return preds, loss 44 | 45 | 46 | def train(config_path): 47 | param_dict = get_yaml_value(config_path) 48 | print(param_dict) 49 | classes = param_dict["classes"] 50 | num_epochs = param_dict["num_epochs"] 51 | drop_rate = param_dict["drop_rate"] 52 | lr = param_dict["lr"] 53 | weight_decay = param_dict["weight_decay"] 54 | model_name = param_dict["model"] 55 | fp16 = param_dict["fp16"] 56 | weight_save_path = param_dict["weight_save_path"] 57 | LPN = param_dict["LPN"] 58 | batchsize = param_dict["batch_size"] 59 | height = param_dict["height"] 60 | data_path = param_dict["dataset_path"] 61 | block = param_dict["block"] 62 | image_size = param_dict["image_size"] 63 | 64 | 65 | all_block = block 66 | train_data_path = data_path + "/Training/{}".format(height) 67 | 68 | dataloaders, image_datasets = Create_Training_Datasets(train_data_path=train_data_path, batch_size=batchsize, image_size=image_size) 69 | dataset_sizes = {x: len(image_datasets[x]) for x in ['satellite', 'drone']} 70 | 71 | model = Hybird_ViT(classes, drop_rate, all_block).to(device) 72 | 73 | if LPN: 74 | ignored_params = list() 75 | for i in range(all_block): 76 | cls_name = 'classifier' + str(i) 77 | c = getattr(model, cls_name) 78 | ignored_params += list(map(id, c.parameters())) 79 | 80 | base_params = filter(lambda p: id(p) not in ignored_params, model.parameters()) 81 | 82 | optim_params = [{'params': base_params, 'lr': 0.1 * lr}] 83 | for i in range(all_block): 84 | cls_name = 'classifier' + str(i) 85 | c = getattr(model, cls_name) 86 | optim_params.append({'params': c.parameters(), 'lr': lr}) 87 | optimizer = optim.SGD(optim_params, weight_decay=weight_decay, momentum=0.9, nesterov=True) 88 | # opt = torchcontrib.optim.SWA(optimizer) 89 | else: 90 | ignored_params = list(map(id, model.classifier.parameters())) 91 | base_params = filter(lambda p: id(p) not in ignored_params, model.parameters()) 92 | 93 | optimizer = optim.SGD([ 94 | {'params': base_params, 'lr': 0.1 * lr}, 95 | {'params': model.classifier.parameters(), 'lr': lr} 96 | ], weight_decay=weight_decay, momentum=0.9, nesterov=True) 97 | 98 | if fp16: 99 | # from apex.fp16_utils import * 100 | from apex import amp, optimizers 101 | model, optimizer_ft = amp.initialize(model, optimizer, opt_level="O2") 102 | 103 | criterion = nn.CrossEntropyLoss() 104 | # criterion1 = nn.KLDivLoss() 105 | # circle = circle_loss.CircleLoss(m=0.4, gamma=80) 106 | criterion_func = losses.TripletMarginLoss(margin=0.3) 107 | miner = miners.MultiSimilarityMiner() 108 | 109 | scheduler = lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.5) 110 | # scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=20) 111 | 112 | print("Dataloader Preprocessing Finished...") 113 | MAX_LOSS = 10 114 | print("Training Start >>>>>>>>") 115 | weight_save_name = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) 116 | dir_model_name = model_name + "_" + str(height) + "_" + weight_save_name 117 | save_path = os.path.join(weight_save_path, dir_model_name) 118 | create_dir(save_path) 119 | print(save_path) 120 | parameter("name", dir_model_name) 121 | 122 | warm_epoch = 5 123 | warm_up = 0.1 # We start from the 0.1*lrRate 124 | warm_iteration = round(dataset_sizes['satellite'] / batchsize) * warm_epoch # first 5 epoch 125 | 126 | for epoch in range(num_epochs): 127 | since = time.time() 128 | 129 | running_loss = 0.0 130 | running_corrects1 = 0.0 131 | running_corrects2 = 0.0 132 | total1 = 0.0 133 | total2 = 0.0 134 | model.train(True) 135 | for data1, data2 in zip(dataloaders["satellite"], dataloaders["drone"]): 136 | 137 | input1, text1, label1 = data1 138 | input2, text2, label2 = data2 139 | 140 | input1, input2 = input1.to(device), input2.to(device) 141 | text1, text2 = text1.to(device), text2.to(device) 142 | label1, label2 = label1.to(device), label2.to(device) 143 | 144 | total1 += label1.size(0) 145 | total2 += label2.size(0) 146 | 147 | optimizer.zero_grad() 148 | 149 | output1, output2, feature1, feature2 = model(input1, input2, text1, text2) 150 | 151 | fnorm = torch.norm(feature1, p=2, dim=1, keepdim=True) * np.sqrt(all_block + 2) 152 | fnorm2 = torch.norm(feature2, p=2, dim=1, keepdim=True) * np.sqrt(all_block + 2) 153 | # fnorm3 = torch.norm(feature3, p=2, dim=1, keepdim=True) * np.sqrt(all_block) 154 | # fnorm4 = torch.norm(feature4, p=2, dim=1, keepdim=True) * np.sqrt(all_block) 155 | 156 | feature1 = feature1.div(fnorm.expand_as(feature1)) 157 | feature2 = feature2.div(fnorm2.expand_as(feature2)) 158 | loss1 = loss2 = loss3 = loss4 = loss6 = loss5 = loss7 = loss8 = 0 159 | 160 | if LPN: 161 | # print(len(output1)) 162 | preds1, loss1 = one_LPN_output(output1[2:], label1, criterion, all_block) 163 | preds2, loss2 = one_LPN_output(output2[2:], label2, criterion, all_block) 164 | 165 | loss3 = criterion(output1[1], label1) 166 | loss4 = criterion(output2[1], label2) 167 | 168 | loss7 = criterion(output1[0], label1) 169 | loss8 = criterion(output2[0], label2) 170 | # _, preds1 = torch.max(output1[1].data, 1) 171 | # _, preds2 = torch.max(output2[1].data, 1) 172 | # print(loss) 173 | else: 174 | loss1 = criterion(output1[0], label1) 175 | loss2 = criterion(output2[1], label2) 176 | loss3 = criterion(output1[0], label1) 177 | loss4 = criterion(output2[1], label2) 178 | 179 | _, preds1 = torch.max(output1[0].data, 1) 180 | _, preds2 = torch.max(output2[1].data, 1) 181 | _, preds3 = torch.max(output1[0].data, 1) 182 | _, preds4 = torch.max(output2[1].data, 1) 183 | 184 | # Identity loss 185 | loss = loss1 + loss2 + loss3 + loss4 + loss7 + loss8 186 | 187 | # Triplet loss 188 | hard_pairs = miner(feature1, label1) 189 | hard_pairs2 = miner(feature2, label2) 190 | loss += criterion_func(feature1, label1, hard_pairs) + \ 191 | criterion_func(feature2, label2, hard_pairs2) 192 | 193 | 194 | if epoch < warm_epoch: 195 | warm_up = min(1.0, warm_up + 0.9 / warm_iteration) 196 | loss *= warm_up 197 | if fp16: # we use optimizer to backward loss 198 | with amp.scale_loss(loss, optimizer) as scaled_loss: 199 | scaled_loss.backward() 200 | # pass 201 | else: 202 | loss.backward() 203 | optimizer.step() 204 | 205 | running_loss += loss.item() 206 | running_corrects1 += preds1.eq(label1.data).sum() 207 | running_corrects2 += preds2.eq(label2.data).sum() 208 | # print(loss.item(), preds1.eq(label1.data).sum(), preds2.eq(label2.data).sum()) 209 | 210 | scheduler.step() 211 | epoch_loss = running_loss / classes 212 | satellite_acc = running_corrects1 / total1 213 | drone_acc = running_corrects2 / total2 214 | time_elapsed = time.time() - since 215 | 216 | print('[Epoch {}/{}] {} | Loss: {:.4f} | Drone_Acc: {:.2f}% | Satellite_Acc: {:.2f}% | Time: {:.2f}s' \ 217 | .format(epoch + 1, num_epochs, "Train", epoch_loss, drone_acc * 100, satellite_acc * 100, time_elapsed)) 218 | 219 | if drone_acc > 0.95 and satellite_acc > 0.95: 220 | if epoch_loss < MAX_LOSS and epoch > (num_epochs - 15): 221 | MAX_LOSS = epoch_loss 222 | save_feature_network(model, dir_model_name, epoch + 1) 223 | print(model_name + " Epoch: " + str(epoch + 1) + " has saved with loss: " + str(epoch_loss)) 224 | 225 | def parse_opt(known=False): 226 | parser = argparse.ArgumentParser() 227 | parser.add_argument('--cfg', type=str, default='settings.yaml', help='config file XXX.yaml path') 228 | opt = parser.parse_known_args()[0] if known else parser.parse_args() 229 | 230 | return opt 231 | 232 | 233 | if __name__ == '__main__': 234 | opt = parse_opt(True) 235 | print(opt.cfg) 236 | train(opt.cfg) 237 | -------------------------------------------------------------------------------- /MBF-SUES/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import sys 5 | import glob 6 | import yaml 7 | import math 8 | import torch 9 | from Multi_HBP import Hybird_ViT 10 | import pandas as pd 11 | from shutil import copyfile, copy 12 | # from evaluation_methods import select_best_weight 13 | import torch.distributed as dist 14 | 15 | # from new_model import two_view_net 16 | 17 | def get_params_value(key_name, file_name="settings.yaml"): 18 | f = open(file_name, 'r', encoding="utf-8") 19 | t_value = yaml.load(f, Loader=yaml.FullLoader) 20 | f.close() 21 | params = t_value[key_name] 22 | return params 23 | 24 | 25 | def get_yaml_value(config_path="settings.yaml"): 26 | f = open(config_path, 'r', encoding="utf-8") 27 | t_value = yaml.load(f, Loader=yaml.FullLoader) 28 | f.close() 29 | # params = t_value[key_name] 30 | return t_value 31 | 32 | 33 | def save_network(network, dir_model_name, epoch_label, loss): 34 | save_path = get_params_value('weight_save_path') 35 | # with open("settings.yaml", "r", encoding="utf-8") as f: 36 | # dict = yaml.load(f, Loader=yaml.FullLoader) 37 | # dict['name'] = dir_model_name 38 | # with open("settings.yaml", "w", encoding="utf-8") as f: 39 | # yaml.dump(dict, f) 40 | 41 | # if not os.path.isdir(os.path.join(save_path, dir_model_name)): 42 | # os.mkdir(os.path.join(save_path, dir_model_name)) 43 | 44 | if isinstance(epoch_label, int): 45 | save_filename = 'net_%03d_loss_%f.pth' % (epoch_label, loss) 46 | else: 47 | save_filename = 'net_%s_loss_%f.pth' % (epoch_label, loss) 48 | save_path1 = os.path.join(save_path, dir_model_name, "visualized_" + save_filename) 49 | torch.save(network.module.state_dict(), save_path1) 50 | 51 | save_path2 = os.path.join(save_path, dir_model_name, "pretrained_" + save_filename) 52 | torch.save(network.state_dict(), save_path2) 53 | 54 | 55 | def save_feature_network(network, dir_model_name, epoch_label): 56 | save_path = get_params_value('weight_save_path') 57 | # with open("settings.yaml", "r", encoding="utf-8") as f: 58 | # dict = yaml.load(f, Loader=yaml.FullLoader) 59 | # dict['name'] = dir_model_name 60 | # with open("settings.yaml", "w", encoding="utf-8") as f: 61 | # yaml.dump(dict, f) 62 | 63 | # if not os.path.isdir(os.path.join(save_path, dir_model_name)): 64 | # os.mkdir(os.path.join(save_path, dir_model_name)) 65 | 66 | if isinstance(epoch_label, int): 67 | save_filename = 'net_%03d.pth' % (epoch_label) 68 | else: 69 | save_filename = 'net_%s.pth' % (epoch_label) 70 | save_path = os.path.join(save_path, dir_model_name, save_filename) 71 | torch.save(network.state_dict(), save_path) 72 | 73 | 74 | def fliplr(img): 75 | '''flip horizontal''' 76 | inv_idx = torch.arange(img.size(3) - 1, -1, -1).long() # N x C x H x W 77 | img_flip = img.index_select(3, inv_idx) 78 | return img_flip 79 | 80 | 81 | def which_view(name): 82 | if 'satellite' in name: 83 | return 1 84 | elif 'drone' in name: 85 | return 2 86 | else: 87 | print('unknown view') 88 | return -1 89 | 90 | 91 | def get_model_list(dirname, key, seq): 92 | if os.path.exists(dirname) is False: 93 | print('no dir: %s' % dirname) 94 | return None 95 | gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if 96 | os.path.isfile(os.path.join(dirname, f)) and key in f and ".pth" in f] 97 | if gen_models is None: 98 | return None 99 | gen_models.sort() 100 | last_model_name = gen_models[seq] 101 | return last_model_name 102 | 103 | 104 | def load_network(seq): 105 | model_name = get_params_value("model") 106 | print(model_name) 107 | name = get_params_value("name") 108 | weight_save_path = get_params_value("weight_save_path") 109 | 110 | dirname = os.path.join(weight_save_path, name) 111 | last_model_name = os.path.basename(get_model_list(dirname, 'net', seq)) 112 | print(get_model_list(dirname, 'net', seq) + " " + "seq: " + str(seq)) 113 | # print(os.path.join(dirname,last_model_name)) 114 | classes = get_params_value("classes") 115 | drop_rate = get_params_value("drop_rate") 116 | block = get_params_value("block") 117 | 118 | model = Hybird_ViT(classes, drop_rate, block) 119 | # model = model_.ResNet(classes, drop_rate) 120 | model.load_state_dict(torch.load(os.path.join(dirname, last_model_name))) 121 | return model, last_model_name 122 | 123 | 124 | def get_id(img_path): 125 | camera_id = [] 126 | labels = [] 127 | paths = [] 128 | for path, v in img_path: 129 | folder_name = os.path.basename(os.path.dirname(path)) 130 | labels.append(int(folder_name)) 131 | paths.append(path) 132 | return labels, paths 133 | 134 | 135 | def create_dir(path): 136 | if not os.path.exists(path): 137 | os.mkdir(path) 138 | 139 | 140 | # def get_best_weight(query_name, model_name, height, csv_path): 141 | # drone_best_list, satellite_best_list = select_best_weight(model_name, csv_path) 142 | # net_path = None 143 | # if "drone" in query_name: 144 | # for weight in drone_best_list: 145 | # if str(height) in weight: 146 | # drone_best_weight = weight.split(".")[0] 147 | # table = pd.read_csv(weight, index_col=0) 148 | # query_number = len(list(filter(lambda x: "drone" in x, table.columns))) - 1 149 | # 150 | # values = list(table.loc["recall@1", :])[:query_number] 151 | # indexes = list(table.loc["recall@1", :].index)[:query_number] 152 | # net_name = indexes[values.index(max(values))] 153 | # net = net_name.split("_")[2] + "_" + net_name.split("_")[3] 154 | # net_path = os.path.join(drone_best_weight, net) 155 | # # print(values, indexes) 156 | # if "satellite" in query_name: 157 | # for weight in satellite_best_list: 158 | # if str(height) in weight: 159 | # satellite_best_weight = weight.split(".")[0] 160 | # table = pd.read_csv(weight, index_col=0) 161 | # query_number = len(list(filter(lambda x: "drone" in x, table.columns))) - 1 162 | # 163 | # values = list(table.loc["recall@1", :])[query_number:query_number*2] 164 | # indexes = list(table.loc["recall@1", :].index)[query_number:query_number*2] 165 | # net_name = indexes[values.index(max(values))] 166 | # net = net_name.split("_")[2] + "_" + net_name.split("_")[3] 167 | # net_path = os.path.join(satellite_best_weight, net) 168 | # return net_path 169 | 170 | def parameter(index_name, index_number): 171 | with open("settings.yaml", "r", encoding="utf-8") as f: 172 | setting_dict = yaml.load(f, Loader=yaml.FullLoader) 173 | setting_dict[index_name] = index_number 174 | print(setting_dict) 175 | f.close() 176 | with open("settings.yaml", "w", encoding="utf-8") as f: 177 | yaml.dump(setting_dict, f) 178 | f.close() 179 | 180 | 181 | def summary_csv_extract_pic(csv_path): 182 | csv_table = pd.read_csv(csv_path, index_col=0) 183 | csv_path = os.path.join("result", csv_path.split("_")[-3]) 184 | create_dir(csv_path) 185 | query_pic = list(csv_table.columns) 186 | for pic in query_pic: 187 | dir_path = os.path.join(csv_path, pic.split("/")[-4] + "_" + pic.split("/")[-3]) 188 | create_dir(dir_path) 189 | dir_path = os.path.join(dir_path, pic.split("/")[-2]) 190 | create_dir(dir_path) 191 | copy(pic, dir_path) 192 | gallery_list = list(csv_table[pic]) 193 | print(gallery_list) 194 | count = 0 195 | for gl_path in gallery_list: 196 | print(gl_path) 197 | copy(gl_path, dir_path) 198 | src_name = os.path.join(dir_path, gl_path.split("/")[-1]) 199 | dest_name = os.path.dirname(src_name) + os.sep + str(count) + "_" + gl_path.split("/")[-2] + "." + gl_path.split(".")[-1] 200 | print(src_name) 201 | print(dest_name) 202 | os.rename(src_name, dest_name) 203 | count = count + 1 204 | 205 | if __name__ == '__main__': 206 | csv_list = glob.glob(os.path.join("result", "*matching.csv")) 207 | print(len(csv_list)) 208 | for csv in csv_list: 209 | summary_csv_extract_pic(csv) 210 | # break 211 | 212 | def is_dist_avail_and_initialized(): 213 | if not dist.is_available(): 214 | return False 215 | if not dist.is_initialized(): 216 | return False 217 | return True 218 | 219 | 220 | def get_world_size(): 221 | if not is_dist_avail_and_initialized(): 222 | return 1 223 | return dist.get_world_size() 224 | 225 | 226 | def all_reduce_mean(x): 227 | world_size = get_world_size() 228 | if world_size > 1: 229 | x_reduce = torch.tensor(x).cuda() 230 | dist.all_reduce(x_reduce) 231 | x_reduce /= world_size 232 | return x_reduce.item() 233 | else: 234 | return x 235 | 236 | def adjust_learning_rate(optimizer, epochs, epoch, lr, min_lr): 237 | """Decay the learning rate with half-cycle cosine after warmup""" 238 | # if epoch < args.warmup_epochs: 239 | # lr = args.lr * epoch / args.warmup_epochs 240 | # else: 241 | warmup_epochs = 40 242 | lr = min_lr + (lr - min_lr) * 0.5 * \ 243 | (1. + math.cos(math.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs))) 244 | # for param_group in optimizer.param_groups: 245 | # if "lr_scale" in param_group: 246 | # param_group["lr"] = lr * param_group["lr_scale"] 247 | # else: 248 | # param_group["lr"] = lr 249 | return lr 250 | 251 | def setup_seed(seed=3407): 252 | os.environ['PYTHONHASHSEED'] = str(seed) 253 | 254 | torch.manual_seed(seed) 255 | torch.cuda.manual_seed(seed) 256 | torch.cuda.manual_seed_all(seed) 257 | 258 | np.random.seed(seed) 259 | random.seed(seed) 260 | 261 | torch.backends.cudnn.deterministic = True 262 | torch.backends.cudnn.benchmark = False 263 | # torch.backends.cudnn.enabled = False -------------------------------------------------------------------------------- /MBF-SUES/vision_transformer_hybrid.py: -------------------------------------------------------------------------------- 1 | """ Hybrid Vision Transformer (ViT) in PyTorch 2 | 3 | A PyTorch implement of the Hybrid Vision Transformers as described in: 4 | 5 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' 6 | - https://arxiv.org/abs/2010.11929 7 | 8 | `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` 9 | - https://arxiv.org/abs/2106.10270 10 | 11 | NOTE These hybrid model definitions depend on code in vision_transformer.py. 12 | They were moved here to keep file sizes sane. 13 | 14 | Hacked together by / Copyright 2020, Ross Wightman 15 | """ 16 | from copy import deepcopy 17 | from functools import partial 18 | 19 | import torch 20 | import torch.nn as nn 21 | 22 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 23 | from timm.models.layers import StdConv2dSame, StdConv2d, to_2tuple 24 | from timm.models.resnet import resnet26d, resnet50d 25 | from timm.models.resnetv2 import ResNetV2, create_resnetv2_stem 26 | from timm.models.registry import register_model 27 | # from timm.models.vision_transformer import _create_vision_transformer 28 | from vision_transformer import _create_vision_transformer 29 | 30 | def _cfg(url='', **kwargs): 31 | return { 32 | 'url': url, 33 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 34 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 35 | 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 36 | 'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head', 37 | **kwargs 38 | } 39 | 40 | 41 | default_cfgs = { 42 | # hybrid in-1k models (weights from official JAX impl where they exist) 43 | 'vit_tiny_r_s16_p8_224': _cfg( 44 | url='https://storage.googleapis.com/vit_models/augreg/' 45 | 'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', 46 | first_conv='patch_embed.backbone.conv'), 47 | 'vit_tiny_r_s16_p8_384': _cfg( 48 | url='https://storage.googleapis.com/vit_models/augreg/' 49 | 'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 50 | first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0), 51 | 'vit_small_r26_s32_224': _cfg( 52 | url='https://storage.googleapis.com/vit_models/augreg/' 53 | 'R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz', 54 | ), 55 | 'vit_small_r26_s32_384': _cfg( 56 | url='https://storage.googleapis.com/vit_models/augreg/' 57 | 'R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 58 | input_size=(3, 384, 384), crop_pct=1.0), 59 | 'vit_base_r26_s32_224': _cfg(), 60 | 'vit_base_r50_s16_224': _cfg(), 61 | 'vit_base_r50_s16_384': _cfg( 62 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', 63 | input_size=(3, 384, 384), crop_pct=1.0), 64 | 'vit_large_r50_s32_224': _cfg( 65 | url='https://storage.googleapis.com/vit_models/augreg/' 66 | 'R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz' 67 | ), 68 | 'vit_large_r50_s32_384': _cfg( 69 | url='https://storage.googleapis.com/vit_models/augreg/' 70 | 'R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', 71 | input_size=(3, 384, 384), crop_pct=1.0 72 | ), 73 | 74 | # hybrid in-21k models (weights from official Google JAX impl where they exist) 75 | 'vit_tiny_r_s16_p8_224_in21k': _cfg( 76 | url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', 77 | num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv'), 78 | 'vit_small_r26_s32_224_in21k': _cfg( 79 | url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz', 80 | num_classes=21843, crop_pct=0.9), 81 | 'vit_base_r50_s16_224_in21k': _cfg( 82 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', 83 | num_classes=21843, crop_pct=0.9), 84 | 'vit_large_r50_s32_224_in21k': _cfg( 85 | url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz', 86 | num_classes=21843, crop_pct=0.9), 87 | 88 | # hybrid models (using timm resnet backbones) 89 | 'vit_small_resnet26d_224': _cfg( 90 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), 91 | 'vit_small_resnet50d_s16_224': _cfg( 92 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), 93 | 'vit_base_resnet26d_224': _cfg( 94 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), 95 | 'vit_base_resnet50d_224': _cfg( 96 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), 97 | } 98 | 99 | 100 | class HybridEmbed(nn.Module): 101 | """ CNN Feature Map Embedding 102 | Extract feature map from CNN, flatten, project to embedding dim. 103 | """ 104 | def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768): 105 | super().__init__() 106 | assert isinstance(backbone, nn.Module) 107 | img_size = to_2tuple(img_size) 108 | patch_size = to_2tuple(patch_size) 109 | self.img_size = img_size 110 | self.patch_size = patch_size 111 | self.backbone = backbone 112 | if feature_size is None: 113 | with torch.no_grad(): 114 | # NOTE Most reliable way of determining output dims is to run forward pass 115 | training = backbone.training 116 | if training: 117 | backbone.eval() 118 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) 119 | if isinstance(o, (list, tuple)): 120 | o = o[-1] # last feature if backbone outputs list/tuple of features 121 | feature_size = o.shape[-2:] 122 | feature_dim = o.shape[1] 123 | backbone.train(training) 124 | else: 125 | feature_size = to_2tuple(feature_size) 126 | if hasattr(self.backbone, 'feature_info'): 127 | feature_dim = self.backbone.feature_info.channels()[-1] 128 | else: 129 | feature_dim = self.backbone.num_features 130 | assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 131 | self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) 132 | self.num_patches = self.grid_size[0] * self.grid_size[1] 133 | self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size) 134 | 135 | def forward(self, x): 136 | x = self.backbone(x) 137 | feature = x 138 | if isinstance(x, (list, tuple)): 139 | x = x[-1] # last feature if backbone outputs list/tuple of features 140 | x = self.proj(x).flatten(2).transpose(1, 2) 141 | return x, feature 142 | 143 | 144 | def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): 145 | embed_layer = partial(HybridEmbed, backbone=backbone) 146 | 147 | kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set 148 | # print(kwargs) 149 | return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs) 150 | 151 | 152 | def _resnetv2(layers=(3, 4, 9), **kwargs): 153 | """ ResNet-V2 backbone helper""" 154 | padding_same = kwargs.get('padding_same', True) 155 | stem_type = 'same' if padding_same else '' 156 | conv_layer = partial(StdConv2dSame, eps=1e-8) if padding_same else partial(StdConv2d, eps=1e-8) 157 | if len(layers): 158 | backbone = ResNetV2( 159 | layers=layers, num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3), 160 | preact=False, stem_type=stem_type, conv_layer=conv_layer) 161 | else: 162 | backbone = create_resnetv2_stem( 163 | kwargs.get('in_chans', 3), stem_type=stem_type, preact=False, conv_layer=conv_layer) 164 | return backbone 165 | 166 | 167 | @register_model 168 | def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs): 169 | """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224. 170 | """ 171 | backbone = _resnetv2(layers=(), **kwargs) 172 | model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) 173 | model = _create_vision_transformer_hybrid( 174 | 'vit_tiny_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs) 175 | return model 176 | 177 | 178 | @register_model 179 | def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs): 180 | """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384. 181 | """ 182 | backbone = _resnetv2(layers=(), **kwargs) 183 | model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) 184 | model = _create_vision_transformer_hybrid( 185 | 'vit_tiny_r_s16_p8_384', backbone=backbone, pretrained=pretrained, **model_kwargs) 186 | return model 187 | 188 | 189 | @register_model 190 | def vit_small_r26_s32_224(pretrained=False, **kwargs): 191 | """ R26+ViT-S/S32 hybrid. 192 | """ 193 | backbone = _resnetv2((2, 2, 2, 2), **kwargs) 194 | model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) 195 | model = _create_vision_transformer_hybrid( 196 | 'vit_small_r26_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs) 197 | return model 198 | 199 | 200 | @register_model 201 | def vit_small_r26_s32_384(pretrained=False, **kwargs): 202 | """ R26+ViT-S/S32 hybrid. 203 | """ 204 | backbone = _resnetv2((2, 2, 2, 2), **kwargs) 205 | model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) 206 | model = _create_vision_transformer_hybrid( 207 | 'vit_small_r26_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs) 208 | return model 209 | 210 | 211 | @register_model 212 | def vit_base_r26_s32_224(pretrained=False, **kwargs): 213 | """ R26+ViT-B/S32 hybrid. 214 | """ 215 | backbone = _resnetv2((2, 2, 2, 2), **kwargs) 216 | model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) 217 | model = _create_vision_transformer_hybrid( 218 | 'vit_base_r26_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs) 219 | return model 220 | 221 | 222 | @register_model 223 | def vit_base_r50_s16_224(pretrained=False, **kwargs): 224 | """ R50+ViT-B/S16 hybrid from original paper (https://arxiv.org/abs/2010.11929). 225 | """ 226 | backbone = _resnetv2((3, 4, 9), **kwargs) 227 | model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) 228 | model = _create_vision_transformer_hybrid( 229 | 'vit_base_r50_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs) 230 | return model 231 | 232 | 233 | @register_model 234 | def vit_base_r50_s16_384(pretrained=False, **kwargs): 235 | """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929). 236 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 237 | """ 238 | backbone = _resnetv2((3, 4, 9), **kwargs) 239 | model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) 240 | model = _create_vision_transformer_hybrid( 241 | 'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs) 242 | return model 243 | 244 | 245 | @register_model 246 | def vit_base_resnet50_384(pretrained=False, **kwargs): 247 | # DEPRECATED this is forwarding to model def above for backwards compatibility 248 | return vit_base_r50_s16_384(pretrained=pretrained, **kwargs) 249 | 250 | 251 | @register_model 252 | def vit_large_r50_s32_224(pretrained=False, **kwargs): 253 | """ R50+ViT-L/S32 hybrid. 254 | """ 255 | backbone = _resnetv2((3, 4, 6, 3), **kwargs) 256 | model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs) 257 | model = _create_vision_transformer_hybrid( 258 | 'vit_large_r50_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs) 259 | return model 260 | 261 | 262 | @register_model 263 | def vit_large_r50_s32_384(pretrained=False, **kwargs): 264 | """ R50+ViT-L/S32 hybrid. 265 | """ 266 | backbone = _resnetv2((3, 4, 6, 3), **kwargs) 267 | model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs) 268 | model = _create_vision_transformer_hybrid( 269 | 'vit_large_r50_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs) 270 | return model 271 | 272 | 273 | @register_model 274 | def vit_tiny_r_s16_p8_224_in21k(pretrained=False, **kwargs): 275 | """ R+ViT-Ti/S16 w/ 8x8 patch hybrid. ImageNet-21k. 276 | """ 277 | backbone = _resnetv2(layers=(), **kwargs) 278 | model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) 279 | model = _create_vision_transformer_hybrid( 280 | 'vit_tiny_r_s16_p8_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) 281 | return model 282 | 283 | 284 | @register_model 285 | def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs): 286 | """ R26+ViT-S/S32 hybrid. ImageNet-21k. 287 | """ 288 | backbone = _resnetv2((2, 2, 2, 2), **kwargs) 289 | model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) 290 | model = _create_vision_transformer_hybrid( 291 | 'vit_small_r26_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) 292 | return model 293 | 294 | 295 | @register_model 296 | def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): 297 | """ R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929). 298 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 299 | """ 300 | backbone = _resnetv2(layers=(3, 4, 9), **kwargs) 301 | model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) 302 | model = _create_vision_transformer_hybrid( 303 | 'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) 304 | return model 305 | 306 | 307 | @register_model 308 | def vit_base_resnet50_224_in21k(pretrained=False, **kwargs): 309 | # DEPRECATED this is forwarding to model def above for backwards compatibility 310 | return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs) 311 | 312 | 313 | @register_model 314 | def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs): 315 | """ R50+ViT-L/S32 hybrid. ImageNet-21k. 316 | """ 317 | backbone = _resnetv2((3, 4, 6, 3), **kwargs) 318 | model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs) 319 | model = _create_vision_transformer_hybrid( 320 | 'vit_large_r50_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) 321 | return model 322 | 323 | 324 | @register_model 325 | def vit_small_resnet26d_224(pretrained=False, **kwargs): 326 | """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights. 327 | """ 328 | backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) 329 | model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, **kwargs) 330 | model = _create_vision_transformer_hybrid( 331 | 'vit_small_resnet26d_224', backbone=backbone, pretrained=pretrained, **model_kwargs) 332 | return model 333 | 334 | 335 | @register_model 336 | def vit_small_resnet50d_s16_224(pretrained=False, **kwargs): 337 | """ Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights. 338 | """ 339 | backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3]) 340 | model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, **kwargs) 341 | model = _create_vision_transformer_hybrid( 342 | 'vit_small_resnet50d_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs) 343 | return model 344 | 345 | 346 | @register_model 347 | def vit_base_resnet26d_224(pretrained=False, **kwargs): 348 | """ Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights. 349 | """ 350 | backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) 351 | model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) 352 | model = _create_vision_transformer_hybrid( 353 | 'vit_base_resnet26d_224', backbone=backbone, pretrained=pretrained, **model_kwargs) 354 | return model 355 | 356 | 357 | @register_model 358 | def vit_base_resnet50d_224(pretrained=False, **kwargs): 359 | """ Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights. 360 | """ 361 | backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) 362 | model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) 363 | model = _create_vision_transformer_hybrid( 364 | 'vit_base_resnet50d_224', backbone=backbone, pretrained=pretrained, **model_kwargs) 365 | return model 366 | -------------------------------------------------------------------------------- /Multi_HBP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from vision_transformer_hybrid import _create_vision_transformer_hybrid 4 | from torch import nn 5 | from torch.nn import init, functional 6 | # from utils import get_yaml_value 7 | from resnetv2 import ResNetV2 8 | from timm.models.layers import StdConv2dSame, StdConv2d, to_2tuple 9 | from vision_transformer import VisionTransformer, checkpoint_filter_fn, _create_vision_transformer, Block 10 | # from timm.models.vision_transformer_hybrid import _create_vision_transformer_hybrid 11 | from functools import partial 12 | from einops import rearrange 13 | from activation import GeM 14 | 15 | 16 | 17 | def weights_init_kaiming(m): 18 | classname = m.__class__.__name__ 19 | # print(classname) 20 | if classname.find('Conv') != -1: 21 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') # For old pytorch, you may use kaiming_normal. 22 | elif classname.find('Linear') != -1: 23 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 24 | init.constant_(m.bias.data, 0.0) 25 | elif classname.find('BatchNorm1d') != -1: 26 | init.normal_(m.weight.data, 1.0, 0.02) 27 | init.constant_(m.bias.data, 0.0) 28 | 29 | 30 | def weights_init_classifier(m): 31 | classname = m.__class__.__name__ 32 | if classname.find('Linear') != -1: 33 | init.normal_(m.weight.data, std=0.001) 34 | init.constant_(m.bias.data, 0.0) 35 | 36 | 37 | class ClassBlock(nn.Module): 38 | 39 | def __init__(self, input_dim, class_num, drop_rate, num_bottleneck=512): 40 | super(ClassBlock, self).__init__() 41 | add_block = [] 42 | add_block += [ 43 | nn.Linear(input_dim, num_bottleneck), 44 | nn.GELU(), 45 | nn.BatchNorm1d(num_bottleneck), 46 | nn.Dropout(p=drop_rate) 47 | ] 48 | 49 | add_block = nn.Sequential(*add_block) 50 | add_block.apply(weights_init_kaiming) 51 | 52 | classifier = [] 53 | classifier += [nn.Linear(num_bottleneck, class_num)] 54 | classifier = nn.Sequential(*classifier) 55 | classifier.apply(weights_init_classifier) 56 | 57 | self.add_block = add_block 58 | self.classifier = classifier 59 | 60 | def forward(self, x): 61 | x = self.add_block(x) 62 | feature = x 63 | x = self.classifier(x) 64 | return x, feature 65 | 66 | 67 | class Hybird_ViT(nn.Module): 68 | def __init__(self, classes, drop_rate, block, share_weight=True): 69 | super(Hybird_ViT, self).__init__() 70 | self.block = block 71 | conv_layer = partial(StdConv2dSame, eps=1e-8) 72 | backbone = ResNetV2( 73 | layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=3, 74 | preact=False, stem_type="same", conv_layer=conv_layer, act_layer=nn.ReLU) 75 | model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, num_classes=0) 76 | model = _create_vision_transformer_hybrid( 77 | 'vit_base_r50_s16_384', backbone=backbone, pretrained=True, **model_kwargs) 78 | self.model_1 = model 79 | if share_weight: 80 | self.model_2 = self.model_1 81 | # else: 82 | # self.model_2 = hybrid_model(layers=(3, 4, 9), img_size=24, patch_size=1, num_classes=1000, depth=12) 83 | self.classifier_hbp = ClassBlock(2048*3, classes, drop_rate) 84 | self.classifier_multi = ClassBlock(768*2, classes, drop_rate) 85 | self.classifier = ClassBlock(768, classes, drop_rate) 86 | 87 | self.proj = nn.Conv2d(768, 1024, kernel_size=1, stride=1) 88 | self.bilinear_proj = torch.nn.Sequential(torch.nn.Conv2d(1024, 2048, kernel_size=1, bias=False), 89 | torch.nn.BatchNorm2d(2048), 90 | torch.nn.ReLU()) 91 | 92 | self.bilinear_proj_lpn = torch.nn.Sequential(torch.nn.Conv2d(1024, 2048, kernel_size=1, bias=False), 93 | torch.nn.BatchNorm2d(2048), 94 | torch.nn.ReLU()) 95 | self.Vit_block = Block(dim=768, num_heads=12, mlp_ratio=4.0, qkv_bias=True, init_values=None, 96 | drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=partial(nn.LayerNorm, eps=1e-6), 97 | act_layer=nn.GELU) 98 | self.gem = GeM(1024) 99 | for m in self.bilinear_proj.modules(): 100 | if isinstance(m, torch.nn.Conv2d): 101 | torch.nn.init.xavier_normal_(m.weight) 102 | if m.bias is not None: 103 | torch.nn.init.constant_(m.bias, 0) 104 | elif isinstance(m, torch.nn.BatchNorm2d): 105 | torch.nn.init.constant_(m.weight, 1) 106 | torch.nn.init.constant_(m.bias, 0) 107 | elif isinstance(m, torch.nn.Linear): 108 | torch.nn.init.xavier_normal_(m.weight) 109 | torch.nn.init.constant_(m.bias, 0) 110 | 111 | 112 | LPN = 1 113 | if LPN: 114 | for i in range(self.block): 115 | # before lpn 116 | # name = 'classifier' + str(i + 1) 117 | # after lpn 118 | name = 'classifier' + str(i) 119 | setattr(self, name, ClassBlock(1024, classes, drop_rate)) 120 | # print(name) 121 | 122 | def hbp(self, conv1, conv2): 123 | N = conv1.size()[0] 124 | proj_1 = self.bilinear_proj(conv1) 125 | proj_2 = self.bilinear_proj(conv2) 126 | 127 | X = proj_1 * proj_2 128 | # print(X.shape) 129 | X = torch.sum(X.view(X.size()[0], X.size()[1], -1), dim=2) 130 | # print(X.shape) 131 | X = X.view(N, 2048) 132 | X = torch.sqrt(X + 1e-5) 133 | X = torch.nn.functional.normalize(X) 134 | return X 135 | 136 | def restore_vit_feature(self, x): 137 | x = x[:, 1:, :] 138 | x = rearrange(x, "b (h w) y -> b y h w", h=24, w=24) 139 | x = self.proj(x) 140 | return x 141 | 142 | def fusion_features(self, x, t, model): 143 | 144 | y = [] 145 | # with torch.no_grad(): 146 | x, p_f, v_f, l_f = model(x, t) 147 | 148 | l_f = self.Vit_block(l_f) 149 | 150 | # direct softmax 151 | y0, f = self.classifier(x) 152 | 153 | # multi modal softmax 154 | v_f = self.restore_vit_feature(v_f) 155 | l_f = self.restore_vit_feature(l_f) 156 | 157 | # HBP softmax 3 layer feature X multiply 158 | x1 = self.hbp(p_f, v_f) 159 | x2 = self.hbp(p_f, l_f) 160 | x3 = self.hbp(v_f, l_f) 161 | x = torch.concat([x1, x2, x3], dim=1) 162 | 163 | y2, hbp_f = self.classifier_hbp(x) 164 | 165 | result = self.get_part_pool(v_f) 166 | 167 | if self.training: 168 | y3, lpn_f = self.part_classifier(result) 169 | else: 170 | lpn_f = self.part_classifier(result) 171 | y3 = [None, None] 172 | 173 | y.append(y0) 174 | # y.append(y1) 175 | y.append(y2) 176 | y.append(y3[0]) 177 | y.append(y3[1]) 178 | # y.append(y4) 179 | if self.training: 180 | f_all = torch.concat([f, hbp_f, lpn_f], dim=1) 181 | else: 182 | f = f.view(f.size()[0], f.size()[1], 1) 183 | hbp_f = hbp_f.view(hbp_f.size()[0], hbp_f.size()[1], 1) 184 | f_all = torch.concat([f, hbp_f, lpn_f], dim=2) 185 | return y, f_all 186 | 187 | def forward(self, x1, x2, t1, t2): 188 | 189 | if x1 is None: 190 | y1 = None 191 | f1 = None 192 | t1 = None 193 | output1 = None 194 | else: 195 | y1, f1 = self.fusion_features(x1, t1, self.model_1) 196 | 197 | if x2 is None: 198 | y2 = None 199 | f2 = None 200 | t2 = None 201 | output2 = None 202 | else: 203 | y2, f2 = self.fusion_features(x2, t2, self.model_2) 204 | 205 | if self.training: 206 | return y1, y2, f1, f2 207 | # output1, output2 208 | else: 209 | # print("ff12", f2.shape) 210 | return f1, f2 211 | 212 | def get_part_pool(self, x, pool='max', no_overlap=True): 213 | result = [] 214 | if pool == 'avg': 215 | pooling = torch.nn.AdaptiveAvgPool2d((1, 1)) 216 | elif pool == 'max': 217 | pooling = torch.nn.AdaptiveMaxPool2d((1, 1)) 218 | H, W = x.size(2), x.size(3) 219 | c_h, c_w = int(H / 2), int(W / 2) 220 | per_h, per_w = H / (2 * self.block), W / (2 * self.block) 221 | if per_h < 1 and per_w < 1: 222 | new_H, new_W = H + (self.block - c_h) * 2, W + (self.block - c_w) * 2 223 | x = nn.functional.interpolate(x, size=[new_H, new_W], mode='bilinear', align_corners=True) 224 | H, W = x.size(2), x.size(3) 225 | c_h, c_w = int(H / 2), int(W / 2) 226 | per_h, per_w = H / (2 * self.block), W / (2 * self.block) 227 | per_h, per_w = math.floor(per_h), math.floor(per_w) # 向下取整 228 | for i in range(self.block): 229 | i = i + 1 230 | if i < self.block: 231 | # print("x", x.shape) 232 | x_curr = x[:, :, (c_h - i * per_h):(c_h + i * per_h), (c_w - i * per_w):(c_w + i * per_w)] 233 | # print("x_curr", x_curr.shape) 234 | if no_overlap and i > 1: 235 | x_pre = x[:, :, (c_h - (i - 1) * per_h):(c_h + (i - 1) * per_h), 236 | (c_w - (i - 1) * per_w):(c_w + (i - 1) * per_w)] 237 | x_pad = functional.pad(x_pre, (per_h, per_h, per_w, per_w), "constant", 0) 238 | x_curr = x_curr - x_pad 239 | # print("x_curr", x_curr.shape) 240 | avgpool = pooling(x_curr) 241 | # print("pool", avgpool.shape) 242 | result.append(avgpool) 243 | # print(x_curr.shape) 244 | else: 245 | if no_overlap and i > 1: 246 | x_pre = x[:, :, (c_h - (i - 1) * per_h):(c_h + (i - 1) * per_h), 247 | (c_w - (i - 1) * per_w):(c_w + (i - 1) * per_w)] 248 | pad_h = c_h - (i - 1) * per_h 249 | pad_w = c_w - (i - 1) * per_w 250 | # x_pad = F.pad(x_pre,(pad_h,pad_h,pad_w,pad_w),"constant",0) 251 | if x_pre.size(2) + 2 * pad_h == H: 252 | x_pad = functional.pad(x_pre, (pad_h, pad_h, pad_w, pad_w), "constant", 0) 253 | else: 254 | ep = H - (x_pre.size(2) + 2 * pad_h) 255 | x_pad = functional.pad(x_pre, (pad_h + ep, pad_h, pad_w + ep, pad_w), "constant", 0) 256 | x = x - x_pad 257 | avgpool = pooling(x) 258 | result.append(avgpool) 259 | # print(x.shape) 260 | return torch.concat(result, dim=2) 261 | 262 | def part_classifier(self, x): 263 | part = {} 264 | predict = {} 265 | features = [] 266 | for i in range(self.block): 267 | part[i] = x[:, :, i].view(x.size(0), -1) 268 | 269 | name = 'classifier' + str(i) 270 | c = getattr(self, name) 271 | # print(c) 272 | predict[i], feature = c(part[i]) 273 | features.append(feature) 274 | 275 | # print(predict[i][0].shape) 276 | # print(predict) 277 | y = [] 278 | for i in range(self.block): 279 | y.append(predict[i]) 280 | if not self.training: 281 | return torch.stack(y, dim=2) 282 | return y, torch.concat(features, dim=1) 283 | 284 | 285 | 286 | 287 | if __name__ == '__main__': 288 | # create_model() 289 | model = Hybird_ViT(classes=701, drop_rate=0.3).cuda() 290 | 291 | feature = torch.randn(8, 3, 384, 384).cuda() 292 | text = torch.rand(8, 1, 768).cuda() 293 | output = model(feature, feature, text, text) 294 | print(output) 295 | -------------------------------------------------------------------------------- /Preprocessing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from utils import get_yaml_value, parameter, create_dir, save_feature_network 4 | from torchvision import datasets, transforms 5 | from Create_MultiModal_Dataset import Multimodel_Dateset 6 | 7 | 8 | def create_U1652_dataloader(data_dir, batch_size, image_size): 9 | transform_train_list = [ 10 | # transforms.RandomResizedCrop(size=(opt.h, opt.w), scale=(0.75,1.0), ratio=(0.75,1.3333), interpolation=3), #Image.BICUBIC) 11 | transforms.Resize((image_size, image_size), interpolation=3), 12 | transforms.Pad(10, padding_mode='edge'), 13 | transforms.RandomCrop((image_size, image_size)), 14 | transforms.RandomPerspective(), 15 | # transforms.RandomHorizontalFlip(), 16 | transforms.ToTensor(), 17 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 18 | ] 19 | 20 | transform_satellite_list = [ 21 | transforms.Resize((image_size, image_size), interpolation=3), 22 | transforms.Pad(10, padding_mode='edge'), 23 | transforms.RandomAffine(90), 24 | transforms.RandomCrop((image_size, image_size)), 25 | # transforms.RandomPerspective(), 26 | transforms.RandomHorizontalFlip(), 27 | transforms.ToTensor(), 28 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 29 | ] 30 | 31 | data_transforms = { 32 | 'train': transforms.Compose(transform_train_list), 33 | 'satellite': transforms.Compose(transform_satellite_list)} 34 | 35 | image_datasets = {} 36 | image_datasets['satellite'] = Multimodel_Dateset(os.path.join(data_dir, 'train', 'satellite'), 37 | data_transforms['satellite']) 38 | image_datasets['drone'] = Multimodel_Dateset(os.path.join(data_dir, 'train', 'drone'), 39 | data_transforms['train']) 40 | dataloaders = {} 41 | dataloaders['satellite'] = torch.utils.data.DataLoader(image_datasets['satellite'], batch_size=batch_size, 42 | shuffle=True) 43 | 44 | dataloaders['drone'] = torch.utils.data.DataLoader(image_datasets['drone'], batch_size=batch_size, 45 | shuffle=True) 46 | return dataloaders, image_datasets 47 | 48 | 49 | if __name__ == "__main__": 50 | # Cross_Dataset("../Datasets/SUES-200/Training/150", 224) 51 | dataloaders, image_datasets = create_U1652_dataloader() 52 | print(image_datasets['drone'].classes) 53 | for img, text, label in dataloaders['drone']: 54 | print(text) 55 | print(img, label) 56 | break 57 | # U1652_path = "/media/data1/University-Release/University-Release/train" 58 | # Cross_Dataset_1652(U1652_path, 224) 59 | 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UAV’s Status Is Worth Considering: A Fusion Representations Matching Method for Geo-Localization 2 | Paper Link : https://doi.org/10.3390/s23020720 (Open Access) 3 | 4 | ## Experiment Result 5 | 6 | | Method | Image_Size | **Drone** → Satellite | **Drone** → Satellite | Satellite → Drone | Satellite → Drone | 7 | | ------------- | ---------- | --------------------- | --------------------- | ----------------- | ----------------- | 8 | | | | Recall@1 | AP | Recall@1 | AP | 9 | | Baseline | 384*384 | 62.99 | 67.69 | 75.75 | 62.09 | 10 | | LCM | 384*384 | 66.65 | 70.82 | 79.89 | 65.38 | 11 | | LPN | 384*384 | 78.02 | 80.99 | 86.16 | 76.56 | 12 | | LDRVSD | 384*384 | 81.02 | 83.51 | 89.87 | 79.80 | 13 | | SGM | 256*256 | 82.14 | 84.72 | 88.16 | 81.81 | 14 | | PCL | 512*512 | 83.27 | 87.32 | 91.78 | 82.18 | 15 | | FSRA | 384*384 | 85.50 | 87.53 | 89.73 | 84.94 | 16 | | MSBA | 384*384 | 86.61 | 88.55 | 92.15 | 84.54 | 17 | | **MBF(ours)** | 384*384 | 89.05 | 90.61 | 93.15 | 88.17 | 18 | 19 | ## Quick Start 20 | ### Installation 21 | Install Pytorch and Torchvision https://pytorch.org/get-started/locally/ 22 | 23 | install other libs (timm should be 0.6.7, not latest) 24 | ```shell 25 | pip install timm==0.6.7 pyyaml pytorch-metric-learning scipy pandas grad-cam pillow pytorch_pretrained_bert 26 | ``` 27 | 28 | ### Generate word embeddings for University-1652 29 | University-1652 Dataset Link https://github.com/layumi/University1652-Baseline 30 | 31 | 32 | set correct dataset path in settings.yaml, then run 33 | ```shell 34 | python U1652_bert.py 35 | ``` 36 | 37 | ### Generate word embeddings for SUES-200 38 | SUES-200 Dataset Link https://github.com/Reza-Zhu/SUES-200-Benchmark 39 | Download SUES-200 Dataset and split dataset, set correct dataset path in settings.yaml, then run 40 | ```shell 41 | python SUES_bert.py 42 | ``` 43 | 44 | ### Dataset files form 45 | University-1652 dir tree: 46 | ```text 47 | 48 | ├── University-1652/ 49 | │ ├── readme.txt 50 | │ ├── train/ 51 | │ ├── drone/ /* drone-view training images 52 | │ ├── 0001 53 | | ├── 0002 54 | | ... 55 | │ ├── street/ /* street-view training images 56 | │ ├── satellite/ /* satellite-view training images 57 | │ ├── google/ /* noisy street-view training images (collected from Google Image) 58 | │ ├── text_drone/ /* word embeddings 59 | | ├── image-01.pth 60 | | ├── image-02.pth 61 | | ... 62 | │ ├── text_satellite/ 63 | | ├── satellite.pth 64 | │ ├── test/ 65 | │ ├── query_drone/ 66 | │ ├── gallery_drone/ 67 | │ ├── query_street/ 68 | │ ├── gallery_street/ 69 | │ ├── query_satellite/ 70 | │ ├── gallery_satellite/ 71 | │ ├── 4K_drone/ 72 | │ ├── text_drone/ /* word embeddings 73 | | ├── image-01.pth 74 | | ├── image-02.pth 75 | | ... 76 | │ ├── text_satellite/ 77 | | ├── satellite.pth 78 | ``` 79 | SUES-200 dir tree: 80 | ```text 81 | ├── SUES-200/ 82 | │ ├── Training/ 83 | │ ├── 150 84 | │ ├── drone/ /* drone-view training images 85 | │ ├── 0001 /* drone-view image of the first site: 50 images 86 | │ ├── 0.jpg 87 | │ ├── 1.jpg 88 | │ ... 89 | │ ├── 49.jpg 90 | │ ├── 0002 /* drone-view image of the second site: 50 images 91 | │ ... 92 | │ ├── satellite/ /* satellite-view training images 93 | │ ├── 0001 /* satellite-view image of the first site: 1 image 94 | │ ├── 0.png 95 | │ ├── 0002 /* satellite-view image of the second site: 1 image 96 | │ ... 97 | │ ├── text_drone 98 | │ ├── drone.pth /* word embeddings 99 | │ ├── text_satellite 100 | │ ├── satellite.pth /* word embeddings 101 | │ ├── 200 102 | │ ├── 250 103 | │ ├── 300 104 | │ ├── Testing/ 105 | │ ├── 150 106 | │ ├── query_drone/ /* drone-view query images 107 | │ ├── 0008 108 | │ ... 109 | │ ├── gallery_drone/ /* drone-view gallery images 110 | │ ├── 0001 111 | │ ... 112 | │ ├── 0200 113 | │ ├── query_satellite/ /* satellite-view query images 114 | │ ├── gallery_satellite/ /* satellite-view gallery images 115 | │ ├── text_drone 116 | │ ├── drone.pth 117 | │ ├── text_satellite 118 | │ ├── satellite.pth 119 | │ ├── 200 120 | │ ├── 250 121 | │ ├── 300 122 | 123 | ``` 124 | 125 | ### Train for University-1652 126 | ```shell 127 | python train.py --cfg "settings.yaml" 128 | ``` 129 | Config file (settings.yaml) sets parameter and path 130 | ```yaml 131 | # dateset path 132 | dataset_path: /home/sues/media/disk1/University-Release-MultiModel/University-Release 133 | weight_save_path: /home/sues/save_model_weight 134 | 135 | # apply LPN and set block number 136 | LPN : 1 137 | block : 2 138 | 139 | # super parameters 140 | batch_size : 16 141 | num_epochs : 80 142 | drop_rate : 0.35 143 | weight_decay : 0.0001 144 | lr : 0.01 145 | 146 | #intial parameters 147 | image_size: 384 148 | fp16 : 1 149 | classes : 701 150 | 151 | model : MBF 152 | name: MBF_1652_2022-11-15-18:56:39 153 | ``` 154 | ### Train for SUES-200 155 | ```shell 156 | python train.py --cfg "settings.yaml" 157 | ``` 158 | Config file (settings.yaml) sets parameter and path 159 | ```yaml 160 | 161 | # dateset path 162 | dataset_path: /home/LVM_date/zhurz/dataset/SUES-200-512x512 163 | weight_save_path: /home/LVM_date/zhurz/dataset/save_model_weight 164 | 165 | # apply LPN and set block number 166 | LPN : 1 167 | block : 2 168 | 169 | # super parameters 170 | batch_size : 8 171 | num_epochs : 40 172 | drop_rate : 0.35 173 | weight_decay : 0.0001 174 | lr : 0.01 175 | 176 | #intial parameters 177 | height : 150 178 | query : drone 179 | image_size: 384 180 | fp16 : 0 181 | classes : 120 182 | 183 | model : MBF 184 | name: MBF 185 | ``` 186 | 187 | 188 | ### Test and evaluate (University-1652 Dataset) 189 | ```shell 190 | python U1652_test_and_evaluate.py --cfg "settings.yaml" --name "your_weight_dirname_1652_2022-11-16-15:14:14" --seq 1 191 | ``` 192 | 193 | ### Test and evaluate (SUES-200 Dataset) 194 | ```shell 195 | python test_and_evaluate.py --cfg "settings.yaml" --name "your_weight_dirname_1652_2022-11-16-15:14:14" --seq 1 196 | ``` 197 | 198 | 199 | ### Multiply Queries (University-1652 Dataset) 200 | ```shell 201 | python multi_test_and_evaluate.py --cfg "settings.yaml" --multi 1 --weight "your_weight_path.pth" --csv_save_path "./result" 202 | 203 | ``` 204 | 205 | ### Shifted Query (University-1652 Dataset) 206 | ```shell 207 | python Shifted_test_and_evaluate.py --cfg "settings.yaml" --query "drone" --weight "your_weight_path.pth" --csv_save_path "./result" --gap 10 208 | ``` 209 | 210 | ### Best Weights 211 | Please check the Release page 212 | Best weights for University-1652 Dataset have been uploaded 213 | 214 | Any questions or suggestions feel free to contact me 215 | email : rzzhu24@m.fudan.edu.cn 216 | 217 | ## Relevant research 218 | 219 | SUES-200 https://github.com/Reza-Zhu/SUES-200-Benchmark 220 | 221 | University-1652 https://github.com/layumi/University1652-Baseline 222 | 223 | LPN https://github.com/wtyhub/LPN 224 | 225 | FRSA https://github.com/dmmm1997/fsra 226 | 227 | ## Citation 228 | 229 | ```text 230 | @Article{uav2023zhu, 231 | AUTHOR = {Zhu, Runzhe and Yang, Mingze and Yin, Ling and Wu, Fei and Yang, Yuncheng}, 232 | TITLE = {UAV’s Status Is Worth Considering: A Fusion Representations Matching Method for Geo-Localization}, 233 | JOURNAL = {Sensors}, 234 | VOLUME = {23}, 235 | YEAR = {2023}, 236 | NUMBER = {2}, 237 | ARTICLE-NUMBER = {720}, 238 | URL = {https://www.mdpi.com/1424-8220/23/2/720}, 239 | PubMedID = {36679517}, 240 | ISSN = {1424-8220}, 241 | DOI = {10.3390/s23020720} 242 | ``` 243 | } 244 | -------------------------------------------------------------------------------- /Shifited_test_and_evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import glob 3 | import os 4 | import time 5 | import model_ 6 | import torch 7 | import scipy.io 8 | import shutil 9 | import argparse 10 | import numpy as np 11 | import pandas as pd 12 | from torch import nn 13 | from utils import fliplr, load_network, which_view, get_id, get_yaml_value 14 | from Create_MultiModal_Dataset import Multimodel_Dateset_flip 15 | from U1652_test_and_evaluate import evaluate 16 | from torchvision import datasets, models, transforms 17 | 18 | from Multi_HBP import Hybird_ViT 19 | if torch.cuda.is_available(): 20 | device = torch.device("cuda:0") 21 | 22 | def extract_feature(model, dataloaders, view_index=1): 23 | features = torch.FloatTensor() 24 | count = 0 25 | for data in dataloaders: 26 | img, text, label = data 27 | n, c, h, w = img.size() 28 | count += n 29 | text = text.to(device) 30 | ff = torch.FloatTensor(n, 512, 4).zero_().cuda() 31 | 32 | # why for in range(2): 33 | # 1. for flip img 34 | # 2. for normal img 35 | 36 | for i in range(2): 37 | if i == 1: 38 | img = fliplr(img) 39 | 40 | input_img = img.to(device) 41 | outputs = None 42 | if view_index == 1: 43 | outputs, _ = model(input_img, None, text, None) 44 | elif view_index == 2: 45 | _, outputs = model(None, input_img, None, text) 46 | # print(outputs.shape) 47 | # print(ff.shape) 48 | ff += outputs 49 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * np.sqrt(4) 50 | # print("fnorm", fnorm.shape) 51 | ff = ff.div(fnorm.expand_as(ff)) 52 | # print("ff", ff.shape) 53 | ff = ff.view(ff.size(0), -1) 54 | 55 | features = torch.cat((features, ff.data.cpu()), 0) # 在维度0上拼接 56 | return features 57 | 58 | ############################### main function ####################################### 59 | def eval_and_test(query_name, config_file, net_path, save_path, gap): 60 | 61 | param = get_yaml_value(config_file) 62 | data_path = param["dataset_path"] 63 | data_transforms = transforms.Compose([ 64 | transforms.Resize((384, 384), interpolation=3), 65 | transforms.ToTensor(), 66 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 67 | ]) 68 | 69 | table_path = os.path.join(save_path, param["model"] + "_" + str(1652) + "_" + "shifted_query_" + 70 | ".csv") 71 | 72 | evaluate_csv = pd.DataFrame(index=["recall@1", "recall@5", "recall@10", "recall@1p", "AP", "time"]) 73 | 74 | image_datasets = {x: Multimodel_Dateset_flip(os.path.join(data_path, 'test', x), data_transforms, gap) for x in 75 | ['gallery_satellite', 'gallery_drone', 'query_satellite', 'query_drone']} 76 | data_loader = {x: torch.utils.data.DataLoader(image_datasets[x], 77 | batch_size=param["batch_size"], 78 | # batch_size=16, 79 | shuffle=False) for x in 80 | ['gallery_satellite', 'gallery_drone', 'query_satellite', 'query_drone']} 81 | 82 | model = Hybird_ViT(701, 0.1) 83 | model.load_state_dict(torch.load(net_path)) 84 | for i in range(2): 85 | cls_name = 'classifier' + str(i) 86 | c = getattr(model, cls_name) 87 | c.classifier = nn.Sequential() 88 | 89 | model = model.eval() 90 | model = model.cuda() 91 | 92 | if "drone" in query_name: 93 | gallery_name = "gallery_satellite" 94 | query_name = "query_drone" 95 | else: 96 | gallery_name = "gallery_drone" 97 | query_name = "query_satellite" 98 | 99 | which_query = which_view(query_name) 100 | which_gallery = which_view(gallery_name) 101 | 102 | gallery_path = image_datasets[gallery_name].imgs 103 | query_path = image_datasets[query_name].imgs 104 | 105 | gallery_label, gallery_path = get_id(gallery_path) 106 | query_label, query_path = get_id(query_path) 107 | 108 | with torch.no_grad(): 109 | 110 | query_feature = extract_feature(model, data_loader[query_name], which_query) 111 | gallery_feature = extract_feature(model, data_loader[gallery_name], which_gallery) 112 | 113 | # fed tensor to GPU 114 | query_feature = query_feature.cuda() 115 | gallery_feature = gallery_feature.cuda() 116 | 117 | # CMC = recall 118 | CMC = torch.IntTensor(len(gallery_label)).zero_() 119 | 120 | # ap = average precision 121 | ap = 0.0 122 | 123 | for i in range(len(query_label)): 124 | ap_tmp, CMC_tmp = evaluate(query_feature[i], query_label[i], gallery_feature, gallery_label) 125 | if CMC_tmp[0] == -1: 126 | continue 127 | CMC += CMC_tmp 128 | ap += ap_tmp 129 | 130 | CMC = CMC.float() 131 | CMC = CMC / len(query_label) 132 | # print(len(query_label)) 133 | recall_1 = CMC[0] * 100 134 | recall_5 = CMC[4] * 100 135 | recall_10 = CMC[9] * 100 136 | recall_1p = CMC[round(len(gallery_label) * 0.01)] * 100 137 | AP = ap / len(query_label) * 100 138 | 139 | evaluate_result = 'Recall@1:%.4f Recall@5:%.4f Recall@10:%.4f Recall@top1:%.4f AP:%.4f' % ( 140 | recall_1, recall_5, recall_10, recall_1p, AP) 141 | 142 | evaluate_csv["shifted_query" + "_" + str(gap) + 143 | "_" + str(1652)] = \ 144 | [float(recall_1), float(recall_5), 145 | float(recall_10), float(recall_1p), 146 | float(AP), float(0)] 147 | 148 | print(evaluate_csv) 149 | 150 | evaluate_csv.columns.name = "" 151 | evaluate_csv.index.name = "index" 152 | evaluate_csv = evaluate_csv.T 153 | evaluate_csv.to_csv(table_path) 154 | print(evaluate_result) 155 | 156 | 157 | if __name__ == '__main__': 158 | parser = argparse.ArgumentParser() 159 | 160 | parser.add_argument('--query', type=str, default="drone", help='query set: drone or satellite') 161 | parser.add_argument('--cfg', type=str, default='settings.yaml', help='config file XXX.yaml path') 162 | 163 | parser.add_argument('--weight', type=str, default=None, help='evaluate which weight, path') 164 | parser.add_argument('--csv_save_path', type=str, default="./result", help="evaluation result table store path") 165 | parser.add_argument('--gap', type=int, default=10, help='shifted gap') 166 | opt = parser.parse_known_args()[0] 167 | 168 | eval_and_test(opt.query, opt.cfg, opt.weight, opt.csv_save_path, opt.gap) 169 | -------------------------------------------------------------------------------- /U1652_bert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from utils import create_dir, get_yaml_value 4 | from pytorch_pretrained_bert import BertTokenizer, BertModel 5 | 6 | 7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 8 | 9 | class Word_Embeding: 10 | def __init__(self): 11 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 12 | self.model = BertModel.from_pretrained('bert-base-uncased') 13 | self.model.eval() 14 | 15 | def word_embedding(self, text): 16 | # tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 17 | # text = "After stealing money from the bank vault, the bank robber was seen fishing on the Mississippi river bank." 18 | marked_text = text 19 | tokenized_text = self.tokenizer.tokenize(marked_text).to(device) 20 | 21 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text) 22 | tokens_tensor = torch.tensor([indexed_tokens]).to(device) 23 | segments_ids = [1] * len(tokenized_text) 24 | 25 | segments_tensors = torch.tensor([segments_ids]).to(device) 26 | 27 | with torch.no_grad(): 28 | encoded_layers, _ = self.model(tokens_tensor, segments_tensors) 29 | 30 | sentence_embedding = torch.mean(encoded_layers[11], 1) 31 | 32 | return sentence_embedding 33 | 34 | 35 | param = get_yaml_value("settings.yaml") 36 | train_path = os.path.join(param["dataset_path"], "train") 37 | test_path = os.path.join(param["dataset_path"], "test") 38 | 39 | wd = Word_Embeding() 40 | 41 | # calculate image height from 256m - 121.5m 42 | coff = (256 - 121.5)/53 43 | heights = [256 - coff*i for i in range(1, 54)] 44 | heights.insert(0, 256) 45 | print("image-%02d" % 1) 46 | 47 | 48 | create_dir(os.path.join(train_path, "text_drone")) 49 | create_dir(os.path.join(test_path, "text_drone")) 50 | create_dir(os.path.join(train_path, "text_satellite")) 51 | create_dir(os.path.join(test_path, "text_satellite")) 52 | 53 | 54 | # drone 55 | for i in range(54): 56 | drone = "The altitude of the drone is %d meters" % heights[i] 57 | drone_tensor = wd.word_embedding(drone) 58 | torch.save(drone_tensor, os.path.join(train_path, "text_drone", "image-%02d.pth" % (i + 1))) 59 | torch.save(drone_tensor, os.path.join(test_path, "text_drone", "image-%02d.pth" % (i + 1))) 60 | print(os.path.join(train_path, "text_drone", "image-%02d.pth" % (i + 1))) 61 | 62 | 63 | # satellite 64 | satellite = "The altitude of the satellite is 1000 kilometers" 65 | satellite_tensor = wd.word_embedding(satellite) 66 | torch.save(satellite_tensor, os.path.join(train_path, "text_satellite", "satellite.pth")) 67 | torch.save(satellite_tensor, os.path.join(test_path, "text_satellite", "satellite.pth")) 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /U1652_test_and_evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import glob 3 | import os 4 | import time 5 | import timm 6 | import torch 7 | import shutil 8 | import argparse 9 | import numpy as np 10 | import pandas as pd 11 | from torch import nn 12 | import scipy 13 | from utils import fliplr, load_network, which_view, get_id, get_yaml_value 14 | from torchvision import datasets, models, transforms 15 | from Create_MultiModal_Dataset import Multimodel_Dateset 16 | 17 | if torch.cuda.is_available(): 18 | device = torch.device("cuda:0") 19 | 20 | 21 | def evaluate(qf, ql, gf, gl): 22 | 23 | query = qf.view(-1, 1) 24 | score = torch.mm(gf, query) 25 | score = score.squeeze(1).cpu() 26 | score = score.numpy() 27 | 28 | # predict index 29 | index = np.argsort(score) # from small to large 30 | index = index[::-1] 31 | 32 | # good index 33 | query_index = np.argwhere(gl == ql) 34 | good_index = query_index 35 | junk_index = np.argwhere(gl == -1) 36 | 37 | CMC_tmp = compute_mAP(index, good_index, junk_index) 38 | return CMC_tmp 39 | 40 | 41 | def compute_mAP(index, good_index, junk_index): 42 | 43 | ap = 0 44 | cmc = torch.IntTensor(len(index)).zero_() 45 | # print(cmc.shape) torch.Size([51355]) 46 | if good_index.size == 0: # if empty 47 | cmc[0] = -1 48 | return ap, cmc 49 | 50 | # remove junk_index 51 | mask = np.in1d(index, junk_index, invert=True) 52 | index = index[mask] 53 | 54 | 55 | # find good_index index 56 | ngood = len(good_index) 57 | 58 | mask = np.in1d(index, good_index) 59 | 60 | 61 | rows_good = np.argwhere(mask == True) 62 | 63 | rows_good = rows_good.flatten() 64 | 65 | cmc[rows_good[0]:] = 1 66 | 67 | for i in range(ngood): 68 | d_recall = 1.0 / ngood 69 | # d_racall = 1/54 70 | precision = (i + 1) * 1.0 / (rows_good[i] + 1) 71 | # n/sum 72 | # print("row_good[]", i, rows_good[i]) 73 | # print(precision) 74 | if rows_good[i] != 0: 75 | old_precision = i * 1.0 / rows_good[i] 76 | else: 77 | old_precision = 1.0 78 | ap = ap + d_recall * (old_precision + precision) / 2 79 | 80 | return ap, cmc 81 | 82 | 83 | def extract_feature(model, dataloaders, block, LPN, view_index=1): 84 | features = torch.FloatTensor() 85 | count = 0 86 | for data in dataloaders: 87 | img, text, label = data 88 | n, c, h, w = img.size() 89 | count += n 90 | text = text.to(device) 91 | 92 | if LPN: 93 | ff = torch.FloatTensor(n, 512, block+2).zero_().cuda() 94 | else: 95 | ff = torch.FloatTensor(n, 512).zero_().cuda() 96 | 97 | # why for in range(2): 98 | # 1. for flip img 99 | # 2. for normal img 100 | 101 | for i in range(2): 102 | if i == 1: 103 | img = fliplr(img) 104 | 105 | input_img = img.to(device) 106 | outputs = None 107 | since = time.time() 108 | 109 | if view_index == 1: 110 | outputs, _ = model(input_img, None, text, None) 111 | elif view_index == 2: 112 | _, outputs = model(None, input_img, None, text) 113 | 114 | ff += outputs 115 | 116 | 117 | if LPN: 118 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * np.sqrt(block) 119 | # print("fnorm", fnorm.shape) 120 | ff = ff.div(fnorm.expand_as(ff)) 121 | # print("ff", ff.shape) 122 | ff = ff.view(ff.size(0), -1) 123 | # print("ff", ff.shape) 124 | else: 125 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) 126 | # print("fnorm", fnorm.shape) 127 | ff = ff.div(fnorm.expand_as(ff)) 128 | # print("ff", ff.shape) 129 | 130 | features = torch.cat((features, ff.data.cpu()), 0) # 在维度0上拼接 131 | return features 132 | 133 | 134 | 135 | ############################### main function ####################################### 136 | def eval_and_test(cfg_path, name, seqs): 137 | param_dict = get_yaml_value(cfg_path) 138 | image_size = param_dict['image_size'] 139 | if name == "": 140 | name = param_dict["name"] 141 | data_transforms = transforms.Compose([ 142 | transforms.Resize((image_size, image_size), interpolation=3), 143 | transforms.ToTensor(), 144 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 145 | ]) 146 | block = param_dict["block"] 147 | LPN = param_dict["LPN"] 148 | data_dir = param_dict["dataset_path"] 149 | all_block = block 150 | image_datasets = {x: Multimodel_Dateset(os.path.join(data_dir, 'test', x), data_transforms) for x in 151 | ['gallery_satellite', 'gallery_drone', 'query_satellite', 'query_drone']} 152 | print(len(image_datasets["query_drone"])) 153 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], 154 | # batch_size=get_yaml_value("batch_size"), 155 | batch_size=16, 156 | shuffle=False) for x in 157 | ['gallery_satellite', 'gallery_drone', 'query_satellite', 'query_drone']} 158 | # print("Testing Start >>>>>>>>") 159 | table_path = os.path.join(param_dict["weight_save_path"], 160 | name + ".csv") 161 | save_model_list = glob.glob(os.path.join(param_dict["weight_save_path"], 162 | name, "*.pth")) 163 | # print(get_yaml_value("name")) 164 | if os.path.exists(os.path.join(param_dict["weight_save_path"], 165 | name)) and len(save_model_list) >= 1: 166 | if not os.path.exists(table_path): 167 | evaluate_csv = pd.DataFrame(index=["recall@1", "recall@5", "recall@10", "recall@1p", "AP", "time"]) 168 | else: 169 | evaluate_csv = pd.read_csv(table_path) 170 | evaluate_csv.index = evaluate_csv["index"] 171 | for query in ['drone', 'satellite']: 172 | for seq in range(-seqs, 0): 173 | # net_name = "mae_pretrained" 174 | model, net_name = load_network(seq=seq) 175 | 176 | if LPN: 177 | for i in range(all_block): 178 | cls_name = 'classifier' + str(i) 179 | c = getattr(model, cls_name) 180 | c.classifier = nn.Sequential() 181 | else: 182 | model.classifier.classifier = nn.Sequential() 183 | # print(net_name) 184 | 185 | model = model.eval() 186 | model = model.cuda() 187 | # print(model) 188 | query_name = "" 189 | gallery_name = "" 190 | 191 | if query == "satellite": 192 | query_name = 'query_satellite' 193 | gallery_name = 'gallery_drone' 194 | elif query == "drone": 195 | query_name = 'query_drone' 196 | gallery_name = 'gallery_satellite' 197 | 198 | which_query = which_view(query_name) 199 | which_gallery = which_view(gallery_name) 200 | 201 | print('%s -> %s:' % (query_name, gallery_name)) 202 | 203 | # image_datasets, data_loader = Create_Testing_Datasets(test_data_path=data_path) 204 | 205 | gallery_path = image_datasets[gallery_name].imgs 206 | query_path = image_datasets[query_name].imgs 207 | 208 | gallery_label, gallery_path = get_id(gallery_path) 209 | query_label, query_path = get_id(query_path) 210 | 211 | with torch.no_grad(): 212 | since = time.time() 213 | query_feature = extract_feature(model, dataloaders[query_name], all_block, LPN, which_query) 214 | gallery_feature = extract_feature(model, dataloaders[gallery_name], all_block, LPN, which_gallery) 215 | print(query_feature.shape) 216 | print(gallery_feature.shape) 217 | 218 | time_elapsed = time.time() - since 219 | print('Testing complete in {:.0f}m {:.0f}s'.format( 220 | time_elapsed // 60, time_elapsed % 60)) 221 | 222 | result = {'gallery_f': gallery_feature.numpy(), 'gallery_label': gallery_label, 223 | 'gallery_path': gallery_path, 224 | 'query_f': query_feature.numpy(), 'query_label': query_label, 'query_path': query_path} 225 | 226 | scipy.io.savemat('U1652_pytorch_result.mat', result) 227 | 228 | print(">>>>>>>> Testing END") 229 | 230 | print("Evaluating Start >>>>>>>>") 231 | # 232 | result = scipy.io.loadmat("U1652_pytorch_result.mat") 233 | 234 | # initialize query feature data 235 | query_feature = torch.FloatTensor(result['query_f']) 236 | query_label = result['query_label'][0] 237 | 238 | # initialize all(gallery) feature data 239 | gallery_feature = torch.FloatTensor(result['gallery_f']) 240 | gallery_label = result['gallery_label'][0] 241 | query_feature = query_feature.cuda() 242 | gallery_feature = gallery_feature.cuda() 243 | query_label = np.array(query_label) 244 | gallery_label = np.array(gallery_label) 245 | 246 | # fed tensor to GPU 247 | query_feature = query_feature.cuda() 248 | gallery_feature = gallery_feature.cuda() 249 | 250 | # CMC = recall 251 | CMC = torch.IntTensor(len(gallery_label)).zero_() 252 | # ap = average precision 253 | ap = 0.0 254 | 255 | for i in range(len(query_label)): 256 | ap_tmp, CMC_tmp = evaluate(query_feature[i], query_label[i], gallery_feature, gallery_label) 257 | if CMC_tmp[0] == -1: 258 | continue 259 | CMC += CMC_tmp 260 | ap += ap_tmp 261 | 262 | # average CMC 263 | 264 | CMC = CMC.float() 265 | CMC = CMC / len(query_label) 266 | # print(len(query_label)) 267 | recall_1 = CMC[0] * 100 268 | recall_5 = CMC[4] * 100 269 | recall_10 = CMC[9] * 100 270 | recall_1p = CMC[round(len(gallery_label) * 0.01)] * 100 271 | AP = ap / len(query_label) * 100 272 | 273 | evaluate_csv[query_name+"_"+net_name] = [float(recall_1), float(recall_5), 274 | float(recall_10), float(recall_1p), 275 | float(AP), 276 | float(time_elapsed) 277 | ] 278 | evaluate_result = 'Recall@1:%.2f Recall@5:%.2f Recall@10:%.2f Recall@top1:%.2f AP:%.2f Time::%.2f' % ( 279 | recall_1, recall_5, recall_10, recall_1p, AP, time_elapsed 280 | ) 281 | 282 | # show result and save 283 | save_path = os.path.join(param_dict["weight_save_path"], name) 284 | save_txt_path = os.path.join(save_path, 285 | '%s_to_%s_%s_%.2f_%.2f.txt' % (query_name[6:], gallery_name[8:], net_name[:7], 286 | recall_1, AP)) 287 | # print(save_txt_path) 288 | 289 | with open(save_txt_path, 'w') as f: 290 | f.write(evaluate_result) 291 | f.close() 292 | 293 | shutil.copy('settings.yaml', os.path.join(save_path, "settings_saved.yaml")) 294 | shutil.copy('train.py', os.path.join(save_path, "train.py")) 295 | shutil.copy('Multi_HBP.py', os.path.join(save_path, "model.py")) 296 | 297 | # print(round(len(gallery_label)*0.01)) 298 | print(evaluate_result) 299 | # evaluate_csv["max"] = 300 | drone_max = [] 301 | satellite_max = [] 302 | 303 | for index in evaluate_csv.index: 304 | drone_max.append(evaluate_csv.loc[index].iloc[:5].max()) 305 | satellite_max.append(evaluate_csv.loc[index].iloc[5:].max()) 306 | 307 | evaluate_csv['drone_max'] = drone_max 308 | evaluate_csv['satellite_max'] = satellite_max 309 | evaluate_csv.columns.name = "net" 310 | evaluate_csv.index.name = "index" 311 | evaluate_csv.to_csv(table_path) 312 | else: 313 | print("Don't have enough weights to evaluate!") 314 | 315 | def parse_opt(known=False): 316 | parser = argparse.ArgumentParser() 317 | parser.add_argument('--cfg', type=str, default='settings.yaml', help='config file XXX.yaml path') 318 | parser.add_argument('--name', type=str, default='', help='evaluate which weight,dir name') 319 | parser.add_argument('--seq', type=int, default=1, help='evaluate how many weights from loss value(small -> big)') 320 | 321 | opt = parser.parse_known_args()[0] if known else parser.parse_args() 322 | 323 | return opt 324 | 325 | 326 | if __name__ == '__main__': 327 | opt = parse_opt(True) 328 | 329 | eval_and_test(opt.cfg, opt.name, opt.seq) -------------------------------------------------------------------------------- /Visualization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import glob 4 | import torch 5 | import model_ 6 | from shutil import copyfile, copy 7 | import random 8 | import scipy 9 | import pandas as pd 10 | import numpy as np 11 | from torch import nn 12 | # from evaluation_methods import select_best_weight 13 | from utils import get_yaml_value, which_view, create_dir 14 | from U1652_test_and_evaluate import extract_feature 15 | # from Preprocessing import 16 | from Create_MultiModal_Dataset import Multimodel_Dateset 17 | from Multi_HBP import Hybird_ViT 18 | from torchvision import datasets, models, transforms 19 | 20 | 21 | def get_rank(query_name, gallery_name): 22 | data_transforms = transforms.Compose([ 23 | transforms.Resize((384, 384), interpolation=3), 24 | transforms.ToTensor(), 25 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 26 | ]) 27 | data_path = get_yaml_value("dataset_path") 28 | 29 | 30 | gallery_drone_path = os.path.join(data_path, "test", "gallery_drone") 31 | gallery_satellite_path = os.path.join(data_path, "test", "gallery_satellite") 32 | gallery_drone_list = glob.glob(os.path.join(gallery_drone_path, "*")) 33 | gallery_drone_list = sorted(gallery_drone_list, key=lambda x: int(re.findall("[0-9]+", x[-4:])[0])) 34 | 35 | 36 | gallery_satellite_list = glob.glob(os.path.join(gallery_satellite_path, "*")) 37 | gallery_satellite_list = sorted(gallery_satellite_list, key=lambda x: int(re.findall("[0-9]+", x[-4:])[0])) 38 | drone_list = [] 39 | satellite_list = [] 40 | 41 | if "drone" in gallery_name: 42 | for drone_img in gallery_drone_list: 43 | img_list = glob.glob(os.path.join(drone_img, "*")) 44 | img_list = sorted(img_list, key=lambda x: int(re.findall("[0-9]+", x.split('/')[-1])[0])) 45 | for img in img_list: 46 | drone_list.append(img) 47 | elif "satellite" in gallery_name: 48 | for satellite_img in gallery_satellite_list: 49 | img_list = glob.glob(os.path.join(satellite_img, "*")) 50 | img_list = sorted(img_list, key=lambda x: int(re.findall("[0-9]+", x.split('/')[-1])[0])) 51 | for img in img_list: 52 | satellite_list.append(img) 53 | 54 | image_datasets = {x: Multimodel_Dateset(os.path.join(data_path, 'test', x), data_transforms) for x in 55 | ['gallery_satellite', 'gallery_drone', 'query_satellite', 'query_drone']} 56 | data_loader = {x: torch.utils.data.DataLoader(image_datasets[x], 57 | # batch_size=get_yaml_value("batch_size"), 58 | batch_size=4, 59 | shuffle=False) for x in 60 | ['gallery_satellite', 'gallery_drone', 'query_satellite', 'query_drone']} 61 | net_path = "/home/sues/save_model_weight/Release_final_weight/net_077.pth" 62 | 63 | which_query = which_view(query_name) 64 | which_gallery = which_view(gallery_name) 65 | print(net_path) 66 | model = Hybird_ViT(701, 0.1).cuda() 67 | model.load_state_dict(torch.load(net_path)) 68 | for i in range(2): 69 | cls_name = 'classifier' + str(i) 70 | c = getattr(model, cls_name) 71 | c.classifier = nn.Sequential() 72 | model = model.eval() 73 | 74 | if not os.path.exists("visual_pytorch_result.mat"): 75 | query_feature = extract_feature(model, data_loader[query_name], 2, 1, which_query) 76 | gallery_feature = extract_feature(model, data_loader[gallery_name], 2, 1, which_gallery) 77 | # result = scipy.io.loadmat("U1652_pytorch_result.mat") 78 | result = {'gallery_f': gallery_feature.numpy(), 'query_f': query_feature.numpy()} 79 | scipy.io.savemat('visual_pytorch_result.mat', result) 80 | # else: 81 | # result = scipy.io.loadmat("visual_pytorch_result.mat") 82 | 83 | result = scipy.io.loadmat("visual_pytorch_result.mat") 84 | # initialize query feature data 85 | query_feature = torch.FloatTensor(result['query_f']) 86 | 87 | 88 | gallery_feature = torch.FloatTensor(result['gallery_f']) 89 | 90 | # gallery_feature = np.array(gallery_feature) 91 | # gallery_features = np.array_split(gallery_feature, len(gallery_label)) 92 | # gallery_feature = torch.FloatTensor() 93 | # label_index = np.argsort(gallery_label) 94 | # label_index = label_index[::-1] 95 | # for i in label_index: 96 | # gallery_feature = torch.cat([gallery_feature, torch.from_numpy(gallery_features[i])]) 97 | # gallery_features = sorted(gallery_features, key=label_index) 98 | # gallery_feature = np.stack(gallery_features) 99 | query_img_list = image_datasets[query_name].imgs 100 | gallery_img_list = image_datasets[gallery_name].imgs 101 | matching_table = {} 102 | random_sample_list = random.sample(range(0, len(query_img_list)), 10) 103 | print(random_sample_list) 104 | for i in random_sample_list: 105 | query = query_feature[i].view(-1, 1) 106 | score = torch.mm(gallery_feature, query) 107 | score = score.squeeze(1).cpu() 108 | index = np.argsort(score.numpy()) 109 | index = index[::-1].tolist() 110 | max_score_list = index[0:10] 111 | query_img = query_img_list[i][0] 112 | most_correlative_img = [] 113 | for index in max_score_list: 114 | if "satellite" in query_name: 115 | most_correlative_img.append(gallery_img_list[index][0]) 116 | elif "drone" in query_name: 117 | most_correlative_img.append(gallery_img_list[index][0]) 118 | matching_table[query_img] = most_correlative_img 119 | matching_table = pd.DataFrame(matching_table) 120 | print(matching_table) 121 | save_path = query_name.split("_")[-1] + "_" + str(1652) + "_matching.csv" 122 | matching_table.to_csv(save_path) 123 | return save_path 124 | 125 | def summary_csv_extract_pic(csv_path): 126 | csv_table = pd.read_csv(csv_path, index_col=0) 127 | create_dir("result") 128 | 129 | csv_path = os.path.join("result", csv_path.split("_")[-3]) 130 | create_dir(csv_path) 131 | query_pic = list(csv_table.columns) 132 | for pic in query_pic: 133 | dir_path = os.path.join(csv_path, pic.split("/")[-4] + "_" + pic.split("/")[-3]) 134 | create_dir(dir_path) 135 | dir_path = os.path.join(dir_path, pic.split("/")[-2]) 136 | create_dir(dir_path) 137 | copy(pic, dir_path) 138 | gallery_list = list(csv_table[pic]) 139 | print(gallery_list) 140 | count = 0 141 | for gl_path in gallery_list: 142 | print(gl_path) 143 | copy(gl_path, dir_path) 144 | src_name = os.path.join(dir_path, gl_path.split("/")[-1]) 145 | dest_name = os.path.dirname(src_name) + os.sep + str(count) + "_" + gl_path.split("/")[-2] + "." + gl_path.split(".")[-1] 146 | print(src_name) 147 | print(dest_name) 148 | os.rename(src_name, dest_name) 149 | count = count + 1 150 | 151 | 152 | if __name__ == '__main__': 153 | 154 | # query_name = 'query_satellite' 155 | # gallery_name = 'gallery_drone' 156 | # 157 | # path = get_rank(query_name, gallery_name) 158 | # summary_csv_extract_pic(path) 159 | 160 | query_name = 'query_drone' 161 | gallery_name = 'gallery_satellite' 162 | 163 | path = get_rank(query_name, gallery_name) 164 | summary_csv_extract_pic(path) -------------------------------------------------------------------------------- /activation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class SiLU(nn.Module): # export-friendly version of nn.SiLU() 8 | @staticmethod 9 | def forward(x): 10 | return x * torch.sigmoid(x) 11 | 12 | 13 | class Hardswish(nn.Module): # export-friendly version of nn.Hardswish() 14 | @staticmethod 15 | def forward(x): 16 | # return x * F.hardsigmoid(x) # for torchscript and CoreML 17 | return x * F.hardtanh(x + 3, 0., 6.) / 6. # for torchscript, CoreML and ONNX 18 | 19 | class Gelu(nn.Module): 20 | @staticmethod 21 | def forward(self, x): 22 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 23 | 24 | class GeM(nn.Module): 25 | # GeM zhedong zheng 26 | def __init__(self, dim=2048, p=3, eps=1e-6): 27 | super(GeM, self).__init__() 28 | self.p = nn.Parameter(torch.ones(dim)*p, requires_grad=True).cuda() # initial p 29 | self.eps = eps 30 | self.dim = dim 31 | def forward(self, x): 32 | return self.gem(x, p=self.p, eps=self.eps) 33 | 34 | def gem(self, x, p=3, eps=1e-6): 35 | x = x.cuda() 36 | x = torch.transpose(x, 1, -1) 37 | x = x.clamp(min=eps).pow(p) 38 | x = torch.transpose(x, 1, -1) 39 | x = F.avg_pool2d(x, (x.size(-2), x.size(-1))) 40 | x = x.view(x.size(0), x.size(1)) 41 | x = x.pow(1./p) 42 | return x 43 | 44 | def __repr__(self): 45 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ',' + 'dim='+str(self.dim)+')' 46 | 47 | class Conv(nn.Module): 48 | # Standard convolution 49 | def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups 50 | super(Conv, self).__init__() 51 | self.conv = nn.Conv2d(c1, c2, k, s, groups=g, bias=False) 52 | self.bn = nn.BatchNorm2d(c2) 53 | self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) 54 | 55 | def forward(self, x): 56 | return self.act(self.bn(self.conv(x))) 57 | 58 | def fuseforward(self, x): 59 | return self.act(self.conv(x)) 60 | 61 | def autopad(k, p=None): # kernel, padding 62 | # Pad to 'same' 63 | if p is None: 64 | p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad 65 | return p 66 | 67 | 68 | if __name__ == "__main__": 69 | pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 70 | c = Conv(64, 64, k=65, s=1) 71 | tensor = torch.randn(8, 64, 128, 128) 72 | p_1 = pool(tensor) 73 | p_2 = c(tensor) 74 | print(p_1.shape, p_2.shape) 75 | t = p_1+p_2 76 | print(t.shape) 77 | # concat = torch.concat((p_1, p_2), dim=1) 78 | # print(concat.shape) 79 | -------------------------------------------------------------------------------- /draw_cam_ViT.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | import torch 5 | import glob 6 | import re 7 | import os 8 | import Multi_HBP 9 | from einops import rearrange 10 | import matplotlib.pyplot as plt 11 | 12 | from pytorch_grad_cam import GradCAM, \ 13 | ScoreCAM, \ 14 | GradCAMPlusPlus, \ 15 | AblationCAM, \ 16 | XGradCAM, \ 17 | EigenCAM, \ 18 | EigenGradCAM, \ 19 | LayerCAM, \ 20 | FullGrad 21 | 22 | from pytorch_grad_cam import GuidedBackpropReLUModel 23 | from pytorch_grad_cam.utils.image import show_cam_on_image, \ 24 | preprocess_image 25 | from pytorch_grad_cam.ablation_layer import AblationLayerVit 26 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget 27 | 28 | 29 | def draw_heat_map(weights, img_path): 30 | 31 | count = 0 32 | for weight in weights: 33 | print(weight) 34 | model = Multi_HBP.Hybird_ViT(701, 0) 35 | 36 | model.load_state_dict(torch.load(weight)) 37 | model = model.model_1 38 | 39 | model.eval() 40 | 41 | if args.use_cuda: 42 | model = model.cuda() 43 | # model.model_1. 44 | target_layers = [model.last_block[-1].norm1] 45 | 46 | if args.method not in methods: 47 | raise Exception(f"Method {args.method} not implemented") 48 | 49 | if args.method == "ablationcam": 50 | cam = methods[args.method](model=model, 51 | target_layers=target_layers, 52 | use_cuda=args.use_cuda, 53 | reshape_transform=reshape_transform, 54 | ablation_layer=AblationLayerVit()) 55 | else: 56 | cam = methods[args.method](model=model, 57 | target_layers=target_layers, 58 | use_cuda=args.use_cuda, 59 | reshape_transform=reshape_transform) 60 | 61 | rgb_img = cv2.imread(img_path, 1)[:, :, ::-1] 62 | rgb_img = cv2.resize(rgb_img, (384, 384)) 63 | rgb_img = np.float32(rgb_img) / 255 64 | input_tensor = preprocess_image(rgb_img, mean=[0.485, 0.456, 0.406], 65 | std=[0.229, 0.224, 0.225]) 66 | print(input_tensor.shape) 67 | 68 | # If None, returns the map for the highest scoring category. 69 | # Otherwise, targets the requested category. 70 | targets = [ClassifierOutputTarget(-1)] 71 | 72 | # AblationCAM and ScoreCAM have batched implementations. 73 | # You can override the internal batch size for faster computation. 74 | cam.batch_size = 32 75 | 76 | grayscale_cam = cam(input_tensor=input_tensor, 77 | targets=targets, 78 | eigen_smooth=args.eigen_smooth, 79 | aug_smooth=args.aug_smooth) 80 | 81 | # Here grayscale_cam has only one image in the batch 82 | grayscale_cam = grayscale_cam[0, :] 83 | cam_image = show_cam_on_image(rgb_img, grayscale_cam) 84 | plt.figure("black") 85 | plt.imshow(cam_image) 86 | plt.show() 87 | 88 | cv2.imwrite(os.path.join("./draw_imgs", f'{args.method}_cam_%d_vit.jpg' % count), cam_image) 89 | count += 5 90 | 91 | 92 | def get_args(): 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument('--use-cuda', action='store_true', default=True, 95 | help='Use NVIDIA GPU acceleration') 96 | 97 | parser.add_argument('--aug_smooth', action='store_true', 98 | help='Apply test time augmentation to smooth the CAM') 99 | parser.add_argument( 100 | '--eigen_smooth', 101 | action='store_true', 102 | help='Reduce noise by taking the first principle componenet' 103 | 'of cam_weights*activations') 104 | 105 | parser.add_argument( 106 | '--method', 107 | type=str, 108 | default='eigengradcam', 109 | help='Can be gradcam/gradcam++/scorecam/xgradcam/ablationcam') 110 | 111 | args = parser.parse_args() 112 | args.use_cuda = args.use_cuda and torch.cuda.is_available() 113 | if args.use_cuda: 114 | print('Using GPU for acceleration') 115 | else: 116 | print('Using CPU for computation') 117 | 118 | return args 119 | 120 | 121 | def reshape_transform(tensor, height=24, width=24): 122 | # print(tensor.shape) 123 | result = tensor[:, 1:, :].reshape(tensor.size(0), 124 | height, width, tensor.size(2)) 125 | # print(result.shape) 126 | # result = rearrange(result, "b (h w) y -> b y h w", h=24, w=24) 127 | # Bring the channels to the first dimension, 128 | # like in CNNs. 129 | result = result.transpose(2, 3).transpose(1, 2) 130 | # print(result.shape) 131 | 132 | return result 133 | 134 | 135 | if __name__ == '__main__': 136 | """ python vit_gradcam.py --image-path 137 | Example usage of using cam-methods on a VIT network. 138 | 139 | """ 140 | 141 | args = get_args() 142 | methods = \ 143 | {"gradcam": GradCAM, 144 | "scorecam": ScoreCAM, 145 | "gradcam++": GradCAMPlusPlus, 146 | "ablationcam": AblationCAM, 147 | "xgradcam": XGradCAM, 148 | "eigencam": EigenCAM, 149 | "eigengradcam": EigenGradCAM, 150 | "layercam": LayerCAM, 151 | "fullgrad": FullGrad} 152 | 153 | if args.method not in list(methods.keys()): 154 | raise Exception(f"method should be one of {list(methods.keys())}") 155 | 156 | paths = ["/home/sues/save_model_weight/Release_final_weight/net_077.pth"] 157 | 158 | img_path = "/home/sues/media/disk2/University-Release/University-Release/test/gallery_satellite/0001/0001.jpg" 159 | # model.load_state_dict(torch.load(weights[0])) 160 | draw_heat_map(paths, img_path) 161 | -------------------------------------------------------------------------------- /multi_test_and_evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import glob 3 | import os 4 | import time 5 | import model_ 6 | import torch 7 | import scipy.io 8 | import argparse 9 | import shutil 10 | import numpy as np 11 | import pandas as pd 12 | from torch import nn 13 | from utils import fliplr, load_network, which_view, get_id, get_yaml_value 14 | from Create_MultiModal_Dataset import Multimodel_Dateset 15 | from U1652_test_and_evaluate import evaluate 16 | 17 | from torchvision import datasets, models, transforms 18 | 19 | from Multi_HBP import Hybird_ViT 20 | if torch.cuda.is_available(): 21 | device = torch.device("cuda:0") 22 | 23 | 24 | def extract_feature(model, dataloaders, view_index=1): 25 | features = torch.FloatTensor() 26 | count = 0 27 | for data in dataloaders: 28 | img, text, label = data 29 | n, c, h, w = img.size() 30 | count += n 31 | text = text.to(device) 32 | ff = torch.FloatTensor(n, 512, 4).zero_().cuda() 33 | 34 | # why for in range(2): 35 | # 1. for flip img 36 | # 2. for normal img 37 | 38 | for i in range(2): 39 | if i == 1: 40 | img = fliplr(img) 41 | 42 | input_img = img.to(device) 43 | outputs = None 44 | if view_index == 1: 45 | outputs, _ = model(input_img, None, text, None) 46 | elif view_index == 2: 47 | _, outputs = model(None, input_img, None, text) 48 | # print(outputs.shape) 49 | # print(ff.shape) 50 | ff += outputs 51 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * np.sqrt(4) 52 | # print("fnorm", fnorm.shape) 53 | ff = ff.div(fnorm.expand_as(ff)) 54 | # print("ff", ff.shape) 55 | ff = ff.view(ff.size(0), -1) 56 | 57 | features = torch.cat((features, ff.data.cpu()), 0) # 在维度0上拼接 58 | return features 59 | 60 | 61 | ############################### main function ####################################### 62 | def eval_and_test(multi_coff, config_file, weight_path, save_path): 63 | param = get_yaml_value(config_file) 64 | data_transforms = transforms.Compose([ 65 | transforms.Resize((384, 384), interpolation=3), 66 | transforms.ToTensor(), 67 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 68 | ]) 69 | 70 | table_path = os.path.join(save_path, param["model"] + "_" + str(1652) + "_" + "multi_query_" + 71 | ".csv") 72 | evaluate_csv = pd.DataFrame(index=["recall@1", "recall@5", "recall@10", "recall@1p", "AP", "time"]) 73 | 74 | image_datasets = {x: Multimodel_Dateset(os.path.join(param["dataset_path"], 'test', x), data_transforms) for x in 75 | ['gallery_satellite', 'query_drone']} 76 | data_loader = {x: torch.utils.data.DataLoader(image_datasets[x], 77 | batch_size=16, 78 | shuffle=False) for x in 79 | ['gallery_satellite', 'query_drone']} 80 | 81 | query_name = "query_drone" 82 | gallery_name = "gallery_satellite" 83 | 84 | model = Hybird_ViT(701, 0.1) 85 | model.load_state_dict(torch.load(weight_path)) 86 | 87 | for i in range(2): 88 | cls_name = 'classifier' + str(i) 89 | c = getattr(model, cls_name) 90 | c.classifier = nn.Sequential() 91 | 92 | model = model.eval() 93 | model = model.cuda() 94 | which_query = which_view(query_name) 95 | which_gallery = which_view(gallery_name) 96 | 97 | gallery_path = image_datasets[gallery_name].imgs 98 | query_path = image_datasets[query_name].imgs 99 | 100 | gallery_label, gallery_path = get_id(gallery_path) 101 | query_label, query_path = get_id(query_path) 102 | 103 | with torch.no_grad(): 104 | query_feature = extract_feature(model, data_loader[query_name], which_query) 105 | gallery_feature = extract_feature(model, data_loader[gallery_name], which_gallery) 106 | 107 | # fed tensor to GPU 108 | query_feature = query_feature.cuda() 109 | new_query_feature = torch.FloatTensor().cuda() 110 | gallery_feature = gallery_feature.cuda() 111 | multi = True 112 | 113 | # coffs = [1, 2, 6, 18, 54] 114 | # University-1652 115 | image_per_class = 54 // multi_coff 116 | # coff = 54 // image_per_class 117 | 118 | print(image_per_class) 119 | query_length = len(query_label) + image_per_class 120 | 121 | feature_list = list(range(0, query_length, image_per_class)) 122 | query_concat = np.ones(((len(feature_list)-1)//multi_coff, multi_coff)) 123 | 124 | if multi: 125 | index = list(query_label).index 126 | query_label = sorted(list(set(list(query_label))), key=index) 127 | 128 | for i in range(len(query_label)): 129 | query_concat[i] = query_label[i] * query_concat[i] 130 | 131 | query_label = query_concat.reshape(-1,) 132 | # print(query_feature.shape) 133 | for i in range(len(feature_list)): 134 | if feature_list[i] == (query_length - image_per_class): 135 | continue 136 | 137 | multi_feature = torch.mean(query_feature[feature_list[i]:feature_list[i+1], :], 0) 138 | # print(multi_feature.shape) 139 | multi_feature = multi_feature.view(1, 2048) 140 | new_query_feature = torch.cat((new_query_feature, multi_feature), 0) 141 | 142 | query_feature = new_query_feature 143 | 144 | # CMC = recall 145 | CMC = torch.IntTensor(len(gallery_label)).zero_() 146 | 147 | # ap = average precision 148 | ap = 0.0 149 | 150 | for i in range(len(query_label)): 151 | ap_tmp, CMC_tmp = evaluate(query_feature[i], query_label[i], gallery_feature, gallery_label) 152 | if CMC_tmp[0] == -1: 153 | continue 154 | CMC += CMC_tmp 155 | ap += ap_tmp 156 | 157 | CMC = CMC.float() 158 | CMC = CMC / len(query_label) 159 | # print(len(query_label)) 160 | recall_1 = CMC[0] * 100 161 | recall_5 = CMC[4] * 100 162 | recall_10 = CMC[9] * 100 163 | recall_1p = CMC[round(len(gallery_label) * 0.01)] * 100 164 | AP = ap / len(query_label) * 100 165 | 166 | evaluate_csv["multi_query" + "_" + str(image_per_class) + 167 | "_" + str(1652)] = \ 168 | [float(recall_1), float(recall_5), 169 | float(recall_10), float(recall_1p), 170 | float(AP), float(0)] 171 | 172 | print(evaluate_csv) 173 | 174 | evaluate_csv.columns.name = "" 175 | evaluate_csv.index.name = "index" 176 | evaluate_csv = evaluate_csv.T 177 | evaluate_csv.to_csv(table_path) 178 | 179 | 180 | if __name__ == '__main__': 181 | parser = argparse.ArgumentParser() 182 | 183 | parser.add_argument('--cfg', type=str, default='settings.yaml', help='config file XXX.yaml path') 184 | parser.add_argument('--multi', type=int, default=1, help='multi number for example: if multi == 1 fusion image ' 185 | 'number = 50/1 = 50') 186 | parser.add_argument('--weight', type=str, default=None, help='evaluate which weight, path') 187 | parser.add_argument('--csv_save_path', type=str, default="./result", help="evaluation result table store path") 188 | opt = parser.parse_known_args()[0] 189 | 190 | eval_and_test(opt.multi, opt.cfg, opt.weight, opt.csv_save_path) 191 | -------------------------------------------------------------------------------- /settings.yaml: -------------------------------------------------------------------------------- 1 | # dateset path 2 | dataset_path: /home/sues/media/disk1/University-Release-MultiModel/University-Release 3 | weight_save_path: /home/sues/save_model_weight 4 | 5 | # apply LPN and set block number 6 | LPN : 1 7 | block : 2 8 | 9 | # super parameters 10 | batch_size : 16 11 | num_epochs : 80 12 | drop_rate : 0.35 13 | weight_decay : 0.0001 14 | lr : 0.01 15 | 16 | #intial parameters 17 | image_size: 384 18 | fp16 : 1 19 | classes : 701 20 | 21 | model : MBF 22 | name: MBF 23 | 24 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import time 4 | import torch 5 | import argparse 6 | 7 | import numpy as np 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.optim import lr_scheduler 11 | import torch.backends.cudnn as cudnn 12 | from pytorch_metric_learning import losses, miners 13 | 14 | from Multi_HBP import Hybird_ViT 15 | from utils import get_yaml_value, parameter, create_dir, save_feature_network, setup_seed 16 | from Preprocessing import create_U1652_dataloader 17 | import random 18 | import os 19 | 20 | if torch.cuda.is_available(): 21 | device = torch.device("cuda:0") 22 | # torch.cuda.manual_seed(random.randint(1, 100)) 23 | setup_seed() 24 | cudnn.benchmark = True 25 | 26 | def one_LPN_output(outputs, labels, criterion, block): 27 | # part = {} 28 | 29 | sm = nn.Softmax(dim=1) 30 | num_part = block 31 | score = 0 32 | loss = 0 33 | # print(len(outputs)) 34 | for i in range(num_part): 35 | part = outputs[i] 36 | score += sm(part) 37 | loss += criterion(part, labels) 38 | _, preds = torch.max(score.data, 1) 39 | 40 | return preds, loss 41 | 42 | 43 | def train(config_path): 44 | param_dict = get_yaml_value(config_path) 45 | print(param_dict) 46 | classes = param_dict["classes"] 47 | num_epochs = param_dict["num_epochs"] 48 | drop_rate = param_dict["drop_rate"] 49 | lr = param_dict["lr"] 50 | weight_decay = param_dict["weight_decay"] 51 | model_name = param_dict["model"] 52 | fp16 = param_dict["fp16"] 53 | weight_save_path = param_dict["weight_save_path"] 54 | LPN = param_dict["LPN"] 55 | batchsize = param_dict["batch_size"] 56 | all_block = param_dict["block"] 57 | data_dir = param_dict["dataset_path"] 58 | image_size = param_dict["image_size"] 59 | 60 | dataloaders, image_datasets = create_U1652_dataloader(data_dir, batchsize, image_size) 61 | dataset_sizes = {x: len(image_datasets[x]) for x in ['satellite', 'drone']} 62 | 63 | print(dataset_sizes) 64 | class_names = image_datasets['satellite'].classes 65 | print(len(class_names)) 66 | 67 | model = Hybird_ViT(classes, drop_rate, all_block).cuda() 68 | 69 | # apply LPN strategy 70 | if LPN: 71 | ignored_params = list() 72 | for i in range(all_block): 73 | cls_name = 'classifier' + str(i) 74 | c = getattr(model, cls_name) 75 | ignored_params += list(map(id, c.parameters())) 76 | 77 | base_params = filter(lambda p: id(p) not in ignored_params, model.parameters()) 78 | 79 | optim_params = [{'params': base_params, 'lr': 0.1 * lr}] 80 | for i in range(all_block): 81 | cls_name = 'classifier' + str(i) 82 | c = getattr(model, cls_name) 83 | optim_params.append({'params': c.parameters(), 'lr': lr}) 84 | 85 | optimizer = optim.SGD(optim_params, weight_decay=weight_decay, momentum=0.9, nesterov=True) 86 | # opt = torchcontrib.optim.SWA(optimizer) 87 | else: 88 | ignored_params = list(map(id, model.classifier.parameters())) 89 | base_params = filter(lambda p: id(p) not in ignored_params, model.parameters()) 90 | 91 | optimizer = optim.SGD([ 92 | {'params': base_params, 'lr': 0.1 * lr}, 93 | {'params': model.classifier.parameters(), 'lr': lr} 94 | ], weight_decay=weight_decay, momentum=0.9, nesterov=True) 95 | 96 | if fp16: 97 | # from apex.fp16_utils import * 98 | from apex import amp, optimizers 99 | model, optimizer_ft = amp.initialize(model, optimizer, opt_level="O2") 100 | 101 | criterion = nn.CrossEntropyLoss() 102 | criterion_func = losses.TripletMarginLoss(margin=0.3) 103 | 104 | miner = miners.MultiSimilarityMiner() 105 | 106 | scheduler = lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.5) 107 | 108 | print("Dataloader Preprocessing Finished...") 109 | MAX_LOSS = 10 110 | print("Training Start >>>>>>>>") 111 | weight_save_name = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) 112 | dir_model_name = model_name + "_" + str(1652) + "_" + weight_save_name 113 | save_path = os.path.join(weight_save_path, dir_model_name) 114 | create_dir(save_path) 115 | print(save_path) 116 | parameter("name", dir_model_name) 117 | 118 | warm_epoch = 5 119 | warm_up = 0.1 # We start from the 0.1*lrRate 120 | warm_iteration = round(dataset_sizes['satellite'] / batchsize) * warm_epoch # first 5 epoch 121 | 122 | for epoch in range(num_epochs): 123 | since = time.time() 124 | 125 | running_loss = 0.0 126 | running_corrects1 = 0.0 127 | running_corrects2 = 0.0 128 | total1 = 0.0 129 | total2 = 0.0 130 | model.train(True) 131 | for data1, data2 in zip(dataloaders["satellite"], dataloaders["drone"]): 132 | 133 | input1, text1, label1 = data1 134 | input2, text2, label2 = data2 135 | 136 | input1, input2 = input1.to(device), input2.to(device) 137 | text1, text2 = text1.to(device), text2.to(device) 138 | label1, label2 = label1.to(device), label2.to(device) 139 | 140 | total1 += label1.size(0) 141 | total2 += label2.size(0) 142 | 143 | optimizer.zero_grad() 144 | # output1, output2, feature1, feature2, lpn_1, lpn_2 = model(input1, input2) 145 | output1, output2, feature1, feature2, = model(input1, input2, text1, text2) 146 | 147 | fnorm = torch.norm(feature1, p=2, dim=1, keepdim=True) * np.sqrt(all_block+2) 148 | fnorm2 = torch.norm(feature2, p=2, dim=1, keepdim=True) * np.sqrt(all_block+2) 149 | # fnorm3 = torch.norm(feature3, p=2, dim=1, keepdim=True) * np.sqrt(all_block) 150 | # fnorm4 = torch.norm(feature4, p=2, dim=1, keepdim=True) * np.sqrt(all_block) 151 | 152 | feature1 = feature1.div(fnorm.expand_as(feature1)) 153 | feature2 = feature2.div(fnorm2.expand_as(feature2)) 154 | # feature3 = feature3.div(fnorm3.expand_as(feature3)) 155 | # feature4 = feature4.div(fnorm4.expand_as(feature4)) 156 | 157 | loss1 = loss2 = loss3 = loss4 = loss6 = loss5 = loss7 = loss8 = 0 158 | 159 | 160 | if LPN: 161 | # print(len(output1)) 162 | preds1, loss1 = one_LPN_output(output1[2:], label1, criterion, all_block) 163 | preds2, loss2 = one_LPN_output(output2[2:], label2, criterion, all_block) 164 | 165 | loss3 = criterion(output1[1], label1) 166 | loss4 = criterion(output2[1], label2) 167 | 168 | # loss5 = criterion(output1[1], label1) 169 | # loss6 = criterion(output2[1], label2) 170 | 171 | loss7 = criterion(output1[0], label1) 172 | loss8 = criterion(output2[0], label2) 173 | # _, preds1 = torch.max(output1[1].data, 1) 174 | # _, preds2 = torch.max(output2[1].data, 1) 175 | # print(loss) 176 | else: 177 | loss1 = criterion(output1[0], label1) 178 | loss2 = criterion(output2[1], label2) 179 | loss3 = criterion(output1[0], label1) 180 | loss4 = criterion(output2[1], label2) 181 | 182 | _, preds1 = torch.max(output1[0].data, 1) 183 | _, preds2 = torch.max(output2[1].data, 1) 184 | _, preds3 = torch.max(output1[0].data, 1) 185 | _, preds4 = torch.max(output2[1].data, 1) 186 | 187 | # Identity loss 188 | loss = loss1 + loss2 + loss3 + loss4 + loss7 + loss8 189 | 190 | # Triplet loss 191 | hard_pairs = miner(feature1, label1) 192 | hard_pairs2 = miner(feature2, label2) 193 | 194 | loss += criterion_func(feature1, label1, hard_pairs) + \ 195 | criterion_func(feature2, label2, hard_pairs2) 196 | 197 | if epoch < warm_epoch: 198 | warm_up = min(1.0, warm_up + 0.9 / warm_iteration) 199 | loss *= warm_up 200 | if fp16: # we use optimizer to backward loss 201 | with amp.scale_loss(loss, optimizer) as scaled_loss: 202 | scaled_loss.backward() 203 | # pass 204 | else: 205 | loss.backward() 206 | optimizer.step() 207 | 208 | running_loss += loss.item() 209 | running_corrects1 += preds1.eq(label1.data).sum() 210 | running_corrects2 += preds2.eq(label2.data).sum() 211 | # print(loss.item(), preds1.eq(label1.data).sum(), preds2.eq(label2.data).sum()) 212 | 213 | scheduler.step() 214 | epoch_loss = running_loss / len(class_names) 215 | satellite_acc = running_corrects1 / total1 216 | drone_acc = running_corrects2 / total2 217 | time_elapsed = time.time() - since 218 | 219 | print('[Epoch {}/{}] {} | Loss: {:.4f} | Drone_Acc: {:.2f}% | Satellite_Acc: {:.2f}% | Time: {:.2f}s' \ 220 | .format(epoch + 1, num_epochs, "Train", epoch_loss, drone_acc * 100, satellite_acc * 100, time_elapsed)) 221 | 222 | if drone_acc > 0.95 and satellite_acc > 0.95: 223 | if epoch_loss < MAX_LOSS and epoch > (num_epochs - 50): 224 | MAX_LOSS = epoch_loss 225 | save_feature_network(model, dir_model_name, epoch + 1) 226 | print(model_name + " Epoch: " + str(epoch + 1) + " has saved with loss: " + str(epoch_loss)) 227 | 228 | 229 | def parse_opt(known=False): 230 | parser = argparse.ArgumentParser() 231 | parser.add_argument('--cfg', type=str, default='settings.yaml', help='config file XXX.yaml path') 232 | opt = parser.parse_known_args()[0] if known else parser.parse_args() 233 | 234 | return opt 235 | 236 | 237 | if __name__ == '__main__': 238 | opt = parse_opt(True) 239 | print(opt.cfg) 240 | train(opt.cfg) 241 | 242 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import sys 5 | import glob 6 | import yaml 7 | import math 8 | import torch 9 | from Multi_HBP import Hybird_ViT 10 | import pandas as pd 11 | from shutil import copyfile, copy 12 | # from evaluation_methods import select_best_weight 13 | import torch.distributed as dist 14 | 15 | # from new_model import two_view_net 16 | def get_params_value(key_name, file_name="settings.yaml"): 17 | f = open(file_name, 'r', encoding="utf-8") 18 | t_value = yaml.load(f, Loader=yaml.FullLoader) 19 | f.close() 20 | params = t_value[key_name] 21 | return params 22 | 23 | def get_yaml_value(config_path): 24 | f = open(config_path, 'r', encoding="utf-8") 25 | t_value = yaml.load(f, Loader=yaml.FullLoader) 26 | f.close() 27 | # params = t_value[key_name] 28 | return t_value 29 | 30 | 31 | def save_network(network, dir_model_name, epoch_label, loss): 32 | save_path = get_params_value('weight_save_path') 33 | # with open("settings.yaml", "r", encoding="utf-8") as f: 34 | # dict = yaml.load(f, Loader=yaml.FullLoader) 35 | # dict['name'] = dir_model_name 36 | # with open("settings.yaml", "w", encoding="utf-8") as f: 37 | # yaml.dump(dict, f) 38 | 39 | # if not os.path.isdir(os.path.join(save_path, dir_model_name)): 40 | # os.mkdir(os.path.join(save_path, dir_model_name)) 41 | 42 | if isinstance(epoch_label, int): 43 | save_filename = 'net_%03d_loss_%f.pth' % (epoch_label, loss) 44 | else: 45 | save_filename = 'net_%s_loss_%f.pth' % (epoch_label, loss) 46 | save_path1 = os.path.join(save_path, dir_model_name, "visualized_" + save_filename) 47 | torch.save(network.module.state_dict(), save_path1) 48 | 49 | save_path2 = os.path.join(save_path, dir_model_name, "pretrained_" + save_filename) 50 | torch.save(network.state_dict(), save_path2) 51 | 52 | 53 | def save_feature_network(network, dir_model_name, epoch_label): 54 | save_path = get_params_value('weight_save_path') 55 | # with open("settings.yaml", "r", encoding="utf-8") as f: 56 | # dict = yaml.load(f, Loader=yaml.FullLoader) 57 | # dict['name'] = dir_model_name 58 | # with open("settings.yaml", "w", encoding="utf-8") as f: 59 | # yaml.dump(dict, f) 60 | 61 | # if not os.path.isdir(os.path.join(save_path, dir_model_name)): 62 | # os.mkdir(os.path.join(save_path, dir_model_name)) 63 | 64 | if isinstance(epoch_label, int): 65 | save_filename = 'net_%03d.pth' % (epoch_label) 66 | else: 67 | save_filename = 'net_%s.pth' % (epoch_label) 68 | save_path = os.path.join(save_path, dir_model_name, save_filename) 69 | torch.save(network.state_dict(), save_path) 70 | 71 | 72 | def fliplr(img): 73 | '''flip horizontal''' 74 | inv_idx = torch.arange(img.size(3) - 1, -1, -1).long() # N x C x H x W 75 | img_flip = img.index_select(3, inv_idx) 76 | return img_flip 77 | 78 | 79 | def which_view(name): 80 | if 'satellite' in name: 81 | return 1 82 | elif 'drone' in name: 83 | return 2 84 | else: 85 | print('unknown view') 86 | return -1 87 | 88 | 89 | def get_model_list(dirname, key, seq): 90 | if os.path.exists(dirname) is False: 91 | print('no dir: %s' % dirname) 92 | return None 93 | # gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if 94 | # os.path.isfile(os.path.join(dirname, f)) and key in f and ".pth" in f] 95 | # print(dirname, key) 96 | gen_models = glob.glob(os.path.join(dirname, "*.pth")) 97 | # print(gen_models) 98 | if gen_models is None: 99 | return None 100 | gen_models.sort() 101 | last_model_name = gen_models[seq] 102 | return last_model_name 103 | 104 | 105 | def load_network(seq): 106 | model_name = get_params_value("model") 107 | print(model_name) 108 | name = get_params_value("name") 109 | weight_save_path = get_params_value("weight_save_path") 110 | 111 | dirname = os.path.join(weight_save_path, name) 112 | # print(get_model_list(dirname, 'net', seq)) 113 | last_model_name = os.path.basename(get_model_list(dirname, 'net', seq)) 114 | print(get_model_list(dirname, 'net', seq) + " " + "seq: " + str(seq)) 115 | # print(os.path.join(dirname,last_model_name)) 116 | classes = get_params_value("classes") 117 | drop_rate = get_params_value("drop_rate") 118 | 119 | model = Hybird_ViT(classes, drop_rate) 120 | # model = model_.ResNet(classes, drop_rate) 121 | model.load_state_dict(torch.load(os.path.join(dirname, last_model_name))) 122 | return model, last_model_name 123 | 124 | 125 | def get_id(img_path): 126 | camera_id = [] 127 | labels = [] 128 | paths = [] 129 | for path, v in img_path: 130 | folder_name = os.path.basename(os.path.dirname(path)) 131 | labels.append(int(folder_name)) 132 | paths.append(path) 133 | return labels, paths 134 | 135 | 136 | def create_dir(path): 137 | if not os.path.exists(path): 138 | os.mkdir(path) 139 | 140 | 141 | # def get_best_weight(query_name, model_name, height, csv_path): 142 | # drone_best_list, satellite_best_list = select_best_weight(model_name, csv_path) 143 | # net_path = None 144 | # if "drone" in query_name: 145 | # for weight in drone_best_list: 146 | # if str(height) in weight: 147 | # drone_best_weight = weight.split(".")[0] 148 | # table = pd.read_csv(weight, index_col=0) 149 | # query_number = len(list(filter(lambda x: "drone" in x, table.columns))) - 1 150 | # 151 | # values = list(table.loc["recall@1", :])[:query_number] 152 | # indexes = list(table.loc["recall@1", :].index)[:query_number] 153 | # net_name = indexes[values.index(max(values))] 154 | # net = net_name.split("_")[2] + "_" + net_name.split("_")[3] 155 | # net_path = os.path.join(drone_best_weight, net) 156 | # # print(values, indexes) 157 | # if "satellite" in query_name: 158 | # for weight in satellite_best_list: 159 | # if str(height) in weight: 160 | # satellite_best_weight = weight.split(".")[0] 161 | # table = pd.read_csv(weight, index_col=0) 162 | # query_number = len(list(filter(lambda x: "drone" in x, table.columns))) - 1 163 | # 164 | # values = list(table.loc["recall@1", :])[query_number:query_number*2] 165 | # indexes = list(table.loc["recall@1", :].index)[query_number:query_number*2] 166 | # net_name = indexes[values.index(max(values))] 167 | # net = net_name.split("_")[2] + "_" + net_name.split("_")[3] 168 | # net_path = os.path.join(satellite_best_weight, net) 169 | # return net_path 170 | 171 | def parameter(index_name, index_number): 172 | with open("settings.yaml", "r", encoding="utf-8") as f: 173 | setting_dict = yaml.load(f, Loader=yaml.FullLoader) 174 | setting_dict[index_name] = index_number 175 | print(setting_dict) 176 | f.close() 177 | with open("settings.yaml", "w", encoding="utf-8") as f: 178 | yaml.dump(setting_dict, f) 179 | f.close() 180 | 181 | 182 | def summary_csv_extract_pic(csv_path): 183 | csv_table = pd.read_csv(csv_path, index_col=0) 184 | csv_path = os.path.join("result", csv_path.split("_")[-3]) 185 | create_dir(csv_path) 186 | query_pic = list(csv_table.columns) 187 | for pic in query_pic: 188 | dir_path = os.path.join(csv_path, pic.split("/")[-4] + "_" + pic.split("/")[-3]) 189 | create_dir(dir_path) 190 | dir_path = os.path.join(dir_path, pic.split("/")[-2]) 191 | create_dir(dir_path) 192 | copy(pic, dir_path) 193 | gallery_list = list(csv_table[pic]) 194 | print(gallery_list) 195 | count = 0 196 | for gl_path in gallery_list: 197 | print(gl_path) 198 | copy(gl_path, dir_path) 199 | src_name = os.path.join(dir_path, gl_path.split("/")[-1]) 200 | dest_name = os.path.dirname(src_name) + os.sep + str(count) + "_" + gl_path.split("/")[-2] + "." + gl_path.split(".")[-1] 201 | print(src_name) 202 | print(dest_name) 203 | os.rename(src_name, dest_name) 204 | count = count + 1 205 | 206 | if __name__ == '__main__': 207 | csv_list = glob.glob(os.path.join("result", "*matching.csv")) 208 | print(len(csv_list)) 209 | for csv in csv_list: 210 | summary_csv_extract_pic(csv) 211 | # break 212 | 213 | def is_dist_avail_and_initialized(): 214 | if not dist.is_available(): 215 | return False 216 | if not dist.is_initialized(): 217 | return False 218 | return True 219 | 220 | 221 | def get_world_size(): 222 | if not is_dist_avail_and_initialized(): 223 | return 1 224 | return dist.get_world_size() 225 | 226 | 227 | def all_reduce_mean(x): 228 | world_size = get_world_size() 229 | if world_size > 1: 230 | x_reduce = torch.tensor(x).cuda() 231 | dist.all_reduce(x_reduce) 232 | x_reduce /= world_size 233 | return x_reduce.item() 234 | else: 235 | return x 236 | 237 | def adjust_learning_rate(optimizer, epochs, epoch, lr, min_lr): 238 | """Decay the learning rate with half-cycle cosine after warmup""" 239 | # if epoch < args.warmup_epochs: 240 | # lr = args.lr * epoch / args.warmup_epochs 241 | # else: 242 | warmup_epochs = 40 243 | lr = min_lr + (lr - min_lr) * 0.5 * \ 244 | (1. + math.cos(math.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs))) 245 | # for param_group in optimizer.param_groups: 246 | # if "lr_scale" in param_group: 247 | # param_group["lr"] = lr * param_group["lr_scale"] 248 | # else: 249 | # param_group["lr"] = lr 250 | return lr 251 | 252 | def setup_seed(seed=3407): 253 | os.environ['PYTHONHASHSEED'] = str(seed) 254 | 255 | torch.manual_seed(seed) 256 | torch.cuda.manual_seed(seed) 257 | torch.cuda.manual_seed_all(seed) 258 | 259 | np.random.seed(seed) 260 | random.seed(seed) 261 | 262 | torch.backends.cudnn.deterministic = True 263 | torch.backends.cudnn.benchmark = False 264 | # torch.backends.cudnn.enabled = False -------------------------------------------------------------------------------- /vision_transformer_hybrid.py: -------------------------------------------------------------------------------- 1 | """ Hybrid Vision Transformer (ViT) in PyTorch 2 | 3 | A PyTorch implement of the Hybrid Vision Transformers as described in: 4 | 5 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' 6 | - https://arxiv.org/abs/2010.11929 7 | 8 | `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` 9 | - https://arxiv.org/abs/2106.10270 10 | 11 | NOTE These hybrid model definitions depend on code in vision_transformer.py. 12 | They were moved here to keep file sizes sane. 13 | 14 | Hacked together by / Copyright 2020, Ross Wightman 15 | """ 16 | from copy import deepcopy 17 | from functools import partial 18 | 19 | import torch 20 | import torch.nn as nn 21 | 22 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 23 | from timm.models.layers import StdConv2dSame, StdConv2d, to_2tuple 24 | from timm.models.resnet import resnet26d, resnet50d 25 | from timm.models.resnetv2 import ResNetV2, create_resnetv2_stem 26 | from timm.models.registry import register_model 27 | # from timm.models.vision_transformer import _create_vision_transformer 28 | from vision_transformer import _create_vision_transformer 29 | 30 | def _cfg(url='', **kwargs): 31 | return { 32 | 'url': url, 33 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 34 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 35 | 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 36 | 'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head', 37 | **kwargs 38 | } 39 | 40 | 41 | default_cfgs = { 42 | # hybrid in-1k models (weights from official JAX impl where they exist) 43 | 'vit_tiny_r_s16_p8_224': _cfg( 44 | url='https://storage.googleapis.com/vit_models/augreg/' 45 | 'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', 46 | first_conv='patch_embed.backbone.conv'), 47 | 'vit_tiny_r_s16_p8_384': _cfg( 48 | url='https://storage.googleapis.com/vit_models/augreg/' 49 | 'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 50 | first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0), 51 | 'vit_small_r26_s32_224': _cfg( 52 | url='https://storage.googleapis.com/vit_models/augreg/' 53 | 'R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz', 54 | ), 55 | 'vit_small_r26_s32_384': _cfg( 56 | url='https://storage.googleapis.com/vit_models/augreg/' 57 | 'R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 58 | input_size=(3, 384, 384), crop_pct=1.0), 59 | 'vit_base_r26_s32_224': _cfg(), 60 | 'vit_base_r50_s16_224': _cfg(), 61 | 'vit_base_r50_s16_384': _cfg( 62 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', 63 | input_size=(3, 384, 384), crop_pct=1.0), 64 | 'vit_large_r50_s32_224': _cfg( 65 | url='https://storage.googleapis.com/vit_models/augreg/' 66 | 'R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz' 67 | ), 68 | 'vit_large_r50_s32_384': _cfg( 69 | url='https://storage.googleapis.com/vit_models/augreg/' 70 | 'R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', 71 | input_size=(3, 384, 384), crop_pct=1.0 72 | ), 73 | 74 | # hybrid in-21k models (weights from official Google JAX impl where they exist) 75 | 'vit_tiny_r_s16_p8_224_in21k': _cfg( 76 | url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', 77 | num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv'), 78 | 'vit_small_r26_s32_224_in21k': _cfg( 79 | url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz', 80 | num_classes=21843, crop_pct=0.9), 81 | 'vit_base_r50_s16_224_in21k': _cfg( 82 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', 83 | num_classes=21843, crop_pct=0.9), 84 | 'vit_large_r50_s32_224_in21k': _cfg( 85 | url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz', 86 | num_classes=21843, crop_pct=0.9), 87 | 88 | # hybrid models (using timm resnet backbones) 89 | 'vit_small_resnet26d_224': _cfg( 90 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), 91 | 'vit_small_resnet50d_s16_224': _cfg( 92 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), 93 | 'vit_base_resnet26d_224': _cfg( 94 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), 95 | 'vit_base_resnet50d_224': _cfg( 96 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), 97 | } 98 | 99 | 100 | class HybridEmbed(nn.Module): 101 | """ CNN Feature Map Embedding 102 | Extract feature map from CNN, flatten, project to embedding dim. 103 | """ 104 | def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768): 105 | super().__init__() 106 | assert isinstance(backbone, nn.Module) 107 | img_size = to_2tuple(img_size) 108 | patch_size = to_2tuple(patch_size) 109 | self.img_size = img_size 110 | self.patch_size = patch_size 111 | self.backbone = backbone 112 | if feature_size is None: 113 | with torch.no_grad(): 114 | # NOTE Most reliable way of determining output dims is to run forward pass 115 | training = backbone.training 116 | if training: 117 | backbone.eval() 118 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) 119 | if isinstance(o, (list, tuple)): 120 | o = o[-1] # last feature if backbone outputs list/tuple of features 121 | feature_size = o.shape[-2:] 122 | feature_dim = o.shape[1] 123 | backbone.train(training) 124 | else: 125 | feature_size = to_2tuple(feature_size) 126 | if hasattr(self.backbone, 'feature_info'): 127 | feature_dim = self.backbone.feature_info.channels()[-1] 128 | else: 129 | feature_dim = self.backbone.num_features 130 | assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 131 | self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) 132 | self.num_patches = self.grid_size[0] * self.grid_size[1] 133 | self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size) 134 | 135 | def forward(self, x): 136 | x = self.backbone(x) 137 | feature = x 138 | if isinstance(x, (list, tuple)): 139 | x = x[-1] # last feature if backbone outputs list/tuple of features 140 | x = self.proj(x).flatten(2).transpose(1, 2) 141 | return x, feature 142 | 143 | 144 | def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): 145 | embed_layer = partial(HybridEmbed, backbone=backbone) 146 | 147 | kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set 148 | # print(kwargs) 149 | return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs) 150 | 151 | 152 | def _resnetv2(layers=(3, 4, 9), **kwargs): 153 | """ ResNet-V2 backbone helper""" 154 | padding_same = kwargs.get('padding_same', True) 155 | stem_type = 'same' if padding_same else '' 156 | conv_layer = partial(StdConv2dSame, eps=1e-8) if padding_same else partial(StdConv2d, eps=1e-8) 157 | if len(layers): 158 | backbone = ResNetV2( 159 | layers=layers, num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3), 160 | preact=False, stem_type=stem_type, conv_layer=conv_layer) 161 | else: 162 | backbone = create_resnetv2_stem( 163 | kwargs.get('in_chans', 3), stem_type=stem_type, preact=False, conv_layer=conv_layer) 164 | return backbone 165 | 166 | 167 | @register_model 168 | def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs): 169 | """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224. 170 | """ 171 | backbone = _resnetv2(layers=(), **kwargs) 172 | model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) 173 | model = _create_vision_transformer_hybrid( 174 | 'vit_tiny_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs) 175 | return model 176 | 177 | 178 | @register_model 179 | def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs): 180 | """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384. 181 | """ 182 | backbone = _resnetv2(layers=(), **kwargs) 183 | model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) 184 | model = _create_vision_transformer_hybrid( 185 | 'vit_tiny_r_s16_p8_384', backbone=backbone, pretrained=pretrained, **model_kwargs) 186 | return model 187 | 188 | 189 | @register_model 190 | def vit_small_r26_s32_224(pretrained=False, **kwargs): 191 | """ R26+ViT-S/S32 hybrid. 192 | """ 193 | backbone = _resnetv2((2, 2, 2, 2), **kwargs) 194 | model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) 195 | model = _create_vision_transformer_hybrid( 196 | 'vit_small_r26_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs) 197 | return model 198 | 199 | 200 | @register_model 201 | def vit_small_r26_s32_384(pretrained=False, **kwargs): 202 | """ R26+ViT-S/S32 hybrid. 203 | """ 204 | backbone = _resnetv2((2, 2, 2, 2), **kwargs) 205 | model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) 206 | model = _create_vision_transformer_hybrid( 207 | 'vit_small_r26_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs) 208 | return model 209 | 210 | 211 | @register_model 212 | def vit_base_r26_s32_224(pretrained=False, **kwargs): 213 | """ R26+ViT-B/S32 hybrid. 214 | """ 215 | backbone = _resnetv2((2, 2, 2, 2), **kwargs) 216 | model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) 217 | model = _create_vision_transformer_hybrid( 218 | 'vit_base_r26_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs) 219 | return model 220 | 221 | 222 | @register_model 223 | def vit_base_r50_s16_224(pretrained=False, **kwargs): 224 | """ R50+ViT-B/S16 hybrid from original paper (https://arxiv.org/abs/2010.11929). 225 | """ 226 | backbone = _resnetv2((3, 4, 9), **kwargs) 227 | model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) 228 | model = _create_vision_transformer_hybrid( 229 | 'vit_base_r50_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs) 230 | return model 231 | 232 | 233 | @register_model 234 | def vit_base_r50_s16_384(pretrained=False, **kwargs): 235 | """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929). 236 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 237 | """ 238 | backbone = _resnetv2((3, 4, 9), **kwargs) 239 | model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) 240 | model = _create_vision_transformer_hybrid( 241 | 'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs) 242 | return model 243 | 244 | 245 | @register_model 246 | def vit_base_resnet50_384(pretrained=False, **kwargs): 247 | # DEPRECATED this is forwarding to model def above for backwards compatibility 248 | return vit_base_r50_s16_384(pretrained=pretrained, **kwargs) 249 | 250 | 251 | @register_model 252 | def vit_large_r50_s32_224(pretrained=False, **kwargs): 253 | """ R50+ViT-L/S32 hybrid. 254 | """ 255 | backbone = _resnetv2((3, 4, 6, 3), **kwargs) 256 | model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs) 257 | model = _create_vision_transformer_hybrid( 258 | 'vit_large_r50_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs) 259 | return model 260 | 261 | 262 | @register_model 263 | def vit_large_r50_s32_384(pretrained=False, **kwargs): 264 | """ R50+ViT-L/S32 hybrid. 265 | """ 266 | backbone = _resnetv2((3, 4, 6, 3), **kwargs) 267 | model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs) 268 | model = _create_vision_transformer_hybrid( 269 | 'vit_large_r50_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs) 270 | return model 271 | 272 | 273 | @register_model 274 | def vit_tiny_r_s16_p8_224_in21k(pretrained=False, **kwargs): 275 | """ R+ViT-Ti/S16 w/ 8x8 patch hybrid. ImageNet-21k. 276 | """ 277 | backbone = _resnetv2(layers=(), **kwargs) 278 | model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) 279 | model = _create_vision_transformer_hybrid( 280 | 'vit_tiny_r_s16_p8_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) 281 | return model 282 | 283 | 284 | @register_model 285 | def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs): 286 | """ R26+ViT-S/S32 hybrid. ImageNet-21k. 287 | """ 288 | backbone = _resnetv2((2, 2, 2, 2), **kwargs) 289 | model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) 290 | model = _create_vision_transformer_hybrid( 291 | 'vit_small_r26_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) 292 | return model 293 | 294 | 295 | @register_model 296 | def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): 297 | """ R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929). 298 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 299 | """ 300 | backbone = _resnetv2(layers=(3, 4, 9), **kwargs) 301 | model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) 302 | model = _create_vision_transformer_hybrid( 303 | 'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) 304 | return model 305 | 306 | 307 | @register_model 308 | def vit_base_resnet50_224_in21k(pretrained=False, **kwargs): 309 | # DEPRECATED this is forwarding to model def above for backwards compatibility 310 | return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs) 311 | 312 | 313 | @register_model 314 | def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs): 315 | """ R50+ViT-L/S32 hybrid. ImageNet-21k. 316 | """ 317 | backbone = _resnetv2((3, 4, 6, 3), **kwargs) 318 | model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs) 319 | model = _create_vision_transformer_hybrid( 320 | 'vit_large_r50_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) 321 | return model 322 | 323 | 324 | @register_model 325 | def vit_small_resnet26d_224(pretrained=False, **kwargs): 326 | """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights. 327 | """ 328 | backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) 329 | model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, **kwargs) 330 | model = _create_vision_transformer_hybrid( 331 | 'vit_small_resnet26d_224', backbone=backbone, pretrained=pretrained, **model_kwargs) 332 | return model 333 | 334 | 335 | @register_model 336 | def vit_small_resnet50d_s16_224(pretrained=False, **kwargs): 337 | """ Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights. 338 | """ 339 | backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3]) 340 | model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, **kwargs) 341 | model = _create_vision_transformer_hybrid( 342 | 'vit_small_resnet50d_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs) 343 | return model 344 | 345 | 346 | @register_model 347 | def vit_base_resnet26d_224(pretrained=False, **kwargs): 348 | """ Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights. 349 | """ 350 | backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) 351 | model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) 352 | model = _create_vision_transformer_hybrid( 353 | 'vit_base_resnet26d_224', backbone=backbone, pretrained=pretrained, **model_kwargs) 354 | return model 355 | 356 | 357 | @register_model 358 | def vit_base_resnet50d_224(pretrained=False, **kwargs): 359 | """ Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights. 360 | """ 361 | backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) 362 | model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) 363 | model = _create_vision_transformer_hybrid( 364 | 'vit_base_resnet50d_224', backbone=backbone, pretrained=pretrained, **model_kwargs) 365 | return model 366 | --------------------------------------------------------------------------------