├── .gitignore
├── .vscode
└── launch.json
├── LICENSE
├── README.md
├── batch_inference_time.sh
├── datasets
├── Dataloader_University.py
├── autoaugment.py
├── make_dataloader.py
└── queryDataset.py
├── docs
├── Get_started.md
├── Request.md
├── images
│ ├── data.jpg
│ ├── framework.jpg
│ └── model.png
└── training_parameters.md
├── evaluateDistance.py
├── evaluateDistance_DifHeight.py
├── evaluateMA.py
├── evaluateMA_dense.py
├── evaluate_RDS.py
├── evaluate_gpu.py
├── heatmap.py
├── losses
├── ArcfaceLoss.py
├── FocalLoss.py
├── TripletLoss.py
├── __Init__.py
├── cal_loss.py
└── loss.py
├── models
├── Backbone
│ ├── RKNet.py
│ ├── __init__.py
│ ├── backbone.py
│ └── cvt.py
├── Head
│ ├── FSRA.py
│ ├── GeM.py
│ ├── LPN.py
│ ├── NeXtVLAD.py
│ ├── NetVLAD.py
│ ├── SingleBranch.py
│ ├── __init__.py
│ ├── head.py
│ └── utils.py
├── __init__.py
├── model.py
└── taskflow.py
├── optimizers
└── make_optimizer.py
├── requirments.txt
├── test.py
├── test_hard.py
├── tool
├── SDM@K_analyze.py
├── SDM@K_compare.py
├── applications
│ ├── forwardAllSatelliteHub.py
│ ├── inference_global.py
│ └── inference_neibor.py
├── dataset_preprocess
│ ├── 1-generateSatelliteByUav.py
│ ├── 2-generate_new_croped_resized_images_difheights.py
│ ├── 3-generate_format_testset.py
│ ├── TEST-1-downloadALlAndReCut.py
│ ├── TEST-2-generateSatelliteHub.py
│ ├── TEST-3-preprocess_difHeightTest.py
│ ├── TEST-Regenerate_dense_testset.py
│ ├── get_property.py
│ ├── google_interface.py
│ ├── google_interface_2.py
│ ├── split_dataset_long_middle_short.py
│ ├── utils.py
│ └── validation_testset.py
├── get_inference_time.py
├── get_model_flops_params.py
├── get_property.py
├── mount_dist.sh
├── transforms
│ ├── rotatetranformtest.py
│ └── transform_visual.py
├── utils.py
├── visual
│ ├── Times New Roman.ttf
│ ├── demo.py
│ ├── demo_custom.py
│ ├── demo_custom_visualization.py
│ ├── draw_SDMcurve.py
│ ├── grad_cam.py
│ └── heatmap.py
└── visual_demo.py
├── train.py
└── train_test_local.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed depepwdndencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | *.pth
132 | maps/
133 | checkpoints/*
134 | visualization/
135 | .history/
136 | .vscode/
137 | experiment_*.sh
138 | experiments/
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // 使用 IntelliSense 了解相关属性。
3 | // 悬停以查看现有属性的描述。
4 | // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 | {
8 | "name": "Python: 当前文件",
9 | "type": "python",
10 | "request": "launch",
11 | "program": "${file}",
12 | "console": "integratedTerminal",
13 | "justMyCode": true,
14 | "cwd": "${workspaceFolder}/checkpoints/Head_Experiment-FSRA2B"
15 | }
16 | ]
17 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Vision-Based UAV Self-Positioning in Low-Altitude Urban Environments
2 |
3 | This repository contains code and dataset for the paper titled [Vision-Based UAV Self-Positioning in Low-Altitude Urban Environments](https://arxiv.org/abs/2201.09201). In this paper, we propose a method for accurately self-positioning unmanned aerial vehicles (UAVs) in challenging low-altitude urban environments using vision-based techniques. We provide the DenseUAV dataset and a Baseline model implementation to facilitate research in this task. Thank you for your kind attention.
4 |
5 | 
6 |
7 | 
8 |
9 | 
10 |
11 | ## News
12 |
13 | - **`2023/12/18`**: Our paper is accepted by IEEE Trans on Image Process.
14 | - **`2023/8/14`**: Our dataset and code are released.
15 |
16 | ## Table of contents
17 |
18 | - [News](#news)
19 | - [Table of contents](#table-of-contents)
20 | - [About Dataset](#about-dataset)
21 | - [Prerequisites](#prerequisites)
22 | - [Installation](#installation)
23 | - [Dataset \& Preparation](#dataset--preparation)
24 | - [Train \& Evaluation](#train--evaluation)
25 | - [Training and Testing](#training-and-testing)
26 | - [Evaluation](#evaluation)
27 | - [Supported Methods](#supported-methods)
28 | - [License](#license)
29 | - [Citation](#citation)
30 | - [Related Work](#related-work)
31 |
32 | ## About Dataset
33 |
34 | The dataset split is as follows:
35 | | Subset | UAV-view | Satellite-view | Classes | universities |
36 | | -------- | ----- | ---- | ---- | ---- |
37 | | Training | 6,768 | 13,536 | 2,256 | 10 |
38 | | Query | 2,331 | 4,662 | 777 | 4 |
39 | | Gallery | 9099 | 18198 | 3033 | 14 |
40 |
41 | More detailed file structure:
42 |
43 | ```
44 | ├── DenseUAV/
45 | │ ├── Dense_GPS_ALL.txt /* format as: path latitude longitude height
46 | │ ├── Dense_GPS_test.txt
47 | │ ├── Dense_GPS_train.txt
48 | │ ├── train/
49 | │ ├── drone/ /* drone-view training images
50 | │ ├── 000001
51 | │ ├── H100.JPG
52 | │ ├── H90.JPG
53 | │ ├── H80.JPG
54 | | ...
55 | │ ├── satellite/ /* satellite-view training images
56 | │ ├── 000001
57 | │ ├── H100_old.tif
58 | │ ├── H90_old.tif
59 | │ ├── H80_old.tif
60 | │ ├── H100.tif
61 | │ ├── H90.tif
62 | │ ├── H80.tif
63 | | ...
64 | │ ├── test/
65 | │ ├── query_drone/ /* UAV-view testing images
66 | │ ├── query_satellite/ /* satellite-view testing images
67 | ```
68 |
69 | ## Prerequisites
70 |
71 | - Python 3.7+
72 | - GPU Memory >= 8G
73 | - Numpy 1.21.2
74 | - Pytorch 1.10.0+cu113
75 | - Torchvision 0.11.1+cu113
76 |
77 | ## Installation
78 |
79 | It is best to use cuda version 11.3 and pytorch version 1.10.0. You can download the corresponding version from this [website](https://download.pytorch.org/whl/torch_stable.html) and install it through `pip install`. Then you can execute the following command to install all dependencies.
80 |
81 | ```
82 | pip install -r requirments.txt
83 | ```
84 |
85 | Create the directory for saving the training log and ckpts.
86 |
87 | ```
88 | mkdir checkpoints
89 | ```
90 |
91 | ## Dataset & Preparation
92 |
93 | Download DenseUAV upon request. You may use the request [Template](https://github.com/Dmmm1997/DenseUAV//blob/main/docs/Request.md).
94 |
95 | ## Train & Evaluation
96 |
97 | ### Training and Testing
98 |
99 | You could execute the following command to implement the entire process of training and testing.
100 |
101 | ```
102 | bash train_test_local.sh
103 | ```
104 |
105 | The setting of parameters in **train_test_local.sh** can refer to [Get Started](https://github.com/Dmmm1997/DenseUAV/blob/main/docs/training_parameters.md).
106 |
107 | ### Evaluation
108 |
109 | The following commands are required to evaluate Recall and SDM separately.
110 |
111 | ```
112 | cd checkpoints/
113 | python test.py --name --test_dir --gpu_ids 0 --num_worker 4
114 | ```
115 |
116 | the `` is the dir name in your training setting, you can find in the `checkpoints/`.
117 |
118 | **For Recall**
119 |
120 | ```
121 | python evaluate_gpu.py
122 | ```
123 |
124 | **For SDM**
125 |
126 | ```
127 | python evaluateDistance.py --root_dir
128 | ```
129 |
130 | We also provide the baseline checkpoints, [quark](https://pan.quark.cn/s/3ced42633793) [one-drive](https://seunic-my.sharepoint.cn/:u:/g/personal/230238525_seu_edu_cn/EUFoYjIdK_JNuxmvpb5QjLcB1hUHyedGwOnT3wTeN7Zqdg?e=LZuUxz).
131 |
132 | ```
133 | unzip -d checkpoints
134 | cd checkpoints/baseline
135 | python test.py --test_dir /test
136 | python evaluate_gpu.py
137 | python evaluateDistance.py --root_dir
138 | ```
139 |
140 | ## Supported Methods
141 |
142 | | Augment | Backbone | Head | Loss |
143 | | ----------------- | --------------- | ----------- | -------------------------- |
144 | | Random Rotate | ResNet | MaxPool | CrossEntropy Loss. |
145 | | Random Affine | EfficientNet | AvgPool | Focal Loss |
146 | | Random Brightness | ConvNext | MaxAvgPool | Triplet Loss |
147 | | Random Erasing | DeiT | GlobalPool | Hard-Mining Triplet Loss |
148 | | | PvT | GemPool | Same-Domain Triplet Loss |
149 | | | SwinTransformer | LPN | Soft-Weighted Triplet Loss |
150 | | | ViT | FSRA | KL Loss |
151 |
152 | ## License
153 |
154 | This project is licensed under the [Apache 2.0 license](https://github.com/Dmmm1997/DenseUAV//blob/main/LICENSE).
155 |
156 | ## Citation
157 |
158 | The following paper uses and reports the result of the baseline model. You may cite it in your paper.
159 |
160 | ```bibtex
161 | @ARTICLE{DenseUAV,
162 | author={Dai, Ming and Zheng, Enhui and Feng, Zhenhua and Qi, Lei and Zhuang, Jiedong and Yang, Wankou},
163 | journal={IEEE Transactions on Image Processing},
164 | title={Vision-Based UAV Self-Positioning in Low-Altitude Urban Environments},
165 | year={2024},
166 | volume={33},
167 | number={},
168 | pages={493-508},
169 | doi={10.1109/TIP.2023.3346279}}
170 | ```
171 |
172 | ## Related Work
173 |
174 | - University-1652 [https://github.com/layumi/University1652-Baseline](https://github.com/layumi/University1652-Baseline)
175 | - FSRA [https://github.com/Dmmm1997/FSRA](https://github.com/Dmmm1997/FSRA)
176 |
--------------------------------------------------------------------------------
/batch_inference_time.sh:
--------------------------------------------------------------------------------
1 | data_dir="/home/dmmm/Dataset/DenseUAV/data_2022/train" #"/media/dmmm/4T-3/DataSets/DenseCV_Data/高度数据集/data_2021/train"
2 | # data_dir="/media/dmmm/4T-3/DataSets/DenseCV_Data/高度数据集/data_2021/train"
3 | test_dir="/home/dmmm/Dataset/DenseUAV/data_2022/test" #"/media/dmmm/4T-3/DataSets/DenseCV_Data/高度数据集/data_2021/test"
4 | # test_dir="/media/dmmm/4T-3/DataSets/DenseCV_Data/高度数据集/data_2021/test"
5 | num_worker=8
6 | gpu_ids=0
7 |
8 | name="checkpoints/Backbone_Experiment_SENet"
9 | cd $name
10 | cd tool
11 | python get_inference_time.py --name $name
12 | cd ../../../
13 |
14 | # name="checkpoints/Backbone_Experiment_ConvnextT"
15 | # cd $name
16 | # cd tool
17 | # python get_inference_time.py --name $name
18 | # cd ../../../
19 |
20 | # name="checkpoints/Backbone_Experiment_DeitS"
21 | # cd $name
22 | # cd tool
23 | # python get_inference_time.py --name $name
24 | # cd ../../../
25 |
26 | # name="checkpoints/Backbone_Experiment_EfficientNet-B2"
27 | # cd $name
28 | # cd tool
29 | # python get_inference_time.py --name $name
30 | # cd ../../../
31 |
32 | # name="checkpoints/Backbone_Experiment_EfficientNet-B3"
33 | # cd $name
34 | # cd tool
35 | # python get_inference_time.py --name $name
36 | # cd ../../../
37 |
38 | # name="checkpoints/Backbone_Experiment_PvTv2b2"
39 | # cd $name
40 | # cd tool
41 | # python get_inference_time.py --name $name
42 | # cd ../../../
43 |
44 | # name="checkpoints/Backbone_Experiment_resnet50"
45 | # cd $name
46 | # cd tool
47 | # python get_inference_time.py --name $name
48 | # cd ../../../
49 |
50 | # name="checkpoints/Backbone_Experiment_Swinv2T-256"
51 | # cd $name
52 | # cd tool
53 | # python get_inference_time.py --name $name --test_h 256 --test_w 256
54 | # cd ../../../
55 |
56 | # name="checkpoints/Backbone_Experiment_VGG16"
57 | # cd $name
58 | # cd tool
59 | # python get_inference_time.py --name $name
60 | # cd ../../../
61 |
62 | # name="checkpoints/Backbone_Experiment_ViTB"
63 | # cd $name
64 | # cd tool
65 | # python get_inference_time.py --name $name
66 | # cd ../../../
67 |
68 | # name="checkpoints/Head_Experiment-FSRA2B"
69 | # cd $name
70 | # cd tool
71 | # python get_inference_time.py --name $name
72 | # cd ../../../
73 |
74 | # name="checkpoints/Head_Experiment-FSRA3B"
75 | # cd $name
76 | # cd tool
77 | # python get_inference_time.py --name $name
78 | # cd ../../../
79 |
80 | # name="checkpoints/Head_Experiment-GeM"
81 | # cd $name
82 | # cd tool
83 | # python get_inference_time.py --name $name
84 | # cd ../../../
85 |
86 | # name="checkpoints/Head_Experiment-LPN2B"
87 | # cd $name
88 | # cd tool
89 | # python get_inference_time.py --name $name
90 | # cd ../../../
91 |
92 | # name="checkpoints/Head_Experiment-LPN3B"
93 | # cd $name
94 | # cd tool
95 | # python get_inference_time.py --name $name
96 | # cd ../../../
97 |
--------------------------------------------------------------------------------
/datasets/Dataloader_University.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 | from torchvision import datasets, transforms
4 | import os
5 | import numpy as np
6 | from PIL import Image
7 | import glob
8 |
9 |
10 | class Dataloader_University(Dataset):
11 | def __init__(self, root, transforms, names=['satellite', 'drone']):
12 | super(Dataloader_University).__init__()
13 | self.transforms_drone_street = transforms['train']
14 | self.transforms_satellite = transforms['satellite']
15 | self.root = root
16 | self.names = names
17 | # 获取所有图片的相对路径分别放到对应的类别中
18 | # {satelite:{0839:[0839.jpg],0840:[0840.jpg]}}
19 | dict_path = {}
20 | for name in names:
21 | dict_ = {}
22 | for cls_name in os.listdir(os.path.join(root, name)):
23 | img_list = os.listdir(os.path.join(root, name, cls_name))
24 | img_path_list = [os.path.join(
25 | root, name, cls_name, img) for img in img_list]
26 | dict_[cls_name] = img_path_list
27 | dict_path[name] = dict_
28 | # dict_path[name+"/"+cls_name] = img_path_list
29 |
30 | # 获取设置名字与索引之间的镜像
31 | cls_names = os.listdir(os.path.join(root, names[0]))
32 | cls_names.sort()
33 | map_dict = {i: cls_names[i] for i in range(len(cls_names))}
34 |
35 | self.cls_names = cls_names
36 | self.map_dict = map_dict
37 | self.dict_path = dict_path
38 | self.index_cls_nums = 2
39 |
40 | # 从对应的类别中抽一张出来
41 | def sample_from_cls(self, name, cls_num):
42 | img_path = self.dict_path[name][cls_num]
43 | img_path = np.random.choice(img_path, 1)[0]
44 | img = Image.open(img_path).convert("RGB")
45 | return img
46 |
47 | def __getitem__(self, index):
48 | cls_nums = self.map_dict[index]
49 | img = self.sample_from_cls("satellite", cls_nums)
50 | img_s = self.transforms_satellite(img)
51 |
52 | # img = self.sample_from_cls("street",cls_nums)
53 | # img_st = self.transforms_drone_street(img)
54 |
55 | img = self.sample_from_cls("drone", cls_nums)
56 | img_d = self.transforms_drone_street(img)
57 | return img_s, img_d, index
58 |
59 | def __len__(self):
60 | return len(self.cls_names)
61 |
62 |
63 | class DataLoader_Inference(Dataset):
64 | def __init__(self, root, transforms):
65 | super(DataLoader_Inference, self).__init__()
66 | self.root = root
67 | self.imgs = glob.glob(root+"/*.tif")
68 | self.tranforms = transforms
69 | sorted(self.imgs)
70 | self.labels = [os.path.basename(img).split(".tif")[
71 | 0] for img in self.imgs]
72 |
73 | def __getitem__(self, index):
74 | img = Image.open(self.imgs[index])
75 | return self.tranforms(img), self.labels[index]
76 |
77 | def __len__(self):
78 | return len(self.imgs)
79 |
80 |
81 | class Sampler_University(object):
82 | r"""Base class for all Samplers.
83 | Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
84 | way to iterate over indices of dataset elements, and a :meth:`__len__` method
85 | that returns the length of the returned iterators.
86 | .. note:: The :meth:`__len__` method isn't strictly required by
87 | :class:`~torch.utils.data.DataLoader`, but is expected in any
88 | calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
89 | """
90 |
91 | def __init__(self, data_source, batchsize=8, sample_num=4):
92 | self.data_len = len(data_source)
93 | self.batchsize = batchsize
94 | self.sample_num = sample_num
95 |
96 | def __iter__(self):
97 | list = np.arange(0, self.data_len)
98 | np.random.shuffle(list)
99 | nums = np.repeat(list, self.sample_num, axis=0)
100 | return iter(nums)
101 |
102 | def __len__(self):
103 | return len(self.data_source)
104 |
105 |
106 | def train_collate_fn(batch):
107 | """
108 | # collate_fn这个函数的输入就是一个list,list的长度是一个batch size,list中的每个元素都是__getitem__得到的结果
109 | """
110 | img_s, img_d, ids = zip(*batch)
111 | ids = torch.tensor(ids, dtype=torch.int64)
112 | return [torch.stack(img_s, dim=0), ids], [torch.stack(img_d, dim=0), ids]
113 |
114 |
115 | if __name__ == '__main__':
116 | transform_train_list = [
117 | # transforms.RandomResizedCrop(size=(opt.h, opt.w), scale=(0.75,1.0), ratio=(0.75,1.3333), interpolation=3), #Image.BICUBIC)
118 | transforms.Resize((256, 256), interpolation=3),
119 | transforms.Pad(10, padding_mode='edge'),
120 | transforms.RandomCrop((256, 256)),
121 | transforms.RandomHorizontalFlip(),
122 | transforms.ToTensor(),
123 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
124 | ]
125 |
126 | transform_train_list = {"satellite": transforms.Compose(transform_train_list),
127 | "train": transforms.Compose(transform_train_list)}
128 | datasets = Dataloader_University(root="/home/dmmm/University-Release/train",
129 | transforms=transform_train_list, names=['satellite', 'drone'])
130 | samper = Sampler_University(datasets, 8)
131 | dataloader = DataLoader(datasets, batch_size=8, num_workers=0,
132 | sampler=samper, collate_fn=train_collate_fn)
133 | for data_s, data_d in dataloader:
134 | print()
135 |
--------------------------------------------------------------------------------
/datasets/make_dataloader.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from .Dataloader_University import Sampler_University, Dataloader_University, train_collate_fn
3 | from .autoaugment import ImageNetPolicy
4 | import torch
5 | from .queryDataset import RotateAndCrop, RandomCrop, RandomErasing
6 |
7 |
8 | def make_dataset(opt):
9 | transform_train_list = []
10 | transform_satellite_list = []
11 | if "uav" in opt.rr:
12 | transform_train_list.append(RotateAndCrop(0.5))
13 | if "satellite" in opt.rr:
14 | transform_satellite_list.append(RotateAndCrop(0.5))
15 | transform_train_list += [
16 | transforms.Resize((opt.h, opt.w), interpolation=3),
17 | transforms.Pad(opt.pad, padding_mode='edge'),
18 | transforms.RandomHorizontalFlip(),
19 | ]
20 |
21 | transform_satellite_list += [
22 | transforms.Resize((opt.h, opt.w), interpolation=3),
23 | transforms.Pad(opt.pad, padding_mode='edge'),
24 | transforms.RandomHorizontalFlip(),
25 | ]
26 |
27 | transform_val_list = [
28 | transforms.Resize(size=(opt.h, opt.w),
29 | interpolation=3), # Image.BICUBIC
30 | ]
31 |
32 | if "uav" in opt.ra:
33 | transform_train_list = transform_train_list + \
34 | [transforms.RandomAffine(180)]
35 | if "satellite" in opt.ra:
36 | transform_satellite_list = transform_satellite_list + \
37 | [transforms.RandomAffine(180)]
38 |
39 | if "uav" in opt.re:
40 | transform_train_list = transform_train_list + \
41 | [RandomErasing(probability=opt.erasing_p)]
42 | if "satellite" in opt.re:
43 | transform_satellite_list = transform_satellite_list + \
44 | [RandomErasing(probability=opt.erasing_p)]
45 |
46 | if "uav" in opt.cj:
47 | transform_train_list = transform_train_list + \
48 | [transforms.ColorJitter(brightness=0.5, contrast=0.1, saturation=0.1,
49 | hue=0)]
50 | if "satellite" in opt.cj:
51 | transform_satellite_list = transform_satellite_list + \
52 | [transforms.ColorJitter(brightness=0.5, contrast=0.1, saturation=0.1,
53 | hue=0)]
54 |
55 | if opt.DA:
56 | transform_train_list = [ImageNetPolicy()] + transform_train_list
57 |
58 | last_aug = [
59 | transforms.ToTensor(),
60 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
61 | ]
62 |
63 | transform_train_list += last_aug
64 | transform_satellite_list += last_aug
65 | transform_val_list += last_aug
66 |
67 | print(transform_train_list)
68 | print(transform_satellite_list)
69 |
70 | data_transforms = {
71 | 'train': transforms.Compose(transform_train_list),
72 | 'val': transforms.Compose(transform_val_list),
73 | 'satellite': transforms.Compose(transform_satellite_list)}
74 |
75 | # custom Dataset
76 | image_datasets = Dataloader_University(
77 | opt.data_dir, transforms=data_transforms)
78 | samper = Sampler_University(
79 | image_datasets, batchsize=opt.batchsize, sample_num=opt.sample_num)
80 | dataloaders = torch.utils.data.DataLoader(image_datasets, batch_size=opt.batchsize,
81 | sampler=samper, num_workers=opt.num_worker, pin_memory=True, collate_fn=train_collate_fn)
82 | dataset_sizes = {x: len(image_datasets) *
83 | opt.sample_num for x in ['satellite', 'drone']}
84 | class_names = image_datasets.cls_names
85 | return dataloaders, class_names, dataset_sizes
86 |
--------------------------------------------------------------------------------
/docs/Get_started.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dmmm1997/DenseUAV/d3e4335fb73e1eeeb8db6771d11f731ac8ef3c14/docs/Get_started.md
--------------------------------------------------------------------------------
/docs/Request.md:
--------------------------------------------------------------------------------
1 | [Title] Request of DenseUAV Dataset
2 |
3 | This database will only be used for research purposes. I will not make any part of this database available to a third party.
4 | I'll not sell any part of this database or make any profit from its use.
5 |
6 | Thank you!
7 |
8 |
9 | *** Please send it to 869906992@qq.com via your academic email. I will reply the dataset address to you. ***
--------------------------------------------------------------------------------
/docs/images/data.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dmmm1997/DenseUAV/d3e4335fb73e1eeeb8db6771d11f731ac8ef3c14/docs/images/data.jpg
--------------------------------------------------------------------------------
/docs/images/framework.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dmmm1997/DenseUAV/d3e4335fb73e1eeeb8db6771d11f731ac8ef3c14/docs/images/framework.jpg
--------------------------------------------------------------------------------
/docs/images/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dmmm1997/DenseUAV/d3e4335fb73e1eeeb8db6771d11f731ac8ef3c14/docs/images/model.png
--------------------------------------------------------------------------------
/docs/training_parameters.md:
--------------------------------------------------------------------------------
1 | # Parameter Introduction
2 |
3 | ### Data-Related Parameters
4 | - `--name`: Experiment name, used for saving models and parameters under `checkpoints/`, facilitating management and tracking of different experimental results.
5 | - `--data_dir`: Directory path for training data.
6 | - `--num_worker`: Number of worker threads used for data loading, affecting the parallelism and efficiency of data preprocessing.
7 | - `--pad`: Amount of padding for input data. Please distinguish this from the `--pad` in Position Shifting.
8 | - `--h, --w`: Height and width of the input images.
9 | - `--rr`: Random rotation applied to one or more views to enhance data diversity.
10 | - `--ra`: Random affine transformation applied to one or more views to enhance data diversity.
11 | - `--re`: Random occlusion applied to one or more views to enhance data diversity.
12 | - `--cj`: Color jitter applied to one or more views to enhance data diversity.
13 | - `--erasing_p`: Probability of random occlusion, controlling the proportion of randomly occluded areas in the images.
14 |
15 | ### Training-Related Parameters
16 | - `--warm_epoch`: Warm-up phase, setting the learning rate to gradually increase over the first `K` epochs.
17 | - `--lr`: Learning rate.
18 | - `--DA`: Whether to use color data augmentation.
19 | - `--droprate`: Dropout rate.
20 | - `--autocast`: Whether to use mixed precision training.
21 | - `--load_from`: Path to the pre-loaded checkpoint for restoring the model from a previous training state.
22 | - `--gpu_ids`: Specification of the GPU devices used, supporting multi-GPU configurations for flexible training environments.
23 | - `--batchsize`: Number of samples per training step.
24 |
25 | ### Model-Related Parameters
26 | - `--block`: Number of ClassBlocks in the model.
27 | - `--cls_loss`: Type of loss function for Representation Learning. Various preset or custom losses can be used, with `CELoss` as the default.
28 | - `--feature_loss`: Type of loss function for Metric Learning. Various preset or custom losses can be used, with no loss applied by default.
29 | - `--kl_loss`: Type of loss function for Mutual Learning. Various preset or custom losses can be used, with no loss applied by default.
30 | - `--num_bottleneck`: Dimensionality of feature embeddings.
31 | - `--backbone`: Backbone architecture used. Various preset or custom backbones can be selected, with `cvt13` as the default.
32 | - `--head`: Head architecture used. Various preset or custom heads can be selected, with `FSRA_CNN` as the default.
33 | - `--head_pool`: Type of pooling used in the head, with various preset or custom pooling methods available, defaulting to `max pooling`.
34 |
--------------------------------------------------------------------------------
/evaluateDistance.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import scipy.io
3 | import torch
4 | import numpy as np
5 | import os
6 | from torchvision import datasets
7 | import matplotlib
8 | # matplotlib.use('agg')
9 | import matplotlib.pyplot as plt
10 | import json
11 | from tqdm import tqdm
12 | import math
13 |
14 |
15 | #####################################################################
16 | # Show result
17 | def imshow(path, title=None):
18 | """Imshow for Tensor."""
19 | im = plt.imread(path)
20 | plt.imshow(im)
21 | if title is not None:
22 | plt.title(title)
23 | plt.pause(0.1) # pause a bit so that plots are updated
24 |
25 |
26 | ######################################################################
27 |
28 |
29 | def getLatitudeAndLongitude(imgPath):
30 | if isinstance(imgPath, list):
31 | posInfo = [configDict[p.split("/")[-2]] for p in imgPath]
32 | else:
33 | posInfo = configDict[imgPath.split("/")[-2]]
34 | return posInfo
35 |
36 |
37 | def euclideanDistance(query, gallery):
38 | query = np.array(query, dtype=np.float32)
39 | gallery = np.array(gallery, dtype=np.float32)
40 | A = gallery - query
41 | A_T = A.transpose()
42 | distance = np.matmul(A, A_T)
43 | mask = np.eye(distance.shape[0], dtype=np.bool8)
44 | distance = distance[mask]
45 | distance = np.sqrt(distance.reshape(-1))
46 | return distance
47 |
48 |
49 | def evaluateSingle(distance, K):
50 | # maxDistance = max(distance) + 1e-14
51 | # weight = np.ones(K) - np.log(range(1, K + 1, 1)) / np.log(opts.M * K)
52 | weight = np.ones(K) - np.array(range(0, K, 1))/K
53 | # m1 = distance / maxDistance
54 | m2 = 1 / np.exp(distance*5e3)
55 | m3 = m2 * weight
56 | result = np.sum(m3) / np.sum(weight)
57 | return result
58 |
59 |
60 | def latlog2meter(lata, loga, latb, logb):
61 | # log 纬度 lat 经度
62 | # EARTH_RADIUS = 6371.0
63 | EARTH_RADIUS =6378.137
64 | PI = math.pi
65 | # // 转弧度
66 | lat_a = lata * PI / 180
67 | lat_b = latb * PI / 180
68 | a = lat_a - lat_b
69 | b = loga * PI / 180 - logb * PI / 180
70 | dis = 2 * math.asin(
71 | math.sqrt(math.pow(math.sin(a / 2), 2) + math.cos(lat_a) * math.cos(lat_b) * math.pow(math.sin(b / 2), 2)))
72 |
73 | distance = EARTH_RADIUS * dis * 1000
74 | return distance
75 |
76 |
77 | def evaluate_SDM(indexOfTopK, queryIndex, K):
78 | query_path, _ = image_datasets[query_name].imgs[queryIndex]
79 | galleryTopKPath = [image_datasets[gallery_name].imgs[i][0]
80 | for i in indexOfTopK[:K]]
81 | # get position information including latitude and longitude
82 | queryPosInfo = getLatitudeAndLongitude(query_path)
83 | galleryTopKPosInfo = getLatitudeAndLongitude(galleryTopKPath)
84 | # compute Euclidean distance of query and gallery
85 | distance = euclideanDistance(queryPosInfo, galleryTopKPosInfo)
86 | # compute single query evaluate result
87 | P = evaluateSingle(distance, K)
88 | return P
89 |
90 |
91 | def evaluate_MA(indexOfTop1, queryIndex):
92 | query_path, _ = image_datasets[query_name].imgs[queryIndex]
93 | galleryTopKPath = image_datasets[gallery_name].imgs[indexOfTop1][0]
94 | # get position information including latitude and longitude
95 | queryPosInfo = getLatitudeAndLongitude(query_path)
96 | galleryTopKPosInfo = getLatitudeAndLongitude(galleryTopKPath)
97 | # get real distance
98 | distance_meter = latlog2meter(queryPosInfo[1],queryPosInfo[0],galleryTopKPosInfo[1],galleryTopKPosInfo[0])
99 | return distance_meter
100 |
101 |
102 | if '__main__' == __name__:
103 | #######################################################################
104 | # Evaluate
105 | parser = argparse.ArgumentParser(description='Demo')
106 | # parser.add_argument('--query_index', default=10, type=int, help='test_image_index')
107 | parser.add_argument(
108 | '--root_dir', default='/home/dmmm/Dataset/DenseUAV/data_2022/', type=str, help='./test_data')
109 | parser.add_argument('--K', default=[1, 3, 5, 10], type=str, help='./test_data')
110 | parser.add_argument('--M', default=5e3, type=str, help='./test_data')
111 | parser.add_argument('--mode', default="1", type=str,
112 | help='1:drone->satellite 2:satellite->drone')
113 | opts = parser.parse_args()
114 |
115 | opts.config = os.path.join(opts.root_dir, "Dense_GPS_ALL.txt")
116 | opts.test_dir = os.path.join(opts.root_dir, "test")
117 | configDict = {}
118 | with open(opts.config, "r") as F:
119 | context = F.readlines()
120 | for line in context:
121 | splitLineList = line.split(" ")
122 | configDict[splitLineList[0].split("/")[-2]] = [float(splitLineList[1].split("E")[-1]),
123 | float(splitLineList[2].split("N")[-1])]
124 |
125 | if opts.mode == "1":
126 | gallery_name = 'gallery_satellite'
127 | query_name = 'query_drone'
128 | else:
129 | gallery_name = 'gallery_drone'
130 | query_name = 'query_satellite'
131 |
132 | data_dir = opts.test_dir
133 | image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x)) for x in [
134 | gallery_name, query_name]}
135 |
136 |
137 |
138 | if opts.mode == "1":
139 | result = scipy.io.loadmat('pytorch_result_1.mat')
140 | else:
141 | result = scipy.io.loadmat('pytorch_result_2.mat')
142 | query_feature = torch.FloatTensor(result['query_f'])
143 | query_label = result['query_label'][0]
144 | gallery_feature = torch.FloatTensor(result['gallery_f'])
145 | gallery_label = result['gallery_label'][0]
146 |
147 | multi = os.path.isfile('multi_query.mat')
148 |
149 | if multi:
150 | m_result = scipy.io.loadmat('multi_query.mat')
151 | mquery_feature = torch.FloatTensor(m_result['mquery_f'])
152 | mquery_cam = m_result['mquery_cam'][0]
153 | mquery_label = m_result['mquery_label'][0]
154 | mquery_feature = mquery_feature.cuda()
155 |
156 | query_feature = query_feature.cuda()
157 | gallery_feature = gallery_feature.cuda()
158 |
159 |
160 | #######################################################################
161 | # sort the images and return topK index
162 | def sort_img(qf, ql, gf, gl, K):
163 | query = qf.view(-1, 1)
164 | # print(query.shape)
165 | score = torch.mm(gf, query)
166 | score = score.squeeze(1).cpu()
167 | score = score.numpy()
168 | # predict index
169 | index = np.argsort(score) # from small to large
170 | index = index[::-1]
171 | # index = index[0:2000]
172 | # good index
173 | query_index = np.argwhere(gl == ql)
174 |
175 | # good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
176 | junk_index = np.argwhere(gl == -1)
177 |
178 | mask = np.in1d(index, junk_index, invert=True)
179 | index = index[mask]
180 | return index[:K]
181 |
182 |
183 | indexOfTopK_list = []
184 | for i in range(len(query_label)):
185 | indexOfTopK = sort_img(
186 | query_feature[i], query_label[i], gallery_feature, gallery_label, 100)
187 | indexOfTopK_list.append(indexOfTopK)
188 |
189 | SDM_dict = {}
190 | for K in tqdm(range(1, 101, 1)):
191 | metric = 0
192 | for i in range(len(query_label)):
193 | P_ = evaluate_SDM(indexOfTopK_list[i], i, K)
194 | metric += P_
195 | metric = metric / len(query_label)
196 | if K in opts.K:
197 | print("metric{} = {:.2f}%".format(K, metric * 100))
198 | SDM_dict[K] = metric
199 |
200 | MA_dict = {}
201 | for meter in tqdm(range(1,101,1)):
202 | MA_K = 0
203 | for i in range(len(query_label)):
204 | MA_meter = evaluate_MA(indexOfTopK_list[i][0],i)
205 | if MA_meter None:
11 | super(Loss,self).__init__()
12 | self.opt = opt
13 | # 分类损失
14 | if opt.cls_loss == "CELoss":
15 | self.cls_loss = nn.CrossEntropyLoss()
16 | elif opt.cls_loss == "FocalLoss":
17 | self.cls_loss = FocalLoss(alpha=0.25, gamma=2, num_classes = opt.nclasses)
18 | else:
19 | self.cls_loss = None
20 |
21 | # 对比损失
22 | if opt.feature_loss == "TripletLoss":
23 | self.feature_loss = TripletLoss(margin=0.3, normalize_feature=True)
24 | elif opt.feature_loss == "HardMiningTripletLoss":
25 | self.feature_loss = HardMiningTripletLoss(margin=0.3, normalize_feature=True)
26 | elif opt.feature_loss == "SameDomainTripletLoss":
27 | self.feature_loss = SameDomainTripletLoss(margin=0.3)
28 | elif opt.feature_loss == "WeightedSoftTripletLoss":
29 | self.feature_loss = WeightedSoftTripletLoss()
30 | elif opt.feature_loss == "ContrastiveLoss":
31 | self.feature_loss = losses.ContrastiveLoss(pos_margin=0, neg_margin=1)
32 | else:
33 | self.feature_loss = None
34 |
35 | # KL 损失
36 | if opt.kl_loss == "KLLoss":
37 | self.kl_loss = nn.KLDivLoss(reduction='batchmean')
38 | else:
39 | self.kl_loss = None
40 |
41 |
42 | def forward(self, outputs, outputs2, labels, labels2):
43 | cls1,feature1 = outputs
44 | cls2,feature2 = outputs2
45 | loss = 0
46 |
47 | # 分类损失
48 | res_cls_loss = torch.tensor((0))
49 | if self.cls_loss is not None:
50 | res_cls_loss = self.calc_cls_loss(cls1, labels, self.cls_loss) + \
51 | self.calc_cls_loss(cls2, labels2, self.cls_loss)
52 | loss += res_cls_loss
53 |
54 | # 特征对比损失
55 | res_triplet_loss = torch.tensor((0))
56 | if self.feature_loss is not None:
57 | split_num = self.opt.batchsize//self.opt.sample_num
58 | res_triplet_loss = self.calc_triplet_loss(
59 | feature1, feature2, labels, self.feature_loss, split_num)
60 | loss += res_triplet_loss
61 |
62 | # 相互学习
63 | res_kl_loss = torch.tensor((0))
64 | if self.kl_loss is not None:
65 | res_kl_loss = self.calc_kl_loss(cls1, cls2, self.kl_loss)
66 | loss += res_kl_loss
67 |
68 | # if self.opt.epoch < self.opt.warm_epoch:
69 | # warm_up = 0.1 # We start from the 0.1*lrRate
70 | # warm_iteration = round(dataset_sizes['satellite'] / opt.batchsize) * opt.warm_epoch # first 5 epoch
71 | # warm_up = min(1.0, warm_up + 0.9 / warm_iteration)
72 | # loss *= warm_up
73 |
74 | return loss, res_cls_loss, res_triplet_loss, res_kl_loss
75 |
76 |
77 | def calc_cls_loss(self, outputs, labels, loss_func):
78 | loss = 0
79 | if isinstance(outputs, list):
80 | for i in outputs:
81 | loss += loss_func(i, labels)
82 | loss = loss/len(outputs)
83 | else:
84 | loss = loss_func(outputs, labels)
85 | return loss
86 |
87 |
88 | def calc_kl_loss(self, outputs, outputs2, loss_func):
89 | loss = 0
90 | if isinstance(outputs, list):
91 | for i in range(len(outputs)):
92 | loss += loss_func(F.log_softmax(outputs[i], dim=1),
93 | F.softmax(Variable(outputs2[i]), dim=1))
94 | loss = loss/len(outputs)
95 | else:
96 | loss = loss_func(F.log_softmax(outputs, dim=1),
97 | F.softmax(Variable(outputs2), dim=1))
98 | return loss
99 |
100 |
101 | def calc_triplet_loss(self, outputs, outputs2, labels, loss_func, split_num=8):
102 | if isinstance(outputs, list):
103 | loss = 0
104 | for i in range(len(outputs)):
105 | out_concat = torch.cat((outputs[i], outputs2[i]), dim=0)
106 | labels_concat = torch.cat((labels, labels), dim=0)
107 | loss += loss_func(out_concat, labels_concat)
108 | loss = loss/len(outputs)
109 | else:
110 | out_concat = torch.cat((outputs, outputs2), dim=0)
111 | labels_concat = torch.cat((labels, labels), dim=0)
112 | loss = loss_func(out_concat, labels_concat)
113 | return loss
114 |
115 |
116 |
--------------------------------------------------------------------------------
/models/Backbone/RKNet.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | import torch.nn.functional as F
4 | from torchvision import models
5 |
6 | class USAM(nn.Module):
7 | def __init__(self, kernel_size=3, padding=1, polish=False):
8 | super(USAM, self).__init__()
9 |
10 | kernel = torch.ones((kernel_size, kernel_size))
11 | kernel = kernel.unsqueeze(0).unsqueeze(0)
12 | self.weight = nn.Parameter(data=kernel, requires_grad=False)
13 |
14 |
15 | kernel2 = torch.ones((1, 1)) * (kernel_size * kernel_size)
16 | kernel2 = kernel2.unsqueeze(0).unsqueeze(0)
17 | self.weight2 = nn.Parameter(data=kernel2, requires_grad=False)
18 |
19 | self.polish = polish
20 | self.pad = padding
21 | self.relu = nn.ReLU()
22 | self.bn = nn.BatchNorm2d(1)
23 |
24 | def __call__(self, x):
25 | fmap = x.sum(1, keepdim=True)
26 | x1 = F.conv2d(fmap, self.weight, padding=self.pad)
27 | x2 = F.conv2d(fmap, self.weight2, padding=0)
28 |
29 | att = x2 - x1
30 | att = self.bn(att)
31 | att = self.relu(att)
32 |
33 | if self.polish:
34 | att[:, :, :, 0] = 0
35 | att[:, :, :, -1] = 0
36 | att[:, :, 0, :] = 0
37 | att[:, :, -1, :] = 0
38 |
39 | output = x + att * x
40 |
41 | return output
42 |
43 |
44 |
45 | class RKNet(nn.Module):
46 | def __init__(self, stride=2, init_model=None, pool='avg'):
47 | super(RKNet, self).__init__()
48 | model_ft = models.resnet50(pretrained=True)
49 | # avg pooling to global pooling
50 | if stride == 1:
51 | model_ft.layer4[0].downsample[0].stride = (1,1)
52 | model_ft.layer4[0].conv2.stride = (1,1)
53 |
54 | self.pool = pool
55 | if pool =='avg+max':
56 | model_ft.avgpool2 = nn.AdaptiveAvgPool2d((1,1))
57 | model_ft.maxpool2 = nn.AdaptiveMaxPool2d((1,1))
58 | #self.classifier = ClassBlock(4096, class_num, droprate)
59 | elif pool=='avg':
60 | model_ft.avgpool2 = nn.AdaptiveAvgPool2d((1,1))
61 | #self.classifier = ClassBlock(2048, class_num, droprate)
62 | elif pool=='max':
63 | model_ft.maxpool2 = nn.AdaptiveMaxPool2d((1,1))
64 | elif pool=='gem':
65 | model_ft.gem2 = GeM(dim=2048)
66 |
67 | self.model = model_ft
68 |
69 | if init_model!=None:
70 | self.model = init_model.model
71 | self.pool = init_model.pool
72 | #self.classifier.add_block = init_model.classifier.add_block
73 |
74 | self.usam_1 = USAM()
75 | self.usam_2 = USAM()
76 |
77 | def forward_features(self, x):
78 | x = self.model.conv1(x)
79 | x = self.model.bn1(x)
80 | x = self.model.relu(x)
81 | x = self.usam_1(x)
82 | x = self.model.maxpool(x)
83 | x = self.model.layer1(x)
84 | x = self.usam_2(x)
85 | x = self.model.layer2(x)
86 | x = self.model.layer3(x)
87 | x = self.model.layer4(x)
88 |
89 | return x
--------------------------------------------------------------------------------
/models/Backbone/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dmmm1997/DenseUAV/d3e4335fb73e1eeeb8db6771d11f731ac8ef3c14/models/Backbone/__init__.py
--------------------------------------------------------------------------------
/models/Backbone/backbone.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import timm
3 | from .RKNet import RKNet
4 | from .cvt import get_cvt_models
5 | import torch
6 |
7 | def make_backbone(opt):
8 | backbone_model = Backbone(opt)
9 | return backbone_model
10 |
11 |
12 | class Backbone(nn.Module):
13 | def __init__(self, opt):
14 | super().__init__()
15 | self.opt = opt
16 | self.img_size = (opt.h,opt.w)
17 | self.backbone,self.output_channel = self.init_backbone(opt.backbone)
18 |
19 |
20 | def init_backbone(self, backbone):
21 | if backbone=="resnet50":
22 | backbone_model = timm.create_model('resnet50', pretrained=True)
23 | output_channel = 2048
24 | elif backbone=="RKNet":
25 | backbone_model = RKNet()
26 | output_channel = 2048
27 | elif backbone=="senet":
28 | backbone_model = timm.create_model('legacy_seresnet50', pretrained=True)
29 | output_channel = 2048
30 | elif backbone=="ViTS-224":
31 | backbone_model = timm.create_model("vit_small_patch16_224", pretrained=True, img_size=self.img_size)
32 | output_channel = 384
33 | elif backbone=="ViTS-384":
34 | backbone_model = timm.create_model("vit_small_patch16_384", pretrained=True)
35 | output_channel = 384
36 | elif backbone=="DeitS-224":
37 | backbone_model = timm.create_model("deit_small_distilled_patch16_224", pretrained=True)
38 | output_channel = 384
39 | elif backbone=="DeitB-224":
40 | backbone_model = timm.create_model("deit_base_distilled_patch16_224", pretrained=True)
41 | output_channel = 384
42 | elif backbone=="Pvtv2b2":
43 | backbone_model = timm.create_model("pvt_v2_b2", pretrained=True)
44 | output_channel = 512
45 | elif backbone=="ViTB-224":
46 | backbone_model = timm.create_model("vit_base_patch16_224", pretrained=True)
47 | output_channel = 768
48 | elif backbone=="SwinB-224":
49 | backbone_model = timm.create_model("swin_base_patch4_window7_224", pretrained=True)
50 | output_channel = 768
51 | elif backbone=="Swinv2S-256":
52 | backbone_model = timm.create_model("swinv2_small_window8_256", pretrained=True)
53 | output_channel = 768
54 | elif backbone=="Swinv2T-256":
55 | backbone_model = timm.create_model("swinv2_tiny_window16_256", pretrained=True)
56 | output_channel = 768
57 | elif backbone=="Convnext-T":
58 | backbone_model = timm.create_model("convnext_tiny", pretrained=True)
59 | output_channel = 768
60 | elif backbone=="EfficientNet-B2":
61 | backbone_model = timm.create_model("efficientnet_b2", pretrained=True)
62 | output_channel = 1408
63 | elif backbone=="EfficientNet-B3":
64 | backbone_model = timm.create_model("efficientnet_b3", pretrained=True)
65 | output_channel = 1536
66 | elif backbone=="EfficientNet-B5":
67 | backbone_model = timm.create_model("tf_efficientnet_b5", pretrained=True)
68 | output_channel = 2048
69 | elif backbone=="EfficientNet-B6":
70 | backbone_model = timm.create_model("tf_efficientnet_b6", pretrained=True)
71 | output_channel = 2304
72 | elif backbone=="vgg16":
73 | backbone_model = timm.create_model("vgg16", pretrained=True)
74 | output_channel = 512
75 | elif backbone=="cvt13":
76 | backbone_model, channels = get_cvt_models(model_size="cvt13")
77 | output_channel = channels[-1]
78 | checkpoint_weight = "/home/dmmm/VscodeProject/FPI/pretrain_model/CvT-13-384x384-IN-22k.pth"
79 | backbone_model = self.load_checkpoints(checkpoint_weight, backbone_model)
80 | else:
81 | raise NameError("{} not in the backbone list!!!".format(backbone))
82 | return backbone_model,output_channel
83 |
84 | def load_checkpoints(self, checkpoint_path, model):
85 | ckpt = torch.load(checkpoint_path, map_location='cpu')
86 | filter_ckpt = {k: v for k, v in ckpt.items() if "pos_embed" not in k}
87 | missing_keys, unexpected_keys = model.load_state_dict(filter_ckpt, strict=False)
88 | print("Load pretrained backbone checkpoint from:", checkpoint_path)
89 | print("missing keys:", missing_keys)
90 | print("unexpected keys:", unexpected_keys)
91 | return model
92 |
93 | def forward(self, image):
94 | features = self.backbone.forward_features(image)
95 | return features
--------------------------------------------------------------------------------
/models/Head/FSRA.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .utils import ClassBlock
4 |
5 |
6 | class FSRA(nn.Module):
7 | def __init__(self, opt) -> None:
8 | super().__init__()
9 |
10 | self.opt = opt
11 | num_classes = opt.nclasses
12 | droprate = opt.droprate
13 | in_planes = opt.in_planes
14 | self.class_name = "classifier_heat"
15 | self.block = opt.block
16 | # global classifier
17 | self.classifier1 = ClassBlock(in_planes, num_classes, droprate)
18 | # local classifier
19 | for i in range(self.block):
20 | name = self.class_name + str(i+1)
21 | setattr(self, name, ClassBlock(in_planes, num_classes, droprate))
22 |
23 | def forward(self, features):
24 | global_cls, global_feature = self.classifier1(features[:, 0])
25 | # tranformer_feature = torch.mean(features,dim=1)
26 | # tranformer_feature = self.classifier1(tranformer_feature)
27 | if self.block == 1:
28 | return global_cls, global_feature
29 |
30 | part_features = features[:, 1:]
31 |
32 | heat_result = self.get_heartmap_pool(part_features)
33 | cls_list, features_list = self.part_classifier(
34 | self.block, heat_result, cls_name=self.class_name)
35 |
36 | total_cls = [global_cls] + cls_list
37 | total_features = [global_feature] + features_list
38 | if not self.training:
39 | total_features = torch.stack(total_features,dim=-1)
40 | return [total_cls, total_features]
41 |
42 | def get_heartmap_pool(self, part_features, add_global=False, otherbranch=False):
43 | heatmap = torch.mean(part_features, dim=-1)
44 | size = part_features.size(1)
45 | arg = torch.argsort(heatmap, dim=1, descending=True)
46 | x_sort = [part_features[i, arg[i], :]
47 | for i in range(part_features.size(0))]
48 | x_sort = torch.stack(x_sort, dim=0)
49 |
50 | split_each = size / self.block
51 | split_list = [int(split_each) for i in range(self.block - 1)]
52 | split_list.append(size - sum(split_list))
53 | split_x = x_sort.split(split_list, dim=1)
54 |
55 | split_list = [torch.mean(split, dim=1) for split in split_x]
56 | part_featuers_ = torch.stack(split_list, dim=2)
57 | if add_global:
58 | global_feat = torch.mean(part_features, dim=1).view(
59 | part_features.size(0), -1, 1).expand(-1, -1, self.block)
60 | part_featuers_ = part_featuers_ + global_feat
61 | if otherbranch:
62 | otherbranch_ = torch.mean(
63 | torch.stack(split_list[1:], dim=2), dim=-1)
64 | return part_featuers_, otherbranch_
65 | return part_featuers_
66 |
67 | def part_classifier(self, block, x, cls_name='classifier_lpn'):
68 | part = {}
69 | cls_list, features_list = [], []
70 | for i in range(block):
71 | part[i] = x[:, :, i].view(x.size(0), -1)
72 | # part[i] = torch.squeeze(x[:,:,i])
73 | name = cls_name + str(i+1)
74 | c = getattr(self, name)
75 | res = c(part[i])
76 | cls_list.append(res[0])
77 | features_list.append(res[1])
78 | return cls_list, features_list
79 |
80 |
81 | class FSRA_CNN(nn.Module):
82 | def __init__(self, opt) -> None:
83 | super().__init__()
84 |
85 | self.opt = opt
86 | num_classes = opt.nclasses
87 | droprate = opt.droprate
88 | in_planes = opt.in_planes
89 | self.class_name = "classifier_heat"
90 | self.block = opt.block
91 | # global classifier
92 | self.classifier1 = ClassBlock(in_planes, num_classes, droprate)
93 | # local classifier
94 | for i in range(self.block):
95 | name = self.class_name + str(i+1)
96 | setattr(self, name, ClassBlock(in_planes, num_classes, droprate))
97 |
98 | def forward(self, features):
99 | # global_cls, global_feature = self.classifier1(features[:, 0])
100 | features = features.reshape(features.shape[0], features.shape[1], -1).transpose(1,2)
101 | global_feature = torch.mean(features,dim=1)
102 | global_cls, global_feature = self.classifier1(global_feature)
103 | if self.block == 1:
104 | return global_cls, global_feature
105 |
106 | part_features = features
107 | # print(part_features.shape)
108 |
109 |
110 | heat_result = self.get_heartmap_pool(part_features)
111 | cls_list, features_list = self.part_classifier(
112 | self.block, heat_result, cls_name=self.class_name)
113 |
114 | total_cls = [global_cls] + cls_list
115 | total_features = [global_feature] + features_list
116 | if not self.training:
117 | total_features = torch.stack(total_features,dim=-1)
118 | return [total_cls, total_features]
119 |
120 | def get_heartmap_pool(self, part_features, add_global=False, otherbranch=False):
121 | heatmap = torch.mean(part_features, dim=-1)
122 | size = part_features.size(1)
123 | arg = torch.argsort(heatmap, dim=1, descending=True)
124 | x_sort = [part_features[i, arg[i], :]
125 | for i in range(part_features.size(0))]
126 | x_sort = torch.stack(x_sort, dim=0)
127 |
128 | split_each = size / self.block
129 | split_list = [int(split_each) for i in range(self.block - 1)]
130 | split_list.append(size - sum(split_list))
131 | split_x = x_sort.split(split_list, dim=1)
132 |
133 | split_list = [torch.mean(split, dim=1) for split in split_x]
134 | part_featuers_ = torch.stack(split_list, dim=2)
135 | if add_global:
136 | global_feat = torch.mean(part_features, dim=1).view(
137 | part_features.size(0), -1, 1).expand(-1, -1, self.block)
138 | part_featuers_ = part_featuers_ + global_feat
139 | if otherbranch:
140 | otherbranch_ = torch.mean(
141 | torch.stack(split_list[1:], dim=2), dim=-1)
142 | return part_featuers_, otherbranch_
143 | return part_featuers_
144 |
145 | def part_classifier(self, block, x, cls_name='classifier_lpn'):
146 | part = {}
147 | cls_list, features_list = [], []
148 | for i in range(block):
149 | part[i] = x[:, :, i].view(x.size(0), -1)
150 | # part[i] = torch.squeeze(x[:,:,i])
151 | name = cls_name + str(i+1)
152 | c = getattr(self, name)
153 | res = c(part[i])
154 | cls_list.append(res[0])
155 | features_list.append(res[1])
156 | return cls_list, features_list
--------------------------------------------------------------------------------
/models/Head/GeM.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .utils import ClassBlock, Pooling, vector2image
3 |
4 |
5 |
6 | class GeM(nn.Module):
7 | def __init__(self, opt) -> None:
8 | super().__init__()
9 | self.opt = opt
10 | self.classifier = ClassBlock(
11 | opt.in_planes, opt.nclasses, opt.droprate, num_bottleneck=opt.num_bottleneck)
12 | self.pool = Pooling(opt.h//16*opt.w//16, "gem")
13 |
14 | def forward(self, features):# (N,(H*W+1),C)
15 | local_feature = features[:, 1:]
16 | local_feature = local_feature.transpose(1,2).contiguous()
17 | # local_feature = vector2image(local_feature,dim = 2)
18 | global_feature = self.pool(local_feature)
19 | cls, feature = self.classifier(global_feature)
20 | return [cls, feature]
--------------------------------------------------------------------------------
/models/Head/NeXtVLAD.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .utils import ClassBlock, Pooling, vector2image
5 |
6 |
7 | class NeXtVLAD(nn.Module):
8 | def __init__(self, opt) -> None:
9 | super(NeXtVLAD, self).__init__()
10 | self.opt = opt
11 | self.classifier = ClassBlock(
12 | int(opt.in_planes*opt.block), opt.nclasses, opt.droprate, num_bottleneck=opt.num_bottleneck)
13 | self.netvlad = NeXtVLAD_block(
14 | num_clusters=opt.block, dim=opt.in_planes)
15 |
16 | def forward(self, features):
17 | local_feature = features[:, 1:]
18 | local_feature = local_feature.transpose(1, 2)
19 |
20 | local_feature = vector2image(local_feature, dim=2)
21 | local_features = self.netvlad(local_feature)
22 |
23 | cls, feature = self.classifier(local_features)
24 | return [cls, feature]
25 |
26 |
27 |
28 | class NeXtVLAD_block(nn.Module):
29 | """NeXtVLAD layer implementation"""
30 |
31 | def __init__(self, num_clusters=64, dim=1024, lamb=2, groups=8, max_frames=300):
32 | super(NeXtVLAD_block, self).__init__()
33 | self.num_clusters = num_clusters
34 | self.dim = dim
35 | self.alpha = 0
36 | self.K = num_clusters
37 | self.G = groups
38 | self.group_size = int((lamb * dim) // self.G)
39 | # expansion FC
40 | self.fc0 = nn.Linear(dim, lamb * dim)
41 | # soft assignment FC (the cluster weights)
42 | self.fc_gk = nn.Linear(lamb * dim, self.G * self.K)
43 | # attention over groups FC
44 | self.fc_g = nn.Linear(lamb * dim, self.G)
45 | self.cluster_weights2 = nn.Parameter(torch.rand(1, self.group_size, self.K))
46 |
47 | self.bn0 = nn.BatchNorm1d(max_frames)
48 | self.bn1 = nn.BatchNorm1d(1)
49 |
50 | def forward(self, x, mask=None):
51 | # print(f"x: {x.shape}")
52 |
53 | _, M, N = x.shape
54 | # expansion FC: B x M x N -> B x M x λN
55 | x_dot = self.fc0(x)
56 |
57 | # reshape into groups: B x M x λN -> B x M x G x (λN/G)
58 | x_tilde = x_dot.reshape(-1, M, self.G, self.group_size)
59 |
60 | # residuals across groups and clusters: B x M x λN -> B x M x (G*K)
61 | WgkX = self.fc_gk(x_dot)
62 | WgkX = self.bn0(WgkX)
63 |
64 | # residuals reshape across clusters: B x M x (G*K) -> B x (M*G) x K
65 | WgkX = WgkX.reshape(-1, M * self.G, self.K)
66 |
67 | # softmax over assignment: B x (M*G) x K -> B x (M*G) x K
68 | alpha_gk = F.softmax(WgkX, dim=-1)
69 |
70 | # attention across groups: B x M x λN -> B x M x G
71 | alpha_g = torch.sigmoid(self.fc_g(x_dot))
72 | if mask is not None:
73 | alpha_g = torch.mul(alpha_g, mask.unsqueeze(2))
74 |
75 | # reshape across time: B x M x G -> B x (M*G) x 1
76 | alpha_g = alpha_g.reshape(-1, M * self.G, 1)
77 |
78 | # apply attention: B x (M*G) x K (X) B x (M*G) x 1 -> B x (M*G) x K
79 | activation = torch.mul(alpha_gk, alpha_g)
80 |
81 | # sum over time and group: B x (M*G) x K -> B x 1 x K
82 | a_sum = torch.sum(activation, -2, keepdim=True)
83 |
84 | # calculate group centers: B x 1 x K (X) 1 x (λN/G) x K -> B x (λN/G) x K
85 | a = torch.mul(a_sum, self.cluster_weights2)
86 |
87 | # permute: B x (M*G) x K -> B x K x (M*G)
88 | activation = activation.permute(0, 2, 1)
89 |
90 | # reshape: B x M x G x (λN/G) -> B x (M*G) x (λN/G)
91 | reshaped_x_tilde = x_tilde.reshape(-1, M * self.G, self.group_size)
92 |
93 | # cluster activation: B x K x (M*G) (X) B x (M*G) x (λN/G) -> B x K x (λN/G)
94 | vlad = torch.matmul(activation, reshaped_x_tilde)
95 | # print(f"vlad: {vlad.shape}")
96 |
97 | # permute: B x K x (λN/G) (X) B x (λN/G) x K
98 | vlad = vlad.permute(0, 2, 1)
99 | # distance to centers: B x (λN/G) x K (-) B x (λN/G) x K
100 | vlad = torch.sub(vlad, a)
101 | # normalize: B x (λN/G) x K
102 | vlad = F.normalize(vlad, 1)
103 | # reshape: B x (λN/G) x K -> B x 1 x (K * (λN/G))
104 | vlad = vlad.reshape(-1, 1, self.K * self.group_size)
105 | vlad = self.bn1(vlad)
106 | # reshape: B x 1 x (K * (λN/G)) -> B x (K * (λN/G))
107 | vlad = vlad.reshape(-1, self.K * self.group_size)
108 |
109 | return vlad
--------------------------------------------------------------------------------
/models/Head/NetVLAD.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .utils import ClassBlock, Pooling, vector2image
5 |
6 |
7 | class NetVLAD(nn.Module):
8 | def __init__(self, opt) -> None:
9 | super(NetVLAD, self).__init__()
10 | self.opt = opt
11 | self.classifier = ClassBlock(
12 | int(opt.in_planes*opt.block), opt.nclasses, opt.droprate, num_bottleneck=opt.num_bottleneck)
13 | self.netvlad = NetVLAD_block(
14 | num_clusters=opt.block, dim=opt.in_planes, alpha=100.0, normalize_input=True)
15 |
16 | def forward(self, features):
17 | local_feature = features[:, 1:]
18 | local_feature = local_feature.transpose(1, 2)
19 |
20 | local_feature = vector2image(local_feature, dim=2)
21 | local_features = self.netvlad(local_feature)
22 |
23 | cls, feature = self.classifier(local_features)
24 | return [cls, feature]
25 |
26 |
27 | class NetVLAD_block(nn.Module):
28 | """NetVLAD layer implementation"""
29 |
30 | def __init__(self, num_clusters=64, dim=128, alpha=100.0,
31 | normalize_input=True):
32 | """
33 | Args:
34 | num_clusters : int
35 | The number of clusters
36 | dim : int
37 | Dimension of descriptors
38 | alpha : float
39 | Parameter of initialization. Larger value is harder assignment.
40 | normalize_input : bool
41 | If true, descriptor-wise L2 normalization is applied to input.
42 | """
43 | super(NetVLAD_block, self).__init__()
44 | self.num_clusters = num_clusters
45 | self.dim = dim
46 | self.alpha = alpha
47 | self.normalize_input = normalize_input
48 | self.conv = nn.Conv2d(dim, num_clusters, kernel_size=(1, 1), bias=True)
49 | self.centroids = nn.Parameter(
50 | torch.rand(num_clusters, dim)) # 聚类中心,参见注释1
51 | self._init_params()
52 |
53 | def _init_params(self):
54 | self.conv.weight = nn.Parameter(
55 | (2.0 * self.alpha * self.centroids).unsqueeze(-1).unsqueeze(-1)
56 | )
57 | self.conv.bias = nn.Parameter(
58 | - self.alpha * self.centroids.norm(dim=1)
59 | )
60 |
61 | def forward(self, x): # x: (N, C, H, W), H * W对应论文中的N表示局部特征的数目,C对应论文中的D表示特征维度
62 | N, C = x.shape[:2]
63 |
64 | if self.normalize_input:
65 | x = F.normalize(x, p=2, dim=1) # across descriptor dim,使用L2归一化特征维度
66 |
67 | # soft-assignment
68 | # (N, C, H, W)->(N, num_clusters, H, W)->(N, num_clusters, H * W)
69 | soft_assign = self.conv(x).view(N, self.num_clusters, -1)
70 | # (N, num_clusters, H * W) # 参见注释3
71 | soft_assign = F.softmax(soft_assign, dim=1)
72 |
73 | x_flatten = x.view(N, C, -1) # (N, C, H, W) -> (N, C, H * W)
74 |
75 | # calculate residuals to each clusters
76 | # 减号前面前记为a,后面记为b, residual = a - b
77 | # a: (N, C, H * W) -> (num_clusters, N, C, H * W) -> (N, num_clusters, C, H * W)
78 | # b: (num_clusters, C) -> (H * W, num_clusters, C) -> (num_clusters, C, H * W)
79 | # residual: (N, num_clusters, C, H * W) 参见注释2
80 | residual = x_flatten.expand(self.num_clusters, -1, -1, -1).permute(1, 0, 2, 3) - \
81 | self.centroids.expand(x_flatten.size(-1), -
82 | 1, -1).permute(1, 2, 0).unsqueeze(0)
83 |
84 | # soft_assign: (N, num_clusters, H * W) -> (N, num_clusters, 1, H * W)
85 | # (N, num_clusters, C, H * W) * (N, num_clusters, 1, H * W)
86 | residual *= soft_assign.unsqueeze(2)
87 | # (N, num_clusters, C, H * W) -> (N, num_clusters, C)
88 | vlad = residual.sum(dim=-1)
89 |
90 | vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization
91 | # flatten;vald: (N, num_clusters, C) -> (N, num_clusters * C)
92 | vlad = vlad.view(x.size(0), -1)
93 | vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize
94 |
95 | return vlad
96 |
--------------------------------------------------------------------------------
/models/Head/SingleBranch.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .utils import ClassBlock, Pooling
3 | import torch.nn.functional as F
4 | import torch
5 |
6 | class SingleBranch(nn.Module):
7 | def __init__(self, opt) -> None:
8 | super().__init__()
9 | self.opt = opt
10 | self.head_pool = opt.head_pool
11 | self.classifier = ClassBlock(
12 | opt.in_planes, opt.nclasses, opt.droprate, num_bottleneck=opt.num_bottleneck)
13 |
14 | def forward(self, features):
15 | global_feature = features[:, 0]
16 | local_feature = features[:, 1:]
17 | if self.head_pool == "global":
18 | feature = global_feature
19 | elif self.head_pool == "avg":
20 | local_feature = local_feature.transpose(1, 2)
21 | feature = torch.mean(local_feature, 2).squeeze()
22 | elif self.head_pool == "max":
23 | local_feature = local_feature.transpose(1, 2)
24 | feature = torch.max(local_feature, 2)[0].squeeze()
25 | elif self.head_pool == "avg+max":
26 | local_feature = local_feature.transpose(1, 2)
27 | avg_feature = torch.mean(local_feature, 2).squeeze()
28 | max_feature = torch.max(local_feature, 2)[0].squeeze()
29 | feature = avg_feature+max_feature
30 | else:
31 | raise TypeError("head_pool 不在支持的列表中!!!")
32 |
33 | cls, feature = self.classifier(feature)
34 | return [cls, feature]
35 |
36 |
37 | class SingleBranchCNN(nn.Module):
38 | def __init__(self, opt) -> None:
39 | super().__init__()
40 | self.opt = opt
41 | self.pool = nn.AdaptiveAvgPool2d(1)
42 | self.classifier = ClassBlock(
43 | opt.in_planes, opt.nclasses, opt.droprate, num_bottleneck=opt.num_bottleneck)
44 |
45 | def forward(self, features):
46 | global_feature = self.pool(features).reshape(features.shape[0], -1)
47 | cls, feature = self.classifier(global_feature)
48 | return [cls, feature]
49 |
50 |
51 | class SingleBranchSwin(nn.Module):
52 | def __init__(self, opt) -> None:
53 | super().__init__()
54 | self.opt = opt
55 | self.pool = nn.AdaptiveAvgPool1d(1)
56 | self.classifier = ClassBlock(
57 | opt.in_planes, opt.nclasses, opt.droprate, num_bottleneck=opt.num_bottleneck)
58 |
59 | def forward(self, features):
60 | global_feature = self.pool(features.transpose(
61 | 2, 1)).reshape(features.shape[0], -1)
62 | cls, feature = self.classifier(global_feature)
63 | return [cls, feature]
64 |
--------------------------------------------------------------------------------
/models/Head/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dmmm1997/DenseUAV/d3e4335fb73e1eeeb8db6771d11f731ac8ef3c14/models/Head/__init__.py
--------------------------------------------------------------------------------
/models/Head/head.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .SingleBranch import SingleBranch, SingleBranchCNN, SingleBranchSwin
3 | from .FSRA import FSRA, FSRA_CNN
4 | from .LPN import LPN, LPN_CNN
5 | from .GeM import GeM
6 | from .NetVLAD import NetVLAD
7 |
8 | def make_head(opt):
9 | return Head(opt)
10 |
11 |
12 | class Head(nn.Module):
13 | def __init__(self, opt) -> None:
14 | super().__init__()
15 | self.head = self.init_head(opt)
16 | self.opt = opt
17 |
18 | def init_head(self, opt):
19 | head = opt.head
20 | if head == "SingleBranch":
21 | head_model = SingleBranch(opt)
22 | elif head == "SingleBranchCNN":
23 | head_model = SingleBranchCNN(opt)
24 | elif head == "SingleBranchSwin":
25 | head_model = SingleBranchSwin(opt)
26 | elif head == "NetVLAD":
27 | head_model = NetVLAD(opt)
28 | elif head == "FSRA":
29 | head_model = FSRA(opt)
30 | elif head == "FSRA_CNN":
31 | head_model = FSRA_CNN(opt)
32 | elif head == "LPN":
33 | head_model = LPN(opt)
34 | elif head == "LPN_CNN":
35 | head_model = LPN_CNN(opt)
36 | elif head == "GeM":
37 | head_model = GeM(opt)
38 | else:
39 | raise NameError("{} not in the head list!!!".format(head))
40 | return head_model
41 |
42 | def forward(self, features):
43 | features = self.head(features)
44 | return features
45 |
--------------------------------------------------------------------------------
/models/Head/utils.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | from torch.nn import functional as F
4 | import numpy as np
5 |
6 |
7 | def weights_init_kaiming(m):
8 | classname = m.__class__.__name__
9 | if classname.find('Linear') != -1:
10 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
11 | nn.init.constant_(m.bias, 0.0)
12 |
13 | elif classname.find('Conv') != -1:
14 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
15 | if m.bias is not None:
16 | nn.init.constant_(m.bias, 0.0)
17 | elif classname.find('BatchNorm') != -1:
18 | if m.affine:
19 | nn.init.constant_(m.weight, 1.0)
20 | nn.init.constant_(m.bias, 0.0)
21 |
22 |
23 | def weights_init_classifier(m):
24 | # classname = m.__class__.__name__
25 | # if classname.find('Linear') != -1:
26 | # nn.init.normal_(m.weight, std=0.001)
27 | # if m.bias:
28 | # nn.init.constant_(m.bias, 0.0)
29 | classname = m.__class__.__name__
30 | if classname.find('Linear') != -1:
31 | nn.init.normal_(m.weight.data, std=0.001)
32 | nn.init.constant_(m.bias.data, 0.0)
33 |
34 |
35 | class ClassBlock(nn.Module):
36 | def __init__(self, input_dim, class_num, droprate, relu=False, bnorm=True, num_bottleneck=512, linear=True, return_f=False):
37 | super(ClassBlock, self).__init__()
38 | self.return_f = return_f
39 | add_block = []
40 | if linear:
41 | add_block += [nn.Linear(input_dim, num_bottleneck)]
42 | else:
43 | num_bottleneck = input_dim
44 | if bnorm:
45 | add_block += [nn.BatchNorm1d(num_bottleneck)]
46 | if relu:
47 | add_block += [nn.LeakyReLU(0.1)]
48 | if droprate > 0:
49 | add_block += [nn.Dropout(p=droprate)]
50 | add_block = nn.Sequential(*add_block)
51 | add_block.apply(weights_init_kaiming)
52 |
53 | classifier = []
54 | classifier += [nn.Linear(num_bottleneck, class_num)]
55 | classifier = nn.Sequential(*classifier)
56 | classifier.apply(weights_init_classifier)
57 |
58 | self.add_block = add_block
59 | self.classifier = classifier
60 |
61 | def forward(self, x):
62 | feature_ = self.add_block(x)
63 | cls_ = self.classifier(feature_)
64 | return cls_, feature_
65 |
66 |
67 | class Gem_heat(nn.Module):
68 | def __init__(self, dim=768, p=3, eps=1e-6):
69 | super(Gem_heat, self).__init__()
70 | self.p = nn.Parameter(torch.ones(dim) * p) # initial p
71 | self.eps = eps
72 |
73 | def forward(self, x):
74 | return self.gem(x, p=self.p, eps=self.eps)
75 |
76 | def gem(self, x, p=3, eps=1e-6):
77 | # x = torch.transpose(x, 1, -1)
78 | p = F.softmax(p).unsqueeze(-1)
79 | x = torch.matmul(x, p)
80 | # x = torch.transpose(x, 1, -1)
81 | # x = F.avg_pool1d(x, x.size(-1))
82 | x = x.view(x.size(0), x.size(1))
83 | # x = x.pow(1. / p)
84 | return x
85 |
86 |
87 | class GeM(nn.Module):
88 | # channel-wise GeM zhedong zheng
89 | def __init__(self, dim=2048, p=1, eps=1e-6):
90 | super(GeM, self).__init__()
91 | self.p = nn.Parameter(torch.ones(dim)*p) # initial p
92 | self.eps = eps
93 |
94 | def forward(self, x):
95 | x = torch.transpose(x, 1, -1)
96 | x = (x+self.eps).pow(self.p)
97 | x = torch.transpose(x, 1, -1)
98 | x = F.avg_pool2d(x, (x.size(-2), x.size(-1)))
99 | x = x.view(x.size(0), x.size(1)).contiguous()
100 | x = x.pow(1./self.p)
101 | return x
102 |
103 |
104 |
105 | class Pooling(nn.Module):
106 | def __init__(self, dim, pool="avg"):
107 | super(Pooling, self).__init__()
108 | self.pool = pool
109 | if pool == 'avg+max':
110 | self.avgpool2 = nn.AdaptiveAvgPool2d((1, 1))
111 | self.maxpool2 = nn.AdaptiveMaxPool2d((1, 1))
112 | elif pool == 'avg':
113 | self.avgpool2 = nn.AdaptiveAvgPool2d((1, 1))
114 | elif pool == 'max':
115 | self.maxpool2 = nn.AdaptiveMaxPool2d((1, 1))
116 | elif pool == 'gem':
117 | self.gem2 = Gem_heat(dim=dim)
118 |
119 | def forward(self, x):
120 | if self.pool == 'avg+max':
121 | x1 = self.avgpool2(x)
122 | x2 = self.maxpool2(x)
123 | x = torch.cat((x1, x2), dim=1)
124 | elif self.pool == 'avg':
125 | x = self.avgpool2(x)
126 | elif self.pool == 'max':
127 | x = self.maxpool2(x)
128 | elif self.pool == 'gem':
129 | x = self.gem2(x)
130 | return x
131 |
132 |
133 | def vector2image(x, dim=1): # (B,N,C)
134 | B, N, C = x.shape
135 | if dim == 1:
136 | return x.reshape(B, int(np.sqrt(N)), int(np.sqrt(N)), C)
137 | if dim == 2:
138 | return x.reshape(B, N, int(np.sqrt(C)), int(np.sqrt(C)))
139 | else:
140 | raise TypeError("dim is not correct!!")
141 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dmmm1997/DenseUAV/d3e4335fb73e1eeeb8db6771d11f731ac8ef3c14/models/__init__.py
--------------------------------------------------------------------------------
/models/taskflow.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .Backbone.backbone import make_backbone
3 | from .Head.head import make_head
4 | import os
5 | import torch
6 |
7 |
8 | class Model(nn.Module):
9 | def __init__(self, opt):
10 | super().__init__()
11 | self.backbone = make_backbone(opt)
12 | opt.in_planes = self.backbone.output_channel
13 | self.head = make_head(opt)
14 | self.opt = opt
15 |
16 | def forward(self, drone_image, satellite_image):
17 | if drone_image is None:
18 | drone_res = None
19 | else:
20 | drone_features = self.backbone(drone_image)
21 | drone_res = self.head(drone_features)
22 | if satellite_image is None:
23 | satellite_res = None
24 | else:
25 | satellite_features = self.backbone(satellite_image)
26 | satellite_res = self.head(satellite_features)
27 |
28 | return drone_res,satellite_res
29 |
30 | def load_params(self, load_from):
31 | pretran_model = torch.load(load_from)
32 | model2_dict = self.state_dict()
33 | state_dict = {k: v for k, v in pretran_model.items() if k in model2_dict.keys() and v.size() == model2_dict[k].size()}
34 | model2_dict.update(state_dict)
35 | self.load_state_dict(model2_dict)
36 |
37 |
38 | def make_model(opt):
39 | model = Model(opt)
40 | if os.path.exists(opt.load_from):
41 | model.load_params(opt.load_from)
42 | return model
43 |
--------------------------------------------------------------------------------
/optimizers/make_optimizer.py:
--------------------------------------------------------------------------------
1 | import torch.optim as optim
2 | from torch.optim import lr_scheduler
3 |
4 |
5 | def make_optimizer(model,opt):
6 | ignored_params = []
7 | ignored_params += list(map(id, model.backbone.parameters()))
8 | extra_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
9 | base_params = filter(lambda p: id(p) in ignored_params, model.parameters())
10 | optimizer_ft = optim.SGD([
11 | {'params': base_params, 'lr': 0.3 * opt.lr},
12 | {'params': extra_params, 'lr': opt.lr}
13 | ], weight_decay=5e-4, momentum=0.9, nesterov=True)
14 |
15 |
16 | exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer_ft, milestones=[70,110], gamma=0.1)
17 | # exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=80, gamma=0.1)
18 | # exp_lr_scheduler = lr_scheduler.ExponentialLR(optimizer_ft, gamma=0.95)
19 | # exp_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer_ft, mode='min', factor=0.5, patience=4, verbose=True,threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=1e-5, eps=1e-08)
20 |
21 | return optimizer_ft,exp_lr_scheduler
--------------------------------------------------------------------------------
/requirments.txt:
--------------------------------------------------------------------------------
1 | scipy
2 | tqdm
3 | pyyaml
4 | matplotlib
5 | opencv-python
6 | timm
7 | pillow
8 | einops
9 | thop
10 | pytorch_metric_learning
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import print_function, division
4 | from datasets.queryDataset import Dataset_query, Query_transforms
5 |
6 | import argparse
7 | import torch
8 | import torch.nn as nn
9 | from torch.autograd import Variable
10 | import torch.backends.cudnn as cudnn
11 | import numpy as np
12 | from torchvision import datasets, models, transforms
13 | import time
14 | import os
15 | import scipy.io
16 | import yaml
17 | import math
18 | from tool.utils import load_network
19 | from tqdm import tqdm
20 | import warnings
21 | warnings.filterwarnings("ignore")
22 |
23 | parser = argparse.ArgumentParser(description='Training')
24 | parser.add_argument('--gpu_ids', default='0', type=str,
25 | help='gpu_ids: e.g. 0 0,1,2 0,2')
26 | parser.add_argument(
27 | '--test_dir', default='', type=str, help='./test_data')
28 | parser.add_argument('--name', default='',
29 | type=str, help='save model path')
30 | parser.add_argument('--checkpoint', default='net_119.pth',
31 | type=str, help='save model path')
32 | parser.add_argument('--batchsize', default=128, type=int, help='batchsize')
33 | parser.add_argument('--h', default=256, type=int, help='height')
34 | parser.add_argument('--w', default=256, type=int, help='width')
35 | parser.add_argument('--ms', default='1', type=str,
36 | help='multiple_scale: e.g. 1 1,1.1 1,1.1,1.2')
37 | parser.add_argument('--num_worker', default=4, type=int,help='')
38 | parser.add_argument('--mode',default='1', type=int,help='1:drone->satellite 2:satellite->drone')
39 | opt = parser.parse_args()
40 |
41 | print(opt.name)
42 |
43 | ###load config###
44 | # load the training config
45 | config_path = 'opts.yaml'
46 | with open(config_path, 'r') as stream:
47 | config = yaml.load(stream, Loader=yaml.FullLoader)
48 | for cfg, value in config.items():
49 | setattr(opt, cfg, value)
50 |
51 | str_ids = opt.gpu_ids.split(',')
52 | test_dir = opt.test_dir
53 |
54 | gpu_ids = []
55 | for str_id in str_ids:
56 | id = int(str_id)
57 | if id >= 0:
58 | gpu_ids.append(id)
59 |
60 | print('We use the scale: %s' % opt.ms)
61 | str_ms = opt.ms.split(',')
62 | ms = []
63 |
64 | for s in str_ms:
65 | s_f = float(s)
66 | ms.append(math.sqrt(s_f))
67 |
68 | if len(gpu_ids) > 0:
69 | torch.cuda.set_device(gpu_ids[0])
70 | cudnn.benchmark = True
71 |
72 |
73 | data_transforms = transforms.Compose([
74 | transforms.Resize((opt.h, opt.w), interpolation=3),
75 | transforms.ToTensor(),
76 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
77 | ])
78 |
79 | data_query_transforms = transforms.Compose([
80 | transforms.Resize((opt.h, opt.w), interpolation=3),
81 | # Query_transforms(pad=10,size=opt.w),
82 | transforms.ToTensor(),
83 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
84 | ])
85 |
86 |
87 | data_dir = test_dir
88 |
89 | image_datasets_query = {x: datasets.ImageFolder(os.path.join(
90 | data_dir, x), data_query_transforms) for x in ['query_drone']}
91 |
92 | image_datasets_gallery = {x: datasets.ImageFolder(os.path.join(
93 | data_dir, x), data_transforms) for x in ['gallery_satellite']}
94 |
95 | image_datasets = {**image_datasets_query, **image_datasets_gallery}
96 |
97 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
98 | shuffle=False, num_workers=opt.num_worker) for x in ['gallery_satellite', 'query_drone']}
99 | use_gpu = torch.cuda.is_available()
100 |
101 |
102 | def fliplr(img):
103 | '''flip horizontal'''
104 | inv_idx = torch.arange(img.size(3)-1, -1, -1).long() # N x C x H x W
105 | img_flip = img.index_select(3, inv_idx)
106 | return img_flip
107 |
108 |
109 | def which_view(name):
110 | if 'satellite' in name:
111 | return 1
112 | elif 'street' in name:
113 | return 2
114 | elif 'drone' in name:
115 | return 3
116 | else:
117 | print('unknown view')
118 | return -1
119 |
120 |
121 | def extract_feature(model, dataloaders, view_index=1):
122 | features = torch.FloatTensor()
123 | count = 0
124 | for data in tqdm(dataloaders):
125 | img, _ = data
126 | batchsize = img.size()[0]
127 | count += batchsize
128 | # if opt.LPN:
129 | # # ff = torch.FloatTensor(n,2048,6).zero_().cuda()
130 | # ff = torch.FloatTensor(n,512,opt.block).zero_().cuda()
131 | # else:
132 | # ff = torch.FloatTensor(n, 2048).zero_().cuda()
133 | for i in range(2):
134 | if(i == 1):
135 | img = fliplr(img)
136 | input_img = Variable(img.cuda())
137 | if view_index == 1:
138 | outputs, _ = model(input_img, None)
139 | elif view_index == 3:
140 | _, outputs = model(None, input_img)
141 | outputs = outputs[1]
142 | if i == 0:
143 | ff = outputs
144 | else:
145 | ff += outputs
146 | # norm feature
147 | if len(ff.shape) == 3:
148 | # feature size (n,2048,6)
149 | # 1. To treat every part equally, I calculate the norm for every 2048-dim part feature.
150 | # 2. To keep the cosine score==1, sqrt(6) is added to norm the whole feature (2048*6).
151 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * \
152 | np.sqrt(opt.block)
153 | ff = ff.div(fnorm.expand_as(ff))
154 | ff = ff.view(ff.size(0), -1)
155 | else:
156 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
157 | ff = ff.div(fnorm.expand_as(ff))
158 |
159 | features = torch.cat((features, ff.data.cpu()), 0)
160 | return features
161 |
162 |
163 | def get_id(img_path):
164 | camera_id = []
165 | labels = []
166 | paths = []
167 | for path, v in img_path:
168 | folder_name = os.path.basename(os.path.dirname(path))
169 | labels.append(int(folder_name))
170 | paths.append(path)
171 | return labels, paths
172 |
173 |
174 | ######################################################################
175 | # Load Collected data Trained model
176 | print('-------test-----------')
177 |
178 | model = load_network(opt)
179 | print("这是%s的结果" % opt.checkpoint)
180 | # model.classifier.classifier = nn.Sequential()
181 | model = model.eval()
182 | if use_gpu:
183 | model = model.cuda()
184 |
185 | # Extract feature
186 | since = time.time()
187 |
188 | if opt.mode == 1:
189 | query_name = 'query_drone'
190 | gallery_name = 'gallery_satellite'
191 | elif opt.mode == 2:
192 | query_name = 'query_satellite'
193 | gallery_name = 'gallery_drone'
194 | else:
195 | raise Exception("opt.mode is not required")
196 |
197 |
198 | which_gallery = which_view(gallery_name)
199 | which_query = which_view(query_name)
200 | print('%d -> %d:' % (which_query, which_gallery))
201 | print(query_name.split("_")[-1], "->", gallery_name.split("_")[-1])
202 |
203 | gallery_path = image_datasets[gallery_name].imgs
204 |
205 | query_path = image_datasets[query_name].imgs
206 |
207 | gallery_label, gallery_path = get_id(gallery_path)
208 | query_label, query_path = get_id(query_path)
209 |
210 | if __name__ == "__main__":
211 | with torch.no_grad():
212 | query_feature = extract_feature(
213 | model, dataloaders[query_name], which_query)
214 | gallery_feature = extract_feature(
215 | model, dataloaders[gallery_name], which_gallery)
216 |
217 | # For street-view image, we use the avg feature as the final feature.
218 |
219 | time_elapsed = time.time() - since
220 | print('Test complete in {:.0f}m {:.0f}s'.format(
221 | time_elapsed // 60, time_elapsed % 60))
222 |
223 | with open('inference_time.txt', 'w') as F:
224 | F.write('Test complete in {:.0f}m {:.0f}s'.format(
225 | time_elapsed // 60, time_elapsed % 60))
226 |
227 | # Save to Matlab for check
228 | result = {'gallery_f': gallery_feature.numpy(), 'gallery_label': gallery_label, 'gallery_path': gallery_path,
229 | 'query_f': query_feature.numpy(), 'query_label': query_label, 'query_path': query_path}
230 | scipy.io.savemat('pytorch_result_{}.mat'.format(opt.mode), result)
231 |
232 | # print(opt.name)
233 | # result = 'result.txt'
234 | # os.system('python evaluate_gpu.py | tee -a %s'%result)
235 |
--------------------------------------------------------------------------------
/test_hard.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import print_function, division
4 | from datasets.queryDataset import Dataset_query, Query_transforms, Dataset_gallery, test_collate_fn
5 |
6 | import argparse
7 | import torch
8 | import torch.nn as nn
9 | from torch.autograd import Variable
10 | import torch.backends.cudnn as cudnn
11 | import numpy as np
12 | from torchvision import datasets, models, transforms
13 | import time
14 | import os
15 | import scipy.io
16 | import yaml
17 | import math
18 | from tool.utils import load_network
19 | from tqdm import tqdm
20 | import warnings
21 | warnings.filterwarnings("ignore")
22 |
23 | parser = argparse.ArgumentParser(description='Training')
24 | parser.add_argument('--gpu_ids', default='0', type=str,
25 | help='gpu_ids: e.g. 0 0,1,2 0,2')
26 | parser.add_argument(
27 | '--test_dir', default='/home/dmmm/Dataset/DenseUAV/data_2022/test', type=str, help='./test_data')
28 | parser.add_argument('--name', default='resnet',
29 | type=str, help='save model path')
30 | parser.add_argument('--checkpoint', default='net_119.pth',
31 | type=str, help='save model path')
32 | parser.add_argument('--batchsize', default=128, type=int, help='batchsize')
33 | parser.add_argument('--h', default=224, type=int, help='height')
34 | parser.add_argument('--w', default=224, type=int, help='width')
35 | parser.add_argument('--ms', default='1', type=str,
36 | help='multiple_scale: e.g. 1 1,1.1 1,1.1,1.2')
37 | parser.add_argument('--mode', default='hard', type=str,
38 | help='1:drone->satellite 2:satellite->drone')
39 | parser.add_argument('--num_worker', default=8, type=int,
40 | help='1:drone->satellite 2:satellite->drone')
41 |
42 | parser.add_argument('--split_feature', default=1, type=int, help='')
43 |
44 | opt = parser.parse_args()
45 | print(opt.name)
46 | ###load config###
47 | # load the training config
48 | config_path = 'opts.yaml'
49 | with open(config_path, 'r') as stream:
50 | config = yaml.load(stream)
51 | for cfg, value in config.items():
52 | if cfg not in opt:
53 | setattr(opt, cfg, value)
54 |
55 | str_ids = opt.gpu_ids.split(',')
56 | test_dir = opt.test_dir
57 |
58 | gpu_ids = []
59 | for str_id in str_ids:
60 | id = int(str_id)
61 | if id >= 0:
62 | gpu_ids.append(id)
63 |
64 | print('We use the scale: %s' % opt.ms)
65 | str_ms = opt.ms.split(',')
66 | ms = []
67 | for s in str_ms:
68 | s_f = float(s)
69 | ms.append(math.sqrt(s_f))
70 |
71 | if len(gpu_ids) > 0:
72 | torch.cuda.set_device(gpu_ids[0])
73 | cudnn.benchmark = True
74 |
75 |
76 | data_transforms = transforms.Compose([
77 | transforms.Resize((opt.h, opt.w), interpolation=3),
78 | transforms.ToTensor(),
79 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
80 | ])
81 |
82 | data_query_transforms = transforms.Compose([
83 | transforms.Resize((opt.h, opt.w), interpolation=3),
84 | # Query_transforms(pad=10,size=opt.w),
85 | transforms.ToTensor(),
86 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
87 | ])
88 |
89 |
90 | data_dir = test_dir
91 |
92 | image_datasets_query = Dataset_query(os.path.join(data_dir, "query_drone"), data_query_transforms)
93 |
94 | # image_datasets_gallery = Dataset_gallery(os.path.join(opt.test_dir, "total_info_ms_i10m.txt"), data_transforms)
95 |
96 | image_datasets_gallery = Dataset_gallery(os.path.join(opt.test_dir, "total_info_ss_i10m.txt"), data_transforms)
97 |
98 |
99 | dataloaders_query = torch.utils.data.DataLoader(image_datasets_query, batch_size=opt.batchsize, shuffle=False, num_workers=opt.num_worker, collate_fn=test_collate_fn)
100 |
101 | split_nums = len(image_datasets_gallery)//opt.split_feature
102 |
103 | list_split = [split_nums]*opt.split_feature
104 |
105 | list_split[-1] = len(image_datasets_gallery)-(opt.split_feature-1)*split_nums
106 |
107 | gallery_datasets_list = torch.utils.data.random_split(image_datasets_gallery, list_split)
108 |
109 | dataloaders_gallery = {ind: torch.utils.data.DataLoader(gallery_datasets_list[ind], batch_size=opt.batchsize, shuffle=False, num_workers=opt.num_worker, collate_fn=test_collate_fn) for ind in range(opt.split_feature)}
110 |
111 | use_gpu = torch.cuda.is_available()
112 |
113 | def extract_feature(model, dataloaders, view_index=1):
114 | features = torch.FloatTensor()
115 | infos_list = np.zeros((0,2),dtype=np.float32)
116 | path_list = []
117 | for data in tqdm(dataloaders):
118 | img, infos, path = data
119 | path_list.extend(path)
120 | # infos_list.extend(infos)
121 | infos_list = np.concatenate((infos_list,infos),0)
122 | # if opt.LPN:
123 | # # ff = torch.FloatTensor(n,2048,6).zero_().cuda()
124 | # ff = torch.FloatTensor(n,512,opt.block).zero_().cuda()
125 | # else:
126 | # ff = torch.FloatTensor(n, 2048).zero_().cuda()
127 |
128 | input_img = Variable(img.cuda())
129 | if view_index == 1:
130 | outputs, _ = model(input_img, None)
131 | elif view_index == 3:
132 | _, outputs = model(None, input_img)
133 | outputs = outputs[1]
134 | ff = outputs
135 | # # norm feature
136 | if len(ff.shape) == 3:
137 | # feature size (n,2048,6)
138 | # 1. To treat every part equally, I calculate the norm for every 2048-dim part feature.
139 | # 2. To keep the cosine score==1, sqrt(6) is added to norm the whole feature (2048*6).
140 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * \
141 | np.sqrt(opt.block)
142 | ff = ff.div(fnorm.expand_as(ff))
143 | ff = ff.view(ff.size(0), -1)
144 | else:
145 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
146 | ff = ff.div(fnorm.expand_as(ff))
147 |
148 | features = torch.cat((features, ff.data.cpu()), 0)
149 | return features,infos_list,path_list
150 |
151 |
152 | model = load_network(opt)
153 | print("这是%s的结果" % opt.checkpoint)
154 | # model.classifier.classifier = nn.Sequential()
155 | model = model.eval()
156 | if use_gpu:
157 | model = model.cuda()
158 |
159 | # Extract feature
160 | since = time.time()
161 |
162 | if __name__ == "__main__":
163 | with torch.no_grad():
164 | query_feature, query_infos, query_path = extract_feature(
165 | model, dataloaders_query, 1)
166 | gallery_features = torch.FloatTensor()
167 | gallery_infos = np.zeros((0,2),dtype=np.float32)
168 | gallery_paths = []
169 | for i in range(opt.split_feature):
170 | gallery_feature,gallery_info,gallery_path = extract_feature(
171 | model, dataloaders_gallery[i], 3)
172 | gallery_infos = np.concatenate((gallery_infos,gallery_info),0)
173 | gallery_features = torch.cat((gallery_features,gallery_feature),0)
174 | gallery_paths.extend(gallery_path)
175 |
176 |
177 | # For street-view image, we use the avg feature as the final feature.
178 |
179 | time_elapsed = time.time() - since
180 | print('Test complete in {:.0f}m {:.0f}s'.format(
181 | time_elapsed // 60, time_elapsed % 60))
182 |
183 | with open('inference_time.txt', 'w') as F:
184 | F.write('Test complete in {:.0f}m {:.0f}s'.format(
185 | time_elapsed // 60, time_elapsed % 60))
186 |
187 | # Save to Matlab for check
188 |
189 | result = {'gallery_f': gallery_features.numpy(), 'gallery_infos': gallery_infos.astype(np.float32), 'query_path':query_path,
190 | 'query_f': query_feature.numpy(), 'query_infos': query_infos.astype(np.float32), 'gallery_path':gallery_paths}
191 | scipy.io.savemat('pytorch_result_{}_ss.mat'.format(opt.mode), result)
192 |
193 | # print(opt.name)
194 | # result = 'result.txt'
195 | # os.system('python evaluate_gpu.py | tee -a %s'%result)
196 |
197 |
198 |
199 |
200 |
--------------------------------------------------------------------------------
/tool/SDM@K_analyze.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 |
4 | ResNet50 = [0.165, 0.343, 0.473, 0.535, 1.0]
5 | ConvNextT = [0.6023, 0.7404, 0.7975, 0.828, 1.0]
6 | ViTS = [0.8018, 0.888, 0.9197, 0.9335, 1.0]
7 |
8 | y1 = ConvNextT
9 | for i in range(len(y1)-1,0,-1):
10 | y1[i] = y1[i]-y1[i-1]
11 |
12 | y2 = ViTS
13 | for i in range(len(y2)-1,0,-1):
14 | y2[i] = y2[i]-y2[i-1]
15 |
16 | y0 = ResNet50
17 | for i in range(len(y0)-1,0,-1):
18 | y0[i] = y0[i]-y0[i-1]
19 |
20 | totol_images = 2331
21 |
22 | # 创建数据
23 | categories = ["0","1","2","3", "other"]
24 |
25 | bar_width = 0.3
26 |
27 | r0 = np.arange(len(categories))
28 | r1 = [x + bar_width for x in r0]
29 | r2 = [x + bar_width for x in r1]
30 |
31 | # 创建左右两个子图
32 | plt.subplots(figsize=(7, 5))
33 |
34 | plt.xlabel("Error on Sampling Interval",fontdict={'family' : 'Times New Roman', 'size': 17})
35 | plt.xticks(fontproperties='Times New Roman',fontsize=15)
36 |
37 | plt.ylabel('Proportion(%)', fontdict={'family': 'Times New Roman', 'size': 17}) # 添加x轴的标签
38 | plt.yticks(fontproperties='Times New Roman',fontsize=15)
39 | plt.ylim(0,0.85)
40 |
41 | # 绘制左边的图
42 | plt.bar(r0, y0, width=bar_width, color="#eed777", edgecolor="k",linewidth=2, label='ResNet-50')
43 | plt.bar(r1, y1, width=bar_width, color="#45a776", edgecolor="k",linewidth=2, label='ConvNext-T')
44 | plt.bar(r2, y2, width=bar_width, color="#b3974e", edgecolor="k",linewidth=2, label="ViT-S")
45 | for i, value in zip(r0,y0):
46 | plt.text(i, value, "{:.1f}".format(value*100), ha='center', va='bottom', color="black", fontsize=13, fontproperties='Times New Roman')
47 | for i, value in zip(r1,y1):
48 | plt.text(i, value, "{:.1f}".format(value*100), ha='center', va='bottom', color="black", fontsize=13, fontproperties='Times New Roman')
49 | for i, value in zip(r2,y2):
50 | plt.text(i, value, "{:.1f}".format(value*100), ha='center', va='bottom', color="black", fontsize=13, fontproperties='Times New Roman')
51 |
52 | # 添加刻度标签
53 | plt.xticks([r + bar_width for r in range(len(categories))], categories)
54 |
55 | plt.legend(prop={'family': 'Times New Roman', 'size': 15})
56 |
57 | # 调整布局
58 | plt.tight_layout()
59 |
60 | # 显示图形
61 | plt.savefig("SDM@K_analyze.eps", dpi = 600)
62 |
63 |
64 |
--------------------------------------------------------------------------------
/tool/SDM@K_compare.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import sys
4 | import matplotlib as mpl
5 | mpl.rcParams['font.family'] = 'Times New Roman'
6 |
7 |
8 | def evaluateSingle(distance, K):
9 | # maxDistance = max(distance) + 1e-14
10 | # weight = np.ones(K) - np.log(range(1, K + 1, 1)) / np.log(opts.M * K)
11 | weight = np.ones(K) - np.array(range(0, K, 1))/K
12 | # m1 = distance / maxDistance
13 | m2 = 1 / np.exp(distance*5e3)
14 | m3 = m2 * weight
15 | result = np.sum(m3) / np.sum(weight)
16 | return result
17 |
18 |
19 | def evaluate_enclidean(distance, K):
20 | # maxDistance = max(distance) + 1e-14
21 | # weight = np.ones(K) - np.log(range(1, K + 1, 1)) / np.log(opts.M * K)
22 | weight = np.ones(K) - np.array(range(0, K, 1))/K
23 | # m1 = distance / maxDistance
24 | m2 = 1-distance*1e3
25 | m3 = m2 * weight
26 | result = np.sum(m3) / np.sum(weight)
27 | return result
28 |
29 |
30 | def Recall_Data(data):
31 | x_len, y_len = data.shape
32 | data = np.zeros_like(data)
33 | data[x_len//2, y_len//2] = 1
34 | return data
35 |
36 | def euclideanDistance(query, gallery):
37 | query = np.array(query, dtype=np.float32)
38 | gallery = np.array(gallery, dtype=np.float32)
39 | A = gallery - query
40 | A_T = A.transpose()
41 | distance = np.matmul(A, A_T)
42 | mask = np.eye(distance.shape[0], dtype=np.bool8)
43 | distance = distance[mask]
44 | distance = np.sqrt(distance.reshape(-1))
45 | return distance
46 |
47 |
48 | def SDM_Data(data):
49 | x_len, y_len = data.shape
50 | x = np.linspace(120.358111-0.0003, 120.358111+0.0003, x_len)
51 | y = np.linspace(30.317842-0.0003, 30.317842+0.0003, y_len)
52 | x_,y_ = np.meshgrid(x,y)
53 | x_ = x_.reshape(-1,1)
54 | y_ = y_.reshape(-1,1)
55 | input = np.concatenate((x_,y_),axis=-1)
56 |
57 | target = np.array((120.358111,30.317842)).reshape(-1,2)
58 | distance = euclideanDistance(input, target)
59 | # compute single query evaluate result
60 | P_list = np.array([evaluateSingle(dist_single, 1) for dist_single in distance])
61 | return P_list.reshape(x_len,y_len)
62 |
63 |
64 | def Enclidean_Data(data):
65 | x_len, y_len = data.shape
66 | x = np.linspace(120.358111-0.0003, 120.358111+0.0003, x_len)
67 | y = np.linspace(30.317842-0.0003, 30.317842+0.0003, y_len)
68 | x_,y_ = np.meshgrid(x,y)
69 | x_ = x_.reshape(-1,1)
70 | y_ = y_.reshape(-1,1)
71 | input = np.concatenate((x_,y_),axis=-1)
72 |
73 | target = np.array((120.358111,30.317842)).reshape(-1,2)
74 | distance = euclideanDistance(input, target)
75 | # compute single query evaluate result
76 | P_list = np.array([evaluate_enclidean(dist_single, 1) for dist_single in distance])
77 | return P_list.reshape(x_len,y_len)
78 |
79 |
80 |
81 | # 创建一个随机的2D数组作为网格数据
82 | data = np.random.rand(7, 7)
83 |
84 | Recall_data = Recall_Data(data)
85 | SDM_data = SDM_Data(data)
86 | Enclidean_data = Enclidean_Data(data)
87 |
88 |
89 | # 创建一个figure和axes对象
90 | fig, ax = plt.subplots(1,3,figsize=(14,4))
91 |
92 | # 使用imshow函数显示网格数据
93 | img1 = ax[0].imshow(Recall_data, cmap='coolwarm', interpolation='nearest')
94 | img2 = ax[2].imshow(SDM_data, cmap='coolwarm', interpolation='nearest')
95 | img3 = ax[1].imshow(Enclidean_data, cmap='coolwarm', interpolation='nearest')
96 |
97 |
98 | # 在每个格子中间显示数值
99 | for i in range(data.shape[0]):
100 | for j in range(data.shape[1]):
101 | ax[0].text(j, i, f'{Recall_data[i, j]:.1f}', ha='center', va='center', color='white')
102 | ax[2].text(j, i, f'{SDM_data[i, j]:.1f}', ha='center', va='center', color='white')
103 | ax[1].text(j, i, f'{Enclidean_data[i, j]:.1f}', ha='center', va='center', color='white')
104 |
105 | for a in ax:
106 | a.set_xticks([])
107 | a.set_yticks([])
108 | a.set_xticklabels([])
109 | a.set_yticklabels([])
110 |
111 | # 添加颜色条
112 | cbar = plt.colorbar(img1)
113 |
114 |
115 | # 设置子图标题
116 | ax[0].set_title('(a) Recall', fontsize=16, pad=10)
117 | ax[2].set_title('(b) SDM', fontsize=16, pad=10)
118 | ax[1].set_title('(c) Euclidean Distance', fontsize=16, pad=10)
119 |
120 | plt.subplots_adjust(wspace=0.0, hspace=0.000)
121 |
122 | # 调整布局,确保子图之间的间距合适
123 | plt.tight_layout()
124 |
125 |
126 |
127 | # ax[0].grid(True, which='both', linestyle='-', linewidth=2, color='black')
128 | # ax[1].grid(True, which='both', linestyle='-', linewidth=2, color='black')
129 | # ax[2].grid(True, which='both', linestyle='-', linewidth=2, color='black')
130 |
131 | plt.savefig("SDM_Recall_Enclidean_compare.eps", dpi=300)
132 |
--------------------------------------------------------------------------------
/tool/applications/forwardAllSatelliteHub.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import print_function, division
4 |
5 | import argparse
6 | import torch
7 | import torch.nn as nn
8 | from torch.autograd import Variable
9 | import torch.backends.cudnn as cudnn
10 | import numpy as np
11 | from torchvision import datasets, models, transforms
12 | import time
13 | import os
14 | import scipy.io
15 | import yaml
16 | import math
17 | from tool.utils import load_network
18 | from tqdm import tqdm
19 | import warnings
20 | from datasets.Dataloader_University import DataLoader_Inference
21 | warnings.filterwarnings("ignore")
22 | from datasets.queryDataset import Dataset_query,Query_transforms
23 | # Options
24 | # --------
25 | University="计量"
26 | parser = argparse.ArgumentParser(description='Training')
27 | parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2')
28 | parser.add_argument('--root',default='/media/dmmm/4T-3/DataSets/DenseCV_Data/inference_data/satelliteHub({})'.format(University),type=str, help='./test_data')
29 | parser.add_argument('--savename', default='features{}.mat'.format(University), type=str, help='save model path')
30 | parser.add_argument('--checkpoint', default='net_119.pth', type=str, help='save model path')
31 | parser.add_argument('--batchsize', default=128, type=int, help='batchsize')
32 | parser.add_argument('--h', default=256, type=int, help='height')
33 | parser.add_argument('--w', default=256, type=int, help='width')
34 | parser.add_argument('--ms',default='1', type=str,help='multiple_scale: e.g. 1 1,1.1 1,1.1,1.2')
35 | parser.add_argument('--num_worker',default=4, type=int,help='1:drone->satellite 2:satellite->drone')
36 |
37 | opt = parser.parse_args()
38 | ###load config###
39 | # load the training config
40 | config_path = 'opts.yaml'
41 | with open(config_path, 'r') as stream:
42 | config = yaml.load(stream)
43 | for cfg,value in config.items():
44 | setattr(opt,cfg,value)
45 |
46 | if 'h' in config:
47 | opt.h = config['h']
48 | opt.w = config['w']
49 |
50 | if 'nclasses' in config: # tp compatible with old config files
51 | opt.nclasses = config['nclasses']
52 | else:
53 | opt.nclasses = 729
54 |
55 | str_ids = opt.gpu_ids.split(',')
56 |
57 | gpu_ids = []
58 | for str_id in str_ids:
59 | id = int(str_id)
60 | if id >=0:
61 | gpu_ids.append(id)
62 |
63 | print('We use the scale: %s'%opt.ms)
64 | str_ms = opt.ms.split(',')
65 | ms = []
66 | for s in str_ms:
67 | s_f = float(s)
68 | ms.append(math.sqrt(s_f))
69 |
70 | # set gpu ids
71 | if len(gpu_ids)>0:
72 | torch.cuda.set_device(gpu_ids[0])
73 | cudnn.benchmark = True
74 |
75 | ######################################################################
76 | # Load Data
77 | # ---------
78 | #
79 | # We will use torchvision and torch.utils.data packages for loading the
80 | # data.
81 | #
82 | def fliplr(img):
83 | '''flip horizontal'''
84 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W
85 | img_flip = img.index_select(3,inv_idx)
86 | return img_flip
87 |
88 | data_transforms = transforms.Compose([
89 | transforms.Resize((opt.h, opt.w), interpolation=3),
90 | transforms.ToTensor(),
91 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
92 | ])
93 |
94 |
95 | image_datasets = DataLoader_Inference(root=opt.root,transforms=data_transforms)
96 |
97 | dataloaders = torch.utils.data.DataLoader(image_datasets, batch_size=opt.batchsize,
98 | shuffle=False, num_workers=opt.num_worker)
99 | use_gpu = torch.cuda.is_available()
100 |
101 |
102 |
103 | def extract_feature(model,dataloaders, view_index = 1):
104 | features = torch.FloatTensor()
105 | count = 0
106 | for data in tqdm(dataloaders):
107 | img, label = data
108 | n, c, h, w = img.size()
109 | count += n
110 | # if opt.LPN:
111 | # # ff = torch.FloatTensor(n,2048,6).zero_().cuda()
112 | # ff = torch.FloatTensor(n,512,opt.block).zero_().cuda()
113 | # else:
114 | # ff = torch.FloatTensor(n, 2048).zero_().cuda()
115 | input_img = Variable(img.cuda())
116 | outputs, _ = model(input_img, None)
117 | outputs = outputs[1]
118 | ff=outputs
119 | # norm feature
120 | if len(ff.shape)==3:
121 | # feature size (n,2048,6)
122 | # 1. To treat every part equally, I calculate the norm for every 2048-dim part feature.
123 | # 2. To keep the cosine score==1, sqrt(6) is added to norm the whole feature (2048*6).
124 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * np.sqrt(opt.block)
125 | ff = ff.div(fnorm.expand_as(ff))
126 | ff = ff.view(ff.size(0), -1)
127 | else:
128 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
129 | ff = ff.div(fnorm.expand_as(ff))
130 |
131 | features = torch.cat((features,ff.data.cpu()), 0)
132 | return features
133 |
134 |
135 |
136 | # Load Collected data Trained model
137 | model = load_network(opt)
138 | model = model.eval()
139 | if use_gpu:
140 | model = model.cuda()
141 | # Extract feature
142 | since = time.time()
143 |
144 |
145 | images = image_datasets.imgs
146 |
147 | labels = image_datasets.labels
148 |
149 | if __name__ == "__main__":
150 | with torch.no_grad():
151 | features = extract_feature(model,dataloaders)
152 |
153 | # Save to Matlab for check
154 | result = {'features':features.numpy(),'labels':labels}
155 | scipy.io.savemat(opt.savename,result)
156 |
--------------------------------------------------------------------------------
/tool/applications/inference_global.py:
--------------------------------------------------------------------------------
1 | from __future__ import with_statement
2 | import argparse
3 | import scipy.io
4 | import torch
5 | import numpy as np
6 | from datasets.queryDataset import CenterCrop
7 | import glob
8 | from torchvision import transforms
9 | from PIL import Image
10 | import yaml
11 | from tool.utils import load_network
12 | from torch.autograd import Variable
13 | import torch.backends.cudnn as cudnn
14 | import os
15 | from tool.get_property import find_GPS_image
16 | import cv2
17 |
18 | #######################################################################
19 | # Evaluate
20 | University="计量"
21 | parser = argparse.ArgumentParser(description='Demo')
22 | parser.add_argument('--imageDir', default="/media/dmmm/CE31-3598/DataSets/DenseCV_Data/实际测试图像({})/test02".format(University), type=str,
23 | help='test_image_index')
24 | parser.add_argument('--satelliteMat', default="features{}.mat".format(University), type=str, help='./test_data')
25 | parser.add_argument('--MapDir', default="../../maps/{}.tif".format(University), type=str, help='./test_data')
26 | parser.add_argument('--galleryPath', default="/media/dmmm/4T-3/DataSets/DenseCV_Data/inference_data/satelliteHub({})".format(University),
27 | type=str, help='./test_data')
28 | opts = parser.parse_args()
29 | # TN30.325763471673625 TE120.37341802729506 BN30.320529681696023 BE120.38174250997761 jinrong
30 |
31 | mapPosInfodir = "/home/dmmm/PycharmProjects/DenseCV/demo/maps/pos.txt"
32 | with open(mapPosInfodir,"r") as F:
33 | listLine = F.readlines()
34 | for line in listLine:
35 | name,TN,TE,BN,BE = line.split(" ")
36 | if name==University:
37 | startE = eval(TE.split("TE")[-1])
38 | startN = eval(TN.split("TN")[-1])
39 | endE = eval(BE.split("BE")[-1])
40 | endN = eval(BN.split("BN")[-1])
41 |
42 | AllImage = cv2.imread(opts.MapDir)
43 | h, w, c = AllImage.shape
44 |
45 |
46 | def generateDictOfGalleryPosInfo():
47 | satellite_configDict = {}
48 | with open(os.path.join(opts.galleryPath, "PosInfo.txt"), "r") as F:
49 | context = F.readlines()
50 | for line in context:
51 | splitLineList = line.split(" ")
52 | satellite_configDict[splitLineList[0]] = [float(splitLineList[1].split("N")[-1]),
53 | float(splitLineList[2].split("E")[-1])]
54 | return satellite_configDict
55 |
56 |
57 | def sort_(list):
58 | list_ = [int(i.split(".JPG")[-2].split("DJI_")[-1]) for i in list]
59 | arg = np.argsort(list_)
60 | newlist = np.array(list)[arg]
61 | return newlist
62 |
63 |
64 | #######################################################################
65 | # sort the images and return topK index
66 | def getBestImage(qf, gf, gl):
67 | query = qf.view(-1, 1)
68 | # print(query.shape)
69 | score = torch.mm(gf, query)
70 | score = score.squeeze().cpu()
71 | score = score.numpy()
72 | # predict index
73 | index = np.argsort(score) # from small to large
74 | index = index[::-1]
75 | # index = index[0:2000]
76 | # good index
77 | # query_index = np.argwhere(gl == ql)
78 |
79 | # good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
80 | # junk_index = np.argwhere(gl == -1)
81 |
82 | # mask = np.in1d(index, junk_index, invert=True)
83 | # index = index[mask]
84 | return gl[index[0]]
85 |
86 |
87 | data_transforms = transforms.Compose([
88 | CenterCrop(),
89 | transforms.Resize((256, 256), interpolation=3),
90 | transforms.ToTensor(),
91 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
92 | ])
93 |
94 | query_paths = glob.glob(opts.imageDir + "/*.JPG")
95 | # sorted(query_paths,key=lambda x : int(x.split(".JPG")[-2].split("DJI_")[-1]))
96 | query_paths = sort_(query_paths)
97 |
98 |
99 | #####################################################################
100 | def extract_feature(img, model):
101 | count = 0
102 | n, c, h, w = img.size()
103 | count += n
104 | input_img = Variable(img.cuda())
105 | outputs, _ = model(input_img, None)
106 | ff = outputs
107 | # norm feature
108 | if len(ff.shape) == 3:
109 | # 1. To treat every part equally, I calculate the norm for every 2048-dim part feature.
110 | # 2. To keep the cosine score==1, sqrt(6) is added to norm the whole feature (2048*6).
111 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * np.sqrt(opts.block)
112 | ff = ff.div(fnorm.expand_as(ff))
113 | ff = ff.view(ff.size(0), -1)
114 | else:
115 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
116 | ff = ff.div(fnorm.expand_as(ff))
117 |
118 | features = ff.data
119 | return features
120 |
121 |
122 | def getPosInfo(imgPath):
123 | GPS_info = find_GPS_image(imgPath)
124 | x = list(GPS_info.values())
125 | gps_dict_formate = x[0]
126 | y = list(gps_dict_formate.values())
127 | height = eval(y[5])
128 | E = y[3]
129 | N = y[1]
130 | return [N, E]
131 |
132 |
133 | def imshowByIndex(index):
134 | galleryPath = os.path.join(opts.galleryPath, index + ".tif")
135 | image = cv2.imread(galleryPath)
136 | cv2.imshow("gallery", image)
137 |
138 |
139 | ######################################################################
140 | # load network
141 | config_path = 'opts.yaml'
142 | with open(config_path, 'r') as stream:
143 | config = yaml.load(stream)
144 | opts.stride = config['stride']
145 | opts.views = config['views']
146 | opts.transformer = config['transformer']
147 | opts.pool = config['pool']
148 | opts.views = config['views']
149 | opts.LPN = config['LPN']
150 | opts.block = config['block']
151 | opts.nclasses = config['nclasses']
152 | opts.droprate = config['droprate']
153 | opts.share = config['share']
154 | opts.checkpoint = "net_119.pth"
155 | torch.cuda.set_device("cuda:0")
156 | cudnn.benchmark = True
157 |
158 | model = load_network(opts)
159 | model = model.eval()
160 | model = model.cuda()
161 |
162 | ######################################################################
163 | result = scipy.io.loadmat(opts.satelliteMat)
164 | gallery_feature = torch.FloatTensor(result['features'])
165 | gallery_label = result['labels']
166 | gallery_feature = gallery_feature.cuda()
167 |
168 | satellitePosInfoDict = generateDictOfGalleryPosInfo() # 字典中的数组第一位为N 第二位为E
169 | firstPos = getPosInfo(query_paths[0])
170 | # gmap = gmplot.GoogleMapPlotter(firstPos[0], firstPos[1], 19)
171 |
172 | queryPosDict = {"N": [], "E": []}
173 | galleryPosDict = {"N": [], "E": []}
174 | for query in query_paths:
175 | queryPosInfo = getPosInfo(query)
176 | queryPosDict["N"].append(float(queryPosInfo[0]))
177 | queryPosDict["E"].append(float(queryPosInfo[1]))
178 | img = Image.open(query)
179 | input = data_transforms(img)
180 | input = torch.unsqueeze(input, 0)
181 | with torch.no_grad():
182 | feature = extract_feature(input, model)
183 | bestIndex = getBestImage(feature, gallery_feature, gallery_label)
184 | imshowByIndex(bestIndex)
185 | bestMatchedPosInfo = satellitePosInfoDict[bestIndex]
186 | galleryPosDict["N"].append(float(bestMatchedPosInfo[0]))
187 | galleryPosDict["E"].append(float(bestMatchedPosInfo[1]))
188 | print("query--N:{} E:{} gallery--N:{} E:{}".format(queryPosInfo[0], queryPosInfo[1], bestMatchedPosInfo[0],
189 | bestMatchedPosInfo[1]))
190 | # cv2.waitKey(0)
191 |
192 | result = {"query": queryPosDict, "gallery": galleryPosDict}
193 | scipy.io.savemat('global_{}.mat'.format(University), result)
194 |
195 | index = 1
196 | for N, E in zip(queryPosDict["N"], queryPosDict["E"]):
197 | X = int((E - startE) / (endE - startE) * w)
198 | Y = int((N - startN) / (endN - startN) * h)
199 | if index>=10:
200 | cv2.circle(AllImage, (X, Y), 50, color=(255, 0, 0), thickness=8)
201 | cv2.putText(AllImage, str(index), (X - 40, Y + 25), cv2.FONT_HERSHEY_COMPLEX, 2.2, color=(255, 0, 0),
202 | thickness=3)
203 | else:
204 | cv2.circle(AllImage, (X, Y), 50, color=(255, 0, 0), thickness=8)
205 | cv2.putText(AllImage, str(index), (X - 30, Y + 30), cv2.FONT_HERSHEY_COMPLEX, 3, color=(255, 0, 0),
206 | thickness=3)
207 | index += 1
208 |
209 | index = 1
210 | for N, E in zip(galleryPosDict["N"], galleryPosDict["E"]):
211 | X = int((E - startE) / (endE - startE) * w)
212 | Y = int((N - startN) / (endN - startN) * h)
213 | if index>=10:
214 | cv2.circle(AllImage, (X, Y), 50, color=(0, 0, 255), thickness=8)
215 | cv2.putText(AllImage, str(index), (X - 40, Y + 25), cv2.FONT_HERSHEY_COMPLEX, 2.2, color=(0, 0, 255),
216 | thickness=3)
217 | else:
218 | cv2.circle(AllImage, (X, Y), 50, color=(0, 0, 255), thickness=8)
219 | cv2.putText(AllImage, str(index), (X - 30, Y + 30), cv2.FONT_HERSHEY_COMPLEX, 3, color=(0, 0, 255),
220 | thickness=3)
221 | index += 1
222 |
223 | # AllImage = cv2.resize(AllImage,(0,0),fx=0.25,fy=0.25)
224 | cv2.imwrite("global_{}.tif".format(University), AllImage)
225 | # gmap.plot(queryPosDict["N"], queryPosDict["E"], color="red")
226 | # gmap.plot(galleryPosDict["N"], galleryPosDict["E"], color="blue")
227 | #
228 | # gmap.draw("user001_map.html")
229 |
--------------------------------------------------------------------------------
/tool/dataset_preprocess/1-generateSatelliteByUav.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import glob
4 | from collections import defaultdict
5 | from get_property import find_GPS_image
6 | from utils import get_fileNames
7 | import sys
8 | from tqdm import tqdm
9 | from multiprocessing import Pool
10 | '''
11 | 通过Uav图像找出satellite图像,用于训练部分
12 | '''
13 | root_dir = "/media/dmmm/4T-3/DataSets/DenseCV_Data/高度数据集/oridata/train/University_UAV_Images"
14 | tif_dir = "/media/dmmm/4T-3/DataSets/DenseCV_Data/高度数据集/oridata/train/old_tif"
15 | PlaceNameList = []
16 |
17 | for root, PlaceName, files in os.walk(root_dir):
18 | PlaceNameList = PlaceName
19 | break
20 |
21 | # PlaceNameList = os.listdir(root_dir)
22 | # y = 10x-500
23 | correspond_size = {'80':640,'90':768,'100':896}
24 |
25 | place_info_dict = defaultdict(list)
26 | with open(os.path.join(tif_dir,"PosInfo.txt"),"r") as F:
27 | context = F.readlines()
28 | for line in context:
29 | name = line.split(" ")[0]
30 | TN = float(line.split((" "))[1].split("TN")[-1])
31 | TE = float(line.split((" "))[2].split("TE")[-1])
32 | BN = float(line.split((" "))[3].split("BN")[-1])
33 | BE = float(line.split((" "))[4].split("BE")[-1])
34 | place_info_dict[name] = [TN,TE,BN,BE]
35 |
36 |
37 | def process(place):
38 | place_root = os.path.join(root_dir,place)
39 | cur_TN,cur_TE,cur_BN,cur_BE = place_info_dict[place]
40 | satellite_tif = os.path.join(tif_dir,place + ".tif")
41 |
42 | BigSatellite = cv2.imread(satellite_tif)
43 | h,w,c = BigSatellite.shape
44 | JPG_List = get_fileNames(place_root,endwith=".JPG")
45 | for JPG in tqdm(JPG_List):
46 | satelliteTif = JPG.replace(".JPG","_satellite_old.tif")
47 | GPS_info = find_GPS_image(JPG)
48 | gps_dict_formate = list(GPS_info.values())[0]
49 | y = list(gps_dict_formate.values())
50 | E,N = y[3],y[1]
51 | satellite_size = correspond_size[JPG.split("/")[-2]]
52 | centerX = (E-cur_TE)/(cur_BE-cur_TE)*w # 计算当前无人机位置对应大图中的位置
53 | centerY = (N-cur_TN)/(cur_BN-cur_TN)*h
54 |
55 | if centerYw-satellite_size/2 or centerY>h-satellite_size/2:
56 | raise ValueError("切取区域超出图像范围")
57 |
58 | TLX = int(centerX-satellite_size/2)
59 | TLY = int(centerY-satellite_size/2)
60 | BRX = int(centerX+satellite_size/2)
61 | BRY = int(centerY+satellite_size/2)
62 |
63 | cropImage = BigSatellite[TLY:BRY,TLX:BRX,:] # 切出指定位置的内容
64 |
65 | cv2.imwrite(satelliteTif,cropImage)
66 |
67 | p = Pool(10)
68 | # 遍历每个地方的文件加,首先获取到大地图,然后进行切割
69 | for res in p.imap(process, PlaceNameList):
70 | pass
71 |
72 | p.close()
73 |
--------------------------------------------------------------------------------
/tool/dataset_preprocess/2-generate_new_croped_resized_images_difheights.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | from utils import get_fileNames
3 | import os
4 | from get_property import find_GPS_image
5 | from tqdm import tqdm
6 | import numpy as np
7 | import functools
8 | import sys
9 | from multiprocessing import Pool
10 |
11 |
12 | UAV_target_size = (1440,1080)
13 | Satellite_target_size = (512,512)
14 |
15 | def sixNumber(str_number):
16 | str_number=str(str_number)
17 | while(len(str_number)<6):
18 | str_number='0'+str_number
19 | return str_number
20 |
21 | def compare_personal(x,y):
22 | x1 = int(os.path.split(x)[1].split(".JPG")[0])
23 | y1 = int(os.path.split(y)[1].split(".JPG")[0])
24 | return x1-y1
25 |
26 | # from center crop image
27 | def center_crop_and_resize(img,target_size=None):
28 | h,w,c = img.shape
29 | min_edge = min((h,w))
30 | if min_edge==h:
31 | edge_lenth = int((w-min_edge)/2)
32 | new_image = img[:,edge_lenth:w-edge_lenth,:]
33 | else:
34 | edge_lenth = int((h - min_edge) / 2)
35 | new_image = img[edge_lenth:h-edge_lenth, :, :]
36 | assert new_image.shape[0]==new_image.shape[1],"the shape is not correct"
37 | # LINEAR Interpolation
38 | if target_size:
39 | new_image = cv2.resize(new_image,target_size)
40 |
41 | return new_image
42 |
43 | def resize(img,target_size=None):
44 | # LINEAR Interpolation
45 | return cv2.resize(img,target_size)
46 |
47 |
48 | def getFileNameList(fullnamelist):
49 | list_return = []
50 | for i in fullnamelist:
51 | _,filename = os.path.split(i)
52 | list_return.append(filename)
53 | return list_return
54 |
55 | def process(info):
56 | index, [drone_80,drone_90,drone_100] = info
57 | satellite_80 = drone_80.replace(".JPG","_satellite.tif")
58 | satellite_90 = drone_90.replace(".JPG","_satellite.tif")
59 | satellite_100 = drone_100.replace(".JPG","_satellite.tif")
60 | if not (os.path.exists(satellite_80) and os.path.exists(satellite_90) and os.path.exists(satellite_100)):
61 | print("没有对应的satellite图像存在,请查看{}".format(satellite_80+" "+satellite_90+" "+satellite_100))
62 | sys.exit(0)
63 | # ----new added----
64 | satellite_80_old = drone_80.replace(".JPG","_satellite_old.tif")
65 | satellite_90_old = drone_90.replace(".JPG","_satellite_old.tif")
66 | satellite_100_old = drone_100.replace(".JPG","_satellite_old.tif")
67 | if not (os.path.exists(satellite_80_old) and os.path.exists(satellite_90_old) and os.path.exists(satellite_100_old)):
68 | print("没有对应的satellite图像存在,请查看{}".format(satellite_80_old+" "+satellite_90_old+" "+satellite_100_old))
69 | sys.exit(0)
70 |
71 | name = sixNumber(index)
72 | droneCurDir = os.path.join(dronePath, name)
73 | SatelliteCurDir = os.path.join(satellitePath, name)
74 | os.makedirs(droneCurDir,exist_ok=True)
75 | os.makedirs(SatelliteCurDir,exist_ok=True)
76 |
77 | # # load drone and satellite image
78 | # drone80_img = cv2.imread(drone_80)
79 | # drone90_img = cv2.imread(drone_90)
80 | # drone100_img = cv2.imread(drone_100)
81 | # satellite80_img = cv2.imread(satellite_80)
82 | # satellite90_img = cv2.imread(satellite_90)
83 | # satellite100_img = cv2.imread(satellite_100)
84 | # # ---new added---
85 | # satellite80_img_old = cv2.imread(satellite_80_old)
86 | # satellite90_img_old = cv2.imread(satellite_90_old)
87 | # satellite100_img_old = cv2.imread(satellite_100_old)
88 |
89 | # # process image including crop and resize
90 | # processed_drone80_img = resize(drone80_img,target_size=UAV_target_size)
91 | # processed_drone90_img = resize(drone90_img,target_size=UAV_target_size)
92 | # processed_drone100_img = resize(drone100_img,target_size=UAV_target_size)
93 | # processed_satellite80_img = resize(satellite80_img,target_size=Satellite_target_size)
94 | # processed_satellite90_img = resize(satellite90_img,target_size=Satellite_target_size)
95 | # processed_satellite100_img = resize(satellite100_img,target_size=Satellite_target_size)
96 | # # ---new added---
97 | # processed_satellite80_img_old = resize(satellite80_img_old,target_size=Satellite_target_size)
98 | # processed_satellite90_img_old = resize(satellite90_img_old,target_size=Satellite_target_size)
99 | # processed_satellite100_img_old = resize(satellite100_img_old,target_size=Satellite_target_size)
100 |
101 | # cv2.imwrite(os.path.join(droneCurDir, "H80.JPG"), processed_drone80_img)
102 | # cv2.imwrite(os.path.join(droneCurDir, "H90.JPG"), processed_drone90_img)
103 | # cv2.imwrite(os.path.join(droneCurDir, "H100.JPG"), processed_drone100_img)
104 | satelliteImgPath80 = os.path.join(SatelliteCurDir,"H80.tif")
105 | satelliteImgPath90 = os.path.join(SatelliteCurDir,"H90.tif")
106 | satelliteImgPath100 = os.path.join(SatelliteCurDir,"H100.tif")
107 | # cv2.imwrite(satelliteImgPath80, processed_satellite80_img)
108 | # cv2.imwrite(satelliteImgPath90, processed_satellite90_img)
109 | # cv2.imwrite(satelliteImgPath100, processed_satellite100_img)
110 | # # # ---new added---
111 | # satelliteImgPath80_old = os.path.join(SatelliteCurDir,"H80_old.tif")
112 | # satelliteImgPath90_old = os.path.join(SatelliteCurDir,"H90_old.tif")
113 | # satelliteImgPath100_old = os.path.join(SatelliteCurDir,"H100_old.tif")
114 | # cv2.imwrite(satelliteImgPath80_old, processed_satellite80_img_old)
115 | # cv2.imwrite(satelliteImgPath90_old, processed_satellite90_img_old)
116 | # cv2.imwrite(satelliteImgPath100_old, processed_satellite100_img_old)
117 |
118 | # write the GPS information
119 | GPS_info = find_GPS_image(drone_80)
120 | x = list(GPS_info.values())
121 | gps_dict_formate = x[0]
122 | y = list(gps_dict_formate.values())
123 | height = eval(y[5])
124 | information = "{} {}{} {}{} {}\n".format(satelliteImgPath80,y[2],y[3],y[0],y[1],height)
125 | # ---new added---
126 | # information_old = "{} {}{} {}{} {}\n".format(satelliteImgPath80_old,y[2],y[3],y[0],y[1],height)
127 | # return information,information_old
128 | return [information]
129 |
130 | heightList = ["80","90","100"]
131 | index = 2256 # 测试集需要根据训练集的总长设置index
132 | dir_path = "/media/dmmm/4T-3/DataSets/DenseCV_Data/高度数据集/"
133 | mode = "test"
134 | oriPath = os.path.join(dir_path,"oridata", mode, "University_UAV_Images")
135 | dirList = os.listdir(oriPath)
136 | root_dir = os.path.join(dir_path, "data_2021")
137 | mode_dir_path = os.path.join(root_dir, mode)
138 | os.makedirs(mode_dir_path,exist_ok=True)
139 |
140 | # GPS txt file
141 | txt_path = os.path.join(root_dir, "Dense_GPS_{}.txt".format(mode))
142 | file = open(txt_path, 'w')
143 | dronePath = os.path.join(mode_dir_path, "drone")
144 | os.makedirs(dronePath,exist_ok=True)
145 | satellitePath = os.path.join(mode_dir_path, "satellite")
146 | os.makedirs(satellitePath,exist_ok=True)
147 |
148 | p = Pool(8)
149 | for p_idx, place in enumerate(dirList):
150 | if not os.path.isdir(os.path.join(oriPath,place)):
151 | continue
152 |
153 | Drone_JPG_paths_80 = get_fileNames(os.path.join(oriPath, place, "80"),endwith=".JPG")
154 | Drone_JPG_paths_90 = get_fileNames(os.path.join(oriPath, place, "90"), endwith=".JPG")
155 | Drone_JPG_paths_100 = get_fileNames(os.path.join(oriPath, place, "100"), endwith=".JPG")
156 | Drone_JPG_paths_80 = sorted(Drone_JPG_paths_80,key=functools.cmp_to_key(compare_personal))
157 | Drone_JPG_paths_90 = sorted(Drone_JPG_paths_90,key=functools.cmp_to_key(compare_personal))
158 | Drone_JPG_paths_100 = sorted(Drone_JPG_paths_100,key=functools.cmp_to_key(compare_personal))
159 | f_80 = getFileNameList(Drone_JPG_paths_80)
160 | f_90 = getFileNameList(Drone_JPG_paths_90)
161 | f_100 = getFileNameList(Drone_JPG_paths_100)
162 | assert f_80==f_90 and f_90==f_100, "数据存在不对应请检查{}".format(place)
163 | print("Set index for the every iter")
164 | indexed_iters = []
165 | for ind, (drone_80,drone_90,drone_100) in tqdm(enumerate(zip(Drone_JPG_paths_80,Drone_JPG_paths_90,Drone_JPG_paths_100))):
166 | indexed_iters.append([index+ind, [drone_80,drone_90,drone_100]])
167 | index+=len(indexed_iters)
168 | for idx, res in enumerate(p.imap(process,indexed_iters)):
169 | if idx%50==0:
170 | print("-{}- {}/{} {}".format(p_idx, idx, len(indexed_iters), place))
171 | for info in res:
172 | file.write(info)
173 |
174 | file.close()
175 |
176 | p.close()
--------------------------------------------------------------------------------
/tool/dataset_preprocess/3-generate_format_testset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from tqdm import tqdm
4 | import glob
5 | from multiprocessing import Pool
6 |
7 |
8 | def processData(datalist,targetDir):
9 | os.makedirs(targetDir,exist_ok=True)
10 | for dir in tqdm(datalist):
11 | name = os.path.basename(dir)
12 | target_dir = os.path.join(targetDir,name)
13 | shutil.copytree(dir,target_dir)
14 |
15 |
16 | def main():
17 | rootDir = "/media/dmmm/4T-3/DataSets/DenseCV_Data/高度数据集/data_2021/"
18 | # # 测试集
19 | # testDir = os.path.join(rootDir,"test")
20 | # ClassForTestDrone = glob.glob(os.path.join(testDir, "drone/*"))
21 | # ClassForTestSatellite = glob.glob(os.path.join(testDir, "satellite/*"))
22 |
23 | # query_drone = os.path.join(testDir,"query_drone")
24 | # gallery_drone = os.path.join(testDir,"gallery_drone")
25 | # query_satellite = os.path.join(testDir,"query_satellite")
26 | # gallery_satellite = os.path.join(testDir,"gallery_satellite")
27 |
28 | # # 训练集
29 | # trainDir = os.path.join(rootDir, "train")
30 | # ClassForTrainDrone = glob.glob(os.path.join(trainDir, "drone/*"))
31 | # ClassForTrainSatellite = glob.glob(os.path.join(trainDir, "satellite/*"))
32 |
33 | # # process test data
34 | # processData(ClassForTestDrone,query_drone)
35 | # processData(ClassForTestSatellite,query_satellite)
36 | # # gallery需要把测试集和训练集合在一起
37 | # processData(ClassForTrainDrone+ClassForTestDrone,gallery_drone)
38 | # processData(ClassForTrainSatellite+ClassForTestSatellite,gallery_satellite)
39 |
40 | # 生成最终的Dense_GPS_ALL.txt
41 | mode_list = ["train","test"]
42 | total_lines = []
43 | for mode in mode_list:
44 | mode_txt = os.path.join(rootDir,"Dense_GPS_{}.txt".format(mode))
45 | with open(mode_txt,"r") as F:
46 | lines = F.readlines()
47 | total_lines.extend(lines)
48 | ALL_filename = os.path.join(rootDir,"Dense_GPS_ALL.txt")
49 | with open(ALL_filename,"w") as F:
50 | for info in total_lines:
51 | F.write(info)
52 | print("success write {}".format(ALL_filename))
53 |
54 |
55 | if __name__ == '__main__':
56 | main()
--------------------------------------------------------------------------------
/tool/dataset_preprocess/TEST-1-downloadALlAndReCut.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from tqdm import tqdm
4 | import os
5 | import shutil
6 |
7 | # TN30.325763481658377 TE120.3524959718277 BN30.318953289529816 BE120.36312306473656 jiliang
8 | # TN30.32631193830858 TE120.36678791233928 BN30.317870065161575 BE120.37745052386553 xianke
9 | # TN30.325763471673625 TE120.37341802729506 BN30.320441588110295 BE120.38193743012363
10 | mapPosInfodir = "/home/dmmm/PycharmProjects/DenseCV/demo/maps/pos.txt"
11 | # get_picture(startE, startN, endE, endN, 20, "./计量.tif", server="Google")
12 |
13 | University = "Jiliang"
14 |
15 | with open(mapPosInfodir,"r") as F:
16 | listLine = F.readlines()
17 | for line in listLine:
18 | name,TN,TE,BN,BE = line.split(" ")
19 | if name==University:
20 | startE = eval(TE.split("TE")[-1])
21 | startN = eval(TN.split("TN")[-1])
22 | endE = eval(BE.split("BE")[-1])
23 | endN = eval(BN.split("BN")[-1])
24 |
25 | # 图像的根文件夹
26 | dir_path = r"/media/dmmm/CE31-3598/DataSets/DenseCV_Data/satelliteHub({})".format(University)
27 | if os.path.exists(dir_path):
28 | shutil.rmtree(dir_path)
29 | os.mkdir(dir_path)
30 | # 经纬度信息存放文件夹
31 | infoFile = "/media/dmmm/4T-31/DataSets/DenseCV_Data/高度数据集/oridata/train/old_tif/PosInfo.txt".format(University)
32 | file = open(infoFile, "w")
33 |
34 | def sixNumber(str_number):
35 | str_number=str(str_number)
36 | while(len(str_number)<6):
37 | str_number='0'+str_number
38 | return str_number
39 |
40 | oriImage = cv2.imread("/home/dmmm/PycharmProjects/DenseCV/demo/maps/{}.tif".format(University))
41 | h,w,c = oriImage.shape
42 |
43 | cropedSizeList = [640,768,896]
44 | marginrate = 4
45 | index = 0
46 | for cropedSize in cropedSizeList:
47 | margin = cropedSize//marginrate #间距为图像尺寸的1/marginrate
48 | TopLeftEast = startE + (endE - startE)*cropedSize/2/w
49 | TopLeftNorth = startN + (endN - startN)*cropedSize/2/h
50 | BottomRightEast = endE - (endE - startE)*cropedSize/2/w
51 | BottomRightNorth = endN - (endN - startN)*cropedSize/2/h
52 |
53 | # YY = list(range(cropedSize//2,h-cropedSize//2+margin,margin))
54 | # XX = list(range(cropedSize//2,w-cropedSize//2+margin,margin))
55 | # Y_Size = len(YY)#y轴上总共记录图像数
56 | # X_Size = len(XX)#x轴上总共记录图像数
57 | X_Size = (w-cropedSize)//margin
58 | Y_Size = (h-cropedSize)//margin
59 | YY = np.linspace(cropedSize//2,h-cropedSize//2,Y_Size,dtype=np.int16)
60 | XX = np.linspace(cropedSize//2,w-cropedSize//2,X_Size,dtype=np.int16)
61 |
62 | pbar = tqdm(total=Y_Size*X_Size)
63 |
64 | PosInfoN = np.linspace(TopLeftNorth,BottomRightNorth,Y_Size)
65 | PosInfoE = np.linspace(TopLeftEast,BottomRightEast,X_Size)
66 | for n,y in zip(PosInfoN,YY):
67 | for e,x in zip(PosInfoE,XX):
68 | topLX = x - cropedSize//2
69 | topLY = y - cropedSize//2
70 | BottomRX = x + cropedSize//2
71 | BottomRY = y + cropedSize//2
72 | cropImage = oriImage[topLY:BottomRY,topLX:BottomRX,:]
73 |
74 | index_ = sixNumber(index)
75 |
76 | cv2.imwrite(os.path.join(dir_path,index_+".tif"),cropImage)
77 |
78 | file.write("{} N{:.10f} E{:.10f}\n".format(index_,n,e))
79 |
80 | index+=1
81 | pbar.update()
--------------------------------------------------------------------------------
/tool/dataset_preprocess/TEST-2-generateSatelliteHub.py:
--------------------------------------------------------------------------------
1 | # from google_interface import get_picture
2 | import time
3 | from google_interface_2 import get_picture
4 | import os
5 | from tqdm import tqdm
6 | import numpy as np
7 |
8 | '''
9 | 测试阶段用于产生卫星图像库
10 | '''
11 |
12 | def sixNumber(str_number):
13 | str_number=str(str_number)
14 | while(len(str_number)<6):
15 | str_number='0'+str_number
16 | return str_number
17 |
18 |
19 | if __name__ == '__main__':# 30.32331706,120.37025917
20 | startE = 120.3524959718277 #start左上角
21 | startN = 30.325763481658377
22 | endE = 120.36312306473656 #end右下角30.32128231,120.37425118,
23 | endN = 30.318953289529816
24 | start_time = time.time()
25 | margin = 0.0001
26 | #图像的根文件夹
27 | dir_path = r"/media/dmmm/CE31-3598/DataSets/DenseCV_Data/satelliteHub(现科dense)"
28 | os.mkdir(dir_path)
29 | #经纬度信息存放文件夹
30 | infoFile = "/media/dmmm/CE31-3598/DataSets/DenseCV_Data/satelliteHub(现科dense)/PosInfo.txt"
31 | file = open(infoFile,"w")
32 | #获取文件夹中所有的jpg图像的路径
33 | # jpg_paths = get_fileNames(dir_path,endwith=".JPG")
34 | index = 0
35 | NorthList = np.linspace(startN, endN, int((startN - endN) / (margin)))
36 | EastList = np.linspace(startE,endE,int((endE-startE)/(margin)))
37 | pbar = tqdm(total=len(NorthList)*len(EastList))
38 | for North in NorthList:
39 | for East in EastList:
40 | east_left = East - margin
41 | east_right = East + margin
42 | north_top = North + margin
43 | north_bottom = North - margin
44 |
45 | # 设置新的存放的路径
46 | index_ = sixNumber(index)
47 | satellite_tif_path = os.path.join(dir_path,"{}.tif".format(index_))
48 | get_picture(east_left, north_top, east_right, north_bottom, 19, satellite_tif_path, server="Google")
49 | file.write("{} N{:.10f} E{:.10f}\n".format(index_,North,East))
50 | index+=1
51 | pbar.update()
52 |
53 | file.close()
54 | pbar.close()
--------------------------------------------------------------------------------
/tool/dataset_preprocess/TEST-3-preprocess_difHeightTest.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import glob
4 | import tqdm
5 |
6 |
7 | def mkdirAndCopyfile(ori_path,to_path):
8 | if not os.path.exists(to_path):
9 | os.mkdir(to_path)
10 | basename,name = ori_path.split("/")[-2:]
11 | new_basename = os.path.join(to_path,basename)
12 | if os.path.exists(new_basename):
13 | shutil.rmtree(new_basename)
14 | os.mkdir(new_basename)
15 | shutil.copyfile(ori_path,os.path.join(new_basename,name))
16 |
17 |
18 |
19 | if __name__ == '__main__':
20 | root = "/home/dmmm/Dataset/DenseUAV/data_2022/test"
21 | root_80 = os.path.join(root,"queryDrone80")
22 | root_90 = os.path.join(root,"queryDrone90")
23 | root_100 = os.path.join(root,"queryDrone100")
24 | root_Drone_all = os.path.join(root,"query_drone")
25 | Drone_80_List = glob.glob(os.path.join(root_Drone_all,"*","H80.JPG"))
26 | Drone_90_List = glob.glob(os.path.join(root_Drone_all,"*","H90.JPG"))
27 | Drone_100_List = glob.glob(os.path.join(root_Drone_all,"*","H100.JPG"))
28 | tq = tqdm.tqdm(len(Drone_80_List))
29 | for H80,H90,H100 in zip(Drone_80_List,Drone_90_List,Drone_100_List):
30 | mkdirAndCopyfile(H80,root_80)
31 | mkdirAndCopyfile(H90,root_90)
32 | mkdirAndCopyfile(H100,root_100)
33 | tq.update()
34 |
35 |
36 |
--------------------------------------------------------------------------------
/tool/dataset_preprocess/TEST-Regenerate_dense_testset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import tqdm
3 | import os
4 | import shutil
5 | import glob
6 | import cv2
7 | from multiprocessing import Pool,Manager
8 | import copy
9 | from utils import Distance
10 |
11 | # 对测试卫星图像进行密集切分,old版本,生成一个GPS_ALL.txt文件
12 |
13 | def sixNumber(str_number):
14 | str_number=str(str_number)
15 | while(len(str_number)<6):
16 | str_number='0'+str_number
17 | return str_number
18 |
19 |
20 | type = "new"
21 |
22 | root_dir = "/media/dmmm/4T-3/DataSets/DenseCV_Data/高度数据集/oridata/test/{}_tif".format(type)
23 |
24 | source_loc_info = os.path.join(root_dir, "PosInfo.txt")
25 |
26 | output_dir = "/home/dmmm/Dataset/DenseUAV/data_2022/test/hard_gallery_satellite_ss_interval10m"
27 | os.makedirs(output_dir,exist_ok=True)
28 |
29 | with open(source_loc_info, "r") as F:
30 | context = F.readlines()
31 |
32 | correspond_size = [640,768,896]
33 | correspond_size = correspond_size[1:2]
34 | gap=77
35 | output_size = [256,256]
36 | total_info = []
37 |
38 |
39 |
40 | # p = Pool(10)
41 |
42 | for line in context:
43 | info = line.strip().split(" ")
44 | university_name, TN, TE, BN, BE = info
45 | TE = eval(TE.split("TE")[-1])
46 | TN = eval(TN.split("TN")[-1])
47 | BE = eval(BE.split("BE")[-1])
48 | BN = eval(BN.split("BN")[-1])
49 | print("acutal_height:{}--width:{}".format(Distance(TN,TE,BN,TE),Distance(TN,TE,TN,BE)))
50 |
51 | source_map_tif_path = os.path.join(root_dir,"{}.tif".format(university_name))
52 | source_map_image = cv2.imread(source_map_tif_path)
53 |
54 | manager = Manager()
55 | shared_dict = manager.dict()
56 | shared_dict['total_info'] = []
57 |
58 | h,w,c = source_map_image.shape
59 | print("image_height:{}--width:{}".format(h,w))
60 | x,y = np.meshgrid(list(range(0,w-896,gap)),list(range(0,h-896,gap)))
61 | x,y = x.reshape(-1),y.reshape(-1)
62 | inds = np.array(list(range(0,len(x),1)))
63 |
64 | def process(infos):
65 | ind = infos[0]
66 | position = infos[1:]
67 |
68 | info_list = []
69 | for size in correspond_size:
70 | East = TE + (position[0]+size/2)/w*(BE-TE)
71 | North = TN - (position[1]+size/2)/h*(TN-BN)
72 | image = source_map_image[position[1]:position[1]+size, position[0]:position[0]+size, :]
73 | image = cv2.resize(image, output_size)
74 | filepath = os.path.join(output_dir, "{}_{}_{}_{}.jpg".format(university_name,sixNumber(ind),size,type))
75 | cv2.imwrite(filepath, image)
76 | pos_info = [filepath, East, North, size]
77 | info_list.append(pos_info)
78 |
79 | return info_list
80 |
81 | p = Pool(20)
82 | for ind,res in enumerate(p.imap(process,zip(inds,x,y))):
83 | if ind % 100 == 0:
84 | print("{}/{} process the image!!".format(ind,len(inds)))
85 | total_info.extend(res)
86 | p.close()
87 |
88 | info_path = "/home/dmmm/Dataset/DenseUAV/data_2022/test/{}_info.txt".format(type)
89 |
90 | F = open(info_path, "w")
91 | for info in total_info:
92 | F.write("{} {} {} {}\n".format(*info))
93 |
94 | F.close()
95 |
96 | # cat file1.txt file2.txt > merged.txt 合并文件
--------------------------------------------------------------------------------
/tool/dataset_preprocess/get_property.py:
--------------------------------------------------------------------------------
1 | import exifread
2 | import re
3 | import json
4 | import requests
5 |
6 | def latitude_and_longitude_convert_to_decimal_system(*arg):
7 | """
8 | 经纬度转为小数, param arg:
9 | :return: 十进制小数
10 | """
11 | return float(arg[0]) + ((float(arg[1]) + (float(arg[2].split('/')[0]) / float(arg[2].split('/')[-1]) / 60)) / 60)
12 |
13 | def find_GPS_image(pic_path):
14 | GPS = {}
15 | date = ''
16 | with open(pic_path, 'rb') as f:
17 | tags = exifread.process_file(f)
18 | for tag, value in tags.items():
19 | if re.match('GPS GPSLatitudeRef', tag):
20 | GPS['GPSLatitudeRef'] = str(value)
21 | elif re.match('GPS GPSLongitudeRef', tag):
22 | GPS['GPSLongitudeRef'] = str(value)
23 | elif re.match('GPS GPSAltitudeRef', tag):
24 | GPS['GPSAltitudeRef'] = str(value)
25 | elif re.match('GPS GPSLatitude', tag):
26 | try:
27 | match_result = re.match('\[(\w*),(\w*),(\w.*)/(\w.*)\]', str(value)).groups()
28 | GPS['GPSLatitude'] = int(match_result[0]), int(match_result[1]), int(match_result[2])
29 | except:
30 | deg, min, sec = [x.replace(' ', '') for x in str(value)[1:-1].split(',')]
31 | GPS['GPSLatitude'] = latitude_and_longitude_convert_to_decimal_system(deg, min, sec)
32 | elif re.match('GPS GPSLongitude', tag):
33 | try:
34 | match_result = re.match('\[(\w*),(\w*),(\w.*)/(\w.*)\]', str(value)).groups()
35 | GPS['GPSLongitude'] = int(match_result[0]), int(match_result[1]), int(match_result[2])
36 | except:
37 | deg, min, sec = [x.replace(' ', '') for x in str(value)[1:-1].split(',')]
38 | GPS['GPSLongitude'] = latitude_and_longitude_convert_to_decimal_system(deg, min, sec)
39 | elif re.match('GPS GPSAltitude', tag):
40 | GPS['GPSAltitude'] = str(value)
41 | elif re.match('.*Date.*', tag):
42 | date = str(value)
43 | return {'GPS_information': GPS, 'date_information': date}
44 |
45 | def find_address_from_GPS(GPS):
46 | """
47 | 使用Geocoding API把经纬度坐标转换为结构化地址。
48 | :param GPS:
49 | :return:
50 | """
51 | secret_key = 'zbLsuDDL4CS2U0M4KezOZZbGUY9iWtVf'
52 | if not GPS['GPS_information']:
53 | return '该照片无GPS信息'
54 | lat, lng = GPS['GPS_information']['GPSLatitude'], GPS['GPS_information']['GPSLongitude']
55 | baidu_map_api = "http://api.map.baidu.com/geocoder/v2/?ak={0}&callback=renderReverse&location={1},{2}s&output=json&pois=0".format(
56 | secret_key, lat, lng)
57 | response = requests.get(baidu_map_api)
58 | content = response.text.replace("renderReverse&&renderReverse(", "")[:-1]
59 | baidu_map_address = json.loads(content)
60 | formatted_address = baidu_map_address["result"]["formatted_address"]
61 | province = baidu_map_address["result"]["addressComponent"]["province"]
62 | city = baidu_map_address["result"]["addressComponent"]["city"]
63 | district = baidu_map_address["result"]["addressComponent"]["district"]
64 | return formatted_address,province,city,district
65 |
66 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/tool/dataset_preprocess/split_dataset_long_middle_short.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import glob
4 | base_dir = "/home/dmmm/University-Release"
5 |
6 | class MakeDataset():
7 | def __init__(self,name):
8 | self.name = name
9 | self.dir = os.path.join(base_dir, name)
10 | self.ori_dir = os.path.join(base_dir,"test")
11 | self.target_path = os.path.join(self.dir,"query_drone")
12 | # 移动query_drone的图片
13 | self.copy_pictures()
14 | # 移动其他需要的文件
15 | self.copy_other()
16 |
17 |
18 | def copy_pictures(self):
19 | class_list = os.listdir(os.path.join(self.ori_dir,"query_drone"))
20 | for i in class_list:
21 | #新建分类对应的文件夹
22 | self.mkdir(os.path.join(self.target_path,i))
23 |
24 | #取出每个Long Middle Short对应的图片
25 | path = os.path.join(self.ori_dir,"query_drone",i)
26 | tar_path = os.path.join(self.target_path,i)
27 | img_list = os.listdir(path)
28 | img_list.sort()
29 | long_num = len(img_list)//3
30 | middle_num = len(img_list)//3*2
31 | short_num = len(img_list)
32 | if self.name =="Long":
33 | list = img_list[:long_num]
34 | elif self.name == "Middle":
35 | list = img_list[long_num:middle_num]
36 | elif self.name == "Short":
37 | list = img_list[middle_num:short_num]
38 | else:
39 | raise ValueError("输入的name参数有误,必须为Long、Middle或Short")
40 |
41 | #复制图片到指定路径
42 | for j in list:
43 | path_j = os.path.join(path,j)
44 | path_t = os.path.join(tar_path,j)
45 | shutil.copyfile(path_j,path_t)
46 |
47 |
48 |
49 |
50 | def mkdir(self,path):
51 | if not os.path.exists(path):
52 | os.makedirs(path)
53 |
54 | def copy_other(self):
55 | filename_list = ["gallery_drone","gallery_satellite","query_satellite"]
56 | for i in filename_list:
57 | source_path = os.path.join(self.ori_dir,i)
58 | target_path = os.path.join(self.dir,i)
59 | shutil.copytree(source_path, target_path)
60 |
61 |
62 |
63 |
64 | if __name__ == '__main__':
65 | MakeDataset("Long")
66 | MakeDataset("Middle")
67 | MakeDataset("Short")
--------------------------------------------------------------------------------
/tool/dataset_preprocess/utils.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import math
4 |
5 | def get_fileNames(rootdir,endwith=".JPG"):
6 | fs = []
7 | for root, dirs, files in os.walk(rootdir,topdown = True):
8 | for name in files:
9 | _, ending = os.path.splitext(name)
10 | if ending == endwith:
11 | fs.append(os.path.join(root,name))
12 | return fs
13 |
14 |
15 | def Distance(lata, loga, latb, logb):
16 | # EARTH_RADIUS = 6371.0
17 | EARTH_RADIUS = 6378.137
18 | PI = math.pi
19 | # // 转弧度
20 | lat_a = lata * PI / 180
21 | lat_b = latb * PI / 180
22 | a = lat_a - lat_b
23 | b = loga * PI / 180 - logb * PI / 180
24 | dis = 2 * math.asin(math.sqrt(math.pow(math.sin(a / 2), 2) + math.cos(lat_a)
25 | * math.cos(lat_b) * math.pow(math.sin(b / 2), 2)))
26 |
27 | distance = EARTH_RADIUS * dis * 1000
28 | return distance
--------------------------------------------------------------------------------
/tool/dataset_preprocess/validation_testset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import tqdm
3 | import os
4 | import shutil
5 | import glob
6 | import cv2
7 | from multiprocessing import Pool
8 | import copy
9 |
10 | type = "new"
11 |
12 | info_file = "/home/dmmm/Dataset/DenseUAV/data_2022/test/{}_info.txt".format(type)
13 |
14 | root_dir = "/home/dmmm/Dataset/DenseUAV/data_2022/test/{}_tif".format(type)
15 |
16 | source_loc_info = os.path.join(root_dir, "PosInfo.txt")
17 |
18 | with open(source_loc_info, "r") as F:
19 | context_map = F.readlines()
20 |
21 | with open(info_file,"r") as F:
22 | context = F.readlines()
23 |
24 | line = context[100]
25 | infos = line.strip().split(" ")
26 | filename = infos[0]
27 | query_image = cv2.imread(filename)
28 |
29 | E = float(infos[1])
30 | N = float(infos[2])
31 |
32 | for line in context_map:
33 | info = line.strip().split(" ")
34 | university_name, TN, TE, BN, BE = info
35 | TE = eval(TE.split("TE")[-1])
36 | TN = eval(TN.split("TN")[-1])
37 | BE = eval(BE.split("BE")[-1])
38 | BN = eval(BN.split("BN")[-1])
39 |
40 | source_map_tif_path = os.path.join(root_dir,"{}.tif".format(university_name))
41 | source_map_image = cv2.imread(source_map_tif_path)
42 | h,w = source_map_image.shape[:2]
43 |
44 | if TE<=E<=BE and BN<=N<=TN:
45 | x = (E-TE)/(BE-TE)*w
46 | y = (N-TN)/(BN-TN)*h
47 | break
48 |
49 | cv2.circle(source_map_image,(int(x),int(y)),40,(255,0,0),20)
50 | cv2.imwrite("map.jpg",source_map_image)
51 | cv2.imwrite("query.jpg",query_image)
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
--------------------------------------------------------------------------------
/tool/get_inference_time.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import sys
3 | sys.path.append("../")
4 | import yaml
5 | import argparse
6 | import torch
7 | from tool.utils import load_network, calc_flops_params
8 | import time
9 |
10 |
11 | parser = argparse.ArgumentParser(description='Training')
12 | parser.add_argument('--name', default='resnet',
13 | type=str, help='save model path')
14 | parser.add_argument('--checkpoint', default='../net_119.pth',
15 | type=str, help='save model path')
16 | parser.add_argument('--test_h', default=224, type=int, help='height')
17 | parser.add_argument('--test_w', default=224, type=int, help='width')
18 | parser.add_argument('--calc_nums', default=2000, type=int, help='width')
19 | opt = parser.parse_args()
20 |
21 | config_path = '../opts.yaml'
22 | with open(config_path, 'r') as stream:
23 | config = yaml.load(stream)
24 | for cfg, value in config.items():
25 | setattr(opt, cfg, value)
26 |
27 | model = load_network(opt).cuda()
28 | model = model.eval()
29 |
30 | # thop计算MACs
31 | macs, params = calc_flops_params(
32 | model, (1, 3, opt.test_h, opt.test_w), (1, 3, opt.test_h, opt.test_w))
33 | input_size_drone = (1, 3, opt.test_h, opt.test_w)
34 | input_size_satellite = (1, 3, opt.test_h, opt.test_w)
35 |
36 | inputs_drone = torch.randn(input_size_drone).cuda()
37 | inputs_satellite = torch.randn(input_size_satellite).cuda()
38 |
39 | # 预热
40 | for _ in range(10):
41 | model(inputs_drone,inputs_satellite)
42 |
43 | since = time.time()
44 | for _ in range(opt.calc_nums):
45 | model(inputs_drone,inputs_satellite)
46 |
47 |
48 | print("inference_time = {}s".format(time.time()-since))
--------------------------------------------------------------------------------
/tool/get_model_flops_params.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import sys
3 | sys.path.append("../../")
4 | import yaml
5 | import argparse
6 |
7 | from tool.utils import load_network, calc_flops_params
8 |
9 | parser = argparse.ArgumentParser(description='Training')
10 | parser.add_argument('--name', default='resnet',
11 | type=str, help='save model path')
12 | parser.add_argument('--checkpoint', default='net_119.pth',
13 | type=str, help='save model path')
14 | parser.add_argument('--test_h', default=224, type=int, help='height')
15 | parser.add_argument('--test_w', default=224, type=int, help='width')
16 | opt = parser.parse_args()
17 |
18 | config_path = 'opts.yaml'
19 | with open(config_path, 'r') as stream:
20 | config = yaml.load(stream)
21 | for cfg, value in config.items():
22 | setattr(opt, cfg, value)
23 |
24 | model = load_network(opt).cuda()
25 | model = model.eval()
26 |
27 | # thop计算MACs
28 | macs, params = calc_flops_params(
29 | model, (1, 3, opt.test_h, opt.test_w), (1, 3, opt.test_h, opt.test_w))
30 | print("model MACs={}, Params={}".format(macs, params))
31 |
--------------------------------------------------------------------------------
/tool/get_property.py:
--------------------------------------------------------------------------------
1 | import exifread
2 | import re
3 | import json
4 | import requests
5 |
6 | def latitude_and_longitude_convert_to_decimal_system(*arg):
7 | """
8 | 经纬度转为小数, param arg:
9 | :return: 十进制小数
10 | """
11 | return float(arg[0]) + ((float(arg[1]) + (float(arg[2].split('/')[0]) / float(arg[2].split('/')[-1]) / 60)) / 60)
12 |
13 | def find_GPS_image(pic_path):
14 | GPS = {}
15 | date = ''
16 | with open(pic_path, 'rb') as f:
17 | tags = exifread.process_file(f)
18 | for tag, value in tags.items():
19 | if re.match('GPS GPSLatitudeRef', tag):
20 | GPS['GPSLatitudeRef'] = str(value)
21 | elif re.match('GPS GPSLongitudeRef', tag):
22 | GPS['GPSLongitudeRef'] = str(value)
23 | elif re.match('GPS GPSAltitudeRef', tag):
24 | GPS['GPSAltitudeRef'] = str(value)
25 | elif re.match('GPS GPSLatitude', tag):
26 | try:
27 | match_result = re.match('\[(\w*),(\w*),(\w.*)/(\w.*)\]', str(value)).groups()
28 | GPS['GPSLatitude'] = int(match_result[0]), int(match_result[1]), int(match_result[2])
29 | except:
30 | deg, min, sec = [x.replace(' ', '') for x in str(value)[1:-1].split(',')]
31 | GPS['GPSLatitude'] = latitude_and_longitude_convert_to_decimal_system(deg, min, sec)
32 | elif re.match('GPS GPSLongitude', tag):
33 | try:
34 | match_result = re.match('\[(\w*),(\w*),(\w.*)/(\w.*)\]', str(value)).groups()
35 | GPS['GPSLongitude'] = int(match_result[0]), int(match_result[1]), int(match_result[2])
36 | except:
37 | deg, min, sec = [x.replace(' ', '') for x in str(value)[1:-1].split(',')]
38 | GPS['GPSLongitude'] = latitude_and_longitude_convert_to_decimal_system(deg, min, sec)
39 | elif re.match('GPS GPSAltitude', tag):
40 | GPS['GPSAltitude'] = str(value)
41 | elif re.match('.*Date.*', tag):
42 | date = str(value)
43 | return {'GPS_information': GPS, 'date_information': date}
44 |
45 | def find_address_from_GPS(GPS):
46 | """
47 | 使用Geocoding API把经纬度坐标转换为结构化地址。
48 | :param GPS:
49 | :return:
50 | """
51 | secret_key = 'zbLsuDDL4CS2U0M4KezOZZbGUY9iWtVf'
52 | if not GPS['GPS_information']:
53 | return '该照片无GPS信息'
54 | lat, lng = GPS['GPS_information']['GPSLatitude'], GPS['GPS_information']['GPSLongitude']
55 | baidu_map_api = "http://api.map.baidu.com/geocoder/v2/?ak={0}&callback=renderReverse&location={1},{2}s&output=json&pois=0".format(
56 | secret_key, lat, lng)
57 | response = requests.get(baidu_map_api)
58 | content = response.text.replace("renderReverse&&renderReverse(", "")[:-1]
59 | baidu_map_address = json.loads(content)
60 | formatted_address = baidu_map_address["result"]["formatted_address"]
61 | province = baidu_map_address["result"]["addressComponent"]["province"]
62 | city = baidu_map_address["result"]["addressComponent"]["city"]
63 | district = baidu_map_address["result"]["addressComponent"]["district"]
64 | return formatted_address,province,city,district
65 |
66 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/tool/mount_dist.sh:
--------------------------------------------------------------------------------
1 | sudo mkdir /media/dmmm/4T-3
2 | sudo mount /dev/sda3 /media/dmmm/4T-3
--------------------------------------------------------------------------------
/tool/transforms/rotatetranformtest.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from math import cos,sin,pi
4 | from PIL import Image
5 | import os
6 |
7 | #生成随机rotateandcrop的结果并保存到rotateShow文件夹下
8 |
9 |
10 | class RotateAndCrop(object):
11 | def __init__(self):
12 | pass
13 |
14 | def __call__(self, img):
15 | img_=np.array(img).copy()
16 |
17 | def getPosByAngle(img, angle):
18 | h, w, c = img.shape
19 | x_center = y_center = h // 2
20 | r = h // 2
21 | angle_lt = angle - 45
22 | angle_rt = angle + 45
23 | angle_lb = angle - 135
24 | angle_rb = angle + 135
25 | angleList = [angle_lt, angle_rt, angle_lb, angle_rb]
26 | pointsList = []
27 | for angle in angleList:
28 | x1 = x_center + r * cos(angle * pi / 180)
29 | y1 = y_center + r * sin(angle * pi / 180)
30 | pointsList.append([x1, y1])
31 | pointsOri = np.float32(pointsList)
32 | pointsListAfter = np.float32([[0, 0], [512, 0], [0, 512], [512, 512]])
33 | M = cv2.getPerspectiveTransform(pointsOri, pointsListAfter)
34 | res = cv2.warpPerspective(img, M, (512, 512))
35 | return res
36 |
37 | if not os.path.exists("rotateShow"):
38 | os.mkdir("rotateShow")
39 | img.save("./rotateShow/ori.png")
40 | for i in range(10):
41 | angle = int(np.random.random() * 360)
42 | new_image = getPosByAngle(img_,angle)
43 | image = Image.fromarray(new_image.astype('uint8')).convert('RGB')
44 | image.save("./rotateShow/{}.png".format(i))
45 |
46 | if __name__ == '__main__':
47 | img = Image.open("/media/dmmm/CE31-3598/DataSets/DenseCV_Data/实际测试图像(计量)/part2/DJI_0317.JPG")
48 | rotate = RotateAndCrop()
49 | rotate(img)
50 |
51 |
52 |
53 |
--------------------------------------------------------------------------------
/tool/transforms/transform_visual.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("../../")
3 | from datasets.queryDataset import RotateAndCrop, RandomErasing
4 | from torchvision import transforms
5 | from PIL import Image
6 | import numpy as np
7 | import cv2
8 | import argparse
9 | import os
10 |
11 |
12 | def get_parse():
13 | parser = argparse.ArgumentParser(description='Transfrom Visualization')
14 | parser.add_argument(
15 | '--image_path', default='/home/dmmm/VscodeProject/demo_DenseUAV/visualization/rotateShow/ori.png', type=str, help='')
16 | parser.add_argument(
17 | '--target_dir', default='/home/dmmm/VscodeProject/demo_DenseUAV/visualization/ColorJitter', type=str, help='')
18 | parser.add_argument(
19 | '--num_aug', default=10, type=int, help='')
20 |
21 | opt = parser.parse_args()
22 | return opt
23 |
24 |
25 | if __name__ == '__main__':
26 | opt = get_parse()
27 | image = Image.open(opt.image_path)
28 |
29 | re = RandomErasing(probability=1.0)
30 | ra = transforms.RandomAffine(180)
31 | rac = RotateAndCrop(rate=1.0)
32 | cj = transforms.ColorJitter(brightness=0.5, contrast=0.1, saturation=0.1, hue=0)
33 |
34 | os.makedirs(opt.target_dir, exist_ok=True)
35 |
36 | for ind in range(opt.num_aug):
37 | image_ = cj(image)
38 | image_ = np.array(image_)
39 | image_ = image_[:, :, [2, 1, 0]]
40 | h, w = image_.shape[:2]
41 | # image_ = cv2.circle(
42 | # image_.copy(), (int(w/2), int(h/2)), 3, (0, 0, 255), 2)
43 | cv2.imwrite(os.path.join(opt.target_dir, "{}.jpg".format(ind)), image_)
44 |
--------------------------------------------------------------------------------
/tool/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import yaml
4 | import random
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import cv2
8 | from shutil import copyfile, copytree, rmtree
9 | import logging
10 | from models.taskflow import make_model
11 | from thop import profile, clever_format
12 | import math
13 |
14 |
15 | def get_logger(filename, verbosity=1, name=None):
16 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
17 | formatter = logging.Formatter(
18 | "[%(asctime)s][%(levelname)s] %(message)s"
19 | )
20 | logger = logging.getLogger(name)
21 | logger.setLevel(level_dict[verbosity])
22 |
23 | fh = logging.FileHandler(filename, "w")
24 | fh.setFormatter(formatter)
25 | logger.addHandler(fh)
26 |
27 | sh = logging.StreamHandler()
28 | sh.setFormatter(formatter)
29 | logger.addHandler(sh)
30 |
31 | return logger
32 |
33 |
34 | def copy_file_or_tree(path, target_dir):
35 | target_path = os.path.join(target_dir, path)
36 | if os.path.isdir(path):
37 | if os.path.exists(target_path):
38 | rmtree(target_path)
39 | copytree(path, target_path)
40 | elif os.path.isfile(path):
41 | copyfile(path, target_path)
42 |
43 |
44 | def copyfiles2checkpoints(opt):
45 | dir_name = os.path.join('checkpoints', opt.name)
46 | if not os.path.isdir(dir_name):
47 | os.mkdir(dir_name)
48 | # record every run
49 | copy_file_or_tree('train.py', dir_name)
50 | copy_file_or_tree('test.py', dir_name)
51 | copy_file_or_tree('evaluate_gpu.py', dir_name)
52 | copy_file_or_tree('evaluateDistance.py', dir_name)
53 | copy_file_or_tree('datasets', dir_name)
54 | copy_file_or_tree('losses', dir_name)
55 | copy_file_or_tree('models', dir_name)
56 | copy_file_or_tree('optimizers', dir_name)
57 | copy_file_or_tree('tool', dir_name)
58 | copy_file_or_tree('train_test_local.sh', dir_name)
59 |
60 | # save opts
61 | with open('%s/opts.yaml' % dir_name, 'w') as fp:
62 | yaml.dump(vars(opt), fp, default_flow_style=False)
63 |
64 |
65 | def make_weights_for_balanced_classes(images, nclasses):
66 | count = [0] * nclasses
67 | for item in images:
68 | count[item[1]] += 1 # count the image number in every class
69 | weight_per_class = [0.] * nclasses
70 | N = float(sum(count))
71 | for i in range(nclasses):
72 | weight_per_class[i] = N/float(count[i])
73 | weight = [0] * len(images)
74 | for idx, val in enumerate(images):
75 | weight[idx] = weight_per_class[val[1]]
76 | return weight
77 |
78 | # Get model list for resume
79 |
80 |
81 | def get_model_list(dirname, key):
82 | if os.path.exists(dirname) is False:
83 | print('no dir: %s' % dirname)
84 | return None
85 | gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if
86 | os.path.isfile(os.path.join(dirname, f)) and key in f and ".pth" in f]
87 | if gen_models is None:
88 | return None
89 | gen_models.sort()
90 | last_model_name = gen_models[-1]
91 | return last_model_name
92 |
93 | ######################################################################
94 | # Save model
95 | # ---------------------------
96 |
97 |
98 | def save_network(network, dirname, epoch_label):
99 | if not os.path.isdir('./checkpoints/'+dirname):
100 | os.mkdir('./checkpoints/'+dirname)
101 | if isinstance(epoch_label, int):
102 | save_filename = 'net_%03d.pth' % epoch_label
103 | else:
104 | save_filename = 'net_%s.pth' % epoch_label
105 | save_path = os.path.join('./checkpoints', dirname, save_filename)
106 | torch.save(network.cpu().state_dict(), save_path)
107 | if torch.cuda.is_available:
108 | network.cuda()
109 |
110 |
111 | class UnNormalize(object):
112 | def __init__(self, mean, std):
113 | self.mean = mean
114 | self.std = std
115 |
116 | def __call__(self, tensor):
117 | """
118 | Args:
119 | :param tensor: tensor image of size (B,C,H,W) to be un-normalized
120 | :return: UnNormalized image
121 | """
122 | for t, m, s in zip(tensor, self.mean, self.std):
123 | t.mul_(s).add_(m)
124 | return tensor
125 |
126 |
127 | def check_box(images, boxes):
128 | # Unorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
129 | # images = Unorm(images)*255
130 | images = images.permute(0, 2, 3, 1).cpu().detach().numpy()
131 | boxes = (boxes.cpu().detach().numpy()/16*255).astype(np.int)
132 | for img, box in zip(images, boxes):
133 | fig = plt.figure()
134 | ax = fig.add_subplot(111)
135 | plt.imshow(img)
136 | rect = plt.Rectangle(box[0:2], box[2]-box[0], box[3]-box[1])
137 | ax.add_patch(rect)
138 | plt.show()
139 |
140 |
141 | ######################################################################
142 | # Load model for resume
143 | # ---------------------------
144 | def load_network(opt):
145 | save_filename = opt.checkpoint
146 | model = make_model(opt)
147 | # print('Load the model from %s' % save_filename)
148 | network = model
149 | network.load_state_dict(torch.load(save_filename))
150 | return network
151 |
152 |
153 | def toogle_grad(model, requires_grad):
154 | for p in model.parameters():
155 | p.requires_grad_(requires_grad)
156 |
157 |
158 | def update_average(model_tgt, model_src, beta):
159 | toogle_grad(model_src, False)
160 | toogle_grad(model_tgt, False)
161 |
162 | param_dict_src = dict(model_src.named_parameters())
163 |
164 | for p_name, p_tgt in model_tgt.named_parameters():
165 | p_src = param_dict_src[p_name]
166 | assert(p_src is not p_tgt)
167 | p_tgt.copy_(beta*p_tgt + (1. - beta)*p_src)
168 |
169 | toogle_grad(model_src, True)
170 |
171 |
172 | def get_preds(outputs, outputs2):
173 | if isinstance(outputs, list):
174 | preds = []
175 | preds2 = []
176 | for out, out2 in zip(outputs, outputs2):
177 | preds.append(torch.max(out.data, 1)[1])
178 | preds2.append(torch.max(out2.data, 1)[1])
179 | else:
180 | _, preds = torch.max(outputs.data, 1)
181 | _, preds2 = torch.max(outputs2.data, 1)
182 | return preds, preds2
183 |
184 |
185 | def calc_flops_params(model,
186 | input_size_drone,
187 | input_size_satellite,
188 | ):
189 | inputs_drone = torch.randn(input_size_drone).cuda()
190 | inputs_satellite = torch.randn(input_size_satellite).cuda()
191 | total_ops, total_params = profile(
192 | model, (inputs_drone, inputs_satellite,), verbose=False)
193 | macs, params = clever_format([total_ops, total_params], "%.3f")
194 | return macs, params
195 |
196 |
197 | def set_seed(seed):
198 | torch.manual_seed(seed)
199 | torch.cuda.manual_seed_all(seed)
200 | np.random.seed(seed)
201 | torch.backends.cudnn.deterministic = True
202 | torch.backends.cudnn.benchmark = False
203 | torch.backends.cudnn.enabled = True
204 | random.seed(seed)
205 |
206 |
207 | def Distance(lata, loga, latb, logb):
208 | # EARTH_RADIUS = 6371.0
209 | EARTH_RADIUS = 6378.137
210 | PI = math.pi
211 | # // 转弧度
212 | lat_a = lata * PI / 180
213 | lat_b = latb * PI / 180
214 | a = lat_a - lat_b
215 | b = loga * PI / 180 - logb * PI / 180
216 | dis = 2 * math.asin(math.sqrt(math.pow(math.sin(a / 2), 2) + math.cos(lat_a)
217 | * math.cos(lat_b) * math.pow(math.sin(b / 2), 2)))
218 |
219 | distance = EARTH_RADIUS * dis * 1000
220 | return distance
--------------------------------------------------------------------------------
/tool/visual/Times New Roman.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dmmm1997/DenseUAV/d3e4335fb73e1eeeb8db6771d11f731ac8ef3c14/tool/visual/Times New Roman.ttf
--------------------------------------------------------------------------------
/tool/visual/demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import scipy.io
3 | import torch
4 | import numpy as np
5 | import os
6 | from torchvision import datasets
7 | import matplotlib
8 | #matplotlib.use('agg')
9 | import matplotlib.pyplot as plt
10 | #######################################################################
11 | # Evaluate
12 | parser = argparse.ArgumentParser(description='Demo')
13 | parser.add_argument('--query_index', default=77, type=int, help='test_image_index')
14 | parser.add_argument('--test_dir',default='/home/dmmm/Dataset/DenseUAV/data_2022/test',type=str, help='./test_data')
15 | parser.add_argument('--config',
16 | default="/home/dmmm/Dataset/DenseUAV/data_2022/Dense_GPS_ALL.txt", type=str,
17 | help='./test_data')
18 | opts = parser.parse_args()
19 |
20 | configDict = {}
21 | with open(opts.config, "r") as F:
22 | context = F.readlines()
23 | for line in context:
24 | splitLineList = line.split(" ")
25 | configDict[splitLineList[0].split("/")[-2]] = [float(splitLineList[1].split("E")[-1]),
26 | float(splitLineList[2].split("N")[-1])]
27 |
28 | gallery_name = 'gallery_satellite'
29 | query_name = 'query_drone'
30 | # gallery_name = 'gallery_drone'
31 | # query_name = 'query_satellite'
32 |
33 | data_dir = opts.test_dir
34 | image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ) for x in [gallery_name, query_name]}
35 |
36 | #####################################################################
37 | #Show result
38 | def imshow(path, title=None):
39 | """Imshow for Tensor."""
40 | im = plt.imread(path)
41 | plt.imshow(im)
42 | if title is not None:
43 | plt.title(title)
44 | plt.pause(0.1) # pause a bit so that plots are updated
45 |
46 | ######################################################################
47 | result = scipy.io.loadmat('pytorch_result_1.mat')
48 | query_feature = torch.FloatTensor(result['query_f'])
49 | query_label = result['query_label'][0]
50 | gallery_feature = torch.FloatTensor(result['gallery_f'])
51 | gallery_label = result['gallery_label'][0]
52 |
53 | multi = os.path.isfile('multi_query.mat')
54 |
55 | if multi:
56 | m_result = scipy.io.loadmat('multi_query.mat')
57 | mquery_feature = torch.FloatTensor(m_result['mquery_f'])
58 | mquery_cam = m_result['mquery_cam'][0]
59 | mquery_label = m_result['mquery_label'][0]
60 | mquery_feature = mquery_feature.cuda()
61 |
62 | query_feature = query_feature.cuda()
63 | gallery_feature = gallery_feature.cuda()
64 |
65 | #######################################################################
66 | # sort the images
67 | def sort_img(qf, ql, gf, gl):
68 | query = qf.view(-1,1)
69 | # print(query.shape)
70 | score = torch.mm(gf,query)
71 | score = score.squeeze(1).cpu()
72 | score = score.numpy()
73 | # predict index
74 | index = np.argsort(score) #from small to large
75 | index = index[::-1]
76 | # index = index[0:2000]
77 | # good index
78 | query_index = np.argwhere(gl==ql)
79 |
80 | #good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
81 | junk_index = np.argwhere(gl==-1)
82 |
83 | mask = np.in1d(index, junk_index, invert=True)
84 | index = index[mask]
85 | return index
86 |
87 | i = opts.query_index
88 | index = sort_img(query_feature[i],query_label[i],gallery_feature,gallery_label)
89 |
90 | ########################################################################
91 | # Visualize the rank result
92 | query_path, _ = image_datasets[query_name].imgs[i]
93 | query_label = query_label[i]
94 | print(query_path)
95 | label6num = query_path.split("/")[-2]
96 | x_q,y_q = configDict[label6num]
97 |
98 | print('Top 10 images are as follow:')
99 | save_folder = 'image_show/%02d' % opts.query_index
100 | if not os.path.exists("image_show"):
101 | os.mkdir("image_show")
102 | if not os.path.isdir(save_folder):
103 | os.mkdir(save_folder)
104 | os.system('cp %s %s/query.jpg'%(query_path, save_folder))
105 |
106 | try: # Visualize Ranking Result
107 | # Graphical User Interface is needed
108 | fig = plt.figure(figsize=(16,4))
109 | ax = plt.subplot(1,11,1)
110 | ax.axis('off')
111 | imshow(query_path)
112 | ax.set_title("x:{:.7f}\ny:{:.7f}".format(x_q,y_q), color='blue',fontsize=5)
113 | for i in range(10):
114 | ax = plt.subplot(1,11,i+2)
115 | ax.axis('off')
116 | img_path, _ = image_datasets[gallery_name].imgs[index[i]]
117 | label = gallery_label[index[i]]
118 | labelg6num = img_path.split("/")[-2]
119 | x_g,y_g = configDict[label6num]
120 | print(label)
121 | imshow(img_path)
122 | os.system('cp %s %s/s%02d.tif'%(img_path, save_folder, i))
123 | if label == query_label:
124 | ax.set_title("x:{:.7f}\ny:{:.7f}".format(x_g,y_g),color='green',fontsize=5)
125 | else:
126 | ax.set_title("x:{:.7f}\ny:{:.7f}".format(x_g,y_g), color='red',fontsize=5)
127 | print(img_path)
128 | #plt.pause(100) # pause a bit so that plots are updated
129 | except RuntimeError:
130 | for i in range(10):
131 | img_path = image_datasets.imgs[index[i]]
132 | print(img_path[0])
133 | print('If you want to see the visualization of the ranking result, graphical user interface is needed.')
134 |
135 | fig.savefig(save_folder+"/show.png",dpi = 600)
136 |
137 |
--------------------------------------------------------------------------------
/tool/visual/demo_custom.py:
--------------------------------------------------------------------------------
1 | import scipy.io
2 | import argparse
3 | import torch
4 | import numpy as np
5 | from datasets.queryDataset import CenterCrop
6 | from torchvision import transforms
7 | from PIL import Image
8 | import yaml
9 | from tool.utils_server import load_network
10 | from torch.autograd import Variable
11 | import torch.backends.cudnn as cudnn
12 | import os
13 | from tool.get_property import find_GPS_image
14 | # import matplotlib.pyplot as plt
15 | import cv2
16 | import json
17 |
18 | University="计量"
19 | parser = argparse.ArgumentParser(description='Demo')
20 | parser.add_argument('--img', default="/media/dmmm/CE31-3598/DataSets/DenseCV_Data/实际测试图像({})/test02/DJI_0297.JPG".format(University), type=str, help='image path for visualization')
21 | parser.add_argument("--galleryFeature",default="features{}.mat".format(University), type=str, help='galleryFeature')
22 | parser.add_argument('--galleryPath', default="/media/dmmm/CE31-3598/DataSets/DenseCV_Data/satelliteHub({})".format(University),
23 | type=str, help='./test_data')
24 | parser.add_argument('--MapDir', default="../../maps/{}.tif".format(University), type=str, help='./test_data')
25 | parser.add_argument('--K', default=10, type=int, help='./test_data')
26 | # parser.add_argument("--mode",default="1", type=int, help='1:drone->satellite 2:satellite->drone')
27 | opts = parser.parse_args()
28 |
29 | # 30.3257654243082, 120.37341533989152 30.320441588110295, 120.38193743012363
30 | mapPosInfodir = "/home/dmmm/PycharmProjects/DenseCV/demo/maps/pos.txt"
31 | with open(mapPosInfodir,"r") as F:
32 | listLine = F.readlines()
33 | for line in listLine:
34 | name,TN,TE,BN,BE = line.split(" ")
35 | if name==University:
36 | startE = eval(TE.split("TE")[-1])
37 | startN = eval(TN.split("TN")[-1])
38 | endE = eval(BE.split("BE")[-1])
39 | endN = eval(BN.split("BN")[-1])
40 |
41 | AllImage = cv2.imread(opts.MapDir)
42 | h, w, c = AllImage.shape
43 |
44 |
45 | #####################################################################
46 | # #Show result
47 | # def imshow(path, title=None):
48 | # """Imshow for Tensor."""
49 | # im = plt.imread(path)
50 | # plt.imshow(im)
51 | # if title is not None:
52 | # plt.title(title)
53 | # plt.pause(0.1) # pause a bit so that plots are updated
54 |
55 |
56 | def getPosInfo(imgPath):
57 | GPS_info = find_GPS_image(imgPath)
58 | x = list(GPS_info.values())
59 | gps_dict_formate = x[0]
60 | y = list(gps_dict_formate.values())
61 | height = eval(y[5])
62 | E = y[3]
63 | N = y[1]
64 | return [N, E]
65 |
66 | #######################################################################
67 | # sort the images and return topK index
68 | def getTopKImage(qf, gf, gl,K):
69 | query = qf.view(-1, 1)
70 | # print(query.shape)
71 | score = torch.mm(gf, query)
72 | score = score.squeeze().cpu()
73 | score = score.numpy()
74 | # predict index
75 | index = np.argsort(score) # from small to large
76 | index = index[::-1]
77 | # index = index[0:2000]
78 | # good index
79 | # query_index = np.argwhere(gl == ql)
80 |
81 | # good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
82 | # junk_index = np.argwhere(gl == -1)
83 |
84 | # mask = np.in1d(index, junk_index, invert=True)
85 | # index = index[mask]
86 | return gl[index[:K]]
87 |
88 |
89 | def extract_feature(img, model):
90 | count = 0
91 | n, c, h, w = img.size()
92 | count += n
93 | input_img = Variable(img.cuda())
94 | outputs, _ = model(input_img, None)
95 | ff = outputs
96 | # norm feature
97 | if len(ff.shape) == 3:
98 | # 1. To treat every part equally, I calculate the norm for every 2048-dim part feature.
99 | # 2. To keep the cosine score==1, sqrt(6) is added to norm the whole feature (2048*6).
100 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * np.sqrt(opts.block)
101 | ff = ff.div(fnorm.expand_as(ff))
102 | ff = ff.view(ff.size(0), -1)
103 | else:
104 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
105 | ff = ff.div(fnorm.expand_as(ff))
106 |
107 | features = ff.data
108 | return features
109 |
110 | def generateDictOfGalleryPosInfo():
111 | satellite_configDict = {}
112 | with open(os.path.join(opts.galleryPath, "PosInfo.txt"), "r") as F:
113 | context = F.readlines()
114 | for line in context:
115 | splitLineList = line.split(" ")
116 | satellite_configDict[splitLineList[0]] = [float(splitLineList[1].split("N")[-1]),
117 | float(splitLineList[2].split("E")[-1])]
118 | return satellite_configDict
119 |
120 |
121 | def imshowByIndex(index):
122 | galleryPath = os.path.join(opts.galleryPath, index + ".tif")
123 | image = cv2.imread(galleryPath)
124 | cv2.imshow("gallery", image)
125 |
126 | data_transforms = transforms.Compose([
127 | CenterCrop(),
128 | transforms.Resize((256, 256), interpolation=3),
129 | transforms.ToTensor(),
130 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
131 | ])
132 |
133 | ######################################################################
134 | # load network
135 | config_path = 'opts.yaml'
136 | with open(config_path, 'r') as stream:
137 | config = yaml.load(stream)
138 | opts.stride = config['stride']
139 | opts.views = config['views']
140 | opts.transformer = config['transformer']
141 | opts.pool = config['pool']
142 | opts.views = config['views']
143 | opts.LPN = config['LPN']
144 | opts.block = config['block']
145 | opts.nclasses = config['nclasses']
146 | opts.droprate = config['droprate']
147 | opts.share = config['share']
148 | opts.checkpoint = "net_119.pth"
149 | torch.cuda.set_device("cuda:0")
150 | cudnn.benchmark = True
151 |
152 | model = load_network(opts)
153 | model = model.eval()
154 | model = model.cuda()
155 |
156 | ######################################################################
157 | result = scipy.io.loadmat(opts.galleryFeature)
158 | gallery_feature = torch.FloatTensor(result['features'])
159 | gallery_label = result['labels']
160 | gallery_feature = gallery_feature.cuda()
161 |
162 | satellitePosInfoDict = generateDictOfGalleryPosInfo()
163 |
164 | img = Image.open(opts.img)
165 | input = data_transforms(img)
166 | input = torch.unsqueeze(input, 0)
167 | # query Pos info
168 | queryPosInfo = getPosInfo(opts.img)
169 | Q_N = float(queryPosInfo[0])
170 | Q_E = float(queryPosInfo[1])
171 |
172 | with torch.no_grad():
173 | feature = extract_feature(input, model)
174 | indexSorted = getTopKImage(feature, gallery_feature, gallery_label,opts.K)
175 |
176 |
177 | selectedGalleryPath = [os.path.join(opts.galleryPath,"{}.tif".format(index)) for index in indexSorted]
178 | dict_paths = {"K":opts.K,"query":opts.img,"gallery":selectedGalleryPath}
179 | with open("test.json", "w", encoding='utf-8') as f:
180 | # indent 超级好用,格式化保存字典,默认为None,小于0为零个空格
181 | # f.write(json.dumps(dict_paths, indent=4))
182 | json.dump(dict_paths, f, indent=4) # 传入文件描述符,和dumps一样的结果
183 |
184 |
185 | galleryPosDict = {"N": [], "E": []}
186 | for index in indexSorted:
187 | bestMatchedPosInfo = satellitePosInfoDict[index]
188 | galleryPosDict["N"].append(float(bestMatchedPosInfo[0]))
189 | galleryPosDict["E"].append(float(bestMatchedPosInfo[1]))
190 |
191 | ## query visualization
192 | X = int((Q_E - startE) / (endE - startE) * w)
193 | Y = int((Q_N - startN) / (endN - startN) * h)
194 | cv2.circle(AllImage, (X,Y), 30, color=(0, 0, 255), thickness=10)
195 |
196 | ## gallery visualization
197 | index = 1
198 | for N, E in zip(galleryPosDict["N"], galleryPosDict["E"]):
199 | X = int((E - startE) / (endE - startE) * w)
200 | Y = int((N - startN) / (endN - startN) * h)
201 | if index>=10:
202 | cv2.circle(AllImage, (X, Y), 30, color=(255, 0, 0), thickness=6)
203 | cv2.putText(AllImage, str(index), (X - 20, Y + 15), cv2.FONT_HERSHEY_COMPLEX, 1.2, color=(255, 0, 0),
204 | thickness=2)
205 | else:
206 | cv2.circle(AllImage, (X, Y), 30, color=(255, 0, 0), thickness=6)
207 | cv2.putText(AllImage, str(index), (X - 20, Y + 20), cv2.FONT_HERSHEY_COMPLEX, 2, color=(255, 0, 0),
208 | thickness=2)
209 | index += 1
210 |
211 | AllImage = cv2.resize(AllImage,(0,0),fx=0.25,fy=0.25)
212 | cv2.imwrite("topKLocationBySingleImage.tif",AllImage)
213 |
214 | # os.system("python visualization.py")
215 |
216 | ###报qt的错误
217 | # try:
218 | # fig = plt.figure(figsize=(12, 4))
219 | # ax = plt.subplot(1, opts.K+1, 1)
220 | # ax.axis('off')
221 | # imshow(opts.img, 'query')
222 | # for i,path in enumerate(selectedGalleryPath):
223 | # ax = plt.subplot(1, 11, i + 2)
224 | # ax.axis('off')
225 | # imshow(path)
226 | # except:
227 | # print('If you want to see the visualization of the ranking result, graphical user interface is needed.')
228 | #
229 | # fig.savefig("show.png",dpi=600)
230 |
231 |
--------------------------------------------------------------------------------
/tool/visual/demo_custom_visualization.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import json
3 |
4 |
5 | with open("test.json", "r", encoding='utf-8') as f:
6 | data = json.loads(f.read()) # load的传入参数为字符串类型
7 |
8 | K = data["K"]
9 | query = data["query"]
10 | gallery = data["gallery"]
11 |
12 | #Show result
13 | def imshow(path, title=None):
14 | """Imshow for Tensor."""
15 | im = plt.imread(path)
16 | plt.imshow(im)
17 | if title is not None:
18 | plt.title(title)
19 | # plt.pause(0.1) # pause a bit so that plots are updated
20 |
21 | ##报qt的错误
22 | try:
23 | fig = plt.figure(figsize=(12, 4))
24 | ax = plt.subplot(1, K+1, 1)
25 | ax.axis('off')
26 | imshow(query, 'query')
27 | for i,path in enumerate(gallery):
28 | ax = plt.subplot(1, 11, i + 2)
29 | ax.axis('off')
30 | imshow(path)
31 | except:
32 | print('If you want to see the visualization of the ranking result, graphical user interface is needed.')
33 |
34 | fig.savefig("show.png",dpi=600)
35 |
--------------------------------------------------------------------------------
/tool/visual/draw_SDMcurve.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import os
3 | import json
4 | from glob import glob
5 | from matplotlib import font_manager
6 |
7 | output_dir = "/home/dmmm/VscodeProject/demo_DenseUAV/visualization/SDMCurve/Backbone"
8 |
9 | os.makedirs(output_dir, exist_ok = True)
10 |
11 | plt.rcParams['font.sans-serif'] = ['Times New Roman']
12 |
13 | source_dir = {
14 | # backbone
15 | "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Backbone_Experiment_ConvnextT":"ConvNeXt-T",
16 | "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Backbone_Experiment_DeitS":"DeiT-S",
17 | "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Backbone_Experiment_EfficientNet-B3":"EfficientNet-B3",
18 | "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Backbone_Experiment_EfficientNet-B5":"EfficientNet-B5",
19 | "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Backbone_Experiment_PvTv2b2":"PvTv2-B2",
20 | "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Backbone_Experiment_resnet50":"ResNet50",
21 | "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Backbone_Experiment_Swinv2T-256":"Swinv2-T",
22 | "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Head_Experiment-Global":"ViT-S",
23 | "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Backbone_Experiment_ViTB":"ViT-B",
24 | # head
25 | # "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Head_Experiment-MaxPool":"MaxPool",
26 | # "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Head_Experiment-AvgMaxPool":"AvgMaxPool",
27 | # "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Head_Experiment-AvgPool":"AvgPool",
28 | # "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Head_Experiment-Global":"Global",
29 | # "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Head_Experiment-GeM":"GemPool",
30 | # "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Head_Experiment-FSRA2B":"FSRA(block=2)",
31 | # "/home/dmmm/VscodeProject/demo_DenseUAV/checkpoints/Head_Experiment-LPN2B":"LPN(Block=2)",
32 | }
33 |
34 | json_name = "SDM*.json"
35 |
36 | fig = plt.figure(figsize=(4,4))
37 | plt.grid()
38 | plt.yticks(fontproperties='Times New Roman', size=12)
39 | plt.xticks(fontproperties='Times New Roman', size=12)
40 |
41 | x = list(range(1,101))
42 |
43 | color = ['r', 'k', 'y', 'c', 'g', 'm', 'b', 'coral', 'tan']
44 | ind = 0
45 | for path, name in source_dir.items():
46 | print(name)
47 | target_file = glob(os.path.join(path, json_name))[0]
48 | with open(target_file, 'r') as F:
49 | data = json.load(F)
50 | y = list(data.values())
51 | plt.plot(x,y,c=color[ind],marker = 'o',label=name,linewidth=1.0,markersize=1)
52 | ind+=1
53 |
54 | plt.legend(loc="upper right",prop={'family' : 'Times New Roman', 'size' : 12})
55 | plt.ylabel("SDM@K",fontdict={'family' : 'Times New Roman', 'size': 12})
56 | plt.xlabel("K",fontdict={'family' : 'Times New Roman', 'size': 12})
57 | plt.tight_layout()
58 |
59 |
60 | fig.savefig(os.path.join(output_dir, "backbone.eps"), dpi=600, format='eps')
61 | plt.show()
62 |
63 |
--------------------------------------------------------------------------------
/tool/visual/grad_cam.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from PIL import Image
5 | import matplotlib.pyplot as plt
6 | from torchvision import transforms
7 | from utils import GradCAM, show_cam_on_image, center_crop_img
8 | from vit_model import vit_base_patch16_224
9 |
10 |
11 | class ReshapeTransform:
12 | def __init__(self, model):
13 | input_size = model.patch_embed.img_size
14 | patch_size = model.patch_embed.patch_size
15 | self.h = input_size[0] // patch_size[0]
16 | self.w = input_size[1] // patch_size[1]
17 |
18 | def __call__(self, x):
19 | # remove cls token and reshape
20 | # [batch_size, num_tokens, token_dim]
21 | result = x[:, 1:, :].reshape(x.size(0),
22 | self.h,
23 | self.w,
24 | x.size(2))
25 |
26 | # Bring the channels to the first dimension,
27 | # like in CNNs.
28 | # [batch_size, H, W, C] -> [batch, C, H, W]
29 | result = result.permute(0, 3, 1, 2)
30 | return result
31 |
32 |
33 | def main():
34 | model = vit_base_patch16_224()
35 | # 链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f
36 | weights_path = "./vit_base_patch16_224.pth"
37 | model.load_state_dict(torch.load(weights_path, map_location="cpu"))
38 | # Since the final classification is done on the class token computed in the last attention block,
39 | # the output will not be affected by the 14x14 channels in the last layer.
40 | # The gradient of the output with respect to them, will be 0!
41 | # We should chose any layer before the final attention block.
42 | target_layers = [model.blocks[-1].norm1]
43 |
44 | data_transform = transforms.Compose([transforms.ToTensor(),
45 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
46 | # load image
47 | img_path = "both.png"
48 | assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
49 | img = Image.open(img_path).convert('RGB')
50 | img = np.array(img, dtype=np.uint8)
51 | img = center_crop_img(img, 224)
52 | # [C, H, W]
53 | img_tensor = data_transform(img)
54 | # expand batch dimension
55 | # [C, H, W] -> [N, C, H, W]
56 | input_tensor = torch.unsqueeze(img_tensor, dim=0)
57 |
58 | cam = GradCAM(model=model,
59 | target_layers=target_layers,
60 | use_cuda=False,
61 | reshape_transform=ReshapeTransform(model))
62 | target_category = 281 # tabby, tabby cat
63 | # target_category = 254 # pug, pug-dog
64 |
65 | grayscale_cam = cam(input_tensor=input_tensor, target_category=target_category)
66 |
67 | grayscale_cam = grayscale_cam[0, :]
68 | visualization = show_cam_on_image(img / 255., grayscale_cam, use_rgb=True)
69 | plt.imshow(visualization)
70 | plt.show()
71 |
72 |
73 | if __name__ == '__main__':
74 | main()
--------------------------------------------------------------------------------
/tool/visual/heatmap.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2
4 | import matplotlib
5 | matplotlib.use('agg')
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | from tool.utils_server import load_network
9 | import yaml
10 | import argparse
11 | import torch
12 | from torchvision import datasets, models, transforms
13 | from PIL import Image
14 | os.environ["CUDA_VISIBLE_DEVICES"] = '0'
15 | parser = argparse.ArgumentParser(description='Training')
16 | import math
17 |
18 | parser.add_argument('--data_dir',default='/media/dmmm/CE31-3598/DataSets/DenseCV_Data/improvedOriData/test',type=str, help='./test_data')
19 | parser.add_argument('--name', default='from_transreid_256_4B_small_lr005_kl', type=str, help='save model path')
20 | parser.add_argument('--batchsize', default=1, type=int, help='batchsize')
21 | parser.add_argument('--checkpoint',default="net_119.pth", help='weights' )
22 | opt = parser.parse_args()
23 |
24 | config_path = 'opts.yaml'
25 | with open(config_path, 'r') as stream:
26 | config = yaml.load(stream)
27 | opt.stride = config['stride']
28 | opt.views = config['views']
29 | opt.transformer = config['transformer']
30 | opt.pool = config['pool']
31 | opt.views = config['views']
32 | opt.LPN = config['LPN']
33 | opt.block = config['block']
34 | opt.nclasses = config['nclasses']
35 | opt.droprate = config['droprate']
36 | opt.share = config['share']
37 |
38 | if 'h' in config:
39 | opt.h = config['h']
40 | opt.w = config['w']
41 | if 'nclasses' in config: # tp compatible with old config files
42 | opt.nclasses = config['nclasses']
43 | else:
44 | opt.nclasses = 751
45 |
46 |
47 | def heatmap2d(img, arr):
48 | # fig = plt.figure()
49 | # ax0 = fig.add_subplot(121, title="Image")
50 | # ax1 = fig.add_subplot(122, title="Heatmap")
51 | # fig, ax = plt.subplots()
52 | # ax[0].imshow(Image.open(img))
53 | plt.figure()
54 | heatmap = plt.imshow(arr, cmap='viridis')
55 | plt.axis('off')
56 | # fig.colorbar(heatmap, fraction=0.046, pad=0.04)
57 | #plt.show()
58 | plt.savefig('heatmap_dbase')
59 |
60 | data_transforms = transforms.Compose([
61 | transforms.Resize((opt.h, opt.w), interpolation=3),
62 | transforms.ToTensor(),
63 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
64 | ])
65 |
66 | # image_datasets = {x: datasets.ImageFolder( os.path.join(opt.data_dir,x) ,data_transforms) for x in ['satellite']}
67 | # dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
68 | # shuffle=False, num_workers=1) for x in ['satellite']}
69 |
70 | # imgpath = image_datasets['satellite'].imgs
71 | # print(imgpath)
72 | from glob import glob
73 | print(opt.data_dir)
74 | #list = os.listdir(os.path.join(opt.data_dir,"gallery_drone"))
75 | for i in ["000009","000013","000015","000016","000018","000035","000039","000116","000130"]:
76 | print(i)
77 | imgpath = os.path.join(opt.data_dir,"gallery_drone/"+i)
78 | #imgname = 'gallery_drone/0726/image-28.jpeg'
79 | # imgname = 'query_satellite/0721/0721.jpg'
80 | #imgpath = os.path.join(opt.data_dir,imgname)
81 | imgpath = os.path.join(imgpath, "1.JPG")
82 | print(imgpath)
83 | img = Image.open(imgpath)
84 | img = data_transforms(img)
85 | img = torch.unsqueeze(img,0)
86 | #print(img.shape)
87 | model = load_network(opt)
88 |
89 | model = model.eval().cuda()
90 |
91 | # data = next(iter(dataloaders['satellite']))
92 | # img, label = data
93 | with torch.no_grad():
94 | # x = model.model_1.model.conv1(img.cuda())
95 | # x = model.model_1.model.bn1(x)
96 | # x = model.model_1.model.relu(x)
97 | # x = model.model_1.model.maxpool(x)
98 | # x = model.model_1.model.layer1(x)
99 | # x = model.model_1.model.layer2(x)
100 | # x = model.model_1.model.layer3(x)
101 | # output = model.model_1.model.layer4(x)
102 | features = model.model_1.transformer(img.cuda())
103 | part_features = features[:,1:]
104 | part_features = part_features.view(part_features.size(0),int(math.sqrt(part_features.size(1))),int(math.sqrt(part_features.size(1))),part_features.size(2))
105 | output = part_features.permute(0,3,1,2)
106 | #print(output.shape)
107 | heatmap = output.squeeze().sum(dim=0).cpu().numpy()
108 | # print(heatmap.shape)
109 | # print(heatmap)
110 | # heatmap = np.mean(heatmap, axis=0)
111 | #
112 | # heatmap = np.maximum(heatmap, 0)
113 | # heatmap /= np.max(heatmap)
114 | heatmap = (heatmap - np.min(heatmap))/(np.max(heatmap)-np.min(heatmap))
115 |
116 | #print(heatmap)
117 | # cv2.imshow("1",heatmap)
118 | # cv2.waitKey(0)
119 | #test_array = np.arange(100 * 100).reshape(100, 100)
120 | # Result is saved tas `heatmap.png`
121 | img = cv2.imread(imgpath) # 用cv2加载原始图像
122 | heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0])) # 将热力图的大小调整为与原始图像相同
123 | heatmap = np.uint8(255 * heatmap) # 将热力图转换为RGB格式
124 | #print(heatmap)
125 | heatmap = cv2.applyColorMap(heatmap, 2) # 将热力图应用于原始图像model.py
126 | superimposed_img = heatmap * 0.8 + img # 这里的0.4是热力图强度因子
127 | if not os.path.exists("heatout"):
128 | os.mkdir("./heatout")
129 | cv2.imwrite("./heatout/"+i+".jpg", superimposed_img)
130 | #heatmap2d(imgpath,heatmap)
--------------------------------------------------------------------------------
/tool/visual_demo.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import random
3 | from glob import glob
4 | import numpy as np
5 |
6 | m_50 = []
7 | m_20 = []
8 | m_10 = []
9 | m_5 = []
10 | m_3 = []
11 | m = [m_3,m_5,m_10,m_20,m_50]
12 |
13 | x_label=[]
14 | RDS_value = []
15 | num_a=[3,5,10,20,50]
16 | filenames = glob("output/result_files/height-level/*.txt")
17 | filenames.sort(key=lambda x:int(x.split("/")[-1].split("_")[0]))
18 | for num, path in enumerate(filenames):
19 | x_label.append(path.split("/")[-1].split("_")[0])
20 | RDS_value.append(float(path.split("/")[-1].split(".txt")[0].split("=")[-1]))
21 | result = []
22 | with open(path, "r") as F:
23 | lines = F.readlines()
24 | for i,a in enumerate(num_a):
25 | line = lines[a-1]
26 | out = line.split(' ')[-1]
27 | out = out.split("\n")[0]
28 | m[i].append(float(out))
29 |
30 |
31 |
32 | fig = plt.figure(figsize=(6, 10))
33 | ax1 = fig.subplots()
34 | # ax1.tick_params(axis='x', labelrotation=10)
35 | plt.xlabel("Height (m)",fontdict={'family' : 'Times New Roman', 'size': 21})
36 | plt.xticks(fontproperties='Times New Roman',fontsize=18)
37 | ax1.set_ylabel('Accurary', fontdict={'family': 'Times New Roman', 'size': 21}) # 添加x轴的标签
38 | ax1.set_ylim(0,1.1)
39 | ax1.set_yticklabels([0.0,0.2,0.4,0.6,0.8,1.0], fontproperties='Times New Roman', fontsize=16) # 添加x轴上的标签
40 |
41 |
42 | # 绘制横向条形图
43 | bar1 = ax1.bar(x_label, m_50, width=0.5, color="#45a776", label="MA@"+str(num_a[4]),edgecolor="k",linewidth=2)
44 | bar2 = ax1.bar(x_label, m_20, width=0.5, color="#3682be", label="MA@"+str(num_a[3]),edgecolor="k",linewidth=2)
45 | bar3 = ax1.bar(x_label, m_10, width=0.5, color="#b3974e", label="MA@"+str(num_a[2]),edgecolor="k",linewidth=2)
46 | bar4 = ax1.bar(x_label, m_5, width=0.5, color="#eed777", label="MA@"+str(num_a[1]),edgecolor="k",linewidth=2)
47 | bar5 = ax1.bar(x_label, m_3, width=0.5, color="#f05326", label="MA@"+str(num_a[0]),edgecolor="k",linewidth=2)
48 |
49 | # 在横向条形图上添加数据
50 | for (a, b) in zip(x_label, m_3):
51 | plt.text(a, b-0.03, "{:.3f}".format(b), color='black', fontsize=15, ha="center",va="bottom", fontproperties='Times New Roman')
52 | for (a, b) in zip(x_label, m_5):
53 | plt.text(a, b-0.03, "{:.3f}".format(b), color='black', fontsize=15, ha="center",va="bottom", fontproperties='Times New Roman')
54 | for (a, b) in zip(x_label, m_10):
55 | plt.text(a, b-0.03, "{:.3f}".format(b), color='black', fontsize=15, ha="center",va="bottom", fontproperties='Times New Roman')
56 | for (a, b) in zip(x_label, m_20):
57 | plt.text(a, b-0.03, "{:.3f}".format(b), color='black', fontsize=15, ha="center",va="bottom", fontproperties='Times New Roman')
58 | for (a, b) in zip(x_label, m_50):
59 | plt.text(a, b-0.03, "{:.3f}".format(b), color='black', fontsize=15, ha="center",va="bottom", fontproperties='Times New Roman')
60 | # legend_bar = plt.legend(handles=[bar1,bar2,bar3,bar4,bar5], loc="upper left", ncol=2, prop={'family': 'Times New Roman', 'size': 16})
61 |
62 |
63 | ax2 = ax1.twinx()
64 | rds, = ax2.plot(x_label,RDS_value,color="red",linestyle="--", marker='*',label="RDS", markersize=16, linewidth=3)
65 | # legend_rds = plt.legend(handles=[rds],loc = "upper right",prop={'family': 'Times New Roman', 'size': 16})
66 | for ind, (a, b) in enumerate(zip(x_label, RDS_value)):
67 | bias = + 0.002
68 | plt.text(a, b+bias, "{:.3f}".format(b), color='darkred', fontsize=17, ha="center",va="bottom", fontproperties='Times New Roman')
69 |
70 | ax2.set_ylabel('RDS', fontdict={'family': 'Times New Roman', 'size': 21}) # 添加x轴的标签
71 | ax2.set_ylim(0.68,0.82)
72 | # ax1.set_xticks(fontproperties='Times New Roman',fontsize=22) # 添加y轴上的刻度
73 | ax2.set_yticklabels([0.68,0.70,0.72,0.74,0.76,0.78,0.80,0.82,0.84], fontproperties='Times New Roman', fontsize=16) # 添加x轴上的标签
74 |
75 | # ax = ax1.add_artist(legend_bar)
76 |
77 | legend_bar = plt.legend(handles=[bar1,bar2,bar3,bar4,bar5,rds], loc="upper left", ncol=2, prop={'family': 'Times New Roman', 'size': 16})
78 |
79 |
80 | # 保存图像到当前文件夹下,图像名称为 image.png
81 | plt.tight_layout()
82 | plt.savefig('tool/visual/MA_curve/RDS_MA_height.jpg', dpi=600)
83 | plt.savefig('tool/visual/MA_curve/RDS_MA_height.eps', dpi=600)
--------------------------------------------------------------------------------
/train_test_local.sh:
--------------------------------------------------------------------------------
1 | name="baseline"
2 | root_dir="/data/datasets/crossview/DenseUAV/data_2022"
3 | data_dir=$root_dir/train
4 | test_dir=$root_dir/test
5 | gpu_ids=0
6 | num_worker=8
7 | lr=0.01
8 | batchsize=16
9 | sample_num=1
10 | block=1
11 | num_bottleneck=512
12 | backbone="ViTS-224" # resnet50 ViTS-224 senet
13 | head="SingleBranch"
14 | head_pool="avg" # global avg max avg+max
15 | cls_loss="CELoss" # CELoss FocalLoss
16 | feature_loss="WeightedSoftTripletLoss" # TripletLoss HardMiningTripletLoss WeightedSoftTripletLoss ContrastiveLoss
17 | kl_loss="KLLoss" # KLLoss
18 | h=224
19 | w=224
20 | load_from="no"
21 | ra="satellite" # random affine
22 | re="satellite" # random erasing
23 | cj="no" # color jitter
24 | rr="uav" # random rotate
25 |
26 | python train.py --name $name --data_dir $data_dir --gpu_ids $gpu_ids --sample_num $sample_num \
27 | --block $block --lr $lr --num_worker $num_worker --head $head --head_pool $head_pool \
28 | --num_bottleneck $num_bottleneck --backbone $backbone --h $h --w $w --batchsize $batchsize --load_from $load_from \
29 | --ra $ra --re $re --cj $cj --rr $rr --cls_loss $cls_loss --feature_loss $feature_loss --kl_loss $kl_loss
30 |
31 | cd checkpoints/$name
32 | python test.py --name $name --test_dir $test_dir --gpu_ids $gpu_ids --num_worker $num_worker
33 | python evaluate_gpu.py
34 | python evaluateDistance.py --root_dir $root_dir
35 | cd ../../
--------------------------------------------------------------------------------