├── .gitignore ├── .vscode └── launch.json ├── LICENSE ├── README.md ├── batch_inference_time.sh ├── datasets ├── Dataloader_University.py ├── autoaugment.py ├── make_dataloader.py └── queryDataset.py ├── docs ├── Get_started.md ├── Request.md ├── images │ ├── data.jpg │ ├── framework.jpg │ └── model.png └── training_parameters.md ├── evaluateDistance.py ├── evaluateDistance_DifHeight.py ├── evaluateMA.py ├── evaluateMA_dense.py ├── evaluate_RDS.py ├── evaluate_gpu.py ├── heatmap.py ├── losses ├── ArcfaceLoss.py ├── FocalLoss.py ├── TripletLoss.py ├── __Init__.py ├── cal_loss.py └── loss.py ├── models ├── Backbone │ ├── RKNet.py │ ├── __init__.py │ ├── backbone.py │ └── cvt.py ├── Head │ ├── FSRA.py │ ├── GeM.py │ ├── LPN.py │ ├── NeXtVLAD.py │ ├── NetVLAD.py │ ├── SingleBranch.py │ ├── __init__.py │ ├── head.py │ └── utils.py ├── __init__.py ├── model.py └── taskflow.py ├── optimizers └── make_optimizer.py ├── requirments.txt ├── test.py ├── test_hard.py ├── tool ├── SDM@K_analyze.py ├── SDM@K_compare.py ├── applications │ ├── forwardAllSatelliteHub.py │ ├── inference_global.py │ └── inference_neibor.py ├── dataset_preprocess │ ├── 1-generateSatelliteByUav.py │ ├── 2-generate_new_croped_resized_images_difheights.py │ ├── 3-generate_format_testset.py │ ├── TEST-1-downloadALlAndReCut.py │ ├── TEST-2-generateSatelliteHub.py │ ├── TEST-3-preprocess_difHeightTest.py │ ├── TEST-Regenerate_dense_testset.py │ ├── get_property.py │ ├── google_interface.py │ ├── google_interface_2.py │ ├── split_dataset_long_middle_short.py │ ├── utils.py │ └── validation_testset.py ├── get_inference_time.py ├── get_model_flops_params.py ├── get_property.py ├── mount_dist.sh ├── transforms │ ├── rotatetranformtest.py │ └── transform_visual.py ├── utils.py ├── visual │ ├── Times New Roman.ttf │ ├── demo.py │ ├── demo_custom.py │ ├── demo_custom_visualization.py │ ├── draw_SDMcurve.py │ ├── grad_cam.py │ └── heatmap.py └── visual_demo.py ├── train.py └── train_test_local.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed depepwdndencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | *.pth 132 | maps/ 133 | checkpoints/* 134 | visualization/ 135 | .history/ 136 | .vscode/ 137 | experiment_*.sh 138 | experiments/ -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // 使用 IntelliSense 了解相关属性。 3 | // 悬停以查看现有属性的描述。 4 | // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: 当前文件", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal", 13 | "justMyCode": true, 14 | "cwd": "${workspaceFolder}/checkpoints/Head_Experiment-FSRA2B" 15 | } 16 | ] 17 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Vision-Based UAV Self-Positioning in Low-Altitude Urban Environments

2 | 3 | This repository contains code and dataset for the paper titled [Vision-Based UAV Self-Positioning in Low-Altitude Urban Environments](https://arxiv.org/abs/2201.09201). In this paper, we propose a method for accurately self-positioning unmanned aerial vehicles (UAVs) in challenging low-altitude urban environments using vision-based techniques. We provide the DenseUAV dataset and a Baseline model implementation to facilitate research in this task. Thank you for your kind attention. 4 | 5 | ![](https://github.com/Dmmm1997/DenseUAV/blob/main/docs/images/data.jpg) 6 | 7 | ![](https://github.com/Dmmm1997/DenseUAV/blob/main/docs/images/framework.jpg) 8 | 9 | ![](https://github.com/Dmmm1997/DenseUAV/blob/main/docs/images/model.png) 10 | 11 | ## News 12 | 13 | - **`2023/12/18`**: Our paper is accepted by IEEE Trans on Image Process. 14 | - **`2023/8/14`**: Our dataset and code are released. 15 | 16 | ## Table of contents 17 | 18 | - [News](#news) 19 | - [Table of contents](#table-of-contents) 20 | - [About Dataset](#about-dataset) 21 | - [Prerequisites](#prerequisites) 22 | - [Installation](#installation) 23 | - [Dataset \& Preparation](#dataset--preparation) 24 | - [Train \& Evaluation](#train--evaluation) 25 | - [Training and Testing](#training-and-testing) 26 | - [Evaluation](#evaluation) 27 | - [Supported Methods](#supported-methods) 28 | - [License](#license) 29 | - [Citation](#citation) 30 | - [Related Work](#related-work) 31 | 32 | ## About Dataset 33 | 34 | The dataset split is as follows: 35 | | Subset | UAV-view | Satellite-view | Classes | universities | 36 | | -------- | ----- | ---- | ---- | ---- | 37 | | Training | 6,768 | 13,536 | 2,256 | 10 | 38 | | Query | 2,331 | 4,662 | 777 | 4 | 39 | | Gallery | 9099 | 18198 | 3033 | 14 | 40 | 41 | More detailed file structure: 42 | 43 | ``` 44 | ├── DenseUAV/ 45 | │ ├── Dense_GPS_ALL.txt /* format as: path latitude longitude height 46 | │ ├── Dense_GPS_test.txt 47 | │ ├── Dense_GPS_train.txt 48 | │ ├── train/ 49 | │ ├── drone/ /* drone-view training images 50 | │ ├── 000001 51 | │ ├── H100.JPG 52 | │ ├── H90.JPG 53 | │ ├── H80.JPG 54 | | ... 55 | │ ├── satellite/ /* satellite-view training images 56 | │ ├── 000001 57 | │ ├── H100_old.tif 58 | │ ├── H90_old.tif 59 | │ ├── H80_old.tif 60 | │ ├── H100.tif 61 | │ ├── H90.tif 62 | │ ├── H80.tif 63 | | ... 64 | │ ├── test/ 65 | │ ├── query_drone/ /* UAV-view testing images 66 | │ ├── query_satellite/ /* satellite-view testing images 67 | ``` 68 | 69 | ## Prerequisites 70 | 71 | - Python 3.7+ 72 | - GPU Memory >= 8G 73 | - Numpy 1.21.2 74 | - Pytorch 1.10.0+cu113 75 | - Torchvision 0.11.1+cu113 76 | 77 | ## Installation 78 | 79 | It is best to use cuda version 11.3 and pytorch version 1.10.0. You can download the corresponding version from this [website](https://download.pytorch.org/whl/torch_stable.html) and install it through `pip install`. Then you can execute the following command to install all dependencies. 80 | 81 | ``` 82 | pip install -r requirments.txt 83 | ``` 84 | 85 | Create the directory for saving the training log and ckpts. 86 | 87 | ``` 88 | mkdir checkpoints 89 | ``` 90 | 91 | ## Dataset & Preparation 92 | 93 | Download DenseUAV upon request. You may use the request [Template](https://github.com/Dmmm1997/DenseUAV//blob/main/docs/Request.md). 94 | 95 | ## Train & Evaluation 96 | 97 | ### Training and Testing 98 | 99 | You could execute the following command to implement the entire process of training and testing. 100 | 101 | ``` 102 | bash train_test_local.sh 103 | ``` 104 | 105 | The setting of parameters in **train_test_local.sh** can refer to [Get Started](https://github.com/Dmmm1997/DenseUAV/blob/main/docs/training_parameters.md). 106 | 107 | ### Evaluation 108 | 109 | The following commands are required to evaluate Recall and SDM separately. 110 | 111 | ``` 112 | cd checkpoints/ 113 | python test.py --name --test_dir --gpu_ids 0 --num_worker 4 114 | ``` 115 | 116 | the `` is the dir name in your training setting, you can find in the `checkpoints/`. 117 | 118 | **For Recall** 119 | 120 | ``` 121 | python evaluate_gpu.py 122 | ``` 123 | 124 | **For SDM** 125 | 126 | ``` 127 | python evaluateDistance.py --root_dir 128 | ``` 129 | 130 | We also provide the baseline checkpoints, [quark](https://pan.quark.cn/s/3ced42633793) [one-drive](https://seunic-my.sharepoint.cn/:u:/g/personal/230238525_seu_edu_cn/EUFoYjIdK_JNuxmvpb5QjLcB1hUHyedGwOnT3wTeN7Zqdg?e=LZuUxz). 131 | 132 | ``` 133 | unzip -d checkpoints 134 | cd checkpoints/baseline 135 | python test.py --test_dir /test 136 | python evaluate_gpu.py 137 | python evaluateDistance.py --root_dir 138 | ``` 139 | 140 | ## Supported Methods 141 | 142 | | Augment | Backbone | Head | Loss | 143 | | ----------------- | --------------- | ----------- | -------------------------- | 144 | | Random Rotate | ResNet | MaxPool | CrossEntropy Loss. | 145 | | Random Affine | EfficientNet | AvgPool | Focal Loss | 146 | | Random Brightness | ConvNext | MaxAvgPool | Triplet Loss | 147 | | Random Erasing | DeiT | GlobalPool | Hard-Mining Triplet Loss | 148 | | | PvT | GemPool | Same-Domain Triplet Loss | 149 | | | SwinTransformer | LPN | Soft-Weighted Triplet Loss | 150 | | | ViT | FSRA | KL Loss | 151 | 152 | ## License 153 | 154 | This project is licensed under the [Apache 2.0 license](https://github.com/Dmmm1997/DenseUAV//blob/main/LICENSE). 155 | 156 | ## Citation 157 | 158 | The following paper uses and reports the result of the baseline model. You may cite it in your paper. 159 | 160 | ```bibtex 161 | @ARTICLE{DenseUAV, 162 | author={Dai, Ming and Zheng, Enhui and Feng, Zhenhua and Qi, Lei and Zhuang, Jiedong and Yang, Wankou}, 163 | journal={IEEE Transactions on Image Processing}, 164 | title={Vision-Based UAV Self-Positioning in Low-Altitude Urban Environments}, 165 | year={2024}, 166 | volume={33}, 167 | number={}, 168 | pages={493-508}, 169 | doi={10.1109/TIP.2023.3346279}} 170 | ``` 171 | 172 | ## Related Work 173 | 174 | - University-1652 [https://github.com/layumi/University1652-Baseline](https://github.com/layumi/University1652-Baseline) 175 | - FSRA [https://github.com/Dmmm1997/FSRA](https://github.com/Dmmm1997/FSRA) 176 | -------------------------------------------------------------------------------- /batch_inference_time.sh: -------------------------------------------------------------------------------- 1 | data_dir="/home/dmmm/Dataset/DenseUAV/data_2022/train" #"/media/dmmm/4T-3/DataSets/DenseCV_Data/高度数据集/data_2021/train" 2 | # data_dir="/media/dmmm/4T-3/DataSets/DenseCV_Data/高度数据集/data_2021/train" 3 | test_dir="/home/dmmm/Dataset/DenseUAV/data_2022/test" #"/media/dmmm/4T-3/DataSets/DenseCV_Data/高度数据集/data_2021/test" 4 | # test_dir="/media/dmmm/4T-3/DataSets/DenseCV_Data/高度数据集/data_2021/test" 5 | num_worker=8 6 | gpu_ids=0 7 | 8 | name="checkpoints/Backbone_Experiment_SENet" 9 | cd $name 10 | cd tool 11 | python get_inference_time.py --name $name 12 | cd ../../../ 13 | 14 | # name="checkpoints/Backbone_Experiment_ConvnextT" 15 | # cd $name 16 | # cd tool 17 | # python get_inference_time.py --name $name 18 | # cd ../../../ 19 | 20 | # name="checkpoints/Backbone_Experiment_DeitS" 21 | # cd $name 22 | # cd tool 23 | # python get_inference_time.py --name $name 24 | # cd ../../../ 25 | 26 | # name="checkpoints/Backbone_Experiment_EfficientNet-B2" 27 | # cd $name 28 | # cd tool 29 | # python get_inference_time.py --name $name 30 | # cd ../../../ 31 | 32 | # name="checkpoints/Backbone_Experiment_EfficientNet-B3" 33 | # cd $name 34 | # cd tool 35 | # python get_inference_time.py --name $name 36 | # cd ../../../ 37 | 38 | # name="checkpoints/Backbone_Experiment_PvTv2b2" 39 | # cd $name 40 | # cd tool 41 | # python get_inference_time.py --name $name 42 | # cd ../../../ 43 | 44 | # name="checkpoints/Backbone_Experiment_resnet50" 45 | # cd $name 46 | # cd tool 47 | # python get_inference_time.py --name $name 48 | # cd ../../../ 49 | 50 | # name="checkpoints/Backbone_Experiment_Swinv2T-256" 51 | # cd $name 52 | # cd tool 53 | # python get_inference_time.py --name $name --test_h 256 --test_w 256 54 | # cd ../../../ 55 | 56 | # name="checkpoints/Backbone_Experiment_VGG16" 57 | # cd $name 58 | # cd tool 59 | # python get_inference_time.py --name $name 60 | # cd ../../../ 61 | 62 | # name="checkpoints/Backbone_Experiment_ViTB" 63 | # cd $name 64 | # cd tool 65 | # python get_inference_time.py --name $name 66 | # cd ../../../ 67 | 68 | # name="checkpoints/Head_Experiment-FSRA2B" 69 | # cd $name 70 | # cd tool 71 | # python get_inference_time.py --name $name 72 | # cd ../../../ 73 | 74 | # name="checkpoints/Head_Experiment-FSRA3B" 75 | # cd $name 76 | # cd tool 77 | # python get_inference_time.py --name $name 78 | # cd ../../../ 79 | 80 | # name="checkpoints/Head_Experiment-GeM" 81 | # cd $name 82 | # cd tool 83 | # python get_inference_time.py --name $name 84 | # cd ../../../ 85 | 86 | # name="checkpoints/Head_Experiment-LPN2B" 87 | # cd $name 88 | # cd tool 89 | # python get_inference_time.py --name $name 90 | # cd ../../../ 91 | 92 | # name="checkpoints/Head_Experiment-LPN3B" 93 | # cd $name 94 | # cd tool 95 | # python get_inference_time.py --name $name 96 | # cd ../../../ 97 | -------------------------------------------------------------------------------- /datasets/Dataloader_University.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | from torchvision import datasets, transforms 4 | import os 5 | import numpy as np 6 | from PIL import Image 7 | import glob 8 | 9 | 10 | class Dataloader_University(Dataset): 11 | def __init__(self, root, transforms, names=['satellite', 'drone']): 12 | super(Dataloader_University).__init__() 13 | self.transforms_drone_street = transforms['train'] 14 | self.transforms_satellite = transforms['satellite'] 15 | self.root = root 16 | self.names = names 17 | # 获取所有图片的相对路径分别放到对应的类别中 18 | # {satelite:{0839:[0839.jpg],0840:[0840.jpg]}} 19 | dict_path = {} 20 | for name in names: 21 | dict_ = {} 22 | for cls_name in os.listdir(os.path.join(root, name)): 23 | img_list = os.listdir(os.path.join(root, name, cls_name)) 24 | img_path_list = [os.path.join( 25 | root, name, cls_name, img) for img in img_list] 26 | dict_[cls_name] = img_path_list 27 | dict_path[name] = dict_ 28 | # dict_path[name+"/"+cls_name] = img_path_list 29 | 30 | # 获取设置名字与索引之间的镜像 31 | cls_names = os.listdir(os.path.join(root, names[0])) 32 | cls_names.sort() 33 | map_dict = {i: cls_names[i] for i in range(len(cls_names))} 34 | 35 | self.cls_names = cls_names 36 | self.map_dict = map_dict 37 | self.dict_path = dict_path 38 | self.index_cls_nums = 2 39 | 40 | # 从对应的类别中抽一张出来 41 | def sample_from_cls(self, name, cls_num): 42 | img_path = self.dict_path[name][cls_num] 43 | img_path = np.random.choice(img_path, 1)[0] 44 | img = Image.open(img_path).convert("RGB") 45 | return img 46 | 47 | def __getitem__(self, index): 48 | cls_nums = self.map_dict[index] 49 | img = self.sample_from_cls("satellite", cls_nums) 50 | img_s = self.transforms_satellite(img) 51 | 52 | # img = self.sample_from_cls("street",cls_nums) 53 | # img_st = self.transforms_drone_street(img) 54 | 55 | img = self.sample_from_cls("drone", cls_nums) 56 | img_d = self.transforms_drone_street(img) 57 | return img_s, img_d, index 58 | 59 | def __len__(self): 60 | return len(self.cls_names) 61 | 62 | 63 | class DataLoader_Inference(Dataset): 64 | def __init__(self, root, transforms): 65 | super(DataLoader_Inference, self).__init__() 66 | self.root = root 67 | self.imgs = glob.glob(root+"/*.tif") 68 | self.tranforms = transforms 69 | sorted(self.imgs) 70 | self.labels = [os.path.basename(img).split(".tif")[ 71 | 0] for img in self.imgs] 72 | 73 | def __getitem__(self, index): 74 | img = Image.open(self.imgs[index]) 75 | return self.tranforms(img), self.labels[index] 76 | 77 | def __len__(self): 78 | return len(self.imgs) 79 | 80 | 81 | class Sampler_University(object): 82 | r"""Base class for all Samplers. 83 | Every Sampler subclass has to provide an :meth:`__iter__` method, providing a 84 | way to iterate over indices of dataset elements, and a :meth:`__len__` method 85 | that returns the length of the returned iterators. 86 | .. note:: The :meth:`__len__` method isn't strictly required by 87 | :class:`~torch.utils.data.DataLoader`, but is expected in any 88 | calculation involving the length of a :class:`~torch.utils.data.DataLoader`. 89 | """ 90 | 91 | def __init__(self, data_source, batchsize=8, sample_num=4): 92 | self.data_len = len(data_source) 93 | self.batchsize = batchsize 94 | self.sample_num = sample_num 95 | 96 | def __iter__(self): 97 | list = np.arange(0, self.data_len) 98 | np.random.shuffle(list) 99 | nums = np.repeat(list, self.sample_num, axis=0) 100 | return iter(nums) 101 | 102 | def __len__(self): 103 | return len(self.data_source) 104 | 105 | 106 | def train_collate_fn(batch): 107 | """ 108 | # collate_fn这个函数的输入就是一个list,list的长度是一个batch size,list中的每个元素都是__getitem__得到的结果 109 | """ 110 | img_s, img_d, ids = zip(*batch) 111 | ids = torch.tensor(ids, dtype=torch.int64) 112 | return [torch.stack(img_s, dim=0), ids], [torch.stack(img_d, dim=0), ids] 113 | 114 | 115 | if __name__ == '__main__': 116 | transform_train_list = [ 117 | # transforms.RandomResizedCrop(size=(opt.h, opt.w), scale=(0.75,1.0), ratio=(0.75,1.3333), interpolation=3), #Image.BICUBIC) 118 | transforms.Resize((256, 256), interpolation=3), 119 | transforms.Pad(10, padding_mode='edge'), 120 | transforms.RandomCrop((256, 256)), 121 | transforms.RandomHorizontalFlip(), 122 | transforms.ToTensor(), 123 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 124 | ] 125 | 126 | transform_train_list = {"satellite": transforms.Compose(transform_train_list), 127 | "train": transforms.Compose(transform_train_list)} 128 | datasets = Dataloader_University(root="/home/dmmm/University-Release/train", 129 | transforms=transform_train_list, names=['satellite', 'drone']) 130 | samper = Sampler_University(datasets, 8) 131 | dataloader = DataLoader(datasets, batch_size=8, num_workers=0, 132 | sampler=samper, collate_fn=train_collate_fn) 133 | for data_s, data_d in dataloader: 134 | print() 135 | -------------------------------------------------------------------------------- /datasets/make_dataloader.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from .Dataloader_University import Sampler_University, Dataloader_University, train_collate_fn 3 | from .autoaugment import ImageNetPolicy 4 | import torch 5 | from .queryDataset import RotateAndCrop, RandomCrop, RandomErasing 6 | 7 | 8 | def make_dataset(opt): 9 | transform_train_list = [] 10 | transform_satellite_list = [] 11 | if "uav" in opt.rr: 12 | transform_train_list.append(RotateAndCrop(0.5)) 13 | if "satellite" in opt.rr: 14 | transform_satellite_list.append(RotateAndCrop(0.5)) 15 | transform_train_list += [ 16 | transforms.Resize((opt.h, opt.w), interpolation=3), 17 | transforms.Pad(opt.pad, padding_mode='edge'), 18 | transforms.RandomHorizontalFlip(), 19 | ] 20 | 21 | transform_satellite_list += [ 22 | transforms.Resize((opt.h, opt.w), interpolation=3), 23 | transforms.Pad(opt.pad, padding_mode='edge'), 24 | transforms.RandomHorizontalFlip(), 25 | ] 26 | 27 | transform_val_list = [ 28 | transforms.Resize(size=(opt.h, opt.w), 29 | interpolation=3), # Image.BICUBIC 30 | ] 31 | 32 | if "uav" in opt.ra: 33 | transform_train_list = transform_train_list + \ 34 | [transforms.RandomAffine(180)] 35 | if "satellite" in opt.ra: 36 | transform_satellite_list = transform_satellite_list + \ 37 | [transforms.RandomAffine(180)] 38 | 39 | if "uav" in opt.re: 40 | transform_train_list = transform_train_list + \ 41 | [RandomErasing(probability=opt.erasing_p)] 42 | if "satellite" in opt.re: 43 | transform_satellite_list = transform_satellite_list + \ 44 | [RandomErasing(probability=opt.erasing_p)] 45 | 46 | if "uav" in opt.cj: 47 | transform_train_list = transform_train_list + \ 48 | [transforms.ColorJitter(brightness=0.5, contrast=0.1, saturation=0.1, 49 | hue=0)] 50 | if "satellite" in opt.cj: 51 | transform_satellite_list = transform_satellite_list + \ 52 | [transforms.ColorJitter(brightness=0.5, contrast=0.1, saturation=0.1, 53 | hue=0)] 54 | 55 | if opt.DA: 56 | transform_train_list = [ImageNetPolicy()] + transform_train_list 57 | 58 | last_aug = [ 59 | transforms.ToTensor(), 60 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 61 | ] 62 | 63 | transform_train_list += last_aug 64 | transform_satellite_list += last_aug 65 | transform_val_list += last_aug 66 | 67 | print(transform_train_list) 68 | print(transform_satellite_list) 69 | 70 | data_transforms = { 71 | 'train': transforms.Compose(transform_train_list), 72 | 'val': transforms.Compose(transform_val_list), 73 | 'satellite': transforms.Compose(transform_satellite_list)} 74 | 75 | # custom Dataset 76 | image_datasets = Dataloader_University( 77 | opt.data_dir, transforms=data_transforms) 78 | samper = Sampler_University( 79 | image_datasets, batchsize=opt.batchsize, sample_num=opt.sample_num) 80 | dataloaders = torch.utils.data.DataLoader(image_datasets, batch_size=opt.batchsize, 81 | sampler=samper, num_workers=opt.num_worker, pin_memory=True, collate_fn=train_collate_fn) 82 | dataset_sizes = {x: len(image_datasets) * 83 | opt.sample_num for x in ['satellite', 'drone']} 84 | class_names = image_datasets.cls_names 85 | return dataloaders, class_names, dataset_sizes 86 | -------------------------------------------------------------------------------- /docs/Get_started.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dmmm1997/DenseUAV/d3e4335fb73e1eeeb8db6771d11f731ac8ef3c14/docs/Get_started.md -------------------------------------------------------------------------------- /docs/Request.md: -------------------------------------------------------------------------------- 1 | [Title] Request of DenseUAV Dataset 2 | 3 | This database will only be used for research purposes. I will not make any part of this database available to a third party. 4 | I'll not sell any part of this database or make any profit from its use. 5 | 6 | Thank you! 7 | 8 | 9 | *** Please send it to 869906992@qq.com via your academic email. I will reply the dataset address to you. *** -------------------------------------------------------------------------------- /docs/images/data.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dmmm1997/DenseUAV/d3e4335fb73e1eeeb8db6771d11f731ac8ef3c14/docs/images/data.jpg -------------------------------------------------------------------------------- /docs/images/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dmmm1997/DenseUAV/d3e4335fb73e1eeeb8db6771d11f731ac8ef3c14/docs/images/framework.jpg -------------------------------------------------------------------------------- /docs/images/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dmmm1997/DenseUAV/d3e4335fb73e1eeeb8db6771d11f731ac8ef3c14/docs/images/model.png -------------------------------------------------------------------------------- /docs/training_parameters.md: -------------------------------------------------------------------------------- 1 | # Parameter Introduction 2 | 3 | ### Data-Related Parameters 4 | - `--name`: Experiment name, used for saving models and parameters under `checkpoints/`, facilitating management and tracking of different experimental results. 5 | - `--data_dir`: Directory path for training data. 6 | - `--num_worker`: Number of worker threads used for data loading, affecting the parallelism and efficiency of data preprocessing. 7 | - `--pad`: Amount of padding for input data. Please distinguish this from the `--pad` in Position Shifting. 8 | - `--h, --w`: Height and width of the input images. 9 | - `--rr`: Random rotation applied to one or more views to enhance data diversity. 10 | - `--ra`: Random affine transformation applied to one or more views to enhance data diversity. 11 | - `--re`: Random occlusion applied to one or more views to enhance data diversity. 12 | - `--cj`: Color jitter applied to one or more views to enhance data diversity. 13 | - `--erasing_p`: Probability of random occlusion, controlling the proportion of randomly occluded areas in the images. 14 | 15 | ### Training-Related Parameters 16 | - `--warm_epoch`: Warm-up phase, setting the learning rate to gradually increase over the first `K` epochs. 17 | - `--lr`: Learning rate. 18 | - `--DA`: Whether to use color data augmentation. 19 | - `--droprate`: Dropout rate. 20 | - `--autocast`: Whether to use mixed precision training. 21 | - `--load_from`: Path to the pre-loaded checkpoint for restoring the model from a previous training state. 22 | - `--gpu_ids`: Specification of the GPU devices used, supporting multi-GPU configurations for flexible training environments. 23 | - `--batchsize`: Number of samples per training step. 24 | 25 | ### Model-Related Parameters 26 | - `--block`: Number of ClassBlocks in the model. 27 | - `--cls_loss`: Type of loss function for Representation Learning. Various preset or custom losses can be used, with `CELoss` as the default. 28 | - `--feature_loss`: Type of loss function for Metric Learning. Various preset or custom losses can be used, with no loss applied by default. 29 | - `--kl_loss`: Type of loss function for Mutual Learning. Various preset or custom losses can be used, with no loss applied by default. 30 | - `--num_bottleneck`: Dimensionality of feature embeddings. 31 | - `--backbone`: Backbone architecture used. Various preset or custom backbones can be selected, with `cvt13` as the default. 32 | - `--head`: Head architecture used. Various preset or custom heads can be selected, with `FSRA_CNN` as the default. 33 | - `--head_pool`: Type of pooling used in the head, with various preset or custom pooling methods available, defaulting to `max pooling`. 34 | -------------------------------------------------------------------------------- /evaluateDistance.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import scipy.io 3 | import torch 4 | import numpy as np 5 | import os 6 | from torchvision import datasets 7 | import matplotlib 8 | # matplotlib.use('agg') 9 | import matplotlib.pyplot as plt 10 | import json 11 | from tqdm import tqdm 12 | import math 13 | 14 | 15 | ##################################################################### 16 | # Show result 17 | def imshow(path, title=None): 18 | """Imshow for Tensor.""" 19 | im = plt.imread(path) 20 | plt.imshow(im) 21 | if title is not None: 22 | plt.title(title) 23 | plt.pause(0.1) # pause a bit so that plots are updated 24 | 25 | 26 | ###################################################################### 27 | 28 | 29 | def getLatitudeAndLongitude(imgPath): 30 | if isinstance(imgPath, list): 31 | posInfo = [configDict[p.split("/")[-2]] for p in imgPath] 32 | else: 33 | posInfo = configDict[imgPath.split("/")[-2]] 34 | return posInfo 35 | 36 | 37 | def euclideanDistance(query, gallery): 38 | query = np.array(query, dtype=np.float32) 39 | gallery = np.array(gallery, dtype=np.float32) 40 | A = gallery - query 41 | A_T = A.transpose() 42 | distance = np.matmul(A, A_T) 43 | mask = np.eye(distance.shape[0], dtype=np.bool8) 44 | distance = distance[mask] 45 | distance = np.sqrt(distance.reshape(-1)) 46 | return distance 47 | 48 | 49 | def evaluateSingle(distance, K): 50 | # maxDistance = max(distance) + 1e-14 51 | # weight = np.ones(K) - np.log(range(1, K + 1, 1)) / np.log(opts.M * K) 52 | weight = np.ones(K) - np.array(range(0, K, 1))/K 53 | # m1 = distance / maxDistance 54 | m2 = 1 / np.exp(distance*5e3) 55 | m3 = m2 * weight 56 | result = np.sum(m3) / np.sum(weight) 57 | return result 58 | 59 | 60 | def latlog2meter(lata, loga, latb, logb): 61 | # log 纬度 lat 经度 62 | # EARTH_RADIUS = 6371.0 63 | EARTH_RADIUS =6378.137 64 | PI = math.pi 65 | # // 转弧度 66 | lat_a = lata * PI / 180 67 | lat_b = latb * PI / 180 68 | a = lat_a - lat_b 69 | b = loga * PI / 180 - logb * PI / 180 70 | dis = 2 * math.asin( 71 | math.sqrt(math.pow(math.sin(a / 2), 2) + math.cos(lat_a) * math.cos(lat_b) * math.pow(math.sin(b / 2), 2))) 72 | 73 | distance = EARTH_RADIUS * dis * 1000 74 | return distance 75 | 76 | 77 | def evaluate_SDM(indexOfTopK, queryIndex, K): 78 | query_path, _ = image_datasets[query_name].imgs[queryIndex] 79 | galleryTopKPath = [image_datasets[gallery_name].imgs[i][0] 80 | for i in indexOfTopK[:K]] 81 | # get position information including latitude and longitude 82 | queryPosInfo = getLatitudeAndLongitude(query_path) 83 | galleryTopKPosInfo = getLatitudeAndLongitude(galleryTopKPath) 84 | # compute Euclidean distance of query and gallery 85 | distance = euclideanDistance(queryPosInfo, galleryTopKPosInfo) 86 | # compute single query evaluate result 87 | P = evaluateSingle(distance, K) 88 | return P 89 | 90 | 91 | def evaluate_MA(indexOfTop1, queryIndex): 92 | query_path, _ = image_datasets[query_name].imgs[queryIndex] 93 | galleryTopKPath = image_datasets[gallery_name].imgs[indexOfTop1][0] 94 | # get position information including latitude and longitude 95 | queryPosInfo = getLatitudeAndLongitude(query_path) 96 | galleryTopKPosInfo = getLatitudeAndLongitude(galleryTopKPath) 97 | # get real distance 98 | distance_meter = latlog2meter(queryPosInfo[1],queryPosInfo[0],galleryTopKPosInfo[1],galleryTopKPosInfo[0]) 99 | return distance_meter 100 | 101 | 102 | if '__main__' == __name__: 103 | ####################################################################### 104 | # Evaluate 105 | parser = argparse.ArgumentParser(description='Demo') 106 | # parser.add_argument('--query_index', default=10, type=int, help='test_image_index') 107 | parser.add_argument( 108 | '--root_dir', default='/home/dmmm/Dataset/DenseUAV/data_2022/', type=str, help='./test_data') 109 | parser.add_argument('--K', default=[1, 3, 5, 10], type=str, help='./test_data') 110 | parser.add_argument('--M', default=5e3, type=str, help='./test_data') 111 | parser.add_argument('--mode', default="1", type=str, 112 | help='1:drone->satellite 2:satellite->drone') 113 | opts = parser.parse_args() 114 | 115 | opts.config = os.path.join(opts.root_dir, "Dense_GPS_ALL.txt") 116 | opts.test_dir = os.path.join(opts.root_dir, "test") 117 | configDict = {} 118 | with open(opts.config, "r") as F: 119 | context = F.readlines() 120 | for line in context: 121 | splitLineList = line.split(" ") 122 | configDict[splitLineList[0].split("/")[-2]] = [float(splitLineList[1].split("E")[-1]), 123 | float(splitLineList[2].split("N")[-1])] 124 | 125 | if opts.mode == "1": 126 | gallery_name = 'gallery_satellite' 127 | query_name = 'query_drone' 128 | else: 129 | gallery_name = 'gallery_drone' 130 | query_name = 'query_satellite' 131 | 132 | data_dir = opts.test_dir 133 | image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x)) for x in [ 134 | gallery_name, query_name]} 135 | 136 | 137 | 138 | if opts.mode == "1": 139 | result = scipy.io.loadmat('pytorch_result_1.mat') 140 | else: 141 | result = scipy.io.loadmat('pytorch_result_2.mat') 142 | query_feature = torch.FloatTensor(result['query_f']) 143 | query_label = result['query_label'][0] 144 | gallery_feature = torch.FloatTensor(result['gallery_f']) 145 | gallery_label = result['gallery_label'][0] 146 | 147 | multi = os.path.isfile('multi_query.mat') 148 | 149 | if multi: 150 | m_result = scipy.io.loadmat('multi_query.mat') 151 | mquery_feature = torch.FloatTensor(m_result['mquery_f']) 152 | mquery_cam = m_result['mquery_cam'][0] 153 | mquery_label = m_result['mquery_label'][0] 154 | mquery_feature = mquery_feature.cuda() 155 | 156 | query_feature = query_feature.cuda() 157 | gallery_feature = gallery_feature.cuda() 158 | 159 | 160 | ####################################################################### 161 | # sort the images and return topK index 162 | def sort_img(qf, ql, gf, gl, K): 163 | query = qf.view(-1, 1) 164 | # print(query.shape) 165 | score = torch.mm(gf, query) 166 | score = score.squeeze(1).cpu() 167 | score = score.numpy() 168 | # predict index 169 | index = np.argsort(score) # from small to large 170 | index = index[::-1] 171 | # index = index[0:2000] 172 | # good index 173 | query_index = np.argwhere(gl == ql) 174 | 175 | # good_index = np.setdiff1d(query_index, camera_index, assume_unique=True) 176 | junk_index = np.argwhere(gl == -1) 177 | 178 | mask = np.in1d(index, junk_index, invert=True) 179 | index = index[mask] 180 | return index[:K] 181 | 182 | 183 | indexOfTopK_list = [] 184 | for i in range(len(query_label)): 185 | indexOfTopK = sort_img( 186 | query_feature[i], query_label[i], gallery_feature, gallery_label, 100) 187 | indexOfTopK_list.append(indexOfTopK) 188 | 189 | SDM_dict = {} 190 | for K in tqdm(range(1, 101, 1)): 191 | metric = 0 192 | for i in range(len(query_label)): 193 | P_ = evaluate_SDM(indexOfTopK_list[i], i, K) 194 | metric += P_ 195 | metric = metric / len(query_label) 196 | if K in opts.K: 197 | print("metric{} = {:.2f}%".format(K, metric * 100)) 198 | SDM_dict[K] = metric 199 | 200 | MA_dict = {} 201 | for meter in tqdm(range(1,101,1)): 202 | MA_K = 0 203 | for i in range(len(query_label)): 204 | MA_meter = evaluate_MA(indexOfTopK_list[i][0],i) 205 | if MA_meter None: 11 | super(Loss,self).__init__() 12 | self.opt = opt 13 | # 分类损失 14 | if opt.cls_loss == "CELoss": 15 | self.cls_loss = nn.CrossEntropyLoss() 16 | elif opt.cls_loss == "FocalLoss": 17 | self.cls_loss = FocalLoss(alpha=0.25, gamma=2, num_classes = opt.nclasses) 18 | else: 19 | self.cls_loss = None 20 | 21 | # 对比损失 22 | if opt.feature_loss == "TripletLoss": 23 | self.feature_loss = TripletLoss(margin=0.3, normalize_feature=True) 24 | elif opt.feature_loss == "HardMiningTripletLoss": 25 | self.feature_loss = HardMiningTripletLoss(margin=0.3, normalize_feature=True) 26 | elif opt.feature_loss == "SameDomainTripletLoss": 27 | self.feature_loss = SameDomainTripletLoss(margin=0.3) 28 | elif opt.feature_loss == "WeightedSoftTripletLoss": 29 | self.feature_loss = WeightedSoftTripletLoss() 30 | elif opt.feature_loss == "ContrastiveLoss": 31 | self.feature_loss = losses.ContrastiveLoss(pos_margin=0, neg_margin=1) 32 | else: 33 | self.feature_loss = None 34 | 35 | # KL 损失 36 | if opt.kl_loss == "KLLoss": 37 | self.kl_loss = nn.KLDivLoss(reduction='batchmean') 38 | else: 39 | self.kl_loss = None 40 | 41 | 42 | def forward(self, outputs, outputs2, labels, labels2): 43 | cls1,feature1 = outputs 44 | cls2,feature2 = outputs2 45 | loss = 0 46 | 47 | # 分类损失 48 | res_cls_loss = torch.tensor((0)) 49 | if self.cls_loss is not None: 50 | res_cls_loss = self.calc_cls_loss(cls1, labels, self.cls_loss) + \ 51 | self.calc_cls_loss(cls2, labels2, self.cls_loss) 52 | loss += res_cls_loss 53 | 54 | # 特征对比损失 55 | res_triplet_loss = torch.tensor((0)) 56 | if self.feature_loss is not None: 57 | split_num = self.opt.batchsize//self.opt.sample_num 58 | res_triplet_loss = self.calc_triplet_loss( 59 | feature1, feature2, labels, self.feature_loss, split_num) 60 | loss += res_triplet_loss 61 | 62 | # 相互学习 63 | res_kl_loss = torch.tensor((0)) 64 | if self.kl_loss is not None: 65 | res_kl_loss = self.calc_kl_loss(cls1, cls2, self.kl_loss) 66 | loss += res_kl_loss 67 | 68 | # if self.opt.epoch < self.opt.warm_epoch: 69 | # warm_up = 0.1 # We start from the 0.1*lrRate 70 | # warm_iteration = round(dataset_sizes['satellite'] / opt.batchsize) * opt.warm_epoch # first 5 epoch 71 | # warm_up = min(1.0, warm_up + 0.9 / warm_iteration) 72 | # loss *= warm_up 73 | 74 | return loss, res_cls_loss, res_triplet_loss, res_kl_loss 75 | 76 | 77 | def calc_cls_loss(self, outputs, labels, loss_func): 78 | loss = 0 79 | if isinstance(outputs, list): 80 | for i in outputs: 81 | loss += loss_func(i, labels) 82 | loss = loss/len(outputs) 83 | else: 84 | loss = loss_func(outputs, labels) 85 | return loss 86 | 87 | 88 | def calc_kl_loss(self, outputs, outputs2, loss_func): 89 | loss = 0 90 | if isinstance(outputs, list): 91 | for i in range(len(outputs)): 92 | loss += loss_func(F.log_softmax(outputs[i], dim=1), 93 | F.softmax(Variable(outputs2[i]), dim=1)) 94 | loss = loss/len(outputs) 95 | else: 96 | loss = loss_func(F.log_softmax(outputs, dim=1), 97 | F.softmax(Variable(outputs2), dim=1)) 98 | return loss 99 | 100 | 101 | def calc_triplet_loss(self, outputs, outputs2, labels, loss_func, split_num=8): 102 | if isinstance(outputs, list): 103 | loss = 0 104 | for i in range(len(outputs)): 105 | out_concat = torch.cat((outputs[i], outputs2[i]), dim=0) 106 | labels_concat = torch.cat((labels, labels), dim=0) 107 | loss += loss_func(out_concat, labels_concat) 108 | loss = loss/len(outputs) 109 | else: 110 | out_concat = torch.cat((outputs, outputs2), dim=0) 111 | labels_concat = torch.cat((labels, labels), dim=0) 112 | loss = loss_func(out_concat, labels_concat) 113 | return loss 114 | 115 | 116 | -------------------------------------------------------------------------------- /models/Backbone/RKNet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | 6 | class USAM(nn.Module): 7 | def __init__(self, kernel_size=3, padding=1, polish=False): 8 | super(USAM, self).__init__() 9 | 10 | kernel = torch.ones((kernel_size, kernel_size)) 11 | kernel = kernel.unsqueeze(0).unsqueeze(0) 12 | self.weight = nn.Parameter(data=kernel, requires_grad=False) 13 | 14 | 15 | kernel2 = torch.ones((1, 1)) * (kernel_size * kernel_size) 16 | kernel2 = kernel2.unsqueeze(0).unsqueeze(0) 17 | self.weight2 = nn.Parameter(data=kernel2, requires_grad=False) 18 | 19 | self.polish = polish 20 | self.pad = padding 21 | self.relu = nn.ReLU() 22 | self.bn = nn.BatchNorm2d(1) 23 | 24 | def __call__(self, x): 25 | fmap = x.sum(1, keepdim=True) 26 | x1 = F.conv2d(fmap, self.weight, padding=self.pad) 27 | x2 = F.conv2d(fmap, self.weight2, padding=0) 28 | 29 | att = x2 - x1 30 | att = self.bn(att) 31 | att = self.relu(att) 32 | 33 | if self.polish: 34 | att[:, :, :, 0] = 0 35 | att[:, :, :, -1] = 0 36 | att[:, :, 0, :] = 0 37 | att[:, :, -1, :] = 0 38 | 39 | output = x + att * x 40 | 41 | return output 42 | 43 | 44 | 45 | class RKNet(nn.Module): 46 | def __init__(self, stride=2, init_model=None, pool='avg'): 47 | super(RKNet, self).__init__() 48 | model_ft = models.resnet50(pretrained=True) 49 | # avg pooling to global pooling 50 | if stride == 1: 51 | model_ft.layer4[0].downsample[0].stride = (1,1) 52 | model_ft.layer4[0].conv2.stride = (1,1) 53 | 54 | self.pool = pool 55 | if pool =='avg+max': 56 | model_ft.avgpool2 = nn.AdaptiveAvgPool2d((1,1)) 57 | model_ft.maxpool2 = nn.AdaptiveMaxPool2d((1,1)) 58 | #self.classifier = ClassBlock(4096, class_num, droprate) 59 | elif pool=='avg': 60 | model_ft.avgpool2 = nn.AdaptiveAvgPool2d((1,1)) 61 | #self.classifier = ClassBlock(2048, class_num, droprate) 62 | elif pool=='max': 63 | model_ft.maxpool2 = nn.AdaptiveMaxPool2d((1,1)) 64 | elif pool=='gem': 65 | model_ft.gem2 = GeM(dim=2048) 66 | 67 | self.model = model_ft 68 | 69 | if init_model!=None: 70 | self.model = init_model.model 71 | self.pool = init_model.pool 72 | #self.classifier.add_block = init_model.classifier.add_block 73 | 74 | self.usam_1 = USAM() 75 | self.usam_2 = USAM() 76 | 77 | def forward_features(self, x): 78 | x = self.model.conv1(x) 79 | x = self.model.bn1(x) 80 | x = self.model.relu(x) 81 | x = self.usam_1(x) 82 | x = self.model.maxpool(x) 83 | x = self.model.layer1(x) 84 | x = self.usam_2(x) 85 | x = self.model.layer2(x) 86 | x = self.model.layer3(x) 87 | x = self.model.layer4(x) 88 | 89 | return x -------------------------------------------------------------------------------- /models/Backbone/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dmmm1997/DenseUAV/d3e4335fb73e1eeeb8db6771d11f731ac8ef3c14/models/Backbone/__init__.py -------------------------------------------------------------------------------- /models/Backbone/backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import timm 3 | from .RKNet import RKNet 4 | from .cvt import get_cvt_models 5 | import torch 6 | 7 | def make_backbone(opt): 8 | backbone_model = Backbone(opt) 9 | return backbone_model 10 | 11 | 12 | class Backbone(nn.Module): 13 | def __init__(self, opt): 14 | super().__init__() 15 | self.opt = opt 16 | self.img_size = (opt.h,opt.w) 17 | self.backbone,self.output_channel = self.init_backbone(opt.backbone) 18 | 19 | 20 | def init_backbone(self, backbone): 21 | if backbone=="resnet50": 22 | backbone_model = timm.create_model('resnet50', pretrained=True) 23 | output_channel = 2048 24 | elif backbone=="RKNet": 25 | backbone_model = RKNet() 26 | output_channel = 2048 27 | elif backbone=="senet": 28 | backbone_model = timm.create_model('legacy_seresnet50', pretrained=True) 29 | output_channel = 2048 30 | elif backbone=="ViTS-224": 31 | backbone_model = timm.create_model("vit_small_patch16_224", pretrained=True, img_size=self.img_size) 32 | output_channel = 384 33 | elif backbone=="ViTS-384": 34 | backbone_model = timm.create_model("vit_small_patch16_384", pretrained=True) 35 | output_channel = 384 36 | elif backbone=="DeitS-224": 37 | backbone_model = timm.create_model("deit_small_distilled_patch16_224", pretrained=True) 38 | output_channel = 384 39 | elif backbone=="DeitB-224": 40 | backbone_model = timm.create_model("deit_base_distilled_patch16_224", pretrained=True) 41 | output_channel = 384 42 | elif backbone=="Pvtv2b2": 43 | backbone_model = timm.create_model("pvt_v2_b2", pretrained=True) 44 | output_channel = 512 45 | elif backbone=="ViTB-224": 46 | backbone_model = timm.create_model("vit_base_patch16_224", pretrained=True) 47 | output_channel = 768 48 | elif backbone=="SwinB-224": 49 | backbone_model = timm.create_model("swin_base_patch4_window7_224", pretrained=True) 50 | output_channel = 768 51 | elif backbone=="Swinv2S-256": 52 | backbone_model = timm.create_model("swinv2_small_window8_256", pretrained=True) 53 | output_channel = 768 54 | elif backbone=="Swinv2T-256": 55 | backbone_model = timm.create_model("swinv2_tiny_window16_256", pretrained=True) 56 | output_channel = 768 57 | elif backbone=="Convnext-T": 58 | backbone_model = timm.create_model("convnext_tiny", pretrained=True) 59 | output_channel = 768 60 | elif backbone=="EfficientNet-B2": 61 | backbone_model = timm.create_model("efficientnet_b2", pretrained=True) 62 | output_channel = 1408 63 | elif backbone=="EfficientNet-B3": 64 | backbone_model = timm.create_model("efficientnet_b3", pretrained=True) 65 | output_channel = 1536 66 | elif backbone=="EfficientNet-B5": 67 | backbone_model = timm.create_model("tf_efficientnet_b5", pretrained=True) 68 | output_channel = 2048 69 | elif backbone=="EfficientNet-B6": 70 | backbone_model = timm.create_model("tf_efficientnet_b6", pretrained=True) 71 | output_channel = 2304 72 | elif backbone=="vgg16": 73 | backbone_model = timm.create_model("vgg16", pretrained=True) 74 | output_channel = 512 75 | elif backbone=="cvt13": 76 | backbone_model, channels = get_cvt_models(model_size="cvt13") 77 | output_channel = channels[-1] 78 | checkpoint_weight = "/home/dmmm/VscodeProject/FPI/pretrain_model/CvT-13-384x384-IN-22k.pth" 79 | backbone_model = self.load_checkpoints(checkpoint_weight, backbone_model) 80 | else: 81 | raise NameError("{} not in the backbone list!!!".format(backbone)) 82 | return backbone_model,output_channel 83 | 84 | def load_checkpoints(self, checkpoint_path, model): 85 | ckpt = torch.load(checkpoint_path, map_location='cpu') 86 | filter_ckpt = {k: v for k, v in ckpt.items() if "pos_embed" not in k} 87 | missing_keys, unexpected_keys = model.load_state_dict(filter_ckpt, strict=False) 88 | print("Load pretrained backbone checkpoint from:", checkpoint_path) 89 | print("missing keys:", missing_keys) 90 | print("unexpected keys:", unexpected_keys) 91 | return model 92 | 93 | def forward(self, image): 94 | features = self.backbone.forward_features(image) 95 | return features -------------------------------------------------------------------------------- /models/Head/FSRA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .utils import ClassBlock 4 | 5 | 6 | class FSRA(nn.Module): 7 | def __init__(self, opt) -> None: 8 | super().__init__() 9 | 10 | self.opt = opt 11 | num_classes = opt.nclasses 12 | droprate = opt.droprate 13 | in_planes = opt.in_planes 14 | self.class_name = "classifier_heat" 15 | self.block = opt.block 16 | # global classifier 17 | self.classifier1 = ClassBlock(in_planes, num_classes, droprate) 18 | # local classifier 19 | for i in range(self.block): 20 | name = self.class_name + str(i+1) 21 | setattr(self, name, ClassBlock(in_planes, num_classes, droprate)) 22 | 23 | def forward(self, features): 24 | global_cls, global_feature = self.classifier1(features[:, 0]) 25 | # tranformer_feature = torch.mean(features,dim=1) 26 | # tranformer_feature = self.classifier1(tranformer_feature) 27 | if self.block == 1: 28 | return global_cls, global_feature 29 | 30 | part_features = features[:, 1:] 31 | 32 | heat_result = self.get_heartmap_pool(part_features) 33 | cls_list, features_list = self.part_classifier( 34 | self.block, heat_result, cls_name=self.class_name) 35 | 36 | total_cls = [global_cls] + cls_list 37 | total_features = [global_feature] + features_list 38 | if not self.training: 39 | total_features = torch.stack(total_features,dim=-1) 40 | return [total_cls, total_features] 41 | 42 | def get_heartmap_pool(self, part_features, add_global=False, otherbranch=False): 43 | heatmap = torch.mean(part_features, dim=-1) 44 | size = part_features.size(1) 45 | arg = torch.argsort(heatmap, dim=1, descending=True) 46 | x_sort = [part_features[i, arg[i], :] 47 | for i in range(part_features.size(0))] 48 | x_sort = torch.stack(x_sort, dim=0) 49 | 50 | split_each = size / self.block 51 | split_list = [int(split_each) for i in range(self.block - 1)] 52 | split_list.append(size - sum(split_list)) 53 | split_x = x_sort.split(split_list, dim=1) 54 | 55 | split_list = [torch.mean(split, dim=1) for split in split_x] 56 | part_featuers_ = torch.stack(split_list, dim=2) 57 | if add_global: 58 | global_feat = torch.mean(part_features, dim=1).view( 59 | part_features.size(0), -1, 1).expand(-1, -1, self.block) 60 | part_featuers_ = part_featuers_ + global_feat 61 | if otherbranch: 62 | otherbranch_ = torch.mean( 63 | torch.stack(split_list[1:], dim=2), dim=-1) 64 | return part_featuers_, otherbranch_ 65 | return part_featuers_ 66 | 67 | def part_classifier(self, block, x, cls_name='classifier_lpn'): 68 | part = {} 69 | cls_list, features_list = [], [] 70 | for i in range(block): 71 | part[i] = x[:, :, i].view(x.size(0), -1) 72 | # part[i] = torch.squeeze(x[:,:,i]) 73 | name = cls_name + str(i+1) 74 | c = getattr(self, name) 75 | res = c(part[i]) 76 | cls_list.append(res[0]) 77 | features_list.append(res[1]) 78 | return cls_list, features_list 79 | 80 | 81 | class FSRA_CNN(nn.Module): 82 | def __init__(self, opt) -> None: 83 | super().__init__() 84 | 85 | self.opt = opt 86 | num_classes = opt.nclasses 87 | droprate = opt.droprate 88 | in_planes = opt.in_planes 89 | self.class_name = "classifier_heat" 90 | self.block = opt.block 91 | # global classifier 92 | self.classifier1 = ClassBlock(in_planes, num_classes, droprate) 93 | # local classifier 94 | for i in range(self.block): 95 | name = self.class_name + str(i+1) 96 | setattr(self, name, ClassBlock(in_planes, num_classes, droprate)) 97 | 98 | def forward(self, features): 99 | # global_cls, global_feature = self.classifier1(features[:, 0]) 100 | features = features.reshape(features.shape[0], features.shape[1], -1).transpose(1,2) 101 | global_feature = torch.mean(features,dim=1) 102 | global_cls, global_feature = self.classifier1(global_feature) 103 | if self.block == 1: 104 | return global_cls, global_feature 105 | 106 | part_features = features 107 | # print(part_features.shape) 108 | 109 | 110 | heat_result = self.get_heartmap_pool(part_features) 111 | cls_list, features_list = self.part_classifier( 112 | self.block, heat_result, cls_name=self.class_name) 113 | 114 | total_cls = [global_cls] + cls_list 115 | total_features = [global_feature] + features_list 116 | if not self.training: 117 | total_features = torch.stack(total_features,dim=-1) 118 | return [total_cls, total_features] 119 | 120 | def get_heartmap_pool(self, part_features, add_global=False, otherbranch=False): 121 | heatmap = torch.mean(part_features, dim=-1) 122 | size = part_features.size(1) 123 | arg = torch.argsort(heatmap, dim=1, descending=True) 124 | x_sort = [part_features[i, arg[i], :] 125 | for i in range(part_features.size(0))] 126 | x_sort = torch.stack(x_sort, dim=0) 127 | 128 | split_each = size / self.block 129 | split_list = [int(split_each) for i in range(self.block - 1)] 130 | split_list.append(size - sum(split_list)) 131 | split_x = x_sort.split(split_list, dim=1) 132 | 133 | split_list = [torch.mean(split, dim=1) for split in split_x] 134 | part_featuers_ = torch.stack(split_list, dim=2) 135 | if add_global: 136 | global_feat = torch.mean(part_features, dim=1).view( 137 | part_features.size(0), -1, 1).expand(-1, -1, self.block) 138 | part_featuers_ = part_featuers_ + global_feat 139 | if otherbranch: 140 | otherbranch_ = torch.mean( 141 | torch.stack(split_list[1:], dim=2), dim=-1) 142 | return part_featuers_, otherbranch_ 143 | return part_featuers_ 144 | 145 | def part_classifier(self, block, x, cls_name='classifier_lpn'): 146 | part = {} 147 | cls_list, features_list = [], [] 148 | for i in range(block): 149 | part[i] = x[:, :, i].view(x.size(0), -1) 150 | # part[i] = torch.squeeze(x[:,:,i]) 151 | name = cls_name + str(i+1) 152 | c = getattr(self, name) 153 | res = c(part[i]) 154 | cls_list.append(res[0]) 155 | features_list.append(res[1]) 156 | return cls_list, features_list -------------------------------------------------------------------------------- /models/Head/GeM.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .utils import ClassBlock, Pooling, vector2image 3 | 4 | 5 | 6 | class GeM(nn.Module): 7 | def __init__(self, opt) -> None: 8 | super().__init__() 9 | self.opt = opt 10 | self.classifier = ClassBlock( 11 | opt.in_planes, opt.nclasses, opt.droprate, num_bottleneck=opt.num_bottleneck) 12 | self.pool = Pooling(opt.h//16*opt.w//16, "gem") 13 | 14 | def forward(self, features):# (N,(H*W+1),C) 15 | local_feature = features[:, 1:] 16 | local_feature = local_feature.transpose(1,2).contiguous() 17 | # local_feature = vector2image(local_feature,dim = 2) 18 | global_feature = self.pool(local_feature) 19 | cls, feature = self.classifier(global_feature) 20 | return [cls, feature] -------------------------------------------------------------------------------- /models/Head/NeXtVLAD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .utils import ClassBlock, Pooling, vector2image 5 | 6 | 7 | class NeXtVLAD(nn.Module): 8 | def __init__(self, opt) -> None: 9 | super(NeXtVLAD, self).__init__() 10 | self.opt = opt 11 | self.classifier = ClassBlock( 12 | int(opt.in_planes*opt.block), opt.nclasses, opt.droprate, num_bottleneck=opt.num_bottleneck) 13 | self.netvlad = NeXtVLAD_block( 14 | num_clusters=opt.block, dim=opt.in_planes) 15 | 16 | def forward(self, features): 17 | local_feature = features[:, 1:] 18 | local_feature = local_feature.transpose(1, 2) 19 | 20 | local_feature = vector2image(local_feature, dim=2) 21 | local_features = self.netvlad(local_feature) 22 | 23 | cls, feature = self.classifier(local_features) 24 | return [cls, feature] 25 | 26 | 27 | 28 | class NeXtVLAD_block(nn.Module): 29 | """NeXtVLAD layer implementation""" 30 | 31 | def __init__(self, num_clusters=64, dim=1024, lamb=2, groups=8, max_frames=300): 32 | super(NeXtVLAD_block, self).__init__() 33 | self.num_clusters = num_clusters 34 | self.dim = dim 35 | self.alpha = 0 36 | self.K = num_clusters 37 | self.G = groups 38 | self.group_size = int((lamb * dim) // self.G) 39 | # expansion FC 40 | self.fc0 = nn.Linear(dim, lamb * dim) 41 | # soft assignment FC (the cluster weights) 42 | self.fc_gk = nn.Linear(lamb * dim, self.G * self.K) 43 | # attention over groups FC 44 | self.fc_g = nn.Linear(lamb * dim, self.G) 45 | self.cluster_weights2 = nn.Parameter(torch.rand(1, self.group_size, self.K)) 46 | 47 | self.bn0 = nn.BatchNorm1d(max_frames) 48 | self.bn1 = nn.BatchNorm1d(1) 49 | 50 | def forward(self, x, mask=None): 51 | # print(f"x: {x.shape}") 52 | 53 | _, M, N = x.shape 54 | # expansion FC: B x M x N -> B x M x λN 55 | x_dot = self.fc0(x) 56 | 57 | # reshape into groups: B x M x λN -> B x M x G x (λN/G) 58 | x_tilde = x_dot.reshape(-1, M, self.G, self.group_size) 59 | 60 | # residuals across groups and clusters: B x M x λN -> B x M x (G*K) 61 | WgkX = self.fc_gk(x_dot) 62 | WgkX = self.bn0(WgkX) 63 | 64 | # residuals reshape across clusters: B x M x (G*K) -> B x (M*G) x K 65 | WgkX = WgkX.reshape(-1, M * self.G, self.K) 66 | 67 | # softmax over assignment: B x (M*G) x K -> B x (M*G) x K 68 | alpha_gk = F.softmax(WgkX, dim=-1) 69 | 70 | # attention across groups: B x M x λN -> B x M x G 71 | alpha_g = torch.sigmoid(self.fc_g(x_dot)) 72 | if mask is not None: 73 | alpha_g = torch.mul(alpha_g, mask.unsqueeze(2)) 74 | 75 | # reshape across time: B x M x G -> B x (M*G) x 1 76 | alpha_g = alpha_g.reshape(-1, M * self.G, 1) 77 | 78 | # apply attention: B x (M*G) x K (X) B x (M*G) x 1 -> B x (M*G) x K 79 | activation = torch.mul(alpha_gk, alpha_g) 80 | 81 | # sum over time and group: B x (M*G) x K -> B x 1 x K 82 | a_sum = torch.sum(activation, -2, keepdim=True) 83 | 84 | # calculate group centers: B x 1 x K (X) 1 x (λN/G) x K -> B x (λN/G) x K 85 | a = torch.mul(a_sum, self.cluster_weights2) 86 | 87 | # permute: B x (M*G) x K -> B x K x (M*G) 88 | activation = activation.permute(0, 2, 1) 89 | 90 | # reshape: B x M x G x (λN/G) -> B x (M*G) x (λN/G) 91 | reshaped_x_tilde = x_tilde.reshape(-1, M * self.G, self.group_size) 92 | 93 | # cluster activation: B x K x (M*G) (X) B x (M*G) x (λN/G) -> B x K x (λN/G) 94 | vlad = torch.matmul(activation, reshaped_x_tilde) 95 | # print(f"vlad: {vlad.shape}") 96 | 97 | # permute: B x K x (λN/G) (X) B x (λN/G) x K 98 | vlad = vlad.permute(0, 2, 1) 99 | # distance to centers: B x (λN/G) x K (-) B x (λN/G) x K 100 | vlad = torch.sub(vlad, a) 101 | # normalize: B x (λN/G) x K 102 | vlad = F.normalize(vlad, 1) 103 | # reshape: B x (λN/G) x K -> B x 1 x (K * (λN/G)) 104 | vlad = vlad.reshape(-1, 1, self.K * self.group_size) 105 | vlad = self.bn1(vlad) 106 | # reshape: B x 1 x (K * (λN/G)) -> B x (K * (λN/G)) 107 | vlad = vlad.reshape(-1, self.K * self.group_size) 108 | 109 | return vlad -------------------------------------------------------------------------------- /models/Head/NetVLAD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .utils import ClassBlock, Pooling, vector2image 5 | 6 | 7 | class NetVLAD(nn.Module): 8 | def __init__(self, opt) -> None: 9 | super(NetVLAD, self).__init__() 10 | self.opt = opt 11 | self.classifier = ClassBlock( 12 | int(opt.in_planes*opt.block), opt.nclasses, opt.droprate, num_bottleneck=opt.num_bottleneck) 13 | self.netvlad = NetVLAD_block( 14 | num_clusters=opt.block, dim=opt.in_planes, alpha=100.0, normalize_input=True) 15 | 16 | def forward(self, features): 17 | local_feature = features[:, 1:] 18 | local_feature = local_feature.transpose(1, 2) 19 | 20 | local_feature = vector2image(local_feature, dim=2) 21 | local_features = self.netvlad(local_feature) 22 | 23 | cls, feature = self.classifier(local_features) 24 | return [cls, feature] 25 | 26 | 27 | class NetVLAD_block(nn.Module): 28 | """NetVLAD layer implementation""" 29 | 30 | def __init__(self, num_clusters=64, dim=128, alpha=100.0, 31 | normalize_input=True): 32 | """ 33 | Args: 34 | num_clusters : int 35 | The number of clusters 36 | dim : int 37 | Dimension of descriptors 38 | alpha : float 39 | Parameter of initialization. Larger value is harder assignment. 40 | normalize_input : bool 41 | If true, descriptor-wise L2 normalization is applied to input. 42 | """ 43 | super(NetVLAD_block, self).__init__() 44 | self.num_clusters = num_clusters 45 | self.dim = dim 46 | self.alpha = alpha 47 | self.normalize_input = normalize_input 48 | self.conv = nn.Conv2d(dim, num_clusters, kernel_size=(1, 1), bias=True) 49 | self.centroids = nn.Parameter( 50 | torch.rand(num_clusters, dim)) # 聚类中心,参见注释1 51 | self._init_params() 52 | 53 | def _init_params(self): 54 | self.conv.weight = nn.Parameter( 55 | (2.0 * self.alpha * self.centroids).unsqueeze(-1).unsqueeze(-1) 56 | ) 57 | self.conv.bias = nn.Parameter( 58 | - self.alpha * self.centroids.norm(dim=1) 59 | ) 60 | 61 | def forward(self, x): # x: (N, C, H, W), H * W对应论文中的N表示局部特征的数目,C对应论文中的D表示特征维度 62 | N, C = x.shape[:2] 63 | 64 | if self.normalize_input: 65 | x = F.normalize(x, p=2, dim=1) # across descriptor dim,使用L2归一化特征维度 66 | 67 | # soft-assignment 68 | # (N, C, H, W)->(N, num_clusters, H, W)->(N, num_clusters, H * W) 69 | soft_assign = self.conv(x).view(N, self.num_clusters, -1) 70 | # (N, num_clusters, H * W) # 参见注释3 71 | soft_assign = F.softmax(soft_assign, dim=1) 72 | 73 | x_flatten = x.view(N, C, -1) # (N, C, H, W) -> (N, C, H * W) 74 | 75 | # calculate residuals to each clusters 76 | # 减号前面前记为a,后面记为b, residual = a - b 77 | # a: (N, C, H * W) -> (num_clusters, N, C, H * W) -> (N, num_clusters, C, H * W) 78 | # b: (num_clusters, C) -> (H * W, num_clusters, C) -> (num_clusters, C, H * W) 79 | # residual: (N, num_clusters, C, H * W) 参见注释2 80 | residual = x_flatten.expand(self.num_clusters, -1, -1, -1).permute(1, 0, 2, 3) - \ 81 | self.centroids.expand(x_flatten.size(-1), - 82 | 1, -1).permute(1, 2, 0).unsqueeze(0) 83 | 84 | # soft_assign: (N, num_clusters, H * W) -> (N, num_clusters, 1, H * W) 85 | # (N, num_clusters, C, H * W) * (N, num_clusters, 1, H * W) 86 | residual *= soft_assign.unsqueeze(2) 87 | # (N, num_clusters, C, H * W) -> (N, num_clusters, C) 88 | vlad = residual.sum(dim=-1) 89 | 90 | vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization 91 | # flatten;vald: (N, num_clusters, C) -> (N, num_clusters * C) 92 | vlad = vlad.view(x.size(0), -1) 93 | vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize 94 | 95 | return vlad 96 | -------------------------------------------------------------------------------- /models/Head/SingleBranch.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .utils import ClassBlock, Pooling 3 | import torch.nn.functional as F 4 | import torch 5 | 6 | class SingleBranch(nn.Module): 7 | def __init__(self, opt) -> None: 8 | super().__init__() 9 | self.opt = opt 10 | self.head_pool = opt.head_pool 11 | self.classifier = ClassBlock( 12 | opt.in_planes, opt.nclasses, opt.droprate, num_bottleneck=opt.num_bottleneck) 13 | 14 | def forward(self, features): 15 | global_feature = features[:, 0] 16 | local_feature = features[:, 1:] 17 | if self.head_pool == "global": 18 | feature = global_feature 19 | elif self.head_pool == "avg": 20 | local_feature = local_feature.transpose(1, 2) 21 | feature = torch.mean(local_feature, 2).squeeze() 22 | elif self.head_pool == "max": 23 | local_feature = local_feature.transpose(1, 2) 24 | feature = torch.max(local_feature, 2)[0].squeeze() 25 | elif self.head_pool == "avg+max": 26 | local_feature = local_feature.transpose(1, 2) 27 | avg_feature = torch.mean(local_feature, 2).squeeze() 28 | max_feature = torch.max(local_feature, 2)[0].squeeze() 29 | feature = avg_feature+max_feature 30 | else: 31 | raise TypeError("head_pool 不在支持的列表中!!!") 32 | 33 | cls, feature = self.classifier(feature) 34 | return [cls, feature] 35 | 36 | 37 | class SingleBranchCNN(nn.Module): 38 | def __init__(self, opt) -> None: 39 | super().__init__() 40 | self.opt = opt 41 | self.pool = nn.AdaptiveAvgPool2d(1) 42 | self.classifier = ClassBlock( 43 | opt.in_planes, opt.nclasses, opt.droprate, num_bottleneck=opt.num_bottleneck) 44 | 45 | def forward(self, features): 46 | global_feature = self.pool(features).reshape(features.shape[0], -1) 47 | cls, feature = self.classifier(global_feature) 48 | return [cls, feature] 49 | 50 | 51 | class SingleBranchSwin(nn.Module): 52 | def __init__(self, opt) -> None: 53 | super().__init__() 54 | self.opt = opt 55 | self.pool = nn.AdaptiveAvgPool1d(1) 56 | self.classifier = ClassBlock( 57 | opt.in_planes, opt.nclasses, opt.droprate, num_bottleneck=opt.num_bottleneck) 58 | 59 | def forward(self, features): 60 | global_feature = self.pool(features.transpose( 61 | 2, 1)).reshape(features.shape[0], -1) 62 | cls, feature = self.classifier(global_feature) 63 | return [cls, feature] 64 | -------------------------------------------------------------------------------- /models/Head/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dmmm1997/DenseUAV/d3e4335fb73e1eeeb8db6771d11f731ac8ef3c14/models/Head/__init__.py -------------------------------------------------------------------------------- /models/Head/head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .SingleBranch import SingleBranch, SingleBranchCNN, SingleBranchSwin 3 | from .FSRA import FSRA, FSRA_CNN 4 | from .LPN import LPN, LPN_CNN 5 | from .GeM import GeM 6 | from .NetVLAD import NetVLAD 7 | 8 | def make_head(opt): 9 | return Head(opt) 10 | 11 | 12 | class Head(nn.Module): 13 | def __init__(self, opt) -> None: 14 | super().__init__() 15 | self.head = self.init_head(opt) 16 | self.opt = opt 17 | 18 | def init_head(self, opt): 19 | head = opt.head 20 | if head == "SingleBranch": 21 | head_model = SingleBranch(opt) 22 | elif head == "SingleBranchCNN": 23 | head_model = SingleBranchCNN(opt) 24 | elif head == "SingleBranchSwin": 25 | head_model = SingleBranchSwin(opt) 26 | elif head == "NetVLAD": 27 | head_model = NetVLAD(opt) 28 | elif head == "FSRA": 29 | head_model = FSRA(opt) 30 | elif head == "FSRA_CNN": 31 | head_model = FSRA_CNN(opt) 32 | elif head == "LPN": 33 | head_model = LPN(opt) 34 | elif head == "LPN_CNN": 35 | head_model = LPN_CNN(opt) 36 | elif head == "GeM": 37 | head_model = GeM(opt) 38 | else: 39 | raise NameError("{} not in the head list!!!".format(head)) 40 | return head_model 41 | 42 | def forward(self, features): 43 | features = self.head(features) 44 | return features 45 | -------------------------------------------------------------------------------- /models/Head/utils.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torch.nn import functional as F 4 | import numpy as np 5 | 6 | 7 | def weights_init_kaiming(m): 8 | classname = m.__class__.__name__ 9 | if classname.find('Linear') != -1: 10 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 11 | nn.init.constant_(m.bias, 0.0) 12 | 13 | elif classname.find('Conv') != -1: 14 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 15 | if m.bias is not None: 16 | nn.init.constant_(m.bias, 0.0) 17 | elif classname.find('BatchNorm') != -1: 18 | if m.affine: 19 | nn.init.constant_(m.weight, 1.0) 20 | nn.init.constant_(m.bias, 0.0) 21 | 22 | 23 | def weights_init_classifier(m): 24 | # classname = m.__class__.__name__ 25 | # if classname.find('Linear') != -1: 26 | # nn.init.normal_(m.weight, std=0.001) 27 | # if m.bias: 28 | # nn.init.constant_(m.bias, 0.0) 29 | classname = m.__class__.__name__ 30 | if classname.find('Linear') != -1: 31 | nn.init.normal_(m.weight.data, std=0.001) 32 | nn.init.constant_(m.bias.data, 0.0) 33 | 34 | 35 | class ClassBlock(nn.Module): 36 | def __init__(self, input_dim, class_num, droprate, relu=False, bnorm=True, num_bottleneck=512, linear=True, return_f=False): 37 | super(ClassBlock, self).__init__() 38 | self.return_f = return_f 39 | add_block = [] 40 | if linear: 41 | add_block += [nn.Linear(input_dim, num_bottleneck)] 42 | else: 43 | num_bottleneck = input_dim 44 | if bnorm: 45 | add_block += [nn.BatchNorm1d(num_bottleneck)] 46 | if relu: 47 | add_block += [nn.LeakyReLU(0.1)] 48 | if droprate > 0: 49 | add_block += [nn.Dropout(p=droprate)] 50 | add_block = nn.Sequential(*add_block) 51 | add_block.apply(weights_init_kaiming) 52 | 53 | classifier = [] 54 | classifier += [nn.Linear(num_bottleneck, class_num)] 55 | classifier = nn.Sequential(*classifier) 56 | classifier.apply(weights_init_classifier) 57 | 58 | self.add_block = add_block 59 | self.classifier = classifier 60 | 61 | def forward(self, x): 62 | feature_ = self.add_block(x) 63 | cls_ = self.classifier(feature_) 64 | return cls_, feature_ 65 | 66 | 67 | class Gem_heat(nn.Module): 68 | def __init__(self, dim=768, p=3, eps=1e-6): 69 | super(Gem_heat, self).__init__() 70 | self.p = nn.Parameter(torch.ones(dim) * p) # initial p 71 | self.eps = eps 72 | 73 | def forward(self, x): 74 | return self.gem(x, p=self.p, eps=self.eps) 75 | 76 | def gem(self, x, p=3, eps=1e-6): 77 | # x = torch.transpose(x, 1, -1) 78 | p = F.softmax(p).unsqueeze(-1) 79 | x = torch.matmul(x, p) 80 | # x = torch.transpose(x, 1, -1) 81 | # x = F.avg_pool1d(x, x.size(-1)) 82 | x = x.view(x.size(0), x.size(1)) 83 | # x = x.pow(1. / p) 84 | return x 85 | 86 | 87 | class GeM(nn.Module): 88 | # channel-wise GeM zhedong zheng 89 | def __init__(self, dim=2048, p=1, eps=1e-6): 90 | super(GeM, self).__init__() 91 | self.p = nn.Parameter(torch.ones(dim)*p) # initial p 92 | self.eps = eps 93 | 94 | def forward(self, x): 95 | x = torch.transpose(x, 1, -1) 96 | x = (x+self.eps).pow(self.p) 97 | x = torch.transpose(x, 1, -1) 98 | x = F.avg_pool2d(x, (x.size(-2), x.size(-1))) 99 | x = x.view(x.size(0), x.size(1)).contiguous() 100 | x = x.pow(1./self.p) 101 | return x 102 | 103 | 104 | 105 | class Pooling(nn.Module): 106 | def __init__(self, dim, pool="avg"): 107 | super(Pooling, self).__init__() 108 | self.pool = pool 109 | if pool == 'avg+max': 110 | self.avgpool2 = nn.AdaptiveAvgPool2d((1, 1)) 111 | self.maxpool2 = nn.AdaptiveMaxPool2d((1, 1)) 112 | elif pool == 'avg': 113 | self.avgpool2 = nn.AdaptiveAvgPool2d((1, 1)) 114 | elif pool == 'max': 115 | self.maxpool2 = nn.AdaptiveMaxPool2d((1, 1)) 116 | elif pool == 'gem': 117 | self.gem2 = Gem_heat(dim=dim) 118 | 119 | def forward(self, x): 120 | if self.pool == 'avg+max': 121 | x1 = self.avgpool2(x) 122 | x2 = self.maxpool2(x) 123 | x = torch.cat((x1, x2), dim=1) 124 | elif self.pool == 'avg': 125 | x = self.avgpool2(x) 126 | elif self.pool == 'max': 127 | x = self.maxpool2(x) 128 | elif self.pool == 'gem': 129 | x = self.gem2(x) 130 | return x 131 | 132 | 133 | def vector2image(x, dim=1): # (B,N,C) 134 | B, N, C = x.shape 135 | if dim == 1: 136 | return x.reshape(B, int(np.sqrt(N)), int(np.sqrt(N)), C) 137 | if dim == 2: 138 | return x.reshape(B, N, int(np.sqrt(C)), int(np.sqrt(C))) 139 | else: 140 | raise TypeError("dim is not correct!!") 141 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dmmm1997/DenseUAV/d3e4335fb73e1eeeb8db6771d11f731ac8ef3c14/models/__init__.py -------------------------------------------------------------------------------- /models/taskflow.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .Backbone.backbone import make_backbone 3 | from .Head.head import make_head 4 | import os 5 | import torch 6 | 7 | 8 | class Model(nn.Module): 9 | def __init__(self, opt): 10 | super().__init__() 11 | self.backbone = make_backbone(opt) 12 | opt.in_planes = self.backbone.output_channel 13 | self.head = make_head(opt) 14 | self.opt = opt 15 | 16 | def forward(self, drone_image, satellite_image): 17 | if drone_image is None: 18 | drone_res = None 19 | else: 20 | drone_features = self.backbone(drone_image) 21 | drone_res = self.head(drone_features) 22 | if satellite_image is None: 23 | satellite_res = None 24 | else: 25 | satellite_features = self.backbone(satellite_image) 26 | satellite_res = self.head(satellite_features) 27 | 28 | return drone_res,satellite_res 29 | 30 | def load_params(self, load_from): 31 | pretran_model = torch.load(load_from) 32 | model2_dict = self.state_dict() 33 | state_dict = {k: v for k, v in pretran_model.items() if k in model2_dict.keys() and v.size() == model2_dict[k].size()} 34 | model2_dict.update(state_dict) 35 | self.load_state_dict(model2_dict) 36 | 37 | 38 | def make_model(opt): 39 | model = Model(opt) 40 | if os.path.exists(opt.load_from): 41 | model.load_params(opt.load_from) 42 | return model 43 | -------------------------------------------------------------------------------- /optimizers/make_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | from torch.optim import lr_scheduler 3 | 4 | 5 | def make_optimizer(model,opt): 6 | ignored_params = [] 7 | ignored_params += list(map(id, model.backbone.parameters())) 8 | extra_params = filter(lambda p: id(p) not in ignored_params, model.parameters()) 9 | base_params = filter(lambda p: id(p) in ignored_params, model.parameters()) 10 | optimizer_ft = optim.SGD([ 11 | {'params': base_params, 'lr': 0.3 * opt.lr}, 12 | {'params': extra_params, 'lr': opt.lr} 13 | ], weight_decay=5e-4, momentum=0.9, nesterov=True) 14 | 15 | 16 | exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer_ft, milestones=[70,110], gamma=0.1) 17 | # exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=80, gamma=0.1) 18 | # exp_lr_scheduler = lr_scheduler.ExponentialLR(optimizer_ft, gamma=0.95) 19 | # exp_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer_ft, mode='min', factor=0.5, patience=4, verbose=True,threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=1e-5, eps=1e-08) 20 | 21 | return optimizer_ft,exp_lr_scheduler -------------------------------------------------------------------------------- /requirments.txt: -------------------------------------------------------------------------------- 1 | scipy 2 | tqdm 3 | pyyaml 4 | matplotlib 5 | opencv-python 6 | timm 7 | pillow 8 | einops 9 | thop 10 | pytorch_metric_learning -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import print_function, division 4 | from datasets.queryDataset import Dataset_query, Query_transforms 5 | 6 | import argparse 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import Variable 10 | import torch.backends.cudnn as cudnn 11 | import numpy as np 12 | from torchvision import datasets, models, transforms 13 | import time 14 | import os 15 | import scipy.io 16 | import yaml 17 | import math 18 | from tool.utils import load_network 19 | from tqdm import tqdm 20 | import warnings 21 | warnings.filterwarnings("ignore") 22 | 23 | parser = argparse.ArgumentParser(description='Training') 24 | parser.add_argument('--gpu_ids', default='0', type=str, 25 | help='gpu_ids: e.g. 0 0,1,2 0,2') 26 | parser.add_argument( 27 | '--test_dir', default='', type=str, help='./test_data') 28 | parser.add_argument('--name', default='', 29 | type=str, help='save model path') 30 | parser.add_argument('--checkpoint', default='net_119.pth', 31 | type=str, help='save model path') 32 | parser.add_argument('--batchsize', default=128, type=int, help='batchsize') 33 | parser.add_argument('--h', default=256, type=int, help='height') 34 | parser.add_argument('--w', default=256, type=int, help='width') 35 | parser.add_argument('--ms', default='1', type=str, 36 | help='multiple_scale: e.g. 1 1,1.1 1,1.1,1.2') 37 | parser.add_argument('--num_worker', default=4, type=int,help='') 38 | parser.add_argument('--mode',default='1', type=int,help='1:drone->satellite 2:satellite->drone') 39 | opt = parser.parse_args() 40 | 41 | print(opt.name) 42 | 43 | ###load config### 44 | # load the training config 45 | config_path = 'opts.yaml' 46 | with open(config_path, 'r') as stream: 47 | config = yaml.load(stream, Loader=yaml.FullLoader) 48 | for cfg, value in config.items(): 49 | setattr(opt, cfg, value) 50 | 51 | str_ids = opt.gpu_ids.split(',') 52 | test_dir = opt.test_dir 53 | 54 | gpu_ids = [] 55 | for str_id in str_ids: 56 | id = int(str_id) 57 | if id >= 0: 58 | gpu_ids.append(id) 59 | 60 | print('We use the scale: %s' % opt.ms) 61 | str_ms = opt.ms.split(',') 62 | ms = [] 63 | 64 | for s in str_ms: 65 | s_f = float(s) 66 | ms.append(math.sqrt(s_f)) 67 | 68 | if len(gpu_ids) > 0: 69 | torch.cuda.set_device(gpu_ids[0]) 70 | cudnn.benchmark = True 71 | 72 | 73 | data_transforms = transforms.Compose([ 74 | transforms.Resize((opt.h, opt.w), interpolation=3), 75 | transforms.ToTensor(), 76 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 77 | ]) 78 | 79 | data_query_transforms = transforms.Compose([ 80 | transforms.Resize((opt.h, opt.w), interpolation=3), 81 | # Query_transforms(pad=10,size=opt.w), 82 | transforms.ToTensor(), 83 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 84 | ]) 85 | 86 | 87 | data_dir = test_dir 88 | 89 | image_datasets_query = {x: datasets.ImageFolder(os.path.join( 90 | data_dir, x), data_query_transforms) for x in ['query_drone']} 91 | 92 | image_datasets_gallery = {x: datasets.ImageFolder(os.path.join( 93 | data_dir, x), data_transforms) for x in ['gallery_satellite']} 94 | 95 | image_datasets = {**image_datasets_query, **image_datasets_gallery} 96 | 97 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize, 98 | shuffle=False, num_workers=opt.num_worker) for x in ['gallery_satellite', 'query_drone']} 99 | use_gpu = torch.cuda.is_available() 100 | 101 | 102 | def fliplr(img): 103 | '''flip horizontal''' 104 | inv_idx = torch.arange(img.size(3)-1, -1, -1).long() # N x C x H x W 105 | img_flip = img.index_select(3, inv_idx) 106 | return img_flip 107 | 108 | 109 | def which_view(name): 110 | if 'satellite' in name: 111 | return 1 112 | elif 'street' in name: 113 | return 2 114 | elif 'drone' in name: 115 | return 3 116 | else: 117 | print('unknown view') 118 | return -1 119 | 120 | 121 | def extract_feature(model, dataloaders, view_index=1): 122 | features = torch.FloatTensor() 123 | count = 0 124 | for data in tqdm(dataloaders): 125 | img, _ = data 126 | batchsize = img.size()[0] 127 | count += batchsize 128 | # if opt.LPN: 129 | # # ff = torch.FloatTensor(n,2048,6).zero_().cuda() 130 | # ff = torch.FloatTensor(n,512,opt.block).zero_().cuda() 131 | # else: 132 | # ff = torch.FloatTensor(n, 2048).zero_().cuda() 133 | for i in range(2): 134 | if(i == 1): 135 | img = fliplr(img) 136 | input_img = Variable(img.cuda()) 137 | if view_index == 1: 138 | outputs, _ = model(input_img, None) 139 | elif view_index == 3: 140 | _, outputs = model(None, input_img) 141 | outputs = outputs[1] 142 | if i == 0: 143 | ff = outputs 144 | else: 145 | ff += outputs 146 | # norm feature 147 | if len(ff.shape) == 3: 148 | # feature size (n,2048,6) 149 | # 1. To treat every part equally, I calculate the norm for every 2048-dim part feature. 150 | # 2. To keep the cosine score==1, sqrt(6) is added to norm the whole feature (2048*6). 151 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * \ 152 | np.sqrt(opt.block) 153 | ff = ff.div(fnorm.expand_as(ff)) 154 | ff = ff.view(ff.size(0), -1) 155 | else: 156 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) 157 | ff = ff.div(fnorm.expand_as(ff)) 158 | 159 | features = torch.cat((features, ff.data.cpu()), 0) 160 | return features 161 | 162 | 163 | def get_id(img_path): 164 | camera_id = [] 165 | labels = [] 166 | paths = [] 167 | for path, v in img_path: 168 | folder_name = os.path.basename(os.path.dirname(path)) 169 | labels.append(int(folder_name)) 170 | paths.append(path) 171 | return labels, paths 172 | 173 | 174 | ###################################################################### 175 | # Load Collected data Trained model 176 | print('-------test-----------') 177 | 178 | model = load_network(opt) 179 | print("这是%s的结果" % opt.checkpoint) 180 | # model.classifier.classifier = nn.Sequential() 181 | model = model.eval() 182 | if use_gpu: 183 | model = model.cuda() 184 | 185 | # Extract feature 186 | since = time.time() 187 | 188 | if opt.mode == 1: 189 | query_name = 'query_drone' 190 | gallery_name = 'gallery_satellite' 191 | elif opt.mode == 2: 192 | query_name = 'query_satellite' 193 | gallery_name = 'gallery_drone' 194 | else: 195 | raise Exception("opt.mode is not required") 196 | 197 | 198 | which_gallery = which_view(gallery_name) 199 | which_query = which_view(query_name) 200 | print('%d -> %d:' % (which_query, which_gallery)) 201 | print(query_name.split("_")[-1], "->", gallery_name.split("_")[-1]) 202 | 203 | gallery_path = image_datasets[gallery_name].imgs 204 | 205 | query_path = image_datasets[query_name].imgs 206 | 207 | gallery_label, gallery_path = get_id(gallery_path) 208 | query_label, query_path = get_id(query_path) 209 | 210 | if __name__ == "__main__": 211 | with torch.no_grad(): 212 | query_feature = extract_feature( 213 | model, dataloaders[query_name], which_query) 214 | gallery_feature = extract_feature( 215 | model, dataloaders[gallery_name], which_gallery) 216 | 217 | # For street-view image, we use the avg feature as the final feature. 218 | 219 | time_elapsed = time.time() - since 220 | print('Test complete in {:.0f}m {:.0f}s'.format( 221 | time_elapsed // 60, time_elapsed % 60)) 222 | 223 | with open('inference_time.txt', 'w') as F: 224 | F.write('Test complete in {:.0f}m {:.0f}s'.format( 225 | time_elapsed // 60, time_elapsed % 60)) 226 | 227 | # Save to Matlab for check 228 | result = {'gallery_f': gallery_feature.numpy(), 'gallery_label': gallery_label, 'gallery_path': gallery_path, 229 | 'query_f': query_feature.numpy(), 'query_label': query_label, 'query_path': query_path} 230 | scipy.io.savemat('pytorch_result_{}.mat'.format(opt.mode), result) 231 | 232 | # print(opt.name) 233 | # result = 'result.txt' 234 | # os.system('python evaluate_gpu.py | tee -a %s'%result) 235 | -------------------------------------------------------------------------------- /test_hard.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import print_function, division 4 | from datasets.queryDataset import Dataset_query, Query_transforms, Dataset_gallery, test_collate_fn 5 | 6 | import argparse 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import Variable 10 | import torch.backends.cudnn as cudnn 11 | import numpy as np 12 | from torchvision import datasets, models, transforms 13 | import time 14 | import os 15 | import scipy.io 16 | import yaml 17 | import math 18 | from tool.utils import load_network 19 | from tqdm import tqdm 20 | import warnings 21 | warnings.filterwarnings("ignore") 22 | 23 | parser = argparse.ArgumentParser(description='Training') 24 | parser.add_argument('--gpu_ids', default='0', type=str, 25 | help='gpu_ids: e.g. 0 0,1,2 0,2') 26 | parser.add_argument( 27 | '--test_dir', default='/home/dmmm/Dataset/DenseUAV/data_2022/test', type=str, help='./test_data') 28 | parser.add_argument('--name', default='resnet', 29 | type=str, help='save model path') 30 | parser.add_argument('--checkpoint', default='net_119.pth', 31 | type=str, help='save model path') 32 | parser.add_argument('--batchsize', default=128, type=int, help='batchsize') 33 | parser.add_argument('--h', default=224, type=int, help='height') 34 | parser.add_argument('--w', default=224, type=int, help='width') 35 | parser.add_argument('--ms', default='1', type=str, 36 | help='multiple_scale: e.g. 1 1,1.1 1,1.1,1.2') 37 | parser.add_argument('--mode', default='hard', type=str, 38 | help='1:drone->satellite 2:satellite->drone') 39 | parser.add_argument('--num_worker', default=8, type=int, 40 | help='1:drone->satellite 2:satellite->drone') 41 | 42 | parser.add_argument('--split_feature', default=1, type=int, help='') 43 | 44 | opt = parser.parse_args() 45 | print(opt.name) 46 | ###load config### 47 | # load the training config 48 | config_path = 'opts.yaml' 49 | with open(config_path, 'r') as stream: 50 | config = yaml.load(stream) 51 | for cfg, value in config.items(): 52 | if cfg not in opt: 53 | setattr(opt, cfg, value) 54 | 55 | str_ids = opt.gpu_ids.split(',') 56 | test_dir = opt.test_dir 57 | 58 | gpu_ids = [] 59 | for str_id in str_ids: 60 | id = int(str_id) 61 | if id >= 0: 62 | gpu_ids.append(id) 63 | 64 | print('We use the scale: %s' % opt.ms) 65 | str_ms = opt.ms.split(',') 66 | ms = [] 67 | for s in str_ms: 68 | s_f = float(s) 69 | ms.append(math.sqrt(s_f)) 70 | 71 | if len(gpu_ids) > 0: 72 | torch.cuda.set_device(gpu_ids[0]) 73 | cudnn.benchmark = True 74 | 75 | 76 | data_transforms = transforms.Compose([ 77 | transforms.Resize((opt.h, opt.w), interpolation=3), 78 | transforms.ToTensor(), 79 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 80 | ]) 81 | 82 | data_query_transforms = transforms.Compose([ 83 | transforms.Resize((opt.h, opt.w), interpolation=3), 84 | # Query_transforms(pad=10,size=opt.w), 85 | transforms.ToTensor(), 86 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 87 | ]) 88 | 89 | 90 | data_dir = test_dir 91 | 92 | image_datasets_query = Dataset_query(os.path.join(data_dir, "query_drone"), data_query_transforms) 93 | 94 | # image_datasets_gallery = Dataset_gallery(os.path.join(opt.test_dir, "total_info_ms_i10m.txt"), data_transforms) 95 | 96 | image_datasets_gallery = Dataset_gallery(os.path.join(opt.test_dir, "total_info_ss_i10m.txt"), data_transforms) 97 | 98 | 99 | dataloaders_query = torch.utils.data.DataLoader(image_datasets_query, batch_size=opt.batchsize, shuffle=False, num_workers=opt.num_worker, collate_fn=test_collate_fn) 100 | 101 | split_nums = len(image_datasets_gallery)//opt.split_feature 102 | 103 | list_split = [split_nums]*opt.split_feature 104 | 105 | list_split[-1] = len(image_datasets_gallery)-(opt.split_feature-1)*split_nums 106 | 107 | gallery_datasets_list = torch.utils.data.random_split(image_datasets_gallery, list_split) 108 | 109 | dataloaders_gallery = {ind: torch.utils.data.DataLoader(gallery_datasets_list[ind], batch_size=opt.batchsize, shuffle=False, num_workers=opt.num_worker, collate_fn=test_collate_fn) for ind in range(opt.split_feature)} 110 | 111 | use_gpu = torch.cuda.is_available() 112 | 113 | def extract_feature(model, dataloaders, view_index=1): 114 | features = torch.FloatTensor() 115 | infos_list = np.zeros((0,2),dtype=np.float32) 116 | path_list = [] 117 | for data in tqdm(dataloaders): 118 | img, infos, path = data 119 | path_list.extend(path) 120 | # infos_list.extend(infos) 121 | infos_list = np.concatenate((infos_list,infos),0) 122 | # if opt.LPN: 123 | # # ff = torch.FloatTensor(n,2048,6).zero_().cuda() 124 | # ff = torch.FloatTensor(n,512,opt.block).zero_().cuda() 125 | # else: 126 | # ff = torch.FloatTensor(n, 2048).zero_().cuda() 127 | 128 | input_img = Variable(img.cuda()) 129 | if view_index == 1: 130 | outputs, _ = model(input_img, None) 131 | elif view_index == 3: 132 | _, outputs = model(None, input_img) 133 | outputs = outputs[1] 134 | ff = outputs 135 | # # norm feature 136 | if len(ff.shape) == 3: 137 | # feature size (n,2048,6) 138 | # 1. To treat every part equally, I calculate the norm for every 2048-dim part feature. 139 | # 2. To keep the cosine score==1, sqrt(6) is added to norm the whole feature (2048*6). 140 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * \ 141 | np.sqrt(opt.block) 142 | ff = ff.div(fnorm.expand_as(ff)) 143 | ff = ff.view(ff.size(0), -1) 144 | else: 145 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) 146 | ff = ff.div(fnorm.expand_as(ff)) 147 | 148 | features = torch.cat((features, ff.data.cpu()), 0) 149 | return features,infos_list,path_list 150 | 151 | 152 | model = load_network(opt) 153 | print("这是%s的结果" % opt.checkpoint) 154 | # model.classifier.classifier = nn.Sequential() 155 | model = model.eval() 156 | if use_gpu: 157 | model = model.cuda() 158 | 159 | # Extract feature 160 | since = time.time() 161 | 162 | if __name__ == "__main__": 163 | with torch.no_grad(): 164 | query_feature, query_infos, query_path = extract_feature( 165 | model, dataloaders_query, 1) 166 | gallery_features = torch.FloatTensor() 167 | gallery_infos = np.zeros((0,2),dtype=np.float32) 168 | gallery_paths = [] 169 | for i in range(opt.split_feature): 170 | gallery_feature,gallery_info,gallery_path = extract_feature( 171 | model, dataloaders_gallery[i], 3) 172 | gallery_infos = np.concatenate((gallery_infos,gallery_info),0) 173 | gallery_features = torch.cat((gallery_features,gallery_feature),0) 174 | gallery_paths.extend(gallery_path) 175 | 176 | 177 | # For street-view image, we use the avg feature as the final feature. 178 | 179 | time_elapsed = time.time() - since 180 | print('Test complete in {:.0f}m {:.0f}s'.format( 181 | time_elapsed // 60, time_elapsed % 60)) 182 | 183 | with open('inference_time.txt', 'w') as F: 184 | F.write('Test complete in {:.0f}m {:.0f}s'.format( 185 | time_elapsed // 60, time_elapsed % 60)) 186 | 187 | # Save to Matlab for check 188 | 189 | result = {'gallery_f': gallery_features.numpy(), 'gallery_infos': gallery_infos.astype(np.float32), 'query_path':query_path, 190 | 'query_f': query_feature.numpy(), 'query_infos': query_infos.astype(np.float32), 'gallery_path':gallery_paths} 191 | scipy.io.savemat('pytorch_result_{}_ss.mat'.format(opt.mode), result) 192 | 193 | # print(opt.name) 194 | # result = 'result.txt' 195 | # os.system('python evaluate_gpu.py | tee -a %s'%result) 196 | 197 | 198 | 199 | 200 | -------------------------------------------------------------------------------- /tool/SDM@K_analyze.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | ResNet50 = [0.165, 0.343, 0.473, 0.535, 1.0] 5 | ConvNextT = [0.6023, 0.7404, 0.7975, 0.828, 1.0] 6 | ViTS = [0.8018, 0.888, 0.9197, 0.9335, 1.0] 7 | 8 | y1 = ConvNextT 9 | for i in range(len(y1)-1,0,-1): 10 | y1[i] = y1[i]-y1[i-1] 11 | 12 | y2 = ViTS 13 | for i in range(len(y2)-1,0,-1): 14 | y2[i] = y2[i]-y2[i-1] 15 | 16 | y0 = ResNet50 17 | for i in range(len(y0)-1,0,-1): 18 | y0[i] = y0[i]-y0[i-1] 19 | 20 | totol_images = 2331 21 | 22 | # 创建数据 23 | categories = ["0","1","2","3", "other"] 24 | 25 | bar_width = 0.3 26 | 27 | r0 = np.arange(len(categories)) 28 | r1 = [x + bar_width for x in r0] 29 | r2 = [x + bar_width for x in r1] 30 | 31 | # 创建左右两个子图 32 | plt.subplots(figsize=(7, 5)) 33 | 34 | plt.xlabel("Error on Sampling Interval",fontdict={'family' : 'Times New Roman', 'size': 17}) 35 | plt.xticks(fontproperties='Times New Roman',fontsize=15) 36 | 37 | plt.ylabel('Proportion(%)', fontdict={'family': 'Times New Roman', 'size': 17}) # 添加x轴的标签 38 | plt.yticks(fontproperties='Times New Roman',fontsize=15) 39 | plt.ylim(0,0.85) 40 | 41 | # 绘制左边的图 42 | plt.bar(r0, y0, width=bar_width, color="#eed777", edgecolor="k",linewidth=2, label='ResNet-50') 43 | plt.bar(r1, y1, width=bar_width, color="#45a776", edgecolor="k",linewidth=2, label='ConvNext-T') 44 | plt.bar(r2, y2, width=bar_width, color="#b3974e", edgecolor="k",linewidth=2, label="ViT-S") 45 | for i, value in zip(r0,y0): 46 | plt.text(i, value, "{:.1f}".format(value*100), ha='center', va='bottom', color="black", fontsize=13, fontproperties='Times New Roman') 47 | for i, value in zip(r1,y1): 48 | plt.text(i, value, "{:.1f}".format(value*100), ha='center', va='bottom', color="black", fontsize=13, fontproperties='Times New Roman') 49 | for i, value in zip(r2,y2): 50 | plt.text(i, value, "{:.1f}".format(value*100), ha='center', va='bottom', color="black", fontsize=13, fontproperties='Times New Roman') 51 | 52 | # 添加刻度标签 53 | plt.xticks([r + bar_width for r in range(len(categories))], categories) 54 | 55 | plt.legend(prop={'family': 'Times New Roman', 'size': 15}) 56 | 57 | # 调整布局 58 | plt.tight_layout() 59 | 60 | # 显示图形 61 | plt.savefig("SDM@K_analyze.eps", dpi = 600) 62 | 63 | 64 | -------------------------------------------------------------------------------- /tool/SDM@K_compare.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import sys 4 | import matplotlib as mpl 5 | mpl.rcParams['font.family'] = 'Times New Roman' 6 | 7 | 8 | def evaluateSingle(distance, K): 9 | # maxDistance = max(distance) + 1e-14 10 | # weight = np.ones(K) - np.log(range(1, K + 1, 1)) / np.log(opts.M * K) 11 | weight = np.ones(K) - np.array(range(0, K, 1))/K 12 | # m1 = distance / maxDistance 13 | m2 = 1 / np.exp(distance*5e3) 14 | m3 = m2 * weight 15 | result = np.sum(m3) / np.sum(weight) 16 | return result 17 | 18 | 19 | def evaluate_enclidean(distance, K): 20 | # maxDistance = max(distance) + 1e-14 21 | # weight = np.ones(K) - np.log(range(1, K + 1, 1)) / np.log(opts.M * K) 22 | weight = np.ones(K) - np.array(range(0, K, 1))/K 23 | # m1 = distance / maxDistance 24 | m2 = 1-distance*1e3 25 | m3 = m2 * weight 26 | result = np.sum(m3) / np.sum(weight) 27 | return result 28 | 29 | 30 | def Recall_Data(data): 31 | x_len, y_len = data.shape 32 | data = np.zeros_like(data) 33 | data[x_len//2, y_len//2] = 1 34 | return data 35 | 36 | def euclideanDistance(query, gallery): 37 | query = np.array(query, dtype=np.float32) 38 | gallery = np.array(gallery, dtype=np.float32) 39 | A = gallery - query 40 | A_T = A.transpose() 41 | distance = np.matmul(A, A_T) 42 | mask = np.eye(distance.shape[0], dtype=np.bool8) 43 | distance = distance[mask] 44 | distance = np.sqrt(distance.reshape(-1)) 45 | return distance 46 | 47 | 48 | def SDM_Data(data): 49 | x_len, y_len = data.shape 50 | x = np.linspace(120.358111-0.0003, 120.358111+0.0003, x_len) 51 | y = np.linspace(30.317842-0.0003, 30.317842+0.0003, y_len) 52 | x_,y_ = np.meshgrid(x,y) 53 | x_ = x_.reshape(-1,1) 54 | y_ = y_.reshape(-1,1) 55 | input = np.concatenate((x_,y_),axis=-1) 56 | 57 | target = np.array((120.358111,30.317842)).reshape(-1,2) 58 | distance = euclideanDistance(input, target) 59 | # compute single query evaluate result 60 | P_list = np.array([evaluateSingle(dist_single, 1) for dist_single in distance]) 61 | return P_list.reshape(x_len,y_len) 62 | 63 | 64 | def Enclidean_Data(data): 65 | x_len, y_len = data.shape 66 | x = np.linspace(120.358111-0.0003, 120.358111+0.0003, x_len) 67 | y = np.linspace(30.317842-0.0003, 30.317842+0.0003, y_len) 68 | x_,y_ = np.meshgrid(x,y) 69 | x_ = x_.reshape(-1,1) 70 | y_ = y_.reshape(-1,1) 71 | input = np.concatenate((x_,y_),axis=-1) 72 | 73 | target = np.array((120.358111,30.317842)).reshape(-1,2) 74 | distance = euclideanDistance(input, target) 75 | # compute single query evaluate result 76 | P_list = np.array([evaluate_enclidean(dist_single, 1) for dist_single in distance]) 77 | return P_list.reshape(x_len,y_len) 78 | 79 | 80 | 81 | # 创建一个随机的2D数组作为网格数据 82 | data = np.random.rand(7, 7) 83 | 84 | Recall_data = Recall_Data(data) 85 | SDM_data = SDM_Data(data) 86 | Enclidean_data = Enclidean_Data(data) 87 | 88 | 89 | # 创建一个figure和axes对象 90 | fig, ax = plt.subplots(1,3,figsize=(14,4)) 91 | 92 | # 使用imshow函数显示网格数据 93 | img1 = ax[0].imshow(Recall_data, cmap='coolwarm', interpolation='nearest') 94 | img2 = ax[2].imshow(SDM_data, cmap='coolwarm', interpolation='nearest') 95 | img3 = ax[1].imshow(Enclidean_data, cmap='coolwarm', interpolation='nearest') 96 | 97 | 98 | # 在每个格子中间显示数值 99 | for i in range(data.shape[0]): 100 | for j in range(data.shape[1]): 101 | ax[0].text(j, i, f'{Recall_data[i, j]:.1f}', ha='center', va='center', color='white') 102 | ax[2].text(j, i, f'{SDM_data[i, j]:.1f}', ha='center', va='center', color='white') 103 | ax[1].text(j, i, f'{Enclidean_data[i, j]:.1f}', ha='center', va='center', color='white') 104 | 105 | for a in ax: 106 | a.set_xticks([]) 107 | a.set_yticks([]) 108 | a.set_xticklabels([]) 109 | a.set_yticklabels([]) 110 | 111 | # 添加颜色条 112 | cbar = plt.colorbar(img1) 113 | 114 | 115 | # 设置子图标题 116 | ax[0].set_title('(a) Recall', fontsize=16, pad=10) 117 | ax[2].set_title('(b) SDM', fontsize=16, pad=10) 118 | ax[1].set_title('(c) Euclidean Distance', fontsize=16, pad=10) 119 | 120 | plt.subplots_adjust(wspace=0.0, hspace=0.000) 121 | 122 | # 调整布局,确保子图之间的间距合适 123 | plt.tight_layout() 124 | 125 | 126 | 127 | # ax[0].grid(True, which='both', linestyle='-', linewidth=2, color='black') 128 | # ax[1].grid(True, which='both', linestyle='-', linewidth=2, color='black') 129 | # ax[2].grid(True, which='both', linestyle='-', linewidth=2, color='black') 130 | 131 | plt.savefig("SDM_Recall_Enclidean_compare.eps", dpi=300) 132 | -------------------------------------------------------------------------------- /tool/applications/forwardAllSatelliteHub.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import print_function, division 4 | 5 | import argparse 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | import torch.backends.cudnn as cudnn 10 | import numpy as np 11 | from torchvision import datasets, models, transforms 12 | import time 13 | import os 14 | import scipy.io 15 | import yaml 16 | import math 17 | from tool.utils import load_network 18 | from tqdm import tqdm 19 | import warnings 20 | from datasets.Dataloader_University import DataLoader_Inference 21 | warnings.filterwarnings("ignore") 22 | from datasets.queryDataset import Dataset_query,Query_transforms 23 | # Options 24 | # -------- 25 | University="计量" 26 | parser = argparse.ArgumentParser(description='Training') 27 | parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2') 28 | parser.add_argument('--root',default='/media/dmmm/4T-3/DataSets/DenseCV_Data/inference_data/satelliteHub({})'.format(University),type=str, help='./test_data') 29 | parser.add_argument('--savename', default='features{}.mat'.format(University), type=str, help='save model path') 30 | parser.add_argument('--checkpoint', default='net_119.pth', type=str, help='save model path') 31 | parser.add_argument('--batchsize', default=128, type=int, help='batchsize') 32 | parser.add_argument('--h', default=256, type=int, help='height') 33 | parser.add_argument('--w', default=256, type=int, help='width') 34 | parser.add_argument('--ms',default='1', type=str,help='multiple_scale: e.g. 1 1,1.1 1,1.1,1.2') 35 | parser.add_argument('--num_worker',default=4, type=int,help='1:drone->satellite 2:satellite->drone') 36 | 37 | opt = parser.parse_args() 38 | ###load config### 39 | # load the training config 40 | config_path = 'opts.yaml' 41 | with open(config_path, 'r') as stream: 42 | config = yaml.load(stream) 43 | for cfg,value in config.items(): 44 | setattr(opt,cfg,value) 45 | 46 | if 'h' in config: 47 | opt.h = config['h'] 48 | opt.w = config['w'] 49 | 50 | if 'nclasses' in config: # tp compatible with old config files 51 | opt.nclasses = config['nclasses'] 52 | else: 53 | opt.nclasses = 729 54 | 55 | str_ids = opt.gpu_ids.split(',') 56 | 57 | gpu_ids = [] 58 | for str_id in str_ids: 59 | id = int(str_id) 60 | if id >=0: 61 | gpu_ids.append(id) 62 | 63 | print('We use the scale: %s'%opt.ms) 64 | str_ms = opt.ms.split(',') 65 | ms = [] 66 | for s in str_ms: 67 | s_f = float(s) 68 | ms.append(math.sqrt(s_f)) 69 | 70 | # set gpu ids 71 | if len(gpu_ids)>0: 72 | torch.cuda.set_device(gpu_ids[0]) 73 | cudnn.benchmark = True 74 | 75 | ###################################################################### 76 | # Load Data 77 | # --------- 78 | # 79 | # We will use torchvision and torch.utils.data packages for loading the 80 | # data. 81 | # 82 | def fliplr(img): 83 | '''flip horizontal''' 84 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W 85 | img_flip = img.index_select(3,inv_idx) 86 | return img_flip 87 | 88 | data_transforms = transforms.Compose([ 89 | transforms.Resize((opt.h, opt.w), interpolation=3), 90 | transforms.ToTensor(), 91 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 92 | ]) 93 | 94 | 95 | image_datasets = DataLoader_Inference(root=opt.root,transforms=data_transforms) 96 | 97 | dataloaders = torch.utils.data.DataLoader(image_datasets, batch_size=opt.batchsize, 98 | shuffle=False, num_workers=opt.num_worker) 99 | use_gpu = torch.cuda.is_available() 100 | 101 | 102 | 103 | def extract_feature(model,dataloaders, view_index = 1): 104 | features = torch.FloatTensor() 105 | count = 0 106 | for data in tqdm(dataloaders): 107 | img, label = data 108 | n, c, h, w = img.size() 109 | count += n 110 | # if opt.LPN: 111 | # # ff = torch.FloatTensor(n,2048,6).zero_().cuda() 112 | # ff = torch.FloatTensor(n,512,opt.block).zero_().cuda() 113 | # else: 114 | # ff = torch.FloatTensor(n, 2048).zero_().cuda() 115 | input_img = Variable(img.cuda()) 116 | outputs, _ = model(input_img, None) 117 | outputs = outputs[1] 118 | ff=outputs 119 | # norm feature 120 | if len(ff.shape)==3: 121 | # feature size (n,2048,6) 122 | # 1. To treat every part equally, I calculate the norm for every 2048-dim part feature. 123 | # 2. To keep the cosine score==1, sqrt(6) is added to norm the whole feature (2048*6). 124 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * np.sqrt(opt.block) 125 | ff = ff.div(fnorm.expand_as(ff)) 126 | ff = ff.view(ff.size(0), -1) 127 | else: 128 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) 129 | ff = ff.div(fnorm.expand_as(ff)) 130 | 131 | features = torch.cat((features,ff.data.cpu()), 0) 132 | return features 133 | 134 | 135 | 136 | # Load Collected data Trained model 137 | model = load_network(opt) 138 | model = model.eval() 139 | if use_gpu: 140 | model = model.cuda() 141 | # Extract feature 142 | since = time.time() 143 | 144 | 145 | images = image_datasets.imgs 146 | 147 | labels = image_datasets.labels 148 | 149 | if __name__ == "__main__": 150 | with torch.no_grad(): 151 | features = extract_feature(model,dataloaders) 152 | 153 | # Save to Matlab for check 154 | result = {'features':features.numpy(),'labels':labels} 155 | scipy.io.savemat(opt.savename,result) 156 | -------------------------------------------------------------------------------- /tool/applications/inference_global.py: -------------------------------------------------------------------------------- 1 | from __future__ import with_statement 2 | import argparse 3 | import scipy.io 4 | import torch 5 | import numpy as np 6 | from datasets.queryDataset import CenterCrop 7 | import glob 8 | from torchvision import transforms 9 | from PIL import Image 10 | import yaml 11 | from tool.utils import load_network 12 | from torch.autograd import Variable 13 | import torch.backends.cudnn as cudnn 14 | import os 15 | from tool.get_property import find_GPS_image 16 | import cv2 17 | 18 | ####################################################################### 19 | # Evaluate 20 | University="计量" 21 | parser = argparse.ArgumentParser(description='Demo') 22 | parser.add_argument('--imageDir', default="/media/dmmm/CE31-3598/DataSets/DenseCV_Data/实际测试图像({})/test02".format(University), type=str, 23 | help='test_image_index') 24 | parser.add_argument('--satelliteMat', default="features{}.mat".format(University), type=str, help='./test_data') 25 | parser.add_argument('--MapDir', default="../../maps/{}.tif".format(University), type=str, help='./test_data') 26 | parser.add_argument('--galleryPath', default="/media/dmmm/4T-3/DataSets/DenseCV_Data/inference_data/satelliteHub({})".format(University), 27 | type=str, help='./test_data') 28 | opts = parser.parse_args() 29 | # TN30.325763471673625 TE120.37341802729506 BN30.320529681696023 BE120.38174250997761 jinrong 30 | 31 | mapPosInfodir = "/home/dmmm/PycharmProjects/DenseCV/demo/maps/pos.txt" 32 | with open(mapPosInfodir,"r") as F: 33 | listLine = F.readlines() 34 | for line in listLine: 35 | name,TN,TE,BN,BE = line.split(" ") 36 | if name==University: 37 | startE = eval(TE.split("TE")[-1]) 38 | startN = eval(TN.split("TN")[-1]) 39 | endE = eval(BE.split("BE")[-1]) 40 | endN = eval(BN.split("BN")[-1]) 41 | 42 | AllImage = cv2.imread(opts.MapDir) 43 | h, w, c = AllImage.shape 44 | 45 | 46 | def generateDictOfGalleryPosInfo(): 47 | satellite_configDict = {} 48 | with open(os.path.join(opts.galleryPath, "PosInfo.txt"), "r") as F: 49 | context = F.readlines() 50 | for line in context: 51 | splitLineList = line.split(" ") 52 | satellite_configDict[splitLineList[0]] = [float(splitLineList[1].split("N")[-1]), 53 | float(splitLineList[2].split("E")[-1])] 54 | return satellite_configDict 55 | 56 | 57 | def sort_(list): 58 | list_ = [int(i.split(".JPG")[-2].split("DJI_")[-1]) for i in list] 59 | arg = np.argsort(list_) 60 | newlist = np.array(list)[arg] 61 | return newlist 62 | 63 | 64 | ####################################################################### 65 | # sort the images and return topK index 66 | def getBestImage(qf, gf, gl): 67 | query = qf.view(-1, 1) 68 | # print(query.shape) 69 | score = torch.mm(gf, query) 70 | score = score.squeeze().cpu() 71 | score = score.numpy() 72 | # predict index 73 | index = np.argsort(score) # from small to large 74 | index = index[::-1] 75 | # index = index[0:2000] 76 | # good index 77 | # query_index = np.argwhere(gl == ql) 78 | 79 | # good_index = np.setdiff1d(query_index, camera_index, assume_unique=True) 80 | # junk_index = np.argwhere(gl == -1) 81 | 82 | # mask = np.in1d(index, junk_index, invert=True) 83 | # index = index[mask] 84 | return gl[index[0]] 85 | 86 | 87 | data_transforms = transforms.Compose([ 88 | CenterCrop(), 89 | transforms.Resize((256, 256), interpolation=3), 90 | transforms.ToTensor(), 91 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 92 | ]) 93 | 94 | query_paths = glob.glob(opts.imageDir + "/*.JPG") 95 | # sorted(query_paths,key=lambda x : int(x.split(".JPG")[-2].split("DJI_")[-1])) 96 | query_paths = sort_(query_paths) 97 | 98 | 99 | ##################################################################### 100 | def extract_feature(img, model): 101 | count = 0 102 | n, c, h, w = img.size() 103 | count += n 104 | input_img = Variable(img.cuda()) 105 | outputs, _ = model(input_img, None) 106 | ff = outputs 107 | # norm feature 108 | if len(ff.shape) == 3: 109 | # 1. To treat every part equally, I calculate the norm for every 2048-dim part feature. 110 | # 2. To keep the cosine score==1, sqrt(6) is added to norm the whole feature (2048*6). 111 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * np.sqrt(opts.block) 112 | ff = ff.div(fnorm.expand_as(ff)) 113 | ff = ff.view(ff.size(0), -1) 114 | else: 115 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) 116 | ff = ff.div(fnorm.expand_as(ff)) 117 | 118 | features = ff.data 119 | return features 120 | 121 | 122 | def getPosInfo(imgPath): 123 | GPS_info = find_GPS_image(imgPath) 124 | x = list(GPS_info.values()) 125 | gps_dict_formate = x[0] 126 | y = list(gps_dict_formate.values()) 127 | height = eval(y[5]) 128 | E = y[3] 129 | N = y[1] 130 | return [N, E] 131 | 132 | 133 | def imshowByIndex(index): 134 | galleryPath = os.path.join(opts.galleryPath, index + ".tif") 135 | image = cv2.imread(galleryPath) 136 | cv2.imshow("gallery", image) 137 | 138 | 139 | ###################################################################### 140 | # load network 141 | config_path = 'opts.yaml' 142 | with open(config_path, 'r') as stream: 143 | config = yaml.load(stream) 144 | opts.stride = config['stride'] 145 | opts.views = config['views'] 146 | opts.transformer = config['transformer'] 147 | opts.pool = config['pool'] 148 | opts.views = config['views'] 149 | opts.LPN = config['LPN'] 150 | opts.block = config['block'] 151 | opts.nclasses = config['nclasses'] 152 | opts.droprate = config['droprate'] 153 | opts.share = config['share'] 154 | opts.checkpoint = "net_119.pth" 155 | torch.cuda.set_device("cuda:0") 156 | cudnn.benchmark = True 157 | 158 | model = load_network(opts) 159 | model = model.eval() 160 | model = model.cuda() 161 | 162 | ###################################################################### 163 | result = scipy.io.loadmat(opts.satelliteMat) 164 | gallery_feature = torch.FloatTensor(result['features']) 165 | gallery_label = result['labels'] 166 | gallery_feature = gallery_feature.cuda() 167 | 168 | satellitePosInfoDict = generateDictOfGalleryPosInfo() # 字典中的数组第一位为N 第二位为E 169 | firstPos = getPosInfo(query_paths[0]) 170 | # gmap = gmplot.GoogleMapPlotter(firstPos[0], firstPos[1], 19) 171 | 172 | queryPosDict = {"N": [], "E": []} 173 | galleryPosDict = {"N": [], "E": []} 174 | for query in query_paths: 175 | queryPosInfo = getPosInfo(query) 176 | queryPosDict["N"].append(float(queryPosInfo[0])) 177 | queryPosDict["E"].append(float(queryPosInfo[1])) 178 | img = Image.open(query) 179 | input = data_transforms(img) 180 | input = torch.unsqueeze(input, 0) 181 | with torch.no_grad(): 182 | feature = extract_feature(input, model) 183 | bestIndex = getBestImage(feature, gallery_feature, gallery_label) 184 | imshowByIndex(bestIndex) 185 | bestMatchedPosInfo = satellitePosInfoDict[bestIndex] 186 | galleryPosDict["N"].append(float(bestMatchedPosInfo[0])) 187 | galleryPosDict["E"].append(float(bestMatchedPosInfo[1])) 188 | print("query--N:{} E:{} gallery--N:{} E:{}".format(queryPosInfo[0], queryPosInfo[1], bestMatchedPosInfo[0], 189 | bestMatchedPosInfo[1])) 190 | # cv2.waitKey(0) 191 | 192 | result = {"query": queryPosDict, "gallery": galleryPosDict} 193 | scipy.io.savemat('global_{}.mat'.format(University), result) 194 | 195 | index = 1 196 | for N, E in zip(queryPosDict["N"], queryPosDict["E"]): 197 | X = int((E - startE) / (endE - startE) * w) 198 | Y = int((N - startN) / (endN - startN) * h) 199 | if index>=10: 200 | cv2.circle(AllImage, (X, Y), 50, color=(255, 0, 0), thickness=8) 201 | cv2.putText(AllImage, str(index), (X - 40, Y + 25), cv2.FONT_HERSHEY_COMPLEX, 2.2, color=(255, 0, 0), 202 | thickness=3) 203 | else: 204 | cv2.circle(AllImage, (X, Y), 50, color=(255, 0, 0), thickness=8) 205 | cv2.putText(AllImage, str(index), (X - 30, Y + 30), cv2.FONT_HERSHEY_COMPLEX, 3, color=(255, 0, 0), 206 | thickness=3) 207 | index += 1 208 | 209 | index = 1 210 | for N, E in zip(galleryPosDict["N"], galleryPosDict["E"]): 211 | X = int((E - startE) / (endE - startE) * w) 212 | Y = int((N - startN) / (endN - startN) * h) 213 | if index>=10: 214 | cv2.circle(AllImage, (X, Y), 50, color=(0, 0, 255), thickness=8) 215 | cv2.putText(AllImage, str(index), (X - 40, Y + 25), cv2.FONT_HERSHEY_COMPLEX, 2.2, color=(0, 0, 255), 216 | thickness=3) 217 | else: 218 | cv2.circle(AllImage, (X, Y), 50, color=(0, 0, 255), thickness=8) 219 | cv2.putText(AllImage, str(index), (X - 30, Y + 30), cv2.FONT_HERSHEY_COMPLEX, 3, color=(0, 0, 255), 220 | thickness=3) 221 | index += 1 222 | 223 | # AllImage = cv2.resize(AllImage,(0,0),fx=0.25,fy=0.25) 224 | cv2.imwrite("global_{}.tif".format(University), AllImage) 225 | # gmap.plot(queryPosDict["N"], queryPosDict["E"], color="red") 226 | # gmap.plot(galleryPosDict["N"], galleryPosDict["E"], color="blue") 227 | # 228 | # gmap.draw("user001_map.html") 229 | -------------------------------------------------------------------------------- /tool/dataset_preprocess/1-generateSatelliteByUav.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import glob 4 | from collections import defaultdict 5 | from get_property import find_GPS_image 6 | from utils import get_fileNames 7 | import sys 8 | from tqdm import tqdm 9 | from multiprocessing import Pool 10 | ''' 11 | 通过Uav图像找出satellite图像,用于训练部分 12 | ''' 13 | root_dir = "/media/dmmm/4T-3/DataSets/DenseCV_Data/高度数据集/oridata/train/University_UAV_Images" 14 | tif_dir = "/media/dmmm/4T-3/DataSets/DenseCV_Data/高度数据集/oridata/train/old_tif" 15 | PlaceNameList = [] 16 | 17 | for root, PlaceName, files in os.walk(root_dir): 18 | PlaceNameList = PlaceName 19 | break 20 | 21 | # PlaceNameList = os.listdir(root_dir) 22 | # y = 10x-500 23 | correspond_size = {'80':640,'90':768,'100':896} 24 | 25 | place_info_dict = defaultdict(list) 26 | with open(os.path.join(tif_dir,"PosInfo.txt"),"r") as F: 27 | context = F.readlines() 28 | for line in context: 29 | name = line.split(" ")[0] 30 | TN = float(line.split((" "))[1].split("TN")[-1]) 31 | TE = float(line.split((" "))[2].split("TE")[-1]) 32 | BN = float(line.split((" "))[3].split("BN")[-1]) 33 | BE = float(line.split((" "))[4].split("BE")[-1]) 34 | place_info_dict[name] = [TN,TE,BN,BE] 35 | 36 | 37 | def process(place): 38 | place_root = os.path.join(root_dir,place) 39 | cur_TN,cur_TE,cur_BN,cur_BE = place_info_dict[place] 40 | satellite_tif = os.path.join(tif_dir,place + ".tif") 41 | 42 | BigSatellite = cv2.imread(satellite_tif) 43 | h,w,c = BigSatellite.shape 44 | JPG_List = get_fileNames(place_root,endwith=".JPG") 45 | for JPG in tqdm(JPG_List): 46 | satelliteTif = JPG.replace(".JPG","_satellite_old.tif") 47 | GPS_info = find_GPS_image(JPG) 48 | gps_dict_formate = list(GPS_info.values())[0] 49 | y = list(gps_dict_formate.values()) 50 | E,N = y[3],y[1] 51 | satellite_size = correspond_size[JPG.split("/")[-2]] 52 | centerX = (E-cur_TE)/(cur_BE-cur_TE)*w # 计算当前无人机位置对应大图中的位置 53 | centerY = (N-cur_TN)/(cur_BN-cur_TN)*h 54 | 55 | if centerYw-satellite_size/2 or centerY>h-satellite_size/2: 56 | raise ValueError("切取区域超出图像范围") 57 | 58 | TLX = int(centerX-satellite_size/2) 59 | TLY = int(centerY-satellite_size/2) 60 | BRX = int(centerX+satellite_size/2) 61 | BRY = int(centerY+satellite_size/2) 62 | 63 | cropImage = BigSatellite[TLY:BRY,TLX:BRX,:] # 切出指定位置的内容 64 | 65 | cv2.imwrite(satelliteTif,cropImage) 66 | 67 | p = Pool(10) 68 | # 遍历每个地方的文件加,首先获取到大地图,然后进行切割 69 | for res in p.imap(process, PlaceNameList): 70 | pass 71 | 72 | p.close() 73 | -------------------------------------------------------------------------------- /tool/dataset_preprocess/2-generate_new_croped_resized_images_difheights.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from utils import get_fileNames 3 | import os 4 | from get_property import find_GPS_image 5 | from tqdm import tqdm 6 | import numpy as np 7 | import functools 8 | import sys 9 | from multiprocessing import Pool 10 | 11 | 12 | UAV_target_size = (1440,1080) 13 | Satellite_target_size = (512,512) 14 | 15 | def sixNumber(str_number): 16 | str_number=str(str_number) 17 | while(len(str_number)<6): 18 | str_number='0'+str_number 19 | return str_number 20 | 21 | def compare_personal(x,y): 22 | x1 = int(os.path.split(x)[1].split(".JPG")[0]) 23 | y1 = int(os.path.split(y)[1].split(".JPG")[0]) 24 | return x1-y1 25 | 26 | # from center crop image 27 | def center_crop_and_resize(img,target_size=None): 28 | h,w,c = img.shape 29 | min_edge = min((h,w)) 30 | if min_edge==h: 31 | edge_lenth = int((w-min_edge)/2) 32 | new_image = img[:,edge_lenth:w-edge_lenth,:] 33 | else: 34 | edge_lenth = int((h - min_edge) / 2) 35 | new_image = img[edge_lenth:h-edge_lenth, :, :] 36 | assert new_image.shape[0]==new_image.shape[1],"the shape is not correct" 37 | # LINEAR Interpolation 38 | if target_size: 39 | new_image = cv2.resize(new_image,target_size) 40 | 41 | return new_image 42 | 43 | def resize(img,target_size=None): 44 | # LINEAR Interpolation 45 | return cv2.resize(img,target_size) 46 | 47 | 48 | def getFileNameList(fullnamelist): 49 | list_return = [] 50 | for i in fullnamelist: 51 | _,filename = os.path.split(i) 52 | list_return.append(filename) 53 | return list_return 54 | 55 | def process(info): 56 | index, [drone_80,drone_90,drone_100] = info 57 | satellite_80 = drone_80.replace(".JPG","_satellite.tif") 58 | satellite_90 = drone_90.replace(".JPG","_satellite.tif") 59 | satellite_100 = drone_100.replace(".JPG","_satellite.tif") 60 | if not (os.path.exists(satellite_80) and os.path.exists(satellite_90) and os.path.exists(satellite_100)): 61 | print("没有对应的satellite图像存在,请查看{}".format(satellite_80+" "+satellite_90+" "+satellite_100)) 62 | sys.exit(0) 63 | # ----new added---- 64 | satellite_80_old = drone_80.replace(".JPG","_satellite_old.tif") 65 | satellite_90_old = drone_90.replace(".JPG","_satellite_old.tif") 66 | satellite_100_old = drone_100.replace(".JPG","_satellite_old.tif") 67 | if not (os.path.exists(satellite_80_old) and os.path.exists(satellite_90_old) and os.path.exists(satellite_100_old)): 68 | print("没有对应的satellite图像存在,请查看{}".format(satellite_80_old+" "+satellite_90_old+" "+satellite_100_old)) 69 | sys.exit(0) 70 | 71 | name = sixNumber(index) 72 | droneCurDir = os.path.join(dronePath, name) 73 | SatelliteCurDir = os.path.join(satellitePath, name) 74 | os.makedirs(droneCurDir,exist_ok=True) 75 | os.makedirs(SatelliteCurDir,exist_ok=True) 76 | 77 | # # load drone and satellite image 78 | # drone80_img = cv2.imread(drone_80) 79 | # drone90_img = cv2.imread(drone_90) 80 | # drone100_img = cv2.imread(drone_100) 81 | # satellite80_img = cv2.imread(satellite_80) 82 | # satellite90_img = cv2.imread(satellite_90) 83 | # satellite100_img = cv2.imread(satellite_100) 84 | # # ---new added--- 85 | # satellite80_img_old = cv2.imread(satellite_80_old) 86 | # satellite90_img_old = cv2.imread(satellite_90_old) 87 | # satellite100_img_old = cv2.imread(satellite_100_old) 88 | 89 | # # process image including crop and resize 90 | # processed_drone80_img = resize(drone80_img,target_size=UAV_target_size) 91 | # processed_drone90_img = resize(drone90_img,target_size=UAV_target_size) 92 | # processed_drone100_img = resize(drone100_img,target_size=UAV_target_size) 93 | # processed_satellite80_img = resize(satellite80_img,target_size=Satellite_target_size) 94 | # processed_satellite90_img = resize(satellite90_img,target_size=Satellite_target_size) 95 | # processed_satellite100_img = resize(satellite100_img,target_size=Satellite_target_size) 96 | # # ---new added--- 97 | # processed_satellite80_img_old = resize(satellite80_img_old,target_size=Satellite_target_size) 98 | # processed_satellite90_img_old = resize(satellite90_img_old,target_size=Satellite_target_size) 99 | # processed_satellite100_img_old = resize(satellite100_img_old,target_size=Satellite_target_size) 100 | 101 | # cv2.imwrite(os.path.join(droneCurDir, "H80.JPG"), processed_drone80_img) 102 | # cv2.imwrite(os.path.join(droneCurDir, "H90.JPG"), processed_drone90_img) 103 | # cv2.imwrite(os.path.join(droneCurDir, "H100.JPG"), processed_drone100_img) 104 | satelliteImgPath80 = os.path.join(SatelliteCurDir,"H80.tif") 105 | satelliteImgPath90 = os.path.join(SatelliteCurDir,"H90.tif") 106 | satelliteImgPath100 = os.path.join(SatelliteCurDir,"H100.tif") 107 | # cv2.imwrite(satelliteImgPath80, processed_satellite80_img) 108 | # cv2.imwrite(satelliteImgPath90, processed_satellite90_img) 109 | # cv2.imwrite(satelliteImgPath100, processed_satellite100_img) 110 | # # # ---new added--- 111 | # satelliteImgPath80_old = os.path.join(SatelliteCurDir,"H80_old.tif") 112 | # satelliteImgPath90_old = os.path.join(SatelliteCurDir,"H90_old.tif") 113 | # satelliteImgPath100_old = os.path.join(SatelliteCurDir,"H100_old.tif") 114 | # cv2.imwrite(satelliteImgPath80_old, processed_satellite80_img_old) 115 | # cv2.imwrite(satelliteImgPath90_old, processed_satellite90_img_old) 116 | # cv2.imwrite(satelliteImgPath100_old, processed_satellite100_img_old) 117 | 118 | # write the GPS information 119 | GPS_info = find_GPS_image(drone_80) 120 | x = list(GPS_info.values()) 121 | gps_dict_formate = x[0] 122 | y = list(gps_dict_formate.values()) 123 | height = eval(y[5]) 124 | information = "{} {}{} {}{} {}\n".format(satelliteImgPath80,y[2],y[3],y[0],y[1],height) 125 | # ---new added--- 126 | # information_old = "{} {}{} {}{} {}\n".format(satelliteImgPath80_old,y[2],y[3],y[0],y[1],height) 127 | # return information,information_old 128 | return [information] 129 | 130 | heightList = ["80","90","100"] 131 | index = 2256 # 测试集需要根据训练集的总长设置index 132 | dir_path = "/media/dmmm/4T-3/DataSets/DenseCV_Data/高度数据集/" 133 | mode = "test" 134 | oriPath = os.path.join(dir_path,"oridata", mode, "University_UAV_Images") 135 | dirList = os.listdir(oriPath) 136 | root_dir = os.path.join(dir_path, "data_2021") 137 | mode_dir_path = os.path.join(root_dir, mode) 138 | os.makedirs(mode_dir_path,exist_ok=True) 139 | 140 | # GPS txt file 141 | txt_path = os.path.join(root_dir, "Dense_GPS_{}.txt".format(mode)) 142 | file = open(txt_path, 'w') 143 | dronePath = os.path.join(mode_dir_path, "drone") 144 | os.makedirs(dronePath,exist_ok=True) 145 | satellitePath = os.path.join(mode_dir_path, "satellite") 146 | os.makedirs(satellitePath,exist_ok=True) 147 | 148 | p = Pool(8) 149 | for p_idx, place in enumerate(dirList): 150 | if not os.path.isdir(os.path.join(oriPath,place)): 151 | continue 152 | 153 | Drone_JPG_paths_80 = get_fileNames(os.path.join(oriPath, place, "80"),endwith=".JPG") 154 | Drone_JPG_paths_90 = get_fileNames(os.path.join(oriPath, place, "90"), endwith=".JPG") 155 | Drone_JPG_paths_100 = get_fileNames(os.path.join(oriPath, place, "100"), endwith=".JPG") 156 | Drone_JPG_paths_80 = sorted(Drone_JPG_paths_80,key=functools.cmp_to_key(compare_personal)) 157 | Drone_JPG_paths_90 = sorted(Drone_JPG_paths_90,key=functools.cmp_to_key(compare_personal)) 158 | Drone_JPG_paths_100 = sorted(Drone_JPG_paths_100,key=functools.cmp_to_key(compare_personal)) 159 | f_80 = getFileNameList(Drone_JPG_paths_80) 160 | f_90 = getFileNameList(Drone_JPG_paths_90) 161 | f_100 = getFileNameList(Drone_JPG_paths_100) 162 | assert f_80==f_90 and f_90==f_100, "数据存在不对应请检查{}".format(place) 163 | print("Set index for the every iter") 164 | indexed_iters = [] 165 | for ind, (drone_80,drone_90,drone_100) in tqdm(enumerate(zip(Drone_JPG_paths_80,Drone_JPG_paths_90,Drone_JPG_paths_100))): 166 | indexed_iters.append([index+ind, [drone_80,drone_90,drone_100]]) 167 | index+=len(indexed_iters) 168 | for idx, res in enumerate(p.imap(process,indexed_iters)): 169 | if idx%50==0: 170 | print("-{}- {}/{} {}".format(p_idx, idx, len(indexed_iters), place)) 171 | for info in res: 172 | file.write(info) 173 | 174 | file.close() 175 | 176 | p.close() -------------------------------------------------------------------------------- /tool/dataset_preprocess/3-generate_format_testset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from tqdm import tqdm 4 | import glob 5 | from multiprocessing import Pool 6 | 7 | 8 | def processData(datalist,targetDir): 9 | os.makedirs(targetDir,exist_ok=True) 10 | for dir in tqdm(datalist): 11 | name = os.path.basename(dir) 12 | target_dir = os.path.join(targetDir,name) 13 | shutil.copytree(dir,target_dir) 14 | 15 | 16 | def main(): 17 | rootDir = "/media/dmmm/4T-3/DataSets/DenseCV_Data/高度数据集/data_2021/" 18 | # # 测试集 19 | # testDir = os.path.join(rootDir,"test") 20 | # ClassForTestDrone = glob.glob(os.path.join(testDir, "drone/*")) 21 | # ClassForTestSatellite = glob.glob(os.path.join(testDir, "satellite/*")) 22 | 23 | # query_drone = os.path.join(testDir,"query_drone") 24 | # gallery_drone = os.path.join(testDir,"gallery_drone") 25 | # query_satellite = os.path.join(testDir,"query_satellite") 26 | # gallery_satellite = os.path.join(testDir,"gallery_satellite") 27 | 28 | # # 训练集 29 | # trainDir = os.path.join(rootDir, "train") 30 | # ClassForTrainDrone = glob.glob(os.path.join(trainDir, "drone/*")) 31 | # ClassForTrainSatellite = glob.glob(os.path.join(trainDir, "satellite/*")) 32 | 33 | # # process test data 34 | # processData(ClassForTestDrone,query_drone) 35 | # processData(ClassForTestSatellite,query_satellite) 36 | # # gallery需要把测试集和训练集合在一起 37 | # processData(ClassForTrainDrone+ClassForTestDrone,gallery_drone) 38 | # processData(ClassForTrainSatellite+ClassForTestSatellite,gallery_satellite) 39 | 40 | # 生成最终的Dense_GPS_ALL.txt 41 | mode_list = ["train","test"] 42 | total_lines = [] 43 | for mode in mode_list: 44 | mode_txt = os.path.join(rootDir,"Dense_GPS_{}.txt".format(mode)) 45 | with open(mode_txt,"r") as F: 46 | lines = F.readlines() 47 | total_lines.extend(lines) 48 | ALL_filename = os.path.join(rootDir,"Dense_GPS_ALL.txt") 49 | with open(ALL_filename,"w") as F: 50 | for info in total_lines: 51 | F.write(info) 52 | print("success write {}".format(ALL_filename)) 53 | 54 | 55 | if __name__ == '__main__': 56 | main() -------------------------------------------------------------------------------- /tool/dataset_preprocess/TEST-1-downloadALlAndReCut.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from tqdm import tqdm 4 | import os 5 | import shutil 6 | 7 | # TN30.325763481658377 TE120.3524959718277 BN30.318953289529816 BE120.36312306473656 jiliang 8 | # TN30.32631193830858 TE120.36678791233928 BN30.317870065161575 BE120.37745052386553 xianke 9 | # TN30.325763471673625 TE120.37341802729506 BN30.320441588110295 BE120.38193743012363 10 | mapPosInfodir = "/home/dmmm/PycharmProjects/DenseCV/demo/maps/pos.txt" 11 | # get_picture(startE, startN, endE, endN, 20, "./计量.tif", server="Google") 12 | 13 | University = "Jiliang" 14 | 15 | with open(mapPosInfodir,"r") as F: 16 | listLine = F.readlines() 17 | for line in listLine: 18 | name,TN,TE,BN,BE = line.split(" ") 19 | if name==University: 20 | startE = eval(TE.split("TE")[-1]) 21 | startN = eval(TN.split("TN")[-1]) 22 | endE = eval(BE.split("BE")[-1]) 23 | endN = eval(BN.split("BN")[-1]) 24 | 25 | # 图像的根文件夹 26 | dir_path = r"/media/dmmm/CE31-3598/DataSets/DenseCV_Data/satelliteHub({})".format(University) 27 | if os.path.exists(dir_path): 28 | shutil.rmtree(dir_path) 29 | os.mkdir(dir_path) 30 | # 经纬度信息存放文件夹 31 | infoFile = "/media/dmmm/4T-31/DataSets/DenseCV_Data/高度数据集/oridata/train/old_tif/PosInfo.txt".format(University) 32 | file = open(infoFile, "w") 33 | 34 | def sixNumber(str_number): 35 | str_number=str(str_number) 36 | while(len(str_number)<6): 37 | str_number='0'+str_number 38 | return str_number 39 | 40 | oriImage = cv2.imread("/home/dmmm/PycharmProjects/DenseCV/demo/maps/{}.tif".format(University)) 41 | h,w,c = oriImage.shape 42 | 43 | cropedSizeList = [640,768,896] 44 | marginrate = 4 45 | index = 0 46 | for cropedSize in cropedSizeList: 47 | margin = cropedSize//marginrate #间距为图像尺寸的1/marginrate 48 | TopLeftEast = startE + (endE - startE)*cropedSize/2/w 49 | TopLeftNorth = startN + (endN - startN)*cropedSize/2/h 50 | BottomRightEast = endE - (endE - startE)*cropedSize/2/w 51 | BottomRightNorth = endN - (endN - startN)*cropedSize/2/h 52 | 53 | # YY = list(range(cropedSize//2,h-cropedSize//2+margin,margin)) 54 | # XX = list(range(cropedSize//2,w-cropedSize//2+margin,margin)) 55 | # Y_Size = len(YY)#y轴上总共记录图像数 56 | # X_Size = len(XX)#x轴上总共记录图像数 57 | X_Size = (w-cropedSize)//margin 58 | Y_Size = (h-cropedSize)//margin 59 | YY = np.linspace(cropedSize//2,h-cropedSize//2,Y_Size,dtype=np.int16) 60 | XX = np.linspace(cropedSize//2,w-cropedSize//2,X_Size,dtype=np.int16) 61 | 62 | pbar = tqdm(total=Y_Size*X_Size) 63 | 64 | PosInfoN = np.linspace(TopLeftNorth,BottomRightNorth,Y_Size) 65 | PosInfoE = np.linspace(TopLeftEast,BottomRightEast,X_Size) 66 | for n,y in zip(PosInfoN,YY): 67 | for e,x in zip(PosInfoE,XX): 68 | topLX = x - cropedSize//2 69 | topLY = y - cropedSize//2 70 | BottomRX = x + cropedSize//2 71 | BottomRY = y + cropedSize//2 72 | cropImage = oriImage[topLY:BottomRY,topLX:BottomRX,:] 73 | 74 | index_ = sixNumber(index) 75 | 76 | cv2.imwrite(os.path.join(dir_path,index_+".tif"),cropImage) 77 | 78 | file.write("{} N{:.10f} E{:.10f}\n".format(index_,n,e)) 79 | 80 | index+=1 81 | pbar.update() -------------------------------------------------------------------------------- /tool/dataset_preprocess/TEST-2-generateSatelliteHub.py: -------------------------------------------------------------------------------- 1 | # from google_interface import get_picture 2 | import time 3 | from google_interface_2 import get_picture 4 | import os 5 | from tqdm import tqdm 6 | import numpy as np 7 | 8 | ''' 9 | 测试阶段用于产生卫星图像库 10 | ''' 11 | 12 | def sixNumber(str_number): 13 | str_number=str(str_number) 14 | while(len(str_number)<6): 15 | str_number='0'+str_number 16 | return str_number 17 | 18 | 19 | if __name__ == '__main__':# 30.32331706,120.37025917 20 | startE = 120.3524959718277 #start左上角 21 | startN = 30.325763481658377 22 | endE = 120.36312306473656 #end右下角30.32128231,120.37425118, 23 | endN = 30.318953289529816 24 | start_time = time.time() 25 | margin = 0.0001 26 | #图像的根文件夹 27 | dir_path = r"/media/dmmm/CE31-3598/DataSets/DenseCV_Data/satelliteHub(现科dense)" 28 | os.mkdir(dir_path) 29 | #经纬度信息存放文件夹 30 | infoFile = "/media/dmmm/CE31-3598/DataSets/DenseCV_Data/satelliteHub(现科dense)/PosInfo.txt" 31 | file = open(infoFile,"w") 32 | #获取文件夹中所有的jpg图像的路径 33 | # jpg_paths = get_fileNames(dir_path,endwith=".JPG") 34 | index = 0 35 | NorthList = np.linspace(startN, endN, int((startN - endN) / (margin))) 36 | EastList = np.linspace(startE,endE,int((endE-startE)/(margin))) 37 | pbar = tqdm(total=len(NorthList)*len(EastList)) 38 | for North in NorthList: 39 | for East in EastList: 40 | east_left = East - margin 41 | east_right = East + margin 42 | north_top = North + margin 43 | north_bottom = North - margin 44 | 45 | # 设置新的存放的路径 46 | index_ = sixNumber(index) 47 | satellite_tif_path = os.path.join(dir_path,"{}.tif".format(index_)) 48 | get_picture(east_left, north_top, east_right, north_bottom, 19, satellite_tif_path, server="Google") 49 | file.write("{} N{:.10f} E{:.10f}\n".format(index_,North,East)) 50 | index+=1 51 | pbar.update() 52 | 53 | file.close() 54 | pbar.close() -------------------------------------------------------------------------------- /tool/dataset_preprocess/TEST-3-preprocess_difHeightTest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import glob 4 | import tqdm 5 | 6 | 7 | def mkdirAndCopyfile(ori_path,to_path): 8 | if not os.path.exists(to_path): 9 | os.mkdir(to_path) 10 | basename,name = ori_path.split("/")[-2:] 11 | new_basename = os.path.join(to_path,basename) 12 | if os.path.exists(new_basename): 13 | shutil.rmtree(new_basename) 14 | os.mkdir(new_basename) 15 | shutil.copyfile(ori_path,os.path.join(new_basename,name)) 16 | 17 | 18 | 19 | if __name__ == '__main__': 20 | root = "/home/dmmm/Dataset/DenseUAV/data_2022/test" 21 | root_80 = os.path.join(root,"queryDrone80") 22 | root_90 = os.path.join(root,"queryDrone90") 23 | root_100 = os.path.join(root,"queryDrone100") 24 | root_Drone_all = os.path.join(root,"query_drone") 25 | Drone_80_List = glob.glob(os.path.join(root_Drone_all,"*","H80.JPG")) 26 | Drone_90_List = glob.glob(os.path.join(root_Drone_all,"*","H90.JPG")) 27 | Drone_100_List = glob.glob(os.path.join(root_Drone_all,"*","H100.JPG")) 28 | tq = tqdm.tqdm(len(Drone_80_List)) 29 | for H80,H90,H100 in zip(Drone_80_List,Drone_90_List,Drone_100_List): 30 | mkdirAndCopyfile(H80,root_80) 31 | mkdirAndCopyfile(H90,root_90) 32 | mkdirAndCopyfile(H100,root_100) 33 | tq.update() 34 | 35 | 36 | -------------------------------------------------------------------------------- /tool/dataset_preprocess/TEST-Regenerate_dense_testset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import os 4 | import shutil 5 | import glob 6 | import cv2 7 | from multiprocessing import Pool,Manager 8 | import copy 9 | from utils import Distance 10 | 11 | # 对测试卫星图像进行密集切分,old版本,生成一个GPS_ALL.txt文件 12 | 13 | def sixNumber(str_number): 14 | str_number=str(str_number) 15 | while(len(str_number)<6): 16 | str_number='0'+str_number 17 | return str_number 18 | 19 | 20 | type = "new" 21 | 22 | root_dir = "/media/dmmm/4T-3/DataSets/DenseCV_Data/高度数据集/oridata/test/{}_tif".format(type) 23 | 24 | source_loc_info = os.path.join(root_dir, "PosInfo.txt") 25 | 26 | output_dir = "/home/dmmm/Dataset/DenseUAV/data_2022/test/hard_gallery_satellite_ss_interval10m" 27 | os.makedirs(output_dir,exist_ok=True) 28 | 29 | with open(source_loc_info, "r") as F: 30 | context = F.readlines() 31 | 32 | correspond_size = [640,768,896] 33 | correspond_size = correspond_size[1:2] 34 | gap=77 35 | output_size = [256,256] 36 | total_info = [] 37 | 38 | 39 | 40 | # p = Pool(10) 41 | 42 | for line in context: 43 | info = line.strip().split(" ") 44 | university_name, TN, TE, BN, BE = info 45 | TE = eval(TE.split("TE")[-1]) 46 | TN = eval(TN.split("TN")[-1]) 47 | BE = eval(BE.split("BE")[-1]) 48 | BN = eval(BN.split("BN")[-1]) 49 | print("acutal_height:{}--width:{}".format(Distance(TN,TE,BN,TE),Distance(TN,TE,TN,BE))) 50 | 51 | source_map_tif_path = os.path.join(root_dir,"{}.tif".format(university_name)) 52 | source_map_image = cv2.imread(source_map_tif_path) 53 | 54 | manager = Manager() 55 | shared_dict = manager.dict() 56 | shared_dict['total_info'] = [] 57 | 58 | h,w,c = source_map_image.shape 59 | print("image_height:{}--width:{}".format(h,w)) 60 | x,y = np.meshgrid(list(range(0,w-896,gap)),list(range(0,h-896,gap))) 61 | x,y = x.reshape(-1),y.reshape(-1) 62 | inds = np.array(list(range(0,len(x),1))) 63 | 64 | def process(infos): 65 | ind = infos[0] 66 | position = infos[1:] 67 | 68 | info_list = [] 69 | for size in correspond_size: 70 | East = TE + (position[0]+size/2)/w*(BE-TE) 71 | North = TN - (position[1]+size/2)/h*(TN-BN) 72 | image = source_map_image[position[1]:position[1]+size, position[0]:position[0]+size, :] 73 | image = cv2.resize(image, output_size) 74 | filepath = os.path.join(output_dir, "{}_{}_{}_{}.jpg".format(university_name,sixNumber(ind),size,type)) 75 | cv2.imwrite(filepath, image) 76 | pos_info = [filepath, East, North, size] 77 | info_list.append(pos_info) 78 | 79 | return info_list 80 | 81 | p = Pool(20) 82 | for ind,res in enumerate(p.imap(process,zip(inds,x,y))): 83 | if ind % 100 == 0: 84 | print("{}/{} process the image!!".format(ind,len(inds))) 85 | total_info.extend(res) 86 | p.close() 87 | 88 | info_path = "/home/dmmm/Dataset/DenseUAV/data_2022/test/{}_info.txt".format(type) 89 | 90 | F = open(info_path, "w") 91 | for info in total_info: 92 | F.write("{} {} {} {}\n".format(*info)) 93 | 94 | F.close() 95 | 96 | # cat file1.txt file2.txt > merged.txt 合并文件 -------------------------------------------------------------------------------- /tool/dataset_preprocess/get_property.py: -------------------------------------------------------------------------------- 1 | import exifread 2 | import re 3 | import json 4 | import requests 5 | 6 | def latitude_and_longitude_convert_to_decimal_system(*arg): 7 | """ 8 | 经纬度转为小数, param arg: 9 | :return: 十进制小数 10 | """ 11 | return float(arg[0]) + ((float(arg[1]) + (float(arg[2].split('/')[0]) / float(arg[2].split('/')[-1]) / 60)) / 60) 12 | 13 | def find_GPS_image(pic_path): 14 | GPS = {} 15 | date = '' 16 | with open(pic_path, 'rb') as f: 17 | tags = exifread.process_file(f) 18 | for tag, value in tags.items(): 19 | if re.match('GPS GPSLatitudeRef', tag): 20 | GPS['GPSLatitudeRef'] = str(value) 21 | elif re.match('GPS GPSLongitudeRef', tag): 22 | GPS['GPSLongitudeRef'] = str(value) 23 | elif re.match('GPS GPSAltitudeRef', tag): 24 | GPS['GPSAltitudeRef'] = str(value) 25 | elif re.match('GPS GPSLatitude', tag): 26 | try: 27 | match_result = re.match('\[(\w*),(\w*),(\w.*)/(\w.*)\]', str(value)).groups() 28 | GPS['GPSLatitude'] = int(match_result[0]), int(match_result[1]), int(match_result[2]) 29 | except: 30 | deg, min, sec = [x.replace(' ', '') for x in str(value)[1:-1].split(',')] 31 | GPS['GPSLatitude'] = latitude_and_longitude_convert_to_decimal_system(deg, min, sec) 32 | elif re.match('GPS GPSLongitude', tag): 33 | try: 34 | match_result = re.match('\[(\w*),(\w*),(\w.*)/(\w.*)\]', str(value)).groups() 35 | GPS['GPSLongitude'] = int(match_result[0]), int(match_result[1]), int(match_result[2]) 36 | except: 37 | deg, min, sec = [x.replace(' ', '') for x in str(value)[1:-1].split(',')] 38 | GPS['GPSLongitude'] = latitude_and_longitude_convert_to_decimal_system(deg, min, sec) 39 | elif re.match('GPS GPSAltitude', tag): 40 | GPS['GPSAltitude'] = str(value) 41 | elif re.match('.*Date.*', tag): 42 | date = str(value) 43 | return {'GPS_information': GPS, 'date_information': date} 44 | 45 | def find_address_from_GPS(GPS): 46 | """ 47 | 使用Geocoding API把经纬度坐标转换为结构化地址。 48 | :param GPS: 49 | :return: 50 | """ 51 | secret_key = 'zbLsuDDL4CS2U0M4KezOZZbGUY9iWtVf' 52 | if not GPS['GPS_information']: 53 | return '该照片无GPS信息' 54 | lat, lng = GPS['GPS_information']['GPSLatitude'], GPS['GPS_information']['GPSLongitude'] 55 | baidu_map_api = "http://api.map.baidu.com/geocoder/v2/?ak={0}&callback=renderReverse&location={1},{2}s&output=json&pois=0".format( 56 | secret_key, lat, lng) 57 | response = requests.get(baidu_map_api) 58 | content = response.text.replace("renderReverse&&renderReverse(", "")[:-1] 59 | baidu_map_address = json.loads(content) 60 | formatted_address = baidu_map_address["result"]["formatted_address"] 61 | province = baidu_map_address["result"]["addressComponent"]["province"] 62 | city = baidu_map_address["result"]["addressComponent"]["city"] 63 | district = baidu_map_address["result"]["addressComponent"]["district"] 64 | return formatted_address,province,city,district 65 | 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /tool/dataset_preprocess/split_dataset_long_middle_short.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import glob 4 | base_dir = "/home/dmmm/University-Release" 5 | 6 | class MakeDataset(): 7 | def __init__(self,name): 8 | self.name = name 9 | self.dir = os.path.join(base_dir, name) 10 | self.ori_dir = os.path.join(base_dir,"test") 11 | self.target_path = os.path.join(self.dir,"query_drone") 12 | # 移动query_drone的图片 13 | self.copy_pictures() 14 | # 移动其他需要的文件 15 | self.copy_other() 16 | 17 | 18 | def copy_pictures(self): 19 | class_list = os.listdir(os.path.join(self.ori_dir,"query_drone")) 20 | for i in class_list: 21 | #新建分类对应的文件夹 22 | self.mkdir(os.path.join(self.target_path,i)) 23 | 24 | #取出每个Long Middle Short对应的图片 25 | path = os.path.join(self.ori_dir,"query_drone",i) 26 | tar_path = os.path.join(self.target_path,i) 27 | img_list = os.listdir(path) 28 | img_list.sort() 29 | long_num = len(img_list)//3 30 | middle_num = len(img_list)//3*2 31 | short_num = len(img_list) 32 | if self.name =="Long": 33 | list = img_list[:long_num] 34 | elif self.name == "Middle": 35 | list = img_list[long_num:middle_num] 36 | elif self.name == "Short": 37 | list = img_list[middle_num:short_num] 38 | else: 39 | raise ValueError("输入的name参数有误,必须为Long、Middle或Short") 40 | 41 | #复制图片到指定路径 42 | for j in list: 43 | path_j = os.path.join(path,j) 44 | path_t = os.path.join(tar_path,j) 45 | shutil.copyfile(path_j,path_t) 46 | 47 | 48 | 49 | 50 | def mkdir(self,path): 51 | if not os.path.exists(path): 52 | os.makedirs(path) 53 | 54 | def copy_other(self): 55 | filename_list = ["gallery_drone","gallery_satellite","query_satellite"] 56 | for i in filename_list: 57 | source_path = os.path.join(self.ori_dir,i) 58 | target_path = os.path.join(self.dir,i) 59 | shutil.copytree(source_path, target_path) 60 | 61 | 62 | 63 | 64 | if __name__ == '__main__': 65 | MakeDataset("Long") 66 | MakeDataset("Middle") 67 | MakeDataset("Short") -------------------------------------------------------------------------------- /tool/dataset_preprocess/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import math 4 | 5 | def get_fileNames(rootdir,endwith=".JPG"): 6 | fs = [] 7 | for root, dirs, files in os.walk(rootdir,topdown = True): 8 | for name in files: 9 | _, ending = os.path.splitext(name) 10 | if ending == endwith: 11 | fs.append(os.path.join(root,name)) 12 | return fs 13 | 14 | 15 | def Distance(lata, loga, latb, logb): 16 | # EARTH_RADIUS = 6371.0 17 | EARTH_RADIUS = 6378.137 18 | PI = math.pi 19 | # // 转弧度 20 | lat_a = lata * PI / 180 21 | lat_b = latb * PI / 180 22 | a = lat_a - lat_b 23 | b = loga * PI / 180 - logb * PI / 180 24 | dis = 2 * math.asin(math.sqrt(math.pow(math.sin(a / 2), 2) + math.cos(lat_a) 25 | * math.cos(lat_b) * math.pow(math.sin(b / 2), 2))) 26 | 27 | distance = EARTH_RADIUS * dis * 1000 28 | return distance -------------------------------------------------------------------------------- /tool/dataset_preprocess/validation_testset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import os 4 | import shutil 5 | import glob 6 | import cv2 7 | from multiprocessing import Pool 8 | import copy 9 | 10 | type = "new" 11 | 12 | info_file = "/home/dmmm/Dataset/DenseUAV/data_2022/test/{}_info.txt".format(type) 13 | 14 | root_dir = "/home/dmmm/Dataset/DenseUAV/data_2022/test/{}_tif".format(type) 15 | 16 | source_loc_info = os.path.join(root_dir, "PosInfo.txt") 17 | 18 | with open(source_loc_info, "r") as F: 19 | context_map = F.readlines() 20 | 21 | with open(info_file,"r") as F: 22 | context = F.readlines() 23 | 24 | line = context[100] 25 | infos = line.strip().split(" ") 26 | filename = infos[0] 27 | query_image = cv2.imread(filename) 28 | 29 | E = float(infos[1]) 30 | N = float(infos[2]) 31 | 32 | for line in context_map: 33 | info = line.strip().split(" ") 34 | university_name, TN, TE, BN, BE = info 35 | TE = eval(TE.split("TE")[-1]) 36 | TN = eval(TN.split("TN")[-1]) 37 | BE = eval(BE.split("BE")[-1]) 38 | BN = eval(BN.split("BN")[-1]) 39 | 40 | source_map_tif_path = os.path.join(root_dir,"{}.tif".format(university_name)) 41 | source_map_image = cv2.imread(source_map_tif_path) 42 | h,w = source_map_image.shape[:2] 43 | 44 | if TE<=E<=BE and BN<=N<=TN: 45 | x = (E-TE)/(BE-TE)*w 46 | y = (N-TN)/(BN-TN)*h 47 | break 48 | 49 | cv2.circle(source_map_image,(int(x),int(y)),40,(255,0,0),20) 50 | cv2.imwrite("map.jpg",source_map_image) 51 | cv2.imwrite("query.jpg",query_image) 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /tool/get_inference_time.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | sys.path.append("../") 4 | import yaml 5 | import argparse 6 | import torch 7 | from tool.utils import load_network, calc_flops_params 8 | import time 9 | 10 | 11 | parser = argparse.ArgumentParser(description='Training') 12 | parser.add_argument('--name', default='resnet', 13 | type=str, help='save model path') 14 | parser.add_argument('--checkpoint', default='../net_119.pth', 15 | type=str, help='save model path') 16 | parser.add_argument('--test_h', default=224, type=int, help='height') 17 | parser.add_argument('--test_w', default=224, type=int, help='width') 18 | parser.add_argument('--calc_nums', default=2000, type=int, help='width') 19 | opt = parser.parse_args() 20 | 21 | config_path = '../opts.yaml' 22 | with open(config_path, 'r') as stream: 23 | config = yaml.load(stream) 24 | for cfg, value in config.items(): 25 | setattr(opt, cfg, value) 26 | 27 | model = load_network(opt).cuda() 28 | model = model.eval() 29 | 30 | # thop计算MACs 31 | macs, params = calc_flops_params( 32 | model, (1, 3, opt.test_h, opt.test_w), (1, 3, opt.test_h, opt.test_w)) 33 | input_size_drone = (1, 3, opt.test_h, opt.test_w) 34 | input_size_satellite = (1, 3, opt.test_h, opt.test_w) 35 | 36 | inputs_drone = torch.randn(input_size_drone).cuda() 37 | inputs_satellite = torch.randn(input_size_satellite).cuda() 38 | 39 | # 预热 40 | for _ in range(10): 41 | model(inputs_drone,inputs_satellite) 42 | 43 | since = time.time() 44 | for _ in range(opt.calc_nums): 45 | model(inputs_drone,inputs_satellite) 46 | 47 | 48 | print("inference_time = {}s".format(time.time()-since)) -------------------------------------------------------------------------------- /tool/get_model_flops_params.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | sys.path.append("../../") 4 | import yaml 5 | import argparse 6 | 7 | from tool.utils import load_network, calc_flops_params 8 | 9 | parser = argparse.ArgumentParser(description='Training') 10 | parser.add_argument('--name', default='resnet', 11 | type=str, help='save model path') 12 | parser.add_argument('--checkpoint', default='net_119.pth', 13 | type=str, help='save model path') 14 | parser.add_argument('--test_h', default=224, type=int, help='height') 15 | parser.add_argument('--test_w', default=224, type=int, help='width') 16 | opt = parser.parse_args() 17 | 18 | config_path = 'opts.yaml' 19 | with open(config_path, 'r') as stream: 20 | config = yaml.load(stream) 21 | for cfg, value in config.items(): 22 | setattr(opt, cfg, value) 23 | 24 | model = load_network(opt).cuda() 25 | model = model.eval() 26 | 27 | # thop计算MACs 28 | macs, params = calc_flops_params( 29 | model, (1, 3, opt.test_h, opt.test_w), (1, 3, opt.test_h, opt.test_w)) 30 | print("model MACs={}, Params={}".format(macs, params)) 31 | -------------------------------------------------------------------------------- /tool/get_property.py: -------------------------------------------------------------------------------- 1 | import exifread 2 | import re 3 | import json 4 | import requests 5 | 6 | def latitude_and_longitude_convert_to_decimal_system(*arg): 7 | """ 8 | 经纬度转为小数, param arg: 9 | :return: 十进制小数 10 | """ 11 | return float(arg[0]) + ((float(arg[1]) + (float(arg[2].split('/')[0]) / float(arg[2].split('/')[-1]) / 60)) / 60) 12 | 13 | def find_GPS_image(pic_path): 14 | GPS = {} 15 | date = '' 16 | with open(pic_path, 'rb') as f: 17 | tags = exifread.process_file(f) 18 | for tag, value in tags.items(): 19 | if re.match('GPS GPSLatitudeRef', tag): 20 | GPS['GPSLatitudeRef'] = str(value) 21 | elif re.match('GPS GPSLongitudeRef', tag): 22 | GPS['GPSLongitudeRef'] = str(value) 23 | elif re.match('GPS GPSAltitudeRef', tag): 24 | GPS['GPSAltitudeRef'] = str(value) 25 | elif re.match('GPS GPSLatitude', tag): 26 | try: 27 | match_result = re.match('\[(\w*),(\w*),(\w.*)/(\w.*)\]', str(value)).groups() 28 | GPS['GPSLatitude'] = int(match_result[0]), int(match_result[1]), int(match_result[2]) 29 | except: 30 | deg, min, sec = [x.replace(' ', '') for x in str(value)[1:-1].split(',')] 31 | GPS['GPSLatitude'] = latitude_and_longitude_convert_to_decimal_system(deg, min, sec) 32 | elif re.match('GPS GPSLongitude', tag): 33 | try: 34 | match_result = re.match('\[(\w*),(\w*),(\w.*)/(\w.*)\]', str(value)).groups() 35 | GPS['GPSLongitude'] = int(match_result[0]), int(match_result[1]), int(match_result[2]) 36 | except: 37 | deg, min, sec = [x.replace(' ', '') for x in str(value)[1:-1].split(',')] 38 | GPS['GPSLongitude'] = latitude_and_longitude_convert_to_decimal_system(deg, min, sec) 39 | elif re.match('GPS GPSAltitude', tag): 40 | GPS['GPSAltitude'] = str(value) 41 | elif re.match('.*Date.*', tag): 42 | date = str(value) 43 | return {'GPS_information': GPS, 'date_information': date} 44 | 45 | def find_address_from_GPS(GPS): 46 | """ 47 | 使用Geocoding API把经纬度坐标转换为结构化地址。 48 | :param GPS: 49 | :return: 50 | """ 51 | secret_key = 'zbLsuDDL4CS2U0M4KezOZZbGUY9iWtVf' 52 | if not GPS['GPS_information']: 53 | return '该照片无GPS信息' 54 | lat, lng = GPS['GPS_information']['GPSLatitude'], GPS['GPS_information']['GPSLongitude'] 55 | baidu_map_api = "http://api.map.baidu.com/geocoder/v2/?ak={0}&callback=renderReverse&location={1},{2}s&output=json&pois=0".format( 56 | secret_key, lat, lng) 57 | response = requests.get(baidu_map_api) 58 | content = response.text.replace("renderReverse&&renderReverse(", "")[:-1] 59 | baidu_map_address = json.loads(content) 60 | formatted_address = baidu_map_address["result"]["formatted_address"] 61 | province = baidu_map_address["result"]["addressComponent"]["province"] 62 | city = baidu_map_address["result"]["addressComponent"]["city"] 63 | district = baidu_map_address["result"]["addressComponent"]["district"] 64 | return formatted_address,province,city,district 65 | 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /tool/mount_dist.sh: -------------------------------------------------------------------------------- 1 | sudo mkdir /media/dmmm/4T-3 2 | sudo mount /dev/sda3 /media/dmmm/4T-3 -------------------------------------------------------------------------------- /tool/transforms/rotatetranformtest.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from math import cos,sin,pi 4 | from PIL import Image 5 | import os 6 | 7 | #生成随机rotateandcrop的结果并保存到rotateShow文件夹下 8 | 9 | 10 | class RotateAndCrop(object): 11 | def __init__(self): 12 | pass 13 | 14 | def __call__(self, img): 15 | img_=np.array(img).copy() 16 | 17 | def getPosByAngle(img, angle): 18 | h, w, c = img.shape 19 | x_center = y_center = h // 2 20 | r = h // 2 21 | angle_lt = angle - 45 22 | angle_rt = angle + 45 23 | angle_lb = angle - 135 24 | angle_rb = angle + 135 25 | angleList = [angle_lt, angle_rt, angle_lb, angle_rb] 26 | pointsList = [] 27 | for angle in angleList: 28 | x1 = x_center + r * cos(angle * pi / 180) 29 | y1 = y_center + r * sin(angle * pi / 180) 30 | pointsList.append([x1, y1]) 31 | pointsOri = np.float32(pointsList) 32 | pointsListAfter = np.float32([[0, 0], [512, 0], [0, 512], [512, 512]]) 33 | M = cv2.getPerspectiveTransform(pointsOri, pointsListAfter) 34 | res = cv2.warpPerspective(img, M, (512, 512)) 35 | return res 36 | 37 | if not os.path.exists("rotateShow"): 38 | os.mkdir("rotateShow") 39 | img.save("./rotateShow/ori.png") 40 | for i in range(10): 41 | angle = int(np.random.random() * 360) 42 | new_image = getPosByAngle(img_,angle) 43 | image = Image.fromarray(new_image.astype('uint8')).convert('RGB') 44 | image.save("./rotateShow/{}.png".format(i)) 45 | 46 | if __name__ == '__main__': 47 | img = Image.open("/media/dmmm/CE31-3598/DataSets/DenseCV_Data/实际测试图像(计量)/part2/DJI_0317.JPG") 48 | rotate = RotateAndCrop() 49 | rotate(img) 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /tool/transforms/transform_visual.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("../../") 3 | from datasets.queryDataset import RotateAndCrop, RandomErasing 4 | from torchvision import transforms 5 | from PIL import Image 6 | import numpy as np 7 | import cv2 8 | import argparse 9 | import os 10 | 11 | 12 | def get_parse(): 13 | parser = argparse.ArgumentParser(description='Transfrom Visualization') 14 | parser.add_argument( 15 | '--image_path', default='/home/dmmm/VscodeProject/demo_DenseUAV/visualization/rotateShow/ori.png', type=str, help='') 16 | parser.add_argument( 17 | '--target_dir', default='/home/dmmm/VscodeProject/demo_DenseUAV/visualization/ColorJitter', type=str, help='') 18 | parser.add_argument( 19 | '--num_aug', default=10, type=int, help='') 20 | 21 | opt = parser.parse_args() 22 | return opt 23 | 24 | 25 | if __name__ == '__main__': 26 | opt = get_parse() 27 | image = Image.open(opt.image_path) 28 | 29 | re = RandomErasing(probability=1.0) 30 | ra = transforms.RandomAffine(180) 31 | rac = RotateAndCrop(rate=1.0) 32 | cj = transforms.ColorJitter(brightness=0.5, contrast=0.1, saturation=0.1, hue=0) 33 | 34 | os.makedirs(opt.target_dir, exist_ok=True) 35 | 36 | for ind in range(opt.num_aug): 37 | image_ = cj(image) 38 | image_ = np.array(image_) 39 | image_ = image_[:, :, [2, 1, 0]] 40 | h, w = image_.shape[:2] 41 | # image_ = cv2.circle( 42 | # image_.copy(), (int(w/2), int(h/2)), 3, (0, 0, 255), 2) 43 | cv2.imwrite(os.path.join(opt.target_dir, "{}.jpg".format(ind)), image_) 44 | -------------------------------------------------------------------------------- /tool/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import yaml 4 | import random 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import cv2 8 | from shutil import copyfile, copytree, rmtree 9 | import logging 10 | from models.taskflow import make_model 11 | from thop import profile, clever_format 12 | import math 13 | 14 | 15 | def get_logger(filename, verbosity=1, name=None): 16 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING} 17 | formatter = logging.Formatter( 18 | "[%(asctime)s][%(levelname)s] %(message)s" 19 | ) 20 | logger = logging.getLogger(name) 21 | logger.setLevel(level_dict[verbosity]) 22 | 23 | fh = logging.FileHandler(filename, "w") 24 | fh.setFormatter(formatter) 25 | logger.addHandler(fh) 26 | 27 | sh = logging.StreamHandler() 28 | sh.setFormatter(formatter) 29 | logger.addHandler(sh) 30 | 31 | return logger 32 | 33 | 34 | def copy_file_or_tree(path, target_dir): 35 | target_path = os.path.join(target_dir, path) 36 | if os.path.isdir(path): 37 | if os.path.exists(target_path): 38 | rmtree(target_path) 39 | copytree(path, target_path) 40 | elif os.path.isfile(path): 41 | copyfile(path, target_path) 42 | 43 | 44 | def copyfiles2checkpoints(opt): 45 | dir_name = os.path.join('checkpoints', opt.name) 46 | if not os.path.isdir(dir_name): 47 | os.mkdir(dir_name) 48 | # record every run 49 | copy_file_or_tree('train.py', dir_name) 50 | copy_file_or_tree('test.py', dir_name) 51 | copy_file_or_tree('evaluate_gpu.py', dir_name) 52 | copy_file_or_tree('evaluateDistance.py', dir_name) 53 | copy_file_or_tree('datasets', dir_name) 54 | copy_file_or_tree('losses', dir_name) 55 | copy_file_or_tree('models', dir_name) 56 | copy_file_or_tree('optimizers', dir_name) 57 | copy_file_or_tree('tool', dir_name) 58 | copy_file_or_tree('train_test_local.sh', dir_name) 59 | 60 | # save opts 61 | with open('%s/opts.yaml' % dir_name, 'w') as fp: 62 | yaml.dump(vars(opt), fp, default_flow_style=False) 63 | 64 | 65 | def make_weights_for_balanced_classes(images, nclasses): 66 | count = [0] * nclasses 67 | for item in images: 68 | count[item[1]] += 1 # count the image number in every class 69 | weight_per_class = [0.] * nclasses 70 | N = float(sum(count)) 71 | for i in range(nclasses): 72 | weight_per_class[i] = N/float(count[i]) 73 | weight = [0] * len(images) 74 | for idx, val in enumerate(images): 75 | weight[idx] = weight_per_class[val[1]] 76 | return weight 77 | 78 | # Get model list for resume 79 | 80 | 81 | def get_model_list(dirname, key): 82 | if os.path.exists(dirname) is False: 83 | print('no dir: %s' % dirname) 84 | return None 85 | gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if 86 | os.path.isfile(os.path.join(dirname, f)) and key in f and ".pth" in f] 87 | if gen_models is None: 88 | return None 89 | gen_models.sort() 90 | last_model_name = gen_models[-1] 91 | return last_model_name 92 | 93 | ###################################################################### 94 | # Save model 95 | # --------------------------- 96 | 97 | 98 | def save_network(network, dirname, epoch_label): 99 | if not os.path.isdir('./checkpoints/'+dirname): 100 | os.mkdir('./checkpoints/'+dirname) 101 | if isinstance(epoch_label, int): 102 | save_filename = 'net_%03d.pth' % epoch_label 103 | else: 104 | save_filename = 'net_%s.pth' % epoch_label 105 | save_path = os.path.join('./checkpoints', dirname, save_filename) 106 | torch.save(network.cpu().state_dict(), save_path) 107 | if torch.cuda.is_available: 108 | network.cuda() 109 | 110 | 111 | class UnNormalize(object): 112 | def __init__(self, mean, std): 113 | self.mean = mean 114 | self.std = std 115 | 116 | def __call__(self, tensor): 117 | """ 118 | Args: 119 | :param tensor: tensor image of size (B,C,H,W) to be un-normalized 120 | :return: UnNormalized image 121 | """ 122 | for t, m, s in zip(tensor, self.mean, self.std): 123 | t.mul_(s).add_(m) 124 | return tensor 125 | 126 | 127 | def check_box(images, boxes): 128 | # Unorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 129 | # images = Unorm(images)*255 130 | images = images.permute(0, 2, 3, 1).cpu().detach().numpy() 131 | boxes = (boxes.cpu().detach().numpy()/16*255).astype(np.int) 132 | for img, box in zip(images, boxes): 133 | fig = plt.figure() 134 | ax = fig.add_subplot(111) 135 | plt.imshow(img) 136 | rect = plt.Rectangle(box[0:2], box[2]-box[0], box[3]-box[1]) 137 | ax.add_patch(rect) 138 | plt.show() 139 | 140 | 141 | ###################################################################### 142 | # Load model for resume 143 | # --------------------------- 144 | def load_network(opt): 145 | save_filename = opt.checkpoint 146 | model = make_model(opt) 147 | # print('Load the model from %s' % save_filename) 148 | network = model 149 | network.load_state_dict(torch.load(save_filename)) 150 | return network 151 | 152 | 153 | def toogle_grad(model, requires_grad): 154 | for p in model.parameters(): 155 | p.requires_grad_(requires_grad) 156 | 157 | 158 | def update_average(model_tgt, model_src, beta): 159 | toogle_grad(model_src, False) 160 | toogle_grad(model_tgt, False) 161 | 162 | param_dict_src = dict(model_src.named_parameters()) 163 | 164 | for p_name, p_tgt in model_tgt.named_parameters(): 165 | p_src = param_dict_src[p_name] 166 | assert(p_src is not p_tgt) 167 | p_tgt.copy_(beta*p_tgt + (1. - beta)*p_src) 168 | 169 | toogle_grad(model_src, True) 170 | 171 | 172 | def get_preds(outputs, outputs2): 173 | if isinstance(outputs, list): 174 | preds = [] 175 | preds2 = [] 176 | for out, out2 in zip(outputs, outputs2): 177 | preds.append(torch.max(out.data, 1)[1]) 178 | preds2.append(torch.max(out2.data, 1)[1]) 179 | else: 180 | _, preds = torch.max(outputs.data, 1) 181 | _, preds2 = torch.max(outputs2.data, 1) 182 | return preds, preds2 183 | 184 | 185 | def calc_flops_params(model, 186 | input_size_drone, 187 | input_size_satellite, 188 | ): 189 | inputs_drone = torch.randn(input_size_drone).cuda() 190 | inputs_satellite = torch.randn(input_size_satellite).cuda() 191 | total_ops, total_params = profile( 192 | model, (inputs_drone, inputs_satellite,), verbose=False) 193 | macs, params = clever_format([total_ops, total_params], "%.3f") 194 | return macs, params 195 | 196 | 197 | def set_seed(seed): 198 | torch.manual_seed(seed) 199 | torch.cuda.manual_seed_all(seed) 200 | np.random.seed(seed) 201 | torch.backends.cudnn.deterministic = True 202 | torch.backends.cudnn.benchmark = False 203 | torch.backends.cudnn.enabled = True 204 | random.seed(seed) 205 | 206 | 207 | def Distance(lata, loga, latb, logb): 208 | # EARTH_RADIUS = 6371.0 209 | EARTH_RADIUS = 6378.137 210 | PI = math.pi 211 | # // 转弧度 212 | lat_a = lata * PI / 180 213 | lat_b = latb * PI / 180 214 | a = lat_a - lat_b 215 | b = loga * PI / 180 - logb * PI / 180 216 | dis = 2 * math.asin(math.sqrt(math.pow(math.sin(a / 2), 2) + math.cos(lat_a) 217 | * math.cos(lat_b) * math.pow(math.sin(b / 2), 2))) 218 | 219 | distance = EARTH_RADIUS * dis * 1000 220 | return distance -------------------------------------------------------------------------------- /tool/visual/Times New Roman.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dmmm1997/DenseUAV/d3e4335fb73e1eeeb8db6771d11f731ac8ef3c14/tool/visual/Times New Roman.ttf -------------------------------------------------------------------------------- /tool/visual/demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import scipy.io 3 | import torch 4 | import numpy as np 5 | import os 6 | from torchvision import datasets 7 | import matplotlib 8 | #matplotlib.use('agg') 9 | import matplotlib.pyplot as plt 10 | ####################################################################### 11 | # Evaluate 12 | parser = argparse.ArgumentParser(description='Demo') 13 | parser.add_argument('--query_index', default=77, type=int, help='test_image_index') 14 | parser.add_argument('--test_dir',default='/home/dmmm/Dataset/DenseUAV/data_2022/test',type=str, help='./test_data') 15 | parser.add_argument('--config', 16 | default="/home/dmmm/Dataset/DenseUAV/data_2022/Dense_GPS_ALL.txt", type=str, 17 | help='./test_data') 18 | opts = parser.parse_args() 19 | 20 | configDict = {} 21 | with open(opts.config, "r") as F: 22 | context = F.readlines() 23 | for line in context: 24 | splitLineList = line.split(" ") 25 | configDict[splitLineList[0].split("/")[-2]] = [float(splitLineList[1].split("E")[-1]), 26 | float(splitLineList[2].split("N")[-1])] 27 | 28 | gallery_name = 'gallery_satellite' 29 | query_name = 'query_drone' 30 | # gallery_name = 'gallery_drone' 31 | # query_name = 'query_satellite' 32 | 33 | data_dir = opts.test_dir 34 | image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ) for x in [gallery_name, query_name]} 35 | 36 | ##################################################################### 37 | #Show result 38 | def imshow(path, title=None): 39 | """Imshow for Tensor.""" 40 | im = plt.imread(path) 41 | plt.imshow(im) 42 | if title is not None: 43 | plt.title(title) 44 | plt.pause(0.1) # pause a bit so that plots are updated 45 | 46 | ###################################################################### 47 | result = scipy.io.loadmat('pytorch_result_1.mat') 48 | query_feature = torch.FloatTensor(result['query_f']) 49 | query_label = result['query_label'][0] 50 | gallery_feature = torch.FloatTensor(result['gallery_f']) 51 | gallery_label = result['gallery_label'][0] 52 | 53 | multi = os.path.isfile('multi_query.mat') 54 | 55 | if multi: 56 | m_result = scipy.io.loadmat('multi_query.mat') 57 | mquery_feature = torch.FloatTensor(m_result['mquery_f']) 58 | mquery_cam = m_result['mquery_cam'][0] 59 | mquery_label = m_result['mquery_label'][0] 60 | mquery_feature = mquery_feature.cuda() 61 | 62 | query_feature = query_feature.cuda() 63 | gallery_feature = gallery_feature.cuda() 64 | 65 | ####################################################################### 66 | # sort the images 67 | def sort_img(qf, ql, gf, gl): 68 | query = qf.view(-1,1) 69 | # print(query.shape) 70 | score = torch.mm(gf,query) 71 | score = score.squeeze(1).cpu() 72 | score = score.numpy() 73 | # predict index 74 | index = np.argsort(score) #from small to large 75 | index = index[::-1] 76 | # index = index[0:2000] 77 | # good index 78 | query_index = np.argwhere(gl==ql) 79 | 80 | #good_index = np.setdiff1d(query_index, camera_index, assume_unique=True) 81 | junk_index = np.argwhere(gl==-1) 82 | 83 | mask = np.in1d(index, junk_index, invert=True) 84 | index = index[mask] 85 | return index 86 | 87 | i = opts.query_index 88 | index = sort_img(query_feature[i],query_label[i],gallery_feature,gallery_label) 89 | 90 | ######################################################################## 91 | # Visualize the rank result 92 | query_path, _ = image_datasets[query_name].imgs[i] 93 | query_label = query_label[i] 94 | print(query_path) 95 | label6num = query_path.split("/")[-2] 96 | x_q,y_q = configDict[label6num] 97 | 98 | print('Top 10 images are as follow:') 99 | save_folder = 'image_show/%02d' % opts.query_index 100 | if not os.path.exists("image_show"): 101 | os.mkdir("image_show") 102 | if not os.path.isdir(save_folder): 103 | os.mkdir(save_folder) 104 | os.system('cp %s %s/query.jpg'%(query_path, save_folder)) 105 | 106 | try: # Visualize Ranking Result 107 | # Graphical User Interface is needed 108 | fig = plt.figure(figsize=(16,4)) 109 | ax = plt.subplot(1,11,1) 110 | ax.axis('off') 111 | imshow(query_path) 112 | ax.set_title("x:{:.7f}\ny:{:.7f}".format(x_q,y_q), color='blue',fontsize=5) 113 | for i in range(10): 114 | ax = plt.subplot(1,11,i+2) 115 | ax.axis('off') 116 | img_path, _ = image_datasets[gallery_name].imgs[index[i]] 117 | label = gallery_label[index[i]] 118 | labelg6num = img_path.split("/")[-2] 119 | x_g,y_g = configDict[label6num] 120 | print(label) 121 | imshow(img_path) 122 | os.system('cp %s %s/s%02d.tif'%(img_path, save_folder, i)) 123 | if label == query_label: 124 | ax.set_title("x:{:.7f}\ny:{:.7f}".format(x_g,y_g),color='green',fontsize=5) 125 | else: 126 | ax.set_title("x:{:.7f}\ny:{:.7f}".format(x_g,y_g), color='red',fontsize=5) 127 | print(img_path) 128 | #plt.pause(100) # pause a bit so that plots are updated 129 | except RuntimeError: 130 | for i in range(10): 131 | img_path = image_datasets.imgs[index[i]] 132 | print(img_path[0]) 133 | print('If you want to see the visualization of the ranking result, graphical user interface is needed.') 134 | 135 | fig.savefig(save_folder+"/show.png",dpi = 600) 136 | 137 | -------------------------------------------------------------------------------- /tool/visual/demo_custom.py: -------------------------------------------------------------------------------- 1 | import scipy.io 2 | import argparse 3 | import torch 4 | import numpy as np 5 | from datasets.queryDataset import CenterCrop 6 | from torchvision import transforms 7 | from PIL import Image 8 | import yaml 9 | from tool.utils_server import load_network 10 | from torch.autograd import Variable 11 | import torch.backends.cudnn as cudnn 12 | import os 13 | from tool.get_property import find_GPS_image 14 | # import matplotlib.pyplot as plt 15 | import cv2 16 | import json 17 | 18 | University="计量" 19 | parser = argparse.ArgumentParser(description='Demo') 20 | parser.add_argument('--img', default="/media/dmmm/CE31-3598/DataSets/DenseCV_Data/实际测试图像({})/test02/DJI_0297.JPG".format(University), type=str, help='image path for visualization') 21 | parser.add_argument("--galleryFeature",default="features{}.mat".format(University), type=str, help='galleryFeature') 22 | parser.add_argument('--galleryPath', default="/media/dmmm/CE31-3598/DataSets/DenseCV_Data/satelliteHub({})".format(University), 23 | type=str, help='./test_data') 24 | parser.add_argument('--MapDir', default="../../maps/{}.tif".format(University), type=str, help='./test_data') 25 | parser.add_argument('--K', default=10, type=int, help='./test_data') 26 | # parser.add_argument("--mode",default="1", type=int, help='1:drone->satellite 2:satellite->drone') 27 | opts = parser.parse_args() 28 | 29 | # 30.3257654243082, 120.37341533989152 30.320441588110295, 120.38193743012363 30 | mapPosInfodir = "/home/dmmm/PycharmProjects/DenseCV/demo/maps/pos.txt" 31 | with open(mapPosInfodir,"r") as F: 32 | listLine = F.readlines() 33 | for line in listLine: 34 | name,TN,TE,BN,BE = line.split(" ") 35 | if name==University: 36 | startE = eval(TE.split("TE")[-1]) 37 | startN = eval(TN.split("TN")[-1]) 38 | endE = eval(BE.split("BE")[-1]) 39 | endN = eval(BN.split("BN")[-1]) 40 | 41 | AllImage = cv2.imread(opts.MapDir) 42 | h, w, c = AllImage.shape 43 | 44 | 45 | ##################################################################### 46 | # #Show result 47 | # def imshow(path, title=None): 48 | # """Imshow for Tensor.""" 49 | # im = plt.imread(path) 50 | # plt.imshow(im) 51 | # if title is not None: 52 | # plt.title(title) 53 | # plt.pause(0.1) # pause a bit so that plots are updated 54 | 55 | 56 | def getPosInfo(imgPath): 57 | GPS_info = find_GPS_image(imgPath) 58 | x = list(GPS_info.values()) 59 | gps_dict_formate = x[0] 60 | y = list(gps_dict_formate.values()) 61 | height = eval(y[5]) 62 | E = y[3] 63 | N = y[1] 64 | return [N, E] 65 | 66 | ####################################################################### 67 | # sort the images and return topK index 68 | def getTopKImage(qf, gf, gl,K): 69 | query = qf.view(-1, 1) 70 | # print(query.shape) 71 | score = torch.mm(gf, query) 72 | score = score.squeeze().cpu() 73 | score = score.numpy() 74 | # predict index 75 | index = np.argsort(score) # from small to large 76 | index = index[::-1] 77 | # index = index[0:2000] 78 | # good index 79 | # query_index = np.argwhere(gl == ql) 80 | 81 | # good_index = np.setdiff1d(query_index, camera_index, assume_unique=True) 82 | # junk_index = np.argwhere(gl == -1) 83 | 84 | # mask = np.in1d(index, junk_index, invert=True) 85 | # index = index[mask] 86 | return gl[index[:K]] 87 | 88 | 89 | def extract_feature(img, model): 90 | count = 0 91 | n, c, h, w = img.size() 92 | count += n 93 | input_img = Variable(img.cuda()) 94 | outputs, _ = model(input_img, None) 95 | ff = outputs 96 | # norm feature 97 | if len(ff.shape) == 3: 98 | # 1. To treat every part equally, I calculate the norm for every 2048-dim part feature. 99 | # 2. To keep the cosine score==1, sqrt(6) is added to norm the whole feature (2048*6). 100 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * np.sqrt(opts.block) 101 | ff = ff.div(fnorm.expand_as(ff)) 102 | ff = ff.view(ff.size(0), -1) 103 | else: 104 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) 105 | ff = ff.div(fnorm.expand_as(ff)) 106 | 107 | features = ff.data 108 | return features 109 | 110 | def generateDictOfGalleryPosInfo(): 111 | satellite_configDict = {} 112 | with open(os.path.join(opts.galleryPath, "PosInfo.txt"), "r") as F: 113 | context = F.readlines() 114 | for line in context: 115 | splitLineList = line.split(" ") 116 | satellite_configDict[splitLineList[0]] = [float(splitLineList[1].split("N")[-1]), 117 | float(splitLineList[2].split("E")[-1])] 118 | return satellite_configDict 119 | 120 | 121 | def imshowByIndex(index): 122 | galleryPath = os.path.join(opts.galleryPath, index + ".tif") 123 | image = cv2.imread(galleryPath) 124 | cv2.imshow("gallery", image) 125 | 126 | data_transforms = transforms.Compose([ 127 | CenterCrop(), 128 | transforms.Resize((256, 256), interpolation=3), 129 | transforms.ToTensor(), 130 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 131 | ]) 132 | 133 | ###################################################################### 134 | # load network 135 | config_path = 'opts.yaml' 136 | with open(config_path, 'r') as stream: 137 | config = yaml.load(stream) 138 | opts.stride = config['stride'] 139 | opts.views = config['views'] 140 | opts.transformer = config['transformer'] 141 | opts.pool = config['pool'] 142 | opts.views = config['views'] 143 | opts.LPN = config['LPN'] 144 | opts.block = config['block'] 145 | opts.nclasses = config['nclasses'] 146 | opts.droprate = config['droprate'] 147 | opts.share = config['share'] 148 | opts.checkpoint = "net_119.pth" 149 | torch.cuda.set_device("cuda:0") 150 | cudnn.benchmark = True 151 | 152 | model = load_network(opts) 153 | model = model.eval() 154 | model = model.cuda() 155 | 156 | ###################################################################### 157 | result = scipy.io.loadmat(opts.galleryFeature) 158 | gallery_feature = torch.FloatTensor(result['features']) 159 | gallery_label = result['labels'] 160 | gallery_feature = gallery_feature.cuda() 161 | 162 | satellitePosInfoDict = generateDictOfGalleryPosInfo() 163 | 164 | img = Image.open(opts.img) 165 | input = data_transforms(img) 166 | input = torch.unsqueeze(input, 0) 167 | # query Pos info 168 | queryPosInfo = getPosInfo(opts.img) 169 | Q_N = float(queryPosInfo[0]) 170 | Q_E = float(queryPosInfo[1]) 171 | 172 | with torch.no_grad(): 173 | feature = extract_feature(input, model) 174 | indexSorted = getTopKImage(feature, gallery_feature, gallery_label,opts.K) 175 | 176 | 177 | selectedGalleryPath = [os.path.join(opts.galleryPath,"{}.tif".format(index)) for index in indexSorted] 178 | dict_paths = {"K":opts.K,"query":opts.img,"gallery":selectedGalleryPath} 179 | with open("test.json", "w", encoding='utf-8') as f: 180 | # indent 超级好用,格式化保存字典,默认为None,小于0为零个空格 181 | # f.write(json.dumps(dict_paths, indent=4)) 182 | json.dump(dict_paths, f, indent=4) # 传入文件描述符,和dumps一样的结果 183 | 184 | 185 | galleryPosDict = {"N": [], "E": []} 186 | for index in indexSorted: 187 | bestMatchedPosInfo = satellitePosInfoDict[index] 188 | galleryPosDict["N"].append(float(bestMatchedPosInfo[0])) 189 | galleryPosDict["E"].append(float(bestMatchedPosInfo[1])) 190 | 191 | ## query visualization 192 | X = int((Q_E - startE) / (endE - startE) * w) 193 | Y = int((Q_N - startN) / (endN - startN) * h) 194 | cv2.circle(AllImage, (X,Y), 30, color=(0, 0, 255), thickness=10) 195 | 196 | ## gallery visualization 197 | index = 1 198 | for N, E in zip(galleryPosDict["N"], galleryPosDict["E"]): 199 | X = int((E - startE) / (endE - startE) * w) 200 | Y = int((N - startN) / (endN - startN) * h) 201 | if index>=10: 202 | cv2.circle(AllImage, (X, Y), 30, color=(255, 0, 0), thickness=6) 203 | cv2.putText(AllImage, str(index), (X - 20, Y + 15), cv2.FONT_HERSHEY_COMPLEX, 1.2, color=(255, 0, 0), 204 | thickness=2) 205 | else: 206 | cv2.circle(AllImage, (X, Y), 30, color=(255, 0, 0), thickness=6) 207 | cv2.putText(AllImage, str(index), (X - 20, Y + 20), cv2.FONT_HERSHEY_COMPLEX, 2, color=(255, 0, 0), 208 | thickness=2) 209 | index += 1 210 | 211 | AllImage = cv2.resize(AllImage,(0,0),fx=0.25,fy=0.25) 212 | cv2.imwrite("topKLocationBySingleImage.tif",AllImage) 213 | 214 | # os.system("python visualization.py") 215 | 216 | ###报qt的错误 217 | # try: 218 | # fig = plt.figure(figsize=(12, 4)) 219 | # ax = plt.subplot(1, opts.K+1, 1) 220 | # ax.axis('off') 221 | # imshow(opts.img, 'query') 222 | # for i,path in enumerate(selectedGalleryPath): 223 | # ax = plt.subplot(1, 11, i + 2) 224 | # ax.axis('off') 225 | # imshow(path) 226 | # except: 227 | # print('If you want to see the visualization of the ranking result, graphical user interface is needed.') 228 | # 229 | # fig.savefig("show.png",dpi=600) 230 | 231 | -------------------------------------------------------------------------------- /tool/visual/demo_custom_visualization.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import json 3 | 4 | 5 | with open("test.json", "r", encoding='utf-8') as f: 6 | data = json.loads(f.read()) # load的传入参数为字符串类型 7 | 8 | K = data["K"] 9 | query = data["query"] 10 | gallery = data["gallery"] 11 | 12 | #Show result 13 | def imshow(path, title=None): 14 | """Imshow for Tensor.""" 15 | im = plt.imread(path) 16 | plt.imshow(im) 17 | if title is not None: 18 | plt.title(title) 19 | # plt.pause(0.1) # pause a bit so that plots are updated 20 | 21 | ##报qt的错误 22 | try: 23 | fig = plt.figure(figsize=(12, 4)) 24 | ax = plt.subplot(1, K+1, 1) 25 | ax.axis('off') 26 | imshow(query, 'query') 27 | for i,path in enumerate(gallery): 28 | ax = plt.subplot(1, 11, i + 2) 29 | ax.axis('off') 30 | imshow(path) 31 | except: 32 | print('If you want to see the visualization of the ranking result, graphical user interface is needed.') 33 | 34 | fig.savefig("show.png",dpi=600) 35 | -------------------------------------------------------------------------------- /tool/visual/draw_SDMcurve.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import os 3 | import json 4 | from glob import glob 5 | from matplotlib import font_manager 6 | 7 | output_dir = "/home/dmmm/VscodeProject/demo_DenseUAV/visualization/SDMCurve/Backbone" 8 | 9 | os.makedirs(output_dir, exist_ok = True) 10 | 11 | plt.rcParams['font.sans-serif'] = ['Times New Roman'] 12 | 13 | source_dir = { 14 | # backbone 15 | "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Backbone_Experiment_ConvnextT":"ConvNeXt-T", 16 | "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Backbone_Experiment_DeitS":"DeiT-S", 17 | "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Backbone_Experiment_EfficientNet-B3":"EfficientNet-B3", 18 | "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Backbone_Experiment_EfficientNet-B5":"EfficientNet-B5", 19 | "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Backbone_Experiment_PvTv2b2":"PvTv2-B2", 20 | "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Backbone_Experiment_resnet50":"ResNet50", 21 | "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Backbone_Experiment_Swinv2T-256":"Swinv2-T", 22 | "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Head_Experiment-Global":"ViT-S", 23 | "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Backbone_Experiment_ViTB":"ViT-B", 24 | # head 25 | # "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Head_Experiment-MaxPool":"MaxPool", 26 | # "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Head_Experiment-AvgMaxPool":"AvgMaxPool", 27 | # "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Head_Experiment-AvgPool":"AvgPool", 28 | # "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Head_Experiment-Global":"Global", 29 | # "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Head_Experiment-GeM":"GemPool", 30 | # "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Head_Experiment-FSRA2B":"FSRA(block=2)", 31 | # "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Head_Experiment-LPN2B":"LPN(Block=2)", 32 | } 33 | 34 | json_name = "SDM*.json" 35 | 36 | fig = plt.figure(figsize=(4,4)) 37 | plt.grid() 38 | plt.yticks(fontproperties='Times New Roman', size=12) 39 | plt.xticks(fontproperties='Times New Roman', size=12) 40 | 41 | x = list(range(1,101)) 42 | 43 | color = ['r', 'k', 'y', 'c', 'g', 'm', 'b', 'coral', 'tan'] 44 | ind = 0 45 | for path, name in source_dir.items(): 46 | print(name) 47 | target_file = glob(os.path.join(path, json_name))[0] 48 | with open(target_file, 'r') as F: 49 | data = json.load(F) 50 | y = list(data.values()) 51 | plt.plot(x,y,c=color[ind],marker = 'o',label=name,linewidth=1.0,markersize=1) 52 | ind+=1 53 | 54 | plt.legend(loc="upper right",prop={'family' : 'Times New Roman', 'size' : 12}) 55 | plt.ylabel("SDM@K",fontdict={'family' : 'Times New Roman', 'size': 12}) 56 | plt.xlabel("K",fontdict={'family' : 'Times New Roman', 'size': 12}) 57 | plt.tight_layout() 58 | 59 | 60 | fig.savefig(os.path.join(output_dir, "backbone.eps"), dpi=600, format='eps') 61 | plt.show() 62 | 63 | -------------------------------------------------------------------------------- /tool/visual/grad_cam.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | import matplotlib.pyplot as plt 6 | from torchvision import transforms 7 | from utils import GradCAM, show_cam_on_image, center_crop_img 8 | from vit_model import vit_base_patch16_224 9 | 10 | 11 | class ReshapeTransform: 12 | def __init__(self, model): 13 | input_size = model.patch_embed.img_size 14 | patch_size = model.patch_embed.patch_size 15 | self.h = input_size[0] // patch_size[0] 16 | self.w = input_size[1] // patch_size[1] 17 | 18 | def __call__(self, x): 19 | # remove cls token and reshape 20 | # [batch_size, num_tokens, token_dim] 21 | result = x[:, 1:, :].reshape(x.size(0), 22 | self.h, 23 | self.w, 24 | x.size(2)) 25 | 26 | # Bring the channels to the first dimension, 27 | # like in CNNs. 28 | # [batch_size, H, W, C] -> [batch, C, H, W] 29 | result = result.permute(0, 3, 1, 2) 30 | return result 31 | 32 | 33 | def main(): 34 | model = vit_base_patch16_224() 35 | # 链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f 36 | weights_path = "./vit_base_patch16_224.pth" 37 | model.load_state_dict(torch.load(weights_path, map_location="cpu")) 38 | # Since the final classification is done on the class token computed in the last attention block, 39 | # the output will not be affected by the 14x14 channels in the last layer. 40 | # The gradient of the output with respect to them, will be 0! 41 | # We should chose any layer before the final attention block. 42 | target_layers = [model.blocks[-1].norm1] 43 | 44 | data_transform = transforms.Compose([transforms.ToTensor(), 45 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 46 | # load image 47 | img_path = "both.png" 48 | assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) 49 | img = Image.open(img_path).convert('RGB') 50 | img = np.array(img, dtype=np.uint8) 51 | img = center_crop_img(img, 224) 52 | # [C, H, W] 53 | img_tensor = data_transform(img) 54 | # expand batch dimension 55 | # [C, H, W] -> [N, C, H, W] 56 | input_tensor = torch.unsqueeze(img_tensor, dim=0) 57 | 58 | cam = GradCAM(model=model, 59 | target_layers=target_layers, 60 | use_cuda=False, 61 | reshape_transform=ReshapeTransform(model)) 62 | target_category = 281 # tabby, tabby cat 63 | # target_category = 254 # pug, pug-dog 64 | 65 | grayscale_cam = cam(input_tensor=input_tensor, target_category=target_category) 66 | 67 | grayscale_cam = grayscale_cam[0, :] 68 | visualization = show_cam_on_image(img / 255., grayscale_cam, use_rgb=True) 69 | plt.imshow(visualization) 70 | plt.show() 71 | 72 | 73 | if __name__ == '__main__': 74 | main() -------------------------------------------------------------------------------- /tool/visual/heatmap.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import matplotlib 5 | matplotlib.use('agg') 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | from tool.utils_server import load_network 9 | import yaml 10 | import argparse 11 | import torch 12 | from torchvision import datasets, models, transforms 13 | from PIL import Image 14 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 15 | parser = argparse.ArgumentParser(description='Training') 16 | import math 17 | 18 | parser.add_argument('--data_dir',default='/media/dmmm/CE31-3598/DataSets/DenseCV_Data/improvedOriData/test',type=str, help='./test_data') 19 | parser.add_argument('--name', default='from_transreid_256_4B_small_lr005_kl', type=str, help='save model path') 20 | parser.add_argument('--batchsize', default=1, type=int, help='batchsize') 21 | parser.add_argument('--checkpoint',default="net_119.pth", help='weights' ) 22 | opt = parser.parse_args() 23 | 24 | config_path = 'opts.yaml' 25 | with open(config_path, 'r') as stream: 26 | config = yaml.load(stream) 27 | opt.stride = config['stride'] 28 | opt.views = config['views'] 29 | opt.transformer = config['transformer'] 30 | opt.pool = config['pool'] 31 | opt.views = config['views'] 32 | opt.LPN = config['LPN'] 33 | opt.block = config['block'] 34 | opt.nclasses = config['nclasses'] 35 | opt.droprate = config['droprate'] 36 | opt.share = config['share'] 37 | 38 | if 'h' in config: 39 | opt.h = config['h'] 40 | opt.w = config['w'] 41 | if 'nclasses' in config: # tp compatible with old config files 42 | opt.nclasses = config['nclasses'] 43 | else: 44 | opt.nclasses = 751 45 | 46 | 47 | def heatmap2d(img, arr): 48 | # fig = plt.figure() 49 | # ax0 = fig.add_subplot(121, title="Image") 50 | # ax1 = fig.add_subplot(122, title="Heatmap") 51 | # fig, ax = plt.subplots() 52 | # ax[0].imshow(Image.open(img)) 53 | plt.figure() 54 | heatmap = plt.imshow(arr, cmap='viridis') 55 | plt.axis('off') 56 | # fig.colorbar(heatmap, fraction=0.046, pad=0.04) 57 | #plt.show() 58 | plt.savefig('heatmap_dbase') 59 | 60 | data_transforms = transforms.Compose([ 61 | transforms.Resize((opt.h, opt.w), interpolation=3), 62 | transforms.ToTensor(), 63 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 64 | ]) 65 | 66 | # image_datasets = {x: datasets.ImageFolder( os.path.join(opt.data_dir,x) ,data_transforms) for x in ['satellite']} 67 | # dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize, 68 | # shuffle=False, num_workers=1) for x in ['satellite']} 69 | 70 | # imgpath = image_datasets['satellite'].imgs 71 | # print(imgpath) 72 | from glob import glob 73 | print(opt.data_dir) 74 | #list = os.listdir(os.path.join(opt.data_dir,"gallery_drone")) 75 | for i in ["000009","000013","000015","000016","000018","000035","000039","000116","000130"]: 76 | print(i) 77 | imgpath = os.path.join(opt.data_dir,"gallery_drone/"+i) 78 | #imgname = 'gallery_drone/0726/image-28.jpeg' 79 | # imgname = 'query_satellite/0721/0721.jpg' 80 | #imgpath = os.path.join(opt.data_dir,imgname) 81 | imgpath = os.path.join(imgpath, "1.JPG") 82 | print(imgpath) 83 | img = Image.open(imgpath) 84 | img = data_transforms(img) 85 | img = torch.unsqueeze(img,0) 86 | #print(img.shape) 87 | model = load_network(opt) 88 | 89 | model = model.eval().cuda() 90 | 91 | # data = next(iter(dataloaders['satellite'])) 92 | # img, label = data 93 | with torch.no_grad(): 94 | # x = model.model_1.model.conv1(img.cuda()) 95 | # x = model.model_1.model.bn1(x) 96 | # x = model.model_1.model.relu(x) 97 | # x = model.model_1.model.maxpool(x) 98 | # x = model.model_1.model.layer1(x) 99 | # x = model.model_1.model.layer2(x) 100 | # x = model.model_1.model.layer3(x) 101 | # output = model.model_1.model.layer4(x) 102 | features = model.model_1.transformer(img.cuda()) 103 | part_features = features[:,1:] 104 | part_features = part_features.view(part_features.size(0),int(math.sqrt(part_features.size(1))),int(math.sqrt(part_features.size(1))),part_features.size(2)) 105 | output = part_features.permute(0,3,1,2) 106 | #print(output.shape) 107 | heatmap = output.squeeze().sum(dim=0).cpu().numpy() 108 | # print(heatmap.shape) 109 | # print(heatmap) 110 | # heatmap = np.mean(heatmap, axis=0) 111 | # 112 | # heatmap = np.maximum(heatmap, 0) 113 | # heatmap /= np.max(heatmap) 114 | heatmap = (heatmap - np.min(heatmap))/(np.max(heatmap)-np.min(heatmap)) 115 | 116 | #print(heatmap) 117 | # cv2.imshow("1",heatmap) 118 | # cv2.waitKey(0) 119 | #test_array = np.arange(100 * 100).reshape(100, 100) 120 | # Result is saved tas `heatmap.png` 121 | img = cv2.imread(imgpath) # 用cv2加载原始图像 122 | heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0])) # 将热力图的大小调整为与原始图像相同 123 | heatmap = np.uint8(255 * heatmap) # 将热力图转换为RGB格式 124 | #print(heatmap) 125 | heatmap = cv2.applyColorMap(heatmap, 2) # 将热力图应用于原始图像model.py 126 | superimposed_img = heatmap * 0.8 + img # 这里的0.4是热力图强度因子 127 | if not os.path.exists("heatout"): 128 | os.mkdir("./heatout") 129 | cv2.imwrite("./heatout/"+i+".jpg", superimposed_img) 130 | #heatmap2d(imgpath,heatmap) -------------------------------------------------------------------------------- /tool/visual_demo.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import random 3 | from glob import glob 4 | import numpy as np 5 | 6 | m_50 = [] 7 | m_20 = [] 8 | m_10 = [] 9 | m_5 = [] 10 | m_3 = [] 11 | m = [m_3,m_5,m_10,m_20,m_50] 12 | 13 | x_label=[] 14 | RDS_value = [] 15 | num_a=[3,5,10,20,50] 16 | filenames = glob("output/result_files/height-level/*.txt") 17 | filenames.sort(key=lambda x:int(x.split("/")[-1].split("_")[0])) 18 | for num, path in enumerate(filenames): 19 | x_label.append(path.split("/")[-1].split("_")[0]) 20 | RDS_value.append(float(path.split("/")[-1].split(".txt")[0].split("=")[-1])) 21 | result = [] 22 | with open(path, "r") as F: 23 | lines = F.readlines() 24 | for i,a in enumerate(num_a): 25 | line = lines[a-1] 26 | out = line.split(' ')[-1] 27 | out = out.split("\n")[0] 28 | m[i].append(float(out)) 29 | 30 | 31 | 32 | fig = plt.figure(figsize=(6, 10)) 33 | ax1 = fig.subplots() 34 | # ax1.tick_params(axis='x', labelrotation=10) 35 | plt.xlabel("Height (m)",fontdict={'family' : 'Times New Roman', 'size': 21}) 36 | plt.xticks(fontproperties='Times New Roman',fontsize=18) 37 | ax1.set_ylabel('Accurary', fontdict={'family': 'Times New Roman', 'size': 21}) # 添加x轴的标签 38 | ax1.set_ylim(0,1.1) 39 | ax1.set_yticklabels([0.0,0.2,0.4,0.6,0.8,1.0], fontproperties='Times New Roman', fontsize=16) # 添加x轴上的标签 40 | 41 | 42 | # 绘制横向条形图 43 | bar1 = ax1.bar(x_label, m_50, width=0.5, color="#45a776", label="MA@"+str(num_a[4]),edgecolor="k",linewidth=2) 44 | bar2 = ax1.bar(x_label, m_20, width=0.5, color="#3682be", label="MA@"+str(num_a[3]),edgecolor="k",linewidth=2) 45 | bar3 = ax1.bar(x_label, m_10, width=0.5, color="#b3974e", label="MA@"+str(num_a[2]),edgecolor="k",linewidth=2) 46 | bar4 = ax1.bar(x_label, m_5, width=0.5, color="#eed777", label="MA@"+str(num_a[1]),edgecolor="k",linewidth=2) 47 | bar5 = ax1.bar(x_label, m_3, width=0.5, color="#f05326", label="MA@"+str(num_a[0]),edgecolor="k",linewidth=2) 48 | 49 | # 在横向条形图上添加数据 50 | for (a, b) in zip(x_label, m_3): 51 | plt.text(a, b-0.03, "{:.3f}".format(b), color='black', fontsize=15, ha="center",va="bottom", fontproperties='Times New Roman') 52 | for (a, b) in zip(x_label, m_5): 53 | plt.text(a, b-0.03, "{:.3f}".format(b), color='black', fontsize=15, ha="center",va="bottom", fontproperties='Times New Roman') 54 | for (a, b) in zip(x_label, m_10): 55 | plt.text(a, b-0.03, "{:.3f}".format(b), color='black', fontsize=15, ha="center",va="bottom", fontproperties='Times New Roman') 56 | for (a, b) in zip(x_label, m_20): 57 | plt.text(a, b-0.03, "{:.3f}".format(b), color='black', fontsize=15, ha="center",va="bottom", fontproperties='Times New Roman') 58 | for (a, b) in zip(x_label, m_50): 59 | plt.text(a, b-0.03, "{:.3f}".format(b), color='black', fontsize=15, ha="center",va="bottom", fontproperties='Times New Roman') 60 | # legend_bar = plt.legend(handles=[bar1,bar2,bar3,bar4,bar5], loc="upper left", ncol=2, prop={'family': 'Times New Roman', 'size': 16}) 61 | 62 | 63 | ax2 = ax1.twinx() 64 | rds, = ax2.plot(x_label,RDS_value,color="red",linestyle="--", marker='*',label="RDS", markersize=16, linewidth=3) 65 | # legend_rds = plt.legend(handles=[rds],loc = "upper right",prop={'family': 'Times New Roman', 'size': 16}) 66 | for ind, (a, b) in enumerate(zip(x_label, RDS_value)): 67 | bias = + 0.002 68 | plt.text(a, b+bias, "{:.3f}".format(b), color='darkred', fontsize=17, ha="center",va="bottom", fontproperties='Times New Roman') 69 | 70 | ax2.set_ylabel('RDS', fontdict={'family': 'Times New Roman', 'size': 21}) # 添加x轴的标签 71 | ax2.set_ylim(0.68,0.82) 72 | # ax1.set_xticks(fontproperties='Times New Roman',fontsize=22) # 添加y轴上的刻度 73 | ax2.set_yticklabels([0.68,0.70,0.72,0.74,0.76,0.78,0.80,0.82,0.84], fontproperties='Times New Roman', fontsize=16) # 添加x轴上的标签 74 | 75 | # ax = ax1.add_artist(legend_bar) 76 | 77 | legend_bar = plt.legend(handles=[bar1,bar2,bar3,bar4,bar5,rds], loc="upper left", ncol=2, prop={'family': 'Times New Roman', 'size': 16}) 78 | 79 | 80 | # 保存图像到当前文件夹下,图像名称为 image.png 81 | plt.tight_layout() 82 | plt.savefig('tool/visual/MA_curve/RDS_MA_height.jpg', dpi=600) 83 | plt.savefig('tool/visual/MA_curve/RDS_MA_height.eps', dpi=600) -------------------------------------------------------------------------------- /train_test_local.sh: -------------------------------------------------------------------------------- 1 | name="baseline" 2 | root_dir="/data/datasets/crossview/DenseUAV/data_2022" 3 | data_dir=$root_dir/train 4 | test_dir=$root_dir/test 5 | gpu_ids=0 6 | num_worker=8 7 | lr=0.01 8 | batchsize=16 9 | sample_num=1 10 | block=1 11 | num_bottleneck=512 12 | backbone="ViTS-224" # resnet50 ViTS-224 senet 13 | head="SingleBranch" 14 | head_pool="avg" # global avg max avg+max 15 | cls_loss="CELoss" # CELoss FocalLoss 16 | feature_loss="WeightedSoftTripletLoss" # TripletLoss HardMiningTripletLoss WeightedSoftTripletLoss ContrastiveLoss 17 | kl_loss="KLLoss" # KLLoss 18 | h=224 19 | w=224 20 | load_from="no" 21 | ra="satellite" # random affine 22 | re="satellite" # random erasing 23 | cj="no" # color jitter 24 | rr="uav" # random rotate 25 | 26 | python train.py --name $name --data_dir $data_dir --gpu_ids $gpu_ids --sample_num $sample_num \ 27 | --block $block --lr $lr --num_worker $num_worker --head $head --head_pool $head_pool \ 28 | --num_bottleneck $num_bottleneck --backbone $backbone --h $h --w $w --batchsize $batchsize --load_from $load_from \ 29 | --ra $ra --re $re --cj $cj --rr $rr --cls_loss $cls_loss --feature_loss $feature_loss --kl_loss $kl_loss 30 | 31 | cd checkpoints/$name 32 | python test.py --name $name --test_dir $test_dir --gpu_ids $gpu_ids --num_worker $num_worker 33 | python evaluate_gpu.py 34 | python evaluateDistance.py --root_dir $root_dir 35 | cd ../../ --------------------------------------------------------------------------------