├── .gitignore
├── README.md
├── aachen_452.npy
├── assets
├── overview.png
└── samples
│ ├── 1116.png
│ └── 1417176916116788.png
├── configs
├── config_test_aachen.json
├── config_test_robotcar.json
├── config_train_aachen.json
└── config_train_robotcar.json
├── dataloader
├── aachen.py
├── augmentation.py
├── place_recog.py
└── robotcar.py
├── datasets
├── aachen
│ ├── aachen_db_imglist.txt
│ ├── aachen_db_query_imglist.txt
│ ├── aachen_grgb_gid_v5.txt
│ ├── aachen_query_imglist.txt
│ ├── aachen_test_file_list.txt
│ ├── aachen_test_file_list_v5.txt
│ ├── aachen_train_file_list_v5.txt
│ └── pairs-query-netvlad50.txt
└── robotcar
│ ├── pairs-query-netvlad20-percam-perloc-rear.txt
│ ├── robotcar_db_imglist.txt
│ ├── robotcar_db_query_imglist.txt
│ ├── robotcar_grgb_gid.txt
│ ├── robotcar_queries_with_intrinsics_rear.txt
│ ├── robotcar_query_imglist.txt
│ ├── robotcar_rear_grgb_gid.txt
│ ├── robotcar_rear_test_file_list.txt
│ ├── robotcar_rear_test_file_list_full.txt
│ └── robotcar_rear_train_file_list.txt
├── localization
├── coarse
│ ├── coarselocalization.py
│ └── evaluate.py
├── fine
│ ├── extractor.py
│ ├── features
│ │ ├── extract_d2net.py
│ │ ├── extract_r2d2.py
│ │ ├── extract_sift.py
│ │ └── extract_spp.py
│ ├── localize_cv2.py
│ ├── matcher.py
│ ├── test.py
│ └── triangulation.py
├── localizer.py
├── tools.py
└── utils
│ ├── database.py
│ ├── parsers.py
│ └── read_write_model.py
├── loss
├── __pycache__
│ ├── accuracy.cpython-37.pyc
│ ├── aploss.cpython-36.pyc
│ └── aploss.cpython-37.pyc
├── accuracy.py
├── aploss.py
└── seg_loss
│ ├── __pycache__
│ ├── crossentropy_loss.cpython-37.pyc
│ └── segloss.cpython-37.pyc
│ ├── crossentropy_loss.py
│ ├── focal_loss.py
│ ├── segloss.py
│ └── utils.py
├── net
├── layers.py
├── locnets
│ ├── r2d2.py
│ ├── resnet.py
│ └── superpoint.py
├── plceregnets
│ ├── gem.py
│ └── pregnet.py
├── regnets
│ ├── deeplab.py
│ └── pspnet.py
└── segnet.py
├── robotcar.npy
├── run_loc_aachen
├── run_loc_robotcar
├── run_reconstruct_aachen
├── run_reconstruct_robotcar
├── test.py
├── test_aachen
├── test_robotcar
├── tools
├── common.py
├── config_parser.py
├── loc_tools.py
├── optim.py
└── seg_tools.py
├── train.py
├── train_aachen
├── train_place_recog.py
├── train_robotcar
└── trainer_recog.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | __pycache__
3 | localization/coarse/__pycache__
4 | localization/__pycache__
5 | localization/fine/__pycache__
6 | localization/utils/__pycache__
7 | tools/__pycache__
8 | weights
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Efficient Large-scale Localization by Global Instance Recognition
2 |
3 |
4 |
5 |
6 |
7 | In this work, we propose to leverage global instances, which are robust to illumination and season changes for both
8 | coarse and fine localization. For coarse localization, instead of performing global reference search directly, we search
9 | for reference images from recognized global instances progressively. The recognized instances are further utilized for
10 | instance-wise feature detection and matching to enhance the localization accuracy.
11 |
12 | * Full paper
13 | PDF: [Efficient Large-scale Localization by Global Instance Recognition](https://openaccess.thecvf.com/content/CVPR2022/papers/Xue_Efficient_Large-Scale_Localization_by_Global_Instance_Recognition_CVPR_2022_paper.pdf)
14 | .
15 |
16 | * Authors: *Fei Xue, Ignas Budvytis, Daniel Olmeda Reino, Roberto Cipolla*
17 |
18 | * Website: [lbr](https://github.com/feixue94/feixue94.github.io/lbr) for videos, slides, recent updates, and datasets.
19 |
20 | ## Dependencies
21 |
22 | * Python 3 >= 3.6
23 | * PyTorch >= 1.8
24 | * OpenCV >= 3.4
25 | * NumPy >= 1.18
26 | * segmentation-models-pytorch = 0.1.3
27 | * colmap
28 | * pycolmap = 0.0.1
29 |
30 | ## Data preparation
31 |
32 | Please follow instructions on the [VisualLocalization Benchmark](https://www.visuallocalization.net/datasets/) to
33 | download images and reference 3D models of Aachen_v1.1 and RobotCar-Seasons datasets
34 |
35 | * [Images of Aachen_v1.1 dataset](https://data.ciirc.cvut.cz/public/projects/2020VisualLocalization/Aachen-Day-Night/)
36 | * [Images of RobotCar-Seasons dataset](https://data.ciirc.cvut.cz/public/projects/2020VisualLocalization/RobotCar-Seasons/)
37 |
38 | You can download global instance labels used in this work here:
39 |
40 | * [Global instances of Aachen_v1.1](https://drive.google.com/file/d/17qerRcU8Iemjwz7tUtlX9syfN-WVSs4K/view?usp=sharing)
41 | * [Global instances of RobotCar-Seasons](https://drive.google.com/file/d/1Ns5I3YGoMCBURzWKZTxqsugeG4jUcj4a/view?usp=sharing)
42 |
43 | Since only daytime images are included in Aachen_v1.1 and RobotCar-Seasons database, which may cause huge recognition
44 | errors for the recognition of nighttime query images, we augment the training data by generating some stylized images,
45 | which can also be downloaded along with global instances. The structure of files in Aachen_v1.1 dataset should be like
46 | this:
47 |
48 | ```
49 | - aachen_v1.1
50 | - global_seg_instance
51 | - db
52 | - query
53 | - sequences
54 | - images
55 | - 3D-models
56 | - images.bin
57 | - points3D.bin
58 | - cameras.bin
59 | - stylized
60 | - images (raw database images)
61 | - images_aug1
62 | - images_aug2
63 | ```
64 |
65 | For RobotCar-Seasons dataset, it should be like this:
66 |
67 | ```
68 | - RobotCar-Seasons
69 | - gloabl_seg_instance
70 | - overcast-reference
71 | - rear
72 | - images
73 | - overcast-reference
74 | - night
75 | ...
76 | - night-rain
77 | - 3D-models
78 | - sfm-sift
79 | - cameras.bin
80 | - images.bin
81 | - points3D.bin
82 | - stylized
83 | - overcast-reference (raw database images)
84 | - overcast-reference_aug_d
85 | - overcast-reference_aug_r
86 | - overcast-reference_aug_nr
87 | ```
88 |
89 | ## Pretrained weights
90 |
91 | We provide pretrained weights for local feature detection and extraction, global instance recognition for Aachen_v1.1
92 | and RobotCar-Seasons datasets, respectively, which can be downloaded
93 | from [here](https://drive.google.com/file/d/1N4j7PkZoy2CkWhS7u6dFzMIoai3ShG9p/view?usp=sharing)
94 |
95 | ## Testing of global instance recognition
96 |
97 | you will get predicted masks of global instances, confidence maps, global features, and visualization images.
98 |
99 | * testing recognition on Aachen_v1.1
100 |
101 | ```
102 | ./test_aachen
103 | ```
104 |
105 |
106 |
107 |
108 |
109 | * testing recognition on RobotCar-Seasons
110 |
111 | ```
112 | ./test_robotcar
113 | ```
114 |
115 |
116 |
117 |
118 |
119 | ## 3D Reconstruction
120 |
121 | For fine localization, you also need a 3D map of the environment reconstructed by structure-from-motion.
122 |
123 | * feature extraction and 3D reconstruction for Aachen_v1.1
124 |
125 | ```
126 | ./run_reconstruct_aachen
127 | ```
128 |
129 | * feature extraction and 3d reconstruction for RobotCar-Seasons
130 |
131 | ```
132 | ./run_reconstruct_robotcar
133 | ```
134 |
135 | ## Localization with global instances
136 |
137 | Once you have the global instance masks of the query and database images and the 3D map of the scene, you can run the
138 | following commands for localization.
139 |
140 | * localization on Aachen_v1.1
141 |
142 | ```
143 | ./run_loc_aachn
144 | ```
145 |
146 | you will get results like this:
147 |
148 | | | Day | Night |
149 | | -------- | ------- | -------- |
150 | | cvpr | 89.1 / 96.1 / 99.3 | 77.0 / 90.1 / 99.5 |
151 | | post-cvpr | 88.8 / 95.8 / 99.2 | 75.4 / 91.6 / 100 |
152 |
153 | * localization on RobotCar-Seasons
154 |
155 | ```
156 | ./run_loc_robotcar
157 | ```
158 |
159 | you will get results like this:
160 |
161 | | | Night | Night-rain |
162 | | -------- | ----- | ------- |
163 | | cvpr | 24.9 / 62.3 / 86.1 | 47.5 / 73.4 / 90.0 |
164 | | post-cvpr | 28.1 / 66.9 / 91.8 | 46.1 / 73.6 / 92.5 |
165 |
166 | ## Training
167 |
168 | If you want to retrain the recognition network, you can run the following commands.
169 |
170 | * training recognition on Aachen_v1.1
171 |
172 | ```
173 | ./train_aachen
174 | ```
175 |
176 | * training recognition on RobotCar-Seasons
177 |
178 | ```
179 | ./train_robotcar
180 | ```
181 |
182 | ## BibTeX Citation
183 |
184 | If you use any ideas from the paper or code from this repo, please consider citing:
185 |
186 | ```
187 | @inproceedings{xue2022efficient,
188 | author = {Fei Xue and Ignas Budvytis and Daniel Olmeda Reino and Roberto Cipolla},
189 | title = {Efficient Large-scale Localization by Global Instance Recognition},
190 | booktitle = {CVPR},
191 | year = {2022}
192 | }
193 | ```
194 |
195 | ## Acknowledgements
196 |
197 | Part of the code is from previous excellent works
198 | including [SuperPoint](https://github.com/magicleap/SuperPointPretrainedNetwork), [R2D2](https://github.com/naver/r2d2)
199 | , [HLoc](https://github.com/cvg/Hierarchical-Localization). You can find more details from their released repositories
200 | if you are interested in their works.
201 |
--------------------------------------------------------------------------------
/aachen_452.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feixue94/lbr/bdd126b79da9e75759a364a0d8e010d76658f1af/aachen_452.npy
--------------------------------------------------------------------------------
/assets/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feixue94/lbr/bdd126b79da9e75759a364a0d8e010d76658f1af/assets/overview.png
--------------------------------------------------------------------------------
/assets/samples/1116.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feixue94/lbr/bdd126b79da9e75759a364a0d8e010d76658f1af/assets/samples/1116.png
--------------------------------------------------------------------------------
/assets/samples/1417176916116788.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feixue94/lbr/bdd126b79da9e75759a364a0d8e010d76658f1af/assets/samples/1417176916116788.png
--------------------------------------------------------------------------------
/configs/config_test_aachen.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "SHLoc",
3 | "dataset": "aachen",
4 | "gpu": 0,
5 | "image_path": "/scratches/flyer_2/fx221/localization/aachen_v1_1/images/images_upright",
6 | "image_list": "datasets/aachen/aachen_db_query_imglist.txt",
7 | "grgb_gid_file": "datasets/aachen/aachen_grgb_gid_v5.txt",
8 | "classes": 452,
9 | "segmentation": 1,
10 | "classification": 1,
11 | "augmentation": 1,
12 | "network": "pspf",
13 | "encoder_name": "resnext101_32x4d",
14 | "encoder_weights": "ssl",
15 | "out_channels": 2048,
16 | "encoder_depth": 4,
17 | "R": 256,
18 | "upsampling": 8,
19 | "bs": 8,
20 | "save_dir": "/scratches/flyer_2/fx221/exp/shloc/data_release/aachen_v1.1",
21 | "pretrained_weight": "weights/2021_08_29_12_49_48_aachen_pspf_resnext101_32x4d_d4_u8_b16_R256_E120_ceohem_adam_seg_cls_aug.pth"
22 | }
--------------------------------------------------------------------------------
/configs/config_test_robotcar.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "SHLoc",
3 | "dataset": "robotcar",
4 | "gpu": 0,
5 | "image_path": "/scratches/flyer_2/fx221/localization/RobotCar-Seasons/images",
6 | "image_list": "datasets/robotcar/robotcar_rear_test_file_list_full.txt",
7 | "grgb_gid_file": "datasets/robotcar/robotcar_rear_grgb_gid.txt",
8 | "classes": 692,
9 | "segmentation": 1,
10 | "classification": 1,
11 | "network": "pspf",
12 | "encoder_name": "resnext101_32x4d",
13 | "encoder_weights": "ssl",
14 | "out_channels": 2048,
15 | "encoder_depth": 4,
16 | "upsampling": 8,
17 | "bs": 8,
18 | "save_dir": "/scratches/flyer_2/fx221/exp/shloc/data_release/RobotCar-Seasons",
19 | "pretrained_weight": "weights/2021_10_17_00_21_32_robotcar_pspf_resnext101_32x4d_d4_u8_b16_R256_E500_ceohem_adam_poly_mlr_seg_cls_aug.pth"
20 | }
--------------------------------------------------------------------------------
/configs/config_train_aachen.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "SHLoc",
3 | "dataset": "aachen",
4 | "gpu": [
5 | 0
6 | ],
7 | "root": "/scratches/flyer_2/fx221/localization/aachen_v1_1",
8 | "save_root": "/scratches/flyer_2/fx221/exp/shloc/aachen",
9 | "train_label_path": "global_seg_instance",
10 | "train_image_path": "stylized",
11 | "val_label_path": "global_seg_instance",
12 | "val_image_path": "stylized",
13 | "tag": "stylized",
14 | "grgb_gid_file": "./datasets/aachen/aachen_grgb_gid_v5.txt",
15 | "train_imglist": "./datasets/aachen/aachen_train_file_list_v5.txt",
16 | "test_imglist": "./datasets/aachen/aachen_test_file_list_v5.txt",
17 | "train_cats": [
18 | "images/images_upright",
19 | "images_aug1/images_upright",
20 | "images_aug2/images_upright"
21 | ],
22 | "val_cats": [
23 | "images_aug/images_upright"
24 | ],
25 | "classes": 452,
26 | "weight_cls": 2.0,
27 | "segmentation": 1,
28 | "classification": 1,
29 | "augmentation": 1,
30 | "network": "pspf",
31 | "encoder_name": "resnext101_32x4d",
32 | "encoder_weights": "ssl",
33 | "seg_loss": "ceohem",
34 | "seg_loss_sce": "sceohem",
35 | "out_channels": 2048,
36 | "encoder_depth": 4,
37 | "upsampling": 8,
38 | "loss": "ce",
39 | "bs": 16,
40 | "R": 256,
41 | "optimizer": "adam",
42 | "weight_decay": 0.0005,
43 | "lr_policy": "step",
44 | "multi_lr": 1,
45 | "epochs": 120,
46 | "milestones": [
47 | 80,
48 | 100
49 | ],
50 | "lr": 0.0001,
51 | "workers": 4,
52 | "log_interval": 50,
53 | "save_dir": "",
54 | "val": 1,
55 | "aug": 1
56 | }
--------------------------------------------------------------------------------
/configs/config_train_robotcar.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "SHLoc",
3 | "dataset": "robotcar",
4 | "gpu": [
5 | 0
6 | ],
7 | "root": "/scratches/flyer_2/fx221/localization/RobotCar-Seasons",
8 | "save_root": "/scratches/flyer_2/fx221/exp/shloc/robotcar",
9 | "train_label_path": "global_seg_instance/mixed",
10 | "train_image_path": "stylized",
11 | "val_label_path": "global_seg_instance",
12 | "val_image_path": "stylized",
13 | "tag": "stylized",
14 | "grgb_gid_file": "./datasets/robotcar/robotcar_grgb_gid.txt",
15 | "train_imglist": "./datasets/robotcar/robotcar_train_file_list_full.txt",
16 | "test_imglist": "./datasets/robotcar/robotcar_test_file_list.txt",
17 | "train_cats": [
18 | "overcast-reference",
19 | "overcast-reference_aug_d",
20 | "overcast-reference_aug_n"
21 | ],
22 | "val_cats": [
23 | "overcast-reference_aug_nr"
24 | ],
25 | "classes": 692,
26 | "weight_cls": 2.0,
27 | "segmentation": 1,
28 | "classification": 1,
29 | "augmentation": 1,
30 | "network": "pspf",
31 | "encoder_name": "resnext101_32x4d",
32 | "encoder_weights": "ssl",
33 | "seg_loss": "ceohem",
34 | "seg_loss_sce": "sceohem",
35 | "out_channels": 2048,
36 | "encoder_depth": 4,
37 | "upsampling": 8,
38 | "loss": "ce",
39 | "bs": 16,
40 | "R": 256,
41 | "optimizer": "adam",
42 | "weight_decay": 0.0005,
43 | "lr_policy": "poly",
44 | "multi_lr": 1,
45 | "epochs": 500,
46 | "milestones": [
47 | 80,
48 | 100
49 | ],
50 | "lr": 0.0001,
51 | "workers": 4,
52 | "log_interval": 50,
53 | "val": 1,
54 | "aug": 1
55 | }
--------------------------------------------------------------------------------
/dataloader/aachen.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | '''=================================================
3 | @Project -> File shloc -> aachen
4 | @IDE PyCharm
5 | @Author fx221@cam.ac.uk
6 | @Date 18/07/2021 10:22
7 | =================================================='''
8 | import torch.utils.data as data
9 | from torch.utils.data.dataloader import *
10 | import numpy as np
11 | import os.path as osp
12 | import cv2
13 | import random
14 | from tools.common import resize_img
15 |
16 |
17 | class AachenSegFull(data.Dataset):
18 | def __init__(self,
19 | image_path,
20 | label_path,
21 | grgb_gid_file,
22 | n_classes,
23 | cats,
24 | train=True,
25 | transform=None,
26 | use_cls=False,
27 | img_list=None,
28 | preload=False,
29 | aug=None,
30 | R=256,
31 | keep_ratio=False,
32 | ):
33 | super(AachenSegFull, self).__init__()
34 | self.image_path = image_path
35 | self.label_path = label_path
36 | self.transform = transform
37 | self.train = train
38 | self.cls = use_cls
39 | self.aug = aug
40 | self.cats = cats
41 | self.R = R
42 | print("augmentation: ", self.aug)
43 | self.classes = n_classes
44 | self.keep_ratio = keep_ratio
45 |
46 | self.valid_fns = []
47 | with open(img_list, "r") as f:
48 | lines = f.readlines()
49 | for l in lines:
50 | l = l.strip()
51 | if not osp.exists(osp.join(self.label_path, l.replace("jpg", "png"))):
52 | continue
53 |
54 | keep = True
55 | for c in self.cats:
56 | if not osp.exists(osp.join(self.image_path, c, l)):
57 | keep = False
58 | if not keep:
59 | continue
60 | self.valid_fns.append(l)
61 | #
62 | self.grgb_gid = {}
63 | with open(grgb_gid_file, "r") as f:
64 | lines = f.readlines()
65 | for l in lines:
66 | l = l.strip().split(" ")
67 | self.grgb_gid[int(l[0])] = int(l[1])
68 |
69 | print("Load {:d} valid samples.".format(len(self.valid_fns)))
70 | print("No. of gids: {:d}".format(len(self.grgb_gid.keys())))
71 |
72 | def seg_to_gid(self, seg_img, grgb_gid_map):
73 | id_img = np.int32(seg_img[:, :, 2]) * 256 * 256 + np.int32(seg_img[:, :, 1]) * 256 + np.int32(
74 | seg_img[:, :, 0])
75 | luids = np.unique(id_img).tolist()
76 | # print("id_img: ", id_img.shape)
77 | out_img = np.zeros_like(seg_img)
78 | gid_img = np.zeros_like(id_img)
79 | for id in luids:
80 | if id in grgb_gid_map.keys():
81 | gid = grgb_gid_map[id]
82 | mask = (id_img == id)
83 | gid_img[mask] = gid
84 |
85 | out_img[mask] = seg_img[mask]
86 |
87 | return out_img, gid_img
88 |
89 | def get_item_seg(self, idx):
90 | fn = self.valid_fns[idx]
91 |
92 | if len(self.cats) > 1:
93 | cat_id = random.randint(1, len(self.cats))
94 | else:
95 | cat_id = 0
96 | # print(osp.join(self.image_path, self.cats[cat_id - 1], fn))
97 | # exit(0)
98 | # print("fn: ", fn)
99 | img = cv2.imread(osp.join(self.image_path, self.cats[cat_id - 1], fn))
100 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
101 | # print(osp.join(self.label_path, fn.replace('jpg', 'png')))
102 | seg_img = cv2.imread(osp.join(self.label_path, fn.replace('jpg', 'png')))
103 | # print(img.shape, seg_img.shape)
104 | seg_img = cv2.resize(seg_img, dsize=(img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST)
105 | # print(img.shape, seg_img.shape)
106 | # exit(0)
107 |
108 | if self.aug is not None:
109 | for a in self.aug:
110 | img, seg_img = a((img, seg_img))
111 | else:
112 | if self.keep_ratio:
113 | img = resize_img(img=img, nh=self.R, mode=cv2.INTER_NEAREST)
114 | else:
115 | img = cv2.resize(img.astype(np.uint8), dsize=(self.R, self.R))
116 | seg_img = cv2.resize(seg_img.astype(np.uint8), dsize=(img.shape[1], img.shape[0]),
117 | interpolation=cv2.INTER_NEAREST)
118 | if self.transform is not None:
119 | img = self.transform(img)
120 | seg_img = np.array(seg_img)
121 |
122 | filtered_seg, gids = self.seg_to_gid(seg_img=seg_img, grgb_gid_map=self.grgb_gid)
123 | gids = np.asarray(gids, np.int64)
124 | gids = torch.from_numpy(gids)
125 | gids = torch.LongTensor(gids)
126 |
127 | output = {
128 | "img": img,
129 | "label": [gids],
130 | "label_img": filtered_seg,
131 | }
132 |
133 | if self.cls:
134 | cls_label = np.zeros(shape=(self.classes), dtype=np.float)
135 | uids = np.unique(gids).tolist()
136 | for id in uids:
137 | cls_label[id] = 1.0
138 | cls_label = torch.from_numpy(cls_label)
139 |
140 | output["cls"] = [cls_label]
141 |
142 | return output
143 |
144 | def __getitem__(self, idx):
145 | return self.get_item_seg(idx=idx)
146 |
147 | def __len__(self):
148 | return len(self.valid_fns)
149 |
--------------------------------------------------------------------------------
/dataloader/robotcar.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | '''=================================================
3 | @Project -> File shloc -> robotcar
4 | @IDE PyCharm
5 | @Author fx221@cam.ac.uk
6 | @Date 18/07/2021 10:22
7 | =================================================='''
8 | import torch.utils.data as data
9 | from torch.utils.data.dataloader import *
10 | import numpy as np
11 | import os.path as osp
12 | import cv2
13 | import random
14 | from tools.common import resize_img
15 |
16 |
17 | class RobotCarSegFull(data.Dataset):
18 | def __init__(self,
19 | image_path,
20 | label_path,
21 | grgb_gid_file,
22 | n_classes,
23 | cats,
24 | train=True,
25 | transform=None,
26 | use_cls=False,
27 | img_list=None,
28 | preload=False,
29 | aug=None,
30 | R=256,
31 | keep_ratio=False,
32 | ):
33 | super(RobotCarSegFull, self).__init__()
34 | self.image_path = image_path
35 | self.label_path = label_path
36 | self.transform = transform
37 | self.train = train
38 | self.cls = use_cls
39 | self.aug = aug
40 | self.cats = cats
41 | self.R = R
42 | print("augmentation: ", self.aug)
43 | self.classes = n_classes
44 | self.keep_ratio = keep_ratio
45 |
46 | self.valid_fns = []
47 | with open(img_list, "r") as f:
48 | lines = f.readlines()
49 | for l in lines:
50 | l = l.strip()
51 | if not osp.exists(osp.join(self.label_path, l.replace("jpg", "png"))):
52 | continue
53 |
54 | keep = True
55 | for c in self.cats:
56 | if not osp.exists(osp.join(self.image_path, c, l)):
57 | keep = False
58 | if not keep:
59 | continue
60 | self.valid_fns.append(l)
61 |
62 | self.grgb_gid = {}
63 | with open(grgb_gid_file, "r") as f:
64 | lines = f.readlines()
65 | for l in lines:
66 | l = l.strip().split(" ")
67 | self.grgb_gid[int(l[0])] = int(l[1])
68 |
69 | print("Load {:d} valid samples.".format(len(self.valid_fns)))
70 | print("No. of gids: {:d}".format(len(self.grgb_gid.keys())))
71 |
72 | def seg_to_gid(self, seg_img, grgb_gid_map):
73 | id_img = np.int32(seg_img[:, :, 2]) * 256 * 256 + np.int32(seg_img[:, :, 1]) * 256 + np.int32(
74 | seg_img[:, :, 0])
75 | luids = np.unique(id_img).tolist()
76 | # print("id_img: ", id_img.shape)
77 | out_img = np.zeros_like(seg_img)
78 | gid_img = np.zeros_like(id_img)
79 | for id in luids:
80 | if id in grgb_gid_map.keys():
81 | gid = grgb_gid_map[id]
82 | mask = (id_img == id)
83 | gid_img[mask] = gid
84 |
85 | out_img[mask] = seg_img[mask]
86 |
87 | return out_img, gid_img
88 |
89 | def get_item_seg(self, idx):
90 | fn = self.valid_fns[idx]
91 |
92 | cat_id = random.randint(1, len(self.cats))
93 | # print(osp.join(self.image_path, self.cats[cat_id - 1], fn))
94 | img = cv2.imread(osp.join(self.image_path, self.cats[cat_id - 1], fn))
95 | # raw_img = cv2.resize(img, dsize=(self.R, self.R))
96 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
97 |
98 | seg_img = cv2.imread(osp.join(self.label_path, fn.replace('jpg', 'png')))
99 |
100 | if self.aug is not None:
101 | for a in self.aug:
102 | img, seg_img = a((img, seg_img))
103 |
104 | else:
105 | if self.keep_ratio:
106 | img = resize_img(img=img, nh=self.R, mode=cv2.INTER_NEAREST)
107 | else:
108 | img = cv2.resize(img.astype(np.uint8), dsize=(self.R, self.R))
109 | seg_img = cv2.resize(seg_img.astype(np.uint8), dsize=(img.shape[1], img.shape[0]),
110 | interpolation=cv2.INTER_NEAREST)
111 | raw_img = img.astype(np.uint8)
112 |
113 | if self.transform is not None:
114 | img = self.transform(img)
115 | seg_img = np.array(seg_img)
116 |
117 | filtered_seg, gids = self.seg_to_gid(seg_img=seg_img, grgb_gid_map=self.grgb_gid)
118 | gids = np.asarray(gids, np.int64)
119 | gids = torch.from_numpy(gids)
120 | gids = torch.LongTensor(gids)
121 |
122 | output = {
123 | "raw_img": raw_img,
124 | "img": img,
125 | "label": [gids],
126 | "label_img": filtered_seg,
127 | }
128 |
129 | if self.cls:
130 | cls_label = np.zeros(shape=(self.classes), dtype=np.float)
131 | uids = np.unique(gids).tolist()
132 | for id in uids:
133 | cls_label[id] = 1.0
134 | cls_label = torch.from_numpy(cls_label)
135 |
136 | output["cls"] = [cls_label]
137 |
138 | return output
139 |
140 | def __getitem__(self, idx):
141 | return self.get_item_seg(idx=idx)
142 |
143 | def __len__(self):
144 | return len(self.valid_fns)
145 |
--------------------------------------------------------------------------------
/datasets/aachen/aachen_grgb_gid_v5.txt:
--------------------------------------------------------------------------------
1 | 0 0
2 | 3278090 1
3 | 3278100 2
4 | 3278115 3
5 | 3278120 4
6 | 3278135 5
7 | 3278150 6
8 | 3278155 7
9 | 3278160 8
10 | 3278165 9
11 | 3278175 10
12 | 3278180 11
13 | 3278185 12
14 | 3278195 13
15 | 3278205 14
16 | 3278215 15
17 | 3278225 16
18 | 3278230 17
19 | 3278235 18
20 | 3278245 19
21 | 3278255 20
22 | 3278260 21
23 | 3278265 22
24 | 3278280 23
25 | 3279365 24
26 | 3279380 25
27 | 3279385 26
28 | 3279390 27
29 | 3279430 28
30 | 3279450 29
31 | 3279455 30
32 | 3279460 31
33 | 3279465 32
34 | 3279470 33
35 | 3279475 34
36 | 3279485 35
37 | 3279490 36
38 | 3279500 37
39 | 3279515 38
40 | 3280805 39
41 | 3280815 40
42 | 3280820 41
43 | 3280825 42
44 | 3281960 43
45 | 3281975 44
46 | 3283215 45
47 | 3285785 46
48 | 3285800 47
49 | 3285830 48
50 | 3285850 49
51 | 3285860 50
52 | 3285865 51
53 | 3285885 52
54 | 3285895 53
55 | 3285910 54
56 | 3285925 55
57 | 3287205 56
58 | 3287260 57
59 | 3288385 58
60 | 3288485 59
61 | 3288490 60
62 | 3288500 61
63 | 3288505 62
64 | 3288510 63
65 | 3288520 64
66 | 3288530 65
67 | 3289615 66
68 | 3289630 67
69 | 3289650 68
70 | 3289660 69
71 | 3289665 70
72 | 3289670 71
73 | 3289680 72
74 | 3289725 73
75 | 3289735 74
76 | 3289820 75
77 | 3289825 76
78 | 3290960 77
79 | 3290970 78
80 | 3290980 79
81 | 3291000 80
82 | 3291010 81
83 | 3291015 82
84 | 3291020 83
85 | 3291025 84
86 | 3291035 85
87 | 3291060 86
88 | 3292190 87
89 | 3292195 88
90 | 3292200 89
91 | 3292210 90
92 | 3292215 91
93 | 3292225 92
94 | 3292230 93
95 | 3292235 94
96 | 3292245 95
97 | 3292260 96
98 | 3292275 97
99 | 3292285 98
100 | 3292290 99
101 | 3292295 100
102 | 3292300 101
103 | 3292310 102
104 | 3292320 103
105 | 3292335 104
106 | 3292340 105
107 | 3292345 106
108 | 3292350 107
109 | 3292355 108
110 | 3292360 109
111 | 3292370 110
112 | 3292380 111
113 | 3292385 112
114 | 3292390 113
115 | 3293445 114
116 | 3293455 115
117 | 3293460 116
118 | 3293470 117
119 | 3293490 118
120 | 3293505 119
121 | 3293510 120
122 | 3293515 121
123 | 3293520 122
124 | 3293525 123
125 | 3293530 124
126 | 3293590 125
127 | 3293595 126
128 | 3293650 127
129 | 3293655 128
130 | 3293660 129
131 | 3293665 130
132 | 3294725 131
133 | 3294730 132
134 | 3294735 133
135 | 3294740 134
136 | 3294750 135
137 | 3294805 136
138 | 3294850 137
139 | 3294860 138
140 | 3294865 139
141 | 3294870 140
142 | 3294875 141
143 | 3294885 142
144 | 3294895 143
145 | 3294900 144
146 | 3297295 145
147 | 3297485 146
148 | 3297500 147
149 | 3297525 148
150 | 3298775 149
151 | 3298780 150
152 | 3298790 151
153 | 3300065 152
154 | 3301170 153
155 | 3301290 154
156 | 3301310 155
157 | 3301315 156
158 | 3301325 157
159 | 3301330 158
160 | 3302470 159
161 | 3302475 160
162 | 3302480 161
163 | 3302490 162
164 | 3302545 163
165 | 3302570 164
166 | 3302595 165
167 | 3302605 166
168 | 3303850 167
169 | 3303870 168
170 | 3303915 169
171 | 3303920 170
172 | 3303925 171
173 | 3304995 172
174 | 3306370 173
175 | 3307600 174
176 | 3307615 175
177 | 3307620 176
178 | 3307630 177
179 | 3307635 178
180 | 3307675 179
181 | 3307700 180
182 | 3307705 181
183 | 3307710 182
184 | 3307715 183
185 | 3307735 184
186 | 3308975 185
187 | 3311395 186
188 | 3313985 187
189 | 3316630 188
190 | 3316635 189
191 | 3316640 190
192 | 3316645 191
193 | 3316695 192
194 | 3319100 193
195 | 3319110 194
196 | 3319120 195
197 | 3321615 196
198 | 3321825 197
199 | 3324240 198
200 | 3325525 199
201 | 3325545 200
202 | 3326760 201
203 | 3326840 202
204 | 3326870 203
205 | 3326875 204
206 | 3328015 205
207 | 3330570 206
208 | 3330665 207
209 | 3330680 208
210 | 3330755 209
211 | 3330760 210
212 | 3331930 211
213 | 3331945 212
214 | 3331950 213
215 | 3332005 214
216 | 3333175 215
217 | 3333290 216
218 | 3333305 217
219 | 3334545 218
220 | 3334560 219
221 | 3334575 220
222 | 3334585 221
223 | 3334595 222
224 | 3334600 223
225 | 3334620 224
226 | 3336985 225
227 | 3338330 226
228 | 3339625 227
229 | 3339695 228
230 | 3339700 229
231 | 3474813 230
232 | 3474873 231
233 | 3477278 232
234 | 3477293 233
235 | 3477303 234
236 | 3478538 235
237 | 3479883 236
238 | 3479983 237
239 | 3481328 238
240 | 3488928 239
241 | 3491403 240
242 | 3492753 241
243 | 3492773 242
244 | 3495268 243
245 | 3496653 244
246 | 3497763 245
247 | 3500388 246
248 | 3501768 247
249 | 3502908 248
250 | 3502918 249
251 | 3502973 250
252 | 3503018 251
253 | 3504168 252
254 | 3504173 253
255 | 3504218 254
256 | 3504253 255
257 | 3504258 256
258 | 3504288 257
259 | 3504308 258
260 | 3504338 259
261 | 3505423 260
262 | 3505433 261
263 | 3505508 262
264 | 3506748 263
265 | 3506758 264
266 | 3506793 265
267 | 3506858 266
268 | 3510613 267
269 | 3510643 268
270 | 3510648 269
271 | 3511823 270
272 | 3513293 271
273 | 3515853 272
274 | 3517038 273
275 | 3517153 274
276 | 3518393 275
277 | 3518403 276
278 | 3518438 277
279 | 3519538 278
280 | 3519563 279
281 | 3519608 280
282 | 3519613 281
283 | 3519638 282
284 | 3519643 283
285 | 3519648 284
286 | 3519653 285
287 | 3519713 286
288 | 3522128 287
289 | 3522203 288
290 | 3522208 289
291 | 3523468 290
292 | 3526123 291
293 | 3527253 292
294 | 3527268 293
295 | 3527298 294
296 | 3527308 295
297 | 3527333 296
298 | 3528453 297
299 | 3529798 298
300 | 3529818 299
301 | 3529843 300
302 | 3529903 301
303 | 3529973 302
304 | 3531013 303
305 | 3531033 304
306 | 3531053 305
307 | 3531073 306
308 | 3531108 307
309 | 3531148 308
310 | 3531183 309
311 | 3531228 310
312 | 3531243 311
313 | 3533618 312
314 | 3533698 313
315 | 3533713 314
316 | 3533718 315
317 | 3533738 316
318 | 3533803 317
319 | 3671501 318
320 | 3672671 319
321 | 3675261 320
322 | 3677806 321
323 | 3677831 322
324 | 3677836 323
325 | 3677846 324
326 | 3677876 325
327 | 3677921 326
328 | 3677926 327
329 | 3677936 328
330 | 3677941 329
331 | 3678981 330
332 | 3678996 331
333 | 3679001 332
334 | 3679006 333
335 | 3679036 334
336 | 3679056 335
337 | 3679061 336
338 | 3679066 337
339 | 3679081 338
340 | 3679091 339
341 | 3679096 340
342 | 3679101 341
343 | 3679111 342
344 | 3679116 343
345 | 3679121 344
346 | 3679131 345
347 | 3679146 346
348 | 3679186 347
349 | 3679216 348
350 | 3680381 349
351 | 3680396 350
352 | 3680461 351
353 | 3680476 352
354 | 3681591 353
355 | 3681606 354
356 | 3681656 355
357 | 3681666 356
358 | 3681671 357
359 | 3682856 358
360 | 3682891 359
361 | 3682896 360
362 | 3682911 361
363 | 3682916 362
364 | 3682936 363
365 | 3682971 364
366 | 3683016 365
367 | 3683021 366
368 | 3683031 367
369 | 3683041 368
370 | 3683046 369
371 | 3683051 370
372 | 3683056 371
373 | 3683061 372
374 | 3684121 373
375 | 3684136 374
376 | 3684331 375
377 | 3685471 376
378 | 3685486 377
379 | 3685506 378
380 | 3685541 379
381 | 3685551 380
382 | 3685556 381
383 | 3685571 382
384 | 3685601 383
385 | 3685616 384
386 | 3686716 385
387 | 3686721 386
388 | 3686726 387
389 | 3686736 388
390 | 3686741 389
391 | 3686751 390
392 | 3686786 391
393 | 3686821 392
394 | 3688101 393
395 | 3688141 394
396 | 3688176 395
397 | 3689311 396
398 | 3689321 397
399 | 3689341 398
400 | 3689366 399
401 | 3689436 400
402 | 3690676 401
403 | 3690706 402
404 | 3691806 403
405 | 3691831 404
406 | 3691846 405
407 | 3691871 406
408 | 3691901 407
409 | 3691906 408
410 | 3691911 409
411 | 3691916 410
412 | 3691921 411
413 | 3693191 412
414 | 3693231 413
415 | 3693291 414
416 | 3693301 415
417 | 3694351 416
418 | 3694376 417
419 | 3694396 418
420 | 3694526 419
421 | 3694556 420
422 | 3694571 421
423 | 3695811 422
424 | 3704806 423
425 | 3705926 424
426 | 3707236 425
427 | 3707256 426
428 | 3709721 427
429 | 3709751 428
430 | 3709791 429
431 | 3709796 430
432 | 3709801 431
433 | 3709861 432
434 | 3709881 433
435 | 3709906 434
436 | 3709916 435
437 | 3712281 436
438 | 3712301 437
439 | 3712306 438
440 | 3712316 439
441 | 3713601 440
442 | 3713611 441
443 | 3713701 442
444 | 3713711 443
445 | 3713716 444
446 | 3714936 445
447 | 3725081 446
448 | 3725091 447
449 | 3727651 448
450 | 3727656 449
451 | 3727681 450
452 | 3694917 451
--------------------------------------------------------------------------------
/datasets/robotcar/robotcar_grgb_gid.txt:
--------------------------------------------------------------------------------
1 | 0 0
2 | 6364 1
3 | 23394 2
4 | 57615 3
5 | 60798 4
6 | 150085 5
7 | 175434 6
8 | 176037 7
9 | 181106 8
10 | 199936 9
11 | 210177 10
12 | 212877 11
13 | 213559 12
14 | 247058 13
15 | 270776 14
16 | 288895 15
17 | 319083 16
18 | 320757 17
19 | 325851 18
20 | 353253 19
21 | 381759 20
22 | 437887 21
23 | 453785 22
24 | 498638 23
25 | 526846 24
26 | 531348 25
27 | 543518 26
28 | 543812 27
29 | 545929 28
30 | 564003 29
31 | 576434 30
32 | 580974 31
33 | 590030 32
34 | 610805 33
35 | 702863 34
36 | 709265 35
37 | 709540 36
38 | 765311 37
39 | 782213 38
40 | 810012 39
41 | 819088 40
42 | 824305 41
43 | 850422 42
44 | 884397 43
45 | 889520 44
46 | 912819 45
47 | 978590 46
48 | 995266 47
49 | 995626 48
50 | 999483 49
51 | 1008184 50
52 | 1009542 51
53 | 1019753 52
54 | 1071053 53
55 | 1085329 54
56 | 1119230 55
57 | 1167652 56
58 | 1225038 57
59 | 1238518 58
60 | 1294830 59
61 | 1314892 60
62 | 1327411 61
63 | 1352622 62
64 | 1366495 63
65 | 1395676 64
66 | 1407940 65
67 | 1440984 66
68 | 1464343 67
69 | 1479206 68
70 | 1490275 69
71 | 1497017 70
72 | 1523442 71
73 | 1549427 72
74 | 1654677 73
75 | 1663294 74
76 | 1668363 75
77 | 1673874 76
78 | 1702843 77
79 | 1717718 78
80 | 1739889 79
81 | 1772714 80
82 | 1817813 81
83 | 1829001 82
84 | 1874620 83
85 | 1875618 84
86 | 1889053 85
87 | 1895727 86
88 | 1912346 87
89 | 1912579 88
90 | 1985512 89
91 | 2072595 90
92 | 2102710 91
93 | 2113846 92
94 | 2126589 93
95 | 2131005 94
96 | 2131153 95
97 | 2138826 96
98 | 2174275 97
99 | 2203459 98
100 | 2213499 99
101 | 2219083 100
102 | 2262220 101
103 | 2264533 102
104 | 2277579 103
105 | 2305834 104
106 | 2326657 105
107 | 2332441 106
108 | 2348870 107
109 | 2359303 108
110 | 2405604 109
111 | 2424837 110
112 | 2490602 111
113 | 2514362 112
114 | 2516488 113
115 | 2527590 114
116 | 2542277 115
117 | 2576385 116
118 | 2583196 117
119 | 2587429 118
120 | 2596827 119
121 | 2597173 120
122 | 2601229 121
123 | 2634996 122
124 | 2663514 123
125 | 2715215 124
126 | 2728917 125
127 | 2733281 126
128 | 2766150 127
129 | 2802044 128
130 | 2827180 129
131 | 2847475 130
132 | 2854690 131
133 | 2871057 132
134 | 2903521 133
135 | 2925661 134
136 | 2927162 135
137 | 2929354 136
138 | 2996188 137
139 | 2998206 138
140 | 3014581 139
141 | 3082380 140
142 | 3110452 141
143 | 3116084 142
144 | 3141376 143
145 | 3152933 144
146 | 3181625 145
147 | 3204717 146
148 | 3233591 147
149 | 3244623 148
150 | 3253176 149
151 | 3262203 150
152 | 3305564 151
153 | 3311652 152
154 | 3320315 153
155 | 3325135 154
156 | 3370285 155
157 | 3412498 156
158 | 3417898 157
159 | 3433598 158
160 | 3544405 159
161 | 3578076 160
162 | 3595339 161
163 | 3650993 162
164 | 3685436 163
165 | 3750221 164
166 | 3798912 165
167 | 3861052 166
168 | 3873159 167
169 | 3896946 168
170 | 3903459 169
171 | 3926388 170
172 | 3999718 171
173 | 4026091 172
174 | 4030735 173
175 | 4067188 174
176 | 4069818 175
177 | 4153347 176
178 | 4184294 177
179 | 4234676 178
180 | 4252117 179
181 | 4266213 180
182 | 4309693 181
183 | 4338284 182
184 | 4372118 183
185 | 4384678 184
186 | 4400102 185
187 | 4453066 186
188 | 4462363 187
189 | 4482266 188
190 | 4491681 189
191 | 4523381 190
192 | 4555476 191
193 | 4566179 192
194 | 4569433 193
195 | 4604363 194
196 | 4606003 195
197 | 4616685 196
198 | 4630582 197
199 | 4646095 198
200 | 4647883 199
201 | 4664923 200
202 | 4665749 201
203 | 4671493 202
204 | 4673808 203
205 | 4689288 204
206 | 4777859 205
207 | 4799179 206
208 | 4801898 207
209 | 4811530 208
210 | 4817074 209
211 | 4817818 210
212 | 4858907 211
213 | 4863345 212
214 | 4890813 213
215 | 4898226 214
216 | 4910943 215
217 | 4935911 216
218 | 4936408 217
219 | 4936902 218
220 | 4940944 219
221 | 4971984 220
222 | 5027601 221
223 | 5035672 222
224 | 5039127 223
225 | 5046480 224
226 | 5111995 225
227 | 5138567 226
228 | 5147419 227
229 | 5151530 228
230 | 5205192 229
231 | 5217593 230
232 | 5241512 231
233 | 5242836 232
234 | 5284001 233
235 | 5286259 234
236 | 5315600 235
237 | 5345921 236
238 | 5386746 237
239 | 5395385 238
240 | 5442999 239
241 | 5475736 240
242 | 5510360 241
243 | 5514032 242
244 | 5553568 243
245 | 5569361 244
246 | 5580445 245
247 | 5589550 246
248 | 5615807 247
249 | 5620240 248
250 | 5700847 249
251 | 5731987 250
252 | 5751326 251
253 | 5766574 252
254 | 5787866 253
255 | 6013635 254
256 | 6037937 255
257 | 6041614 256
258 | 6046315 257
259 | 6056857 258
260 | 6063872 259
261 | 6118186 260
262 | 6131942 261
263 | 6179666 262
264 | 6189326 263
265 | 6215098 264
266 | 6279671 265
267 | 6287660 266
268 | 6309752 267
269 | 6460881 268
270 | 6489020 269
271 | 6504191 270
272 | 6509113 271
273 | 6561575 272
274 | 6610995 273
275 | 6613578 274
276 | 6643602 275
277 | 6649335 276
278 | 6658223 277
279 | 6728999 278
280 | 6730265 279
281 | 6791232 280
282 | 6881234 281
283 | 6946263 282
284 | 6990915 283
285 | 6996546 284
286 | 7035810 285
287 | 7052088 286
288 | 7132222 287
289 | 7144087 288
290 | 7154880 289
291 | 7162659 290
292 | 7170321 291
293 | 7211586 292
294 | 7228156 293
295 | 7249365 294
296 | 7253543 295
297 | 7254915 296
298 | 7292702 297
299 | 7308500 298
300 | 7333343 299
301 | 7339663 300
302 | 7366751 301
303 | 7375562 302
304 | 7430941 303
305 | 7497254 304
306 | 7549429 305
307 | 7551394 306
308 | 7595620 307
309 | 7609170 308
310 | 7650576 309
311 | 7666026 310
312 | 7674197 311
313 | 7680288 312
314 | 7680575 313
315 | 7707348 314
316 | 7725442 315
317 | 7744645 316
318 | 7827654 317
319 | 7868719 318
320 | 7892243 319
321 | 7946736 320
322 | 7969433 321
323 | 7973717 322
324 | 7985090 323
325 | 7996993 324
326 | 8016201 325
327 | 8020153 326
328 | 8023008 327
329 | 8048764 328
330 | 8057055 329
331 | 8062996 330
332 | 8108598 331
333 | 8112789 332
334 | 8116651 333
335 | 8123456 334
336 | 8125414 335
337 | 8129018 336
338 | 8174750 337
339 | 8175269 338
340 | 8196328 339
341 | 8228754 340
342 | 8273920 341
343 | 8350040 342
344 | 8403142 343
345 | 8437876 344
346 | 8526921 345
347 | 8540988 346
348 | 8548323 347
349 | 8553226 348
350 | 8562994 349
351 | 8569045 350
352 | 8581435 351
353 | 8606250 352
354 | 8726120 353
355 | 8730736 354
356 | 8749543 355
357 | 8761546 356
358 | 8767900 357
359 | 8777866 358
360 | 8787062 359
361 | 8792788 360
362 | 8818625 361
363 | 8841999 362
364 | 8874784 363
365 | 8892815 364
366 | 8906037 365
367 | 8945531 366
368 | 8980418 367
369 | 9021538 368
370 | 9022943 369
371 | 9041144 370
372 | 9076416 371
373 | 9102346 372
374 | 9107106 373
375 | 9149554 374
376 | 9169103 375
377 | 9169494 376
378 | 9170693 377
379 | 9216740 378
380 | 9282799 379
381 | 9283630 380
382 | 9289720 381
383 | 9310627 382
384 | 9324749 383
385 | 9396141 384
386 | 9436438 385
387 | 9512503 386
388 | 9560712 387
389 | 9562931 388
390 | 9600695 389
391 | 9628261 390
392 | 9629210 391
393 | 9632827 392
394 | 9645554 393
395 | 9719535 394
396 | 9747429 395
397 | 9750312 396
398 | 9782122 397
399 | 9784114 398
400 | 9811084 399
401 | 9811643 400
402 | 9818447 401
403 | 9820147 402
404 | 9847195 403
405 | 9852829 404
406 | 9870536 405
407 | 9878298 406
408 | 9883630 407
409 | 9884893 408
410 | 9902591 409
411 | 9903263 410
412 | 9965108 411
413 | 10002187 412
414 | 10005604 413
415 | 10026878 414
416 | 10058738 415
417 | 10058929 416
418 | 10064207 417
419 | 10103837 418
420 | 10104710 419
421 | 10112543 420
422 | 10137990 421
423 | 10146387 422
424 | 10164310 423
425 | 10183214 424
426 | 10206610 425
427 | 10216704 426
428 | 10230878 427
429 | 10281694 428
430 | 10288524 429
431 | 10385002 430
432 | 10395596 431
433 | 10409534 432
434 | 10439224 433
435 | 10459801 434
436 | 10488652 435
437 | 10493062 436
438 | 10522066 437
439 | 10549138 438
440 | 10601801 439
441 | 10637756 440
442 | 10665783 441
443 | 10677004 442
444 | 10705015 443
445 | 10728864 444
446 | 10731818 445
447 | 10745640 446
448 | 10750859 447
449 | 10800238 448
450 | 10889446 449
451 | 10973279 450
452 | 10982006 451
453 | 10996763 452
454 | 11051513 453
455 | 11068074 454
456 | 11077388 455
457 | 11108289 456
458 | 11123316 457
459 | 11136986 458
460 | 11161653 459
461 | 11175276 460
462 | 11176137 461
463 | 11214427 462
464 | 11260563 463
465 | 11324726 464
466 | 11363293 465
467 | 11371057 466
468 | 11410265 467
469 | 11418527 468
470 | 11421580 469
471 | 11502725 470
472 | 11554693 471
473 | 11561343 472
474 | 11561455 473
475 | 11642146 474
476 | 11647519 475
477 | 11647837 476
478 | 11657865 477
479 | 11667561 478
480 | 11685376 479
481 | 11693576 480
482 | 11723261 481
483 | 11739127 482
484 | 11743117 483
485 | 11763511 484
486 | 11786771 485
487 | 11787975 486
488 | 11789999 487
489 | 11875511 488
490 | 11885186 489
491 | 11957138 490
492 | 11966246 491
493 | 11972745 492
494 | 11998374 493
495 | 12020652 494
496 | 12060872 495
497 | 12086366 496
498 | 12116689 497
499 | 12158955 498
500 | 12203366 499
501 | 12207263 500
502 | 12245620 501
503 | 12252926 502
504 | 12291624 503
505 | 12310383 504
506 | 12371102 505
507 | 12378159 506
508 | 12380007 507
509 | 12385219 508
510 | 12404930 509
511 | 12415030 510
512 | 12434276 511
513 | 12434905 512
514 | 12485852 513
515 | 12493645 514
516 | 12506905 515
517 | 12536153 516
518 | 12572058 517
519 | 12577856 518
520 | 12682672 519
521 | 12689876 520
522 | 12748234 521
523 | 12763375 522
524 | 12780903 523
525 | 12816035 524
526 | 12832517 525
527 | 12871622 526
528 | 12906613 527
529 | 12910727 528
530 | 12958419 529
531 | 12965337 530
532 | 13022185 531
533 | 13026292 532
534 | 13061705 533
535 | 13103964 534
536 | 13142750 535
537 | 13164169 536
538 | 13165124 537
539 | 13171812 538
540 | 13235093 539
541 | 13250309 540
542 | 13252825 541
543 | 13268351 542
544 | 13289735 543
545 | 13316278 544
546 | 13413520 545
547 | 13457974 546
548 | 13458949 547
549 | 13522410 548
550 | 13536838 549
551 | 13545290 550
552 | 13552287 551
553 | 13561967 552
554 | 13575036 553
555 | 13581511 554
556 | 13625816 555
557 | 13627766 556
558 | 13663787 557
559 | 13678685 558
560 | 13730783 559
561 | 13743558 560
562 | 13759183 561
563 | 13792649 562
564 | 13808428 563
565 | 13837615 564
566 | 13877006 565
567 | 13893522 566
568 | 13916954 567
569 | 13917067 568
570 | 13954961 569
571 | 13959251 570
572 | 13974664 571
573 | 13978665 572
574 | 13982792 573
575 | 14003903 574
576 | 14009246 575
577 | 14015053 576
578 | 14017347 577
579 | 14036513 578
580 | 14064916 579
581 | 14073178 580
582 | 14107123 581
583 | 14107593 582
584 | 14151535 583
585 | 14160619 584
586 | 14181127 585
587 | 14189518 586
588 | 14190848 587
589 | 14198842 588
590 | 14208832 589
591 | 14223933 590
592 | 14269518 591
593 | 14312345 592
594 | 14330348 593
595 | 14336072 594
596 | 14339301 595
597 | 14343754 596
598 | 14353177 597
599 | 14355049 598
600 | 14356993 599
601 | 14371602 600
602 | 14384955 601
603 | 14414539 602
604 | 14486186 603
605 | 14543264 604
606 | 14609565 605
607 | 14641169 606
608 | 14665141 607
609 | 14667266 608
610 | 14702863 609
611 | 14708337 610
612 | 14712810 611
613 | 14734180 612
614 | 14758763 613
615 | 14780185 614
616 | 14792991 615
617 | 14817755 616
618 | 14832907 617
619 | 14837073 618
620 | 14862318 619
621 | 14866436 620
622 | 14956803 621
623 | 14974364 622
624 | 14995550 623
625 | 15067389 624
626 | 15067545 625
627 | 15082052 626
628 | 15090153 627
629 | 15093644 628
630 | 15102510 629
631 | 15102568 630
632 | 15139030 631
633 | 15187198 632
634 | 15234898 633
635 | 15250459 634
636 | 15251280 635
637 | 15251388 636
638 | 15260258 637
639 | 15278895 638
640 | 15333752 639
641 | 15389971 640
642 | 15394374 641
643 | 15425435 642
644 | 15441351 643
645 | 15451584 644
646 | 15455566 645
647 | 15480518 646
648 | 15481639 647
649 | 15618964 648
650 | 15654742 649
651 | 15703719 650
652 | 15705344 651
653 | 15817148 652
654 | 15891947 653
655 | 15894739 654
656 | 15908759 655
657 | 15918523 656
658 | 15936857 657
659 | 15946266 658
660 | 15985673 659
661 | 16093211 660
662 | 16101439 661
663 | 16132747 662
664 | 16158626 663
665 | 16175272 664
666 | 16179482 665
667 | 16247824 666
668 | 16252359 667
669 | 16259523 668
670 | 16272330 669
671 | 16278899 670
672 | 16325425 671
673 | 16327790 672
674 | 16328129 673
675 | 16350436 674
676 | 16352135 675
677 | 16381584 676
678 | 16387858 677
679 | 16392621 678
680 | 16393228 679
681 | 16395957 680
682 | 16396044 681
683 | 16416396 682
684 | 16442592 683
685 | 16589510 684
686 | 16600464 685
687 | 16620753 686
688 | 16669188 687
689 | 16705451 688
690 | 16710611 689
691 | 16730133 690
692 | 16740108 691
693 |
--------------------------------------------------------------------------------
/localization/coarse/evaluate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | '''=================================================
3 | @Project -> File lbr -> evaluate
4 | @IDE PyCharm
5 | @Author fx221@cam.ac.uk
6 | @Date 06/04/2022 11:56
7 | =================================================='''
8 | import numpy as np
9 |
10 |
11 | def ap_k(query_labels, k=1):
12 | ap = 0.0
13 | positives = 0
14 | for idx, i in enumerate(query_labels[:k]):
15 | if i == 1:
16 | positives += 1
17 | ap_at_count = positives / (idx + 1)
18 | ap += ap_at_count
19 | return ap / k
20 |
21 |
22 | def recall_k(query_labels, k=1):
23 | return np.sum(query_labels) / k
24 |
25 |
26 | def eval_k(preds, gts, k=1):
27 | output = {}
28 | mean_ap = []
29 | mean_recall = []
30 | for query in preds.keys():
31 | pred_cands = preds[query]
32 | gt_cands = gts[query]
33 | pred = []
34 |
35 | for c in pred_cands[:k]:
36 | if c in gt_cands:
37 | pred.append(1)
38 | else:
39 | pred.append(0)
40 | ap = ap_k(query_labels=pred, k=k)
41 | recall = recall_k(query_labels=pred, k=len(gt_cands))
42 | mean_ap.append(ap)
43 | mean_recall.append(recall)
44 |
45 | # print("{:s} topk: {:d} ap: {:.4f} recall: {:.4f}".format(query, k, ap, recall))
46 |
47 | output[query] = (ap, recall)
48 |
49 | return np.mean(mean_ap), np.mean(mean_recall), output
50 |
51 |
52 | def evaluate_retrieval(preds, gts, ks=[1, 10, 20, 50]):
53 | output = {}
54 | for k in ks:
55 | mean_ap, mean_recall, _ = eval_k(preds=preds, gts=gts, k=k)
56 | output[k] = {
57 | 'accuracy': mean_ap,
58 | 'recall': mean_recall,
59 | }
60 | return output
61 |
62 |
63 | def evaluate_retrieval_by_query(preds, gts, ks=[1, 10, 20, 50]):
64 | output = {}
65 | for k in ks:
66 | output[k] = 0
67 |
68 | failed_cases = []
69 | for q in preds.keys():
70 | gt_cans = gts[q]
71 | for k in ks:
72 | pred_cans = preds[q][:k]
73 | overlap = [v for v in pred_cans if v in gt_cans]
74 | if len(overlap) >= 1:
75 | output[k] += 1
76 |
77 | if k == 50 and len(overlap) == 0:
78 | failed_cases.append(q)
79 |
80 | for k in ks:
81 | output[k] = output[k] / len(preds.keys())
82 |
83 | output['failed_case'] = failed_cases
84 | return output
85 |
--------------------------------------------------------------------------------
/localization/fine/extractor.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | '''=================================================
3 | @Project -> File shloc -> extractor
4 | @IDE PyCharm
5 | @Author fx221@cam.ac.uk
6 | @Date 26/07/2021 16:49
7 | =================================================='''
8 | import torch
9 | import torch.utils.data as Data
10 |
11 | import argparse
12 | import os
13 | import os.path as osp
14 | import h5py
15 | from tqdm import tqdm
16 | from types import SimpleNamespace
17 | import logging
18 | from pathlib import Path
19 | import cv2
20 | import numpy as np
21 | import pprint
22 |
23 | from localization.fine.features.extract_spp import extract_spp_return
24 | from localization.fine.features.extract_r2d2 import extract_r2d2_return, load_network
25 | from net.locnets.superpoint import SuperPointNet
26 | from net.locnets.resnet import ResNetXN, extract_resnet_return
27 |
28 | confs = {
29 | 'resnetxn-triv2-0001-n4096-r1600-mask': {
30 | 'output': 'feats-resnetxn-triv2-0001-n4096-r1600-mask',
31 | 'model': {
32 | 'name': 'resnetxn',
33 | 'max_keypoints': 4096,
34 | 'conf_th': 0.001,
35 | 'multiscale': True,
36 | 'scales': [1.0],
37 | 'model_fn': osp.join(os.getcwd(), "models/resnetxn_triv2_d128_38.pth"),
38 | },
39 | 'preprocessing': {
40 | 'grayscale': False,
41 | 'resize_max': 1600,
42 | },
43 | 'mask': True,
44 | },
45 |
46 | 'resnetxn-triv2-0001-n3000-r1600-mask': {
47 | 'output': 'feats-resnetxn-triv2-0001-n3000-r1600-mask',
48 | 'model': {
49 | 'name': 'resnetxn',
50 | 'max_keypoints': 3000,
51 | 'conf_th': 0.001,
52 | 'multiscale': True,
53 | 'scales': [1.0],
54 | 'model_fn': osp.join(os.getcwd(), "models/resnetxn_triv2_d128_38.pth"),
55 | },
56 | 'preprocessing': {
57 | 'grayscale': False,
58 | 'resize_max': 1600,
59 | },
60 | 'mask': True,
61 | },
62 | 'resnetxn-triv2-0001-n2000-r1600-mask': {
63 | 'output': 'feats-resnetxn-triv2-0001-n2000-r1600-mask',
64 | 'model': {
65 | 'name': 'resnetxn',
66 | 'max_keypoints': 2000,
67 | 'conf_th': 0.001,
68 | 'multiscale': True,
69 | 'scales': [1.0],
70 | 'model_fn': osp.join(os.getcwd(), "models/resnetxn_triv2_d128_38.pth"),
71 | },
72 | 'preprocessing': {
73 | 'grayscale': False,
74 | 'resize_max': 1600,
75 | },
76 | 'mask': True,
77 | },
78 |
79 | 'resnetxn-triv2-0001-n1000-r1600-mask': {
80 | 'output': 'feats-resnetxn-triv2-0001-n1000-r1600-mask',
81 | 'model': {
82 | 'name': 'resnetxn',
83 | 'max_keypoints': 1000,
84 | 'conf_th': 0.001,
85 | 'multiscale': True,
86 | 'scales': [1.0],
87 | 'model_fn': osp.join(os.getcwd(), "models/resnetxn_triv2_d128_38.pth"),
88 | },
89 | 'preprocessing': {
90 | 'grayscale': False,
91 | 'resize_max': 1600,
92 | },
93 | 'mask': True,
94 | },
95 |
96 | 'resnetxn-triv2-0001-n4096-r1024-mask': {
97 | 'output': 'feats-resnetxn-triv2-0001-n4096-r1024-mask',
98 | 'model': {
99 | 'name': 'resnetxn',
100 | 'max_keypoints': 4096,
101 | 'conf_th': 0.001,
102 | 'multiscale': True,
103 | 'scales': [1.0],
104 | 'model_fn': osp.join(os.getcwd(), "models/resnetxn_triv2_d128_38.pth"),
105 | },
106 | 'preprocessing': {
107 | 'grayscale': False,
108 | 'resize_max': 1024,
109 | },
110 | 'mask': True,
111 | },
112 |
113 | 'resnetxn-triv2-ms-0001-n4096-r1024-mask': {
114 | 'output': 'feats-resnetxn-triv2-ms-0001-n4096-r1024-mask',
115 | 'model': {
116 | 'name': 'resnetxn',
117 | 'max_keypoints': 4096,
118 | 'conf_th': 0.001,
119 | 'multiscale': True,
120 | 'scales': [1.2, 1.0, 0.8],
121 | 'model_fn': osp.join(os.getcwd(), "models/resnetxn_triv2_d128_38.pth"),
122 | },
123 | 'preprocessing': {
124 | 'grayscale': False,
125 | 'resize_max': 1024,
126 | },
127 | 'mask': True,
128 | },
129 |
130 | }
131 |
132 | confs_matcher = {
133 | 'superglue': {
134 | 'output': 'matches-superglue',
135 | 'model': {
136 | 'name': 'superglue',
137 | 'weights': 'outdoor',
138 | 'sinkhorn_iterations': 20,
139 | },
140 | },
141 | 'NNM': {
142 | 'output': 'NNM',
143 | 'model': {
144 | 'name': 'nnm',
145 | 'do_mutual_check': True,
146 | 'distance_threshold': None,
147 | },
148 | },
149 | 'NNML': {
150 | 'output': 'NNML',
151 | 'model': {
152 | 'name': 'nnml',
153 | 'do_mutual_check': True,
154 | 'distance_threshold': None,
155 | },
156 | },
157 |
158 | 'ONN': {
159 | 'output': 'ONN',
160 | 'model': {
161 | 'name': 'nn',
162 | 'do_mutual_check': False,
163 | 'distance_threshold': None,
164 | },
165 | },
166 | 'NNR': {
167 | 'output': 'NNR',
168 | 'model': {
169 | 'name': 'nnr',
170 | 'do_mutual_check': True,
171 | 'distance_threshold': 0.9,
172 | },
173 | }
174 | }
175 |
176 |
177 | class ImageDataset(Data.Dataset):
178 | default_conf = {
179 | 'globs': ['*.jpg', '*.png', '*.jpeg', '*.JPG', '*.PNG'],
180 | 'grayscale': False,
181 | 'resize_max': None,
182 | 'resize_force': False,
183 | }
184 |
185 | def __init__(self, root, conf, image_list=None,
186 | mask_root=None):
187 | self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
188 | self.root = root
189 |
190 | self.paths = []
191 | if image_list is None:
192 | for g in conf.globs:
193 | self.paths += list(Path(root).glob('**/' + g))
194 | if len(self.paths) == 0:
195 | raise ValueError(f'Could not find any image in root: {root}.')
196 | self.paths = [i.relative_to(root) for i in self.paths]
197 | else:
198 | with open(image_list, "r") as f:
199 | lines = f.readlines()
200 | for l in lines:
201 | l = l.strip()
202 | self.paths.append(Path(l))
203 |
204 | logging.info(f'Found {len(self.paths)} images in root {root}.')
205 |
206 | if mask_root is not None:
207 | self.mask_root = mask_root
208 | else:
209 | self.mask_root = None
210 |
211 | print("mask_root: ", self.mask_root)
212 |
213 | def __getitem__(self, idx):
214 | path = self.paths[idx]
215 | if self.conf.grayscale:
216 | mode = cv2.IMREAD_GRAYSCALE
217 | else:
218 | mode = cv2.IMREAD_COLOR
219 | image = cv2.imread(str(self.root / path), mode)
220 | if not self.conf.grayscale:
221 | image = image[:, :, ::-1] # BGR to RGB
222 | if image is None:
223 | raise ValueError(f'Cannot read image {str(path)}.')
224 | image = image.astype(np.float32)
225 | size = image.shape[:2][::-1]
226 | w, h = size
227 |
228 | if self.conf.resize_max and (self.conf.resize_force
229 | or max(w, h) > self.conf.resize_max):
230 | scale = self.conf.resize_max / max(h, w)
231 | h_new, w_new = int(round(h * scale)), int(round(w * scale))
232 | image = cv2.resize(
233 | image, (w_new, h_new), interpolation=cv2.INTER_LINEAR)
234 |
235 | if self.conf.grayscale:
236 | image = image[None]
237 | else:
238 | image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
239 | image = image / 255.
240 |
241 | data = {
242 | 'name': str(path),
243 | 'image': image,
244 | 'original_size': np.array(size),
245 | }
246 |
247 | if self.mask_root is not None:
248 | mask_path = Path(str(path).replace("jpg", "png"))
249 | if osp.exists(mask_path):
250 | mask = cv2.imread(str(self.mask_root / mask_path))
251 | mask = cv2.resize(mask, dsize=(image.shape[2], image.shape[1]), interpolation=cv2.INTER_NEAREST)
252 | else:
253 | mask = np.zeros(shape=(image.shape[1], image.shape[2], 3), dtype=np.uint8)
254 |
255 | data['mask'] = mask
256 |
257 | return data
258 |
259 | def __len__(self):
260 | return len(self.paths)
261 |
262 |
263 | def get_model(model_name, weight_path):
264 | if model_name == "superpoint":
265 | model = SuperPointNet().eval()
266 | model.load_state_dict(torch.load(weight_path))
267 | extractor = extract_spp_return
268 | elif model_name == "r2d2":
269 | model = load_network(model_fn=weight_path).eval()
270 | extractor = extract_r2d2_return
271 | elif model_name == 'resnetxn-ori':
272 | model = ResNetXN(encoder_depth=2, outdim=128).eval()
273 | model.load_state_dict(torch.load(weight_path)['state_dict'], strict=True)
274 | extractor = extract_resnet_return
275 | elif model_name == 'resnetxn':
276 | model = ResNetXN(encoder_depth=2, outdim=128).eval()
277 | model.load_state_dict(torch.load(weight_path)['model'], strict=True)
278 | extractor = extract_resnet_return
279 |
280 | return model, extractor
281 |
282 |
283 | @torch.no_grad()
284 | def main(conf, image_dir, export_dir, mask_dir=None, tag=None):
285 | logging.info('Extracting local features with configuration:'
286 | f'\n{pprint.pformat(conf)}')
287 |
288 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
289 | # Model = dynamic_load(extractors, conf['model']['name'])
290 | # model = Model(conf['model']).eval().to(device)
291 | model, extractor = get_model(model_name=conf['model']['name'], weight_path=conf["model"]["model_fn"])
292 | model = model.cuda()
293 | print("model: ", model)
294 |
295 | loader = ImageDataset(image_dir, conf['preprocessing'],
296 | image_list=args.image_list,
297 | mask_root=mask_dir)
298 | loader = torch.utils.data.DataLoader(loader, num_workers=4)
299 |
300 | feature_path = Path(export_dir, conf['output'] + '.h5')
301 | feature_path.parent.mkdir(exist_ok=True, parents=True)
302 | feature_file = h5py.File(str(feature_path), 'a')
303 |
304 | with tqdm(total=len(loader)) as t:
305 | for idx, data in enumerate(loader):
306 | t.update()
307 | if tag is not None:
308 | if data['name'][0].find(tag) < 0:
309 | continue
310 | pred = extractor(model, img=data["image"],
311 | topK=conf["model"]["max_keypoints"],
312 | mask=data["mask"][0].numpy().astype(np.uint8) if "mask" in data.keys() else None,
313 | conf_th=conf["model"]["conf_th"],
314 | scales=conf["model"]["scales"],
315 | )
316 |
317 | # pred = {k: v[0].cpu().numpy() for k, v in pred.items()}
318 | pred['descriptors'] = pred['descriptors'].transpose()
319 |
320 | t.set_postfix(npoints=pred['keypoints'].shape[0])
321 | # print(pred['keypoints'].shape)
322 |
323 | pred['image_size'] = original_size = data['original_size'][0].numpy()
324 | # pred['descriptors'] = pred['descriptors'].T
325 | if 'keypoints' in pred.keys():
326 | size = np.array(data['image'].shape[-2:][::-1])
327 | scales = (original_size / size).astype(np.float32)
328 | pred['keypoints'] = (pred['keypoints'] + .5) * scales[None] - .5
329 |
330 | grp = feature_file.create_group(data['name'][0])
331 | for k, v in pred.items():
332 | # print(k, v.shape)
333 | grp.create_dataset(k, data=v)
334 |
335 | del pred
336 |
337 | feature_file.close()
338 | logging.info('Finished exporting features.')
339 |
340 | return feature_path
341 |
342 |
343 | if __name__ == '__main__':
344 | parser = argparse.ArgumentParser()
345 | parser.add_argument('--image_dir', type=Path, required=True)
346 | parser.add_argument('--image_list', type=str, default=None)
347 | parser.add_argument('--tag', type=str, default=None)
348 | parser.add_argument('--mask_dir', type=Path, default=None)
349 | parser.add_argument('--export_dir', type=Path, required=True)
350 | parser.add_argument('--conf', type=str, default='superpoint_aachen',
351 | choices=list(confs.keys()))
352 | args = parser.parse_args()
353 | main(confs[args.conf], args.image_dir, args.export_dir,
354 | mask_dir=args.mask_dir if confs[args.conf]["mask"] else None, tag=args.tag)
355 |
--------------------------------------------------------------------------------
/localization/fine/features/extract_d2net.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @Time : 2021/2/22 下午2:23
4 | @Auth : Fei Xue
5 | @File : extract_d2net.py
6 | @Email: fx221@cam.ac.uk
7 | """
8 |
9 | import imageio
10 | import numpy as np
11 | import torch
12 | import torch.nn.functional as F
13 |
14 |
15 | class EmptyTensorError(Exception):
16 | pass
17 |
18 |
19 | class NoGradientError(Exception):
20 | pass
21 |
22 |
23 | def interpolate_dense_features(pos, dense_features, return_corners=False):
24 | device = pos.device
25 |
26 | ids = torch.arange(0, pos.size(1), device=device)
27 |
28 | _, h, w = dense_features.size()
29 |
30 | i = pos[0, :]
31 | j = pos[1, :]
32 |
33 | # Valid corners
34 | i_top_left = torch.floor(i).long()
35 | j_top_left = torch.floor(j).long()
36 | valid_top_left = torch.min(i_top_left >= 0, j_top_left >= 0)
37 |
38 | i_top_right = torch.floor(i).long()
39 | j_top_right = torch.ceil(j).long()
40 | valid_top_right = torch.min(i_top_right >= 0, j_top_right < w)
41 |
42 | i_bottom_left = torch.ceil(i).long()
43 | j_bottom_left = torch.floor(j).long()
44 | valid_bottom_left = torch.min(i_bottom_left < h, j_bottom_left >= 0)
45 |
46 | i_bottom_right = torch.ceil(i).long()
47 | j_bottom_right = torch.ceil(j).long()
48 | valid_bottom_right = torch.min(i_bottom_right < h, j_bottom_right < w)
49 |
50 | valid_corners = torch.min(
51 | torch.min(valid_top_left, valid_top_right),
52 | torch.min(valid_bottom_left, valid_bottom_right)
53 | )
54 |
55 | i_top_left = i_top_left[valid_corners]
56 | j_top_left = j_top_left[valid_corners]
57 |
58 | i_top_right = i_top_right[valid_corners]
59 | j_top_right = j_top_right[valid_corners]
60 |
61 | i_bottom_left = i_bottom_left[valid_corners]
62 | j_bottom_left = j_bottom_left[valid_corners]
63 |
64 | i_bottom_right = i_bottom_right[valid_corners]
65 | j_bottom_right = j_bottom_right[valid_corners]
66 |
67 | ids = ids[valid_corners]
68 | if ids.size(0) == 0:
69 | raise EmptyTensorError
70 |
71 | # Interpolation
72 | i = i[ids]
73 | j = j[ids]
74 | dist_i_top_left = i - i_top_left.float()
75 | dist_j_top_left = j - j_top_left.float()
76 | w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
77 | w_top_right = (1 - dist_i_top_left) * dist_j_top_left
78 | w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
79 | w_bottom_right = dist_i_top_left * dist_j_top_left
80 |
81 | descriptors = (
82 | w_top_left * dense_features[:, i_top_left, j_top_left] +
83 | w_top_right * dense_features[:, i_top_right, j_top_right] +
84 | w_bottom_left * dense_features[:, i_bottom_left, j_bottom_left] +
85 | w_bottom_right * dense_features[:, i_bottom_right, j_bottom_right]
86 | )
87 |
88 | pos = torch.cat([i.view(1, -1), j.view(1, -1)], dim=0)
89 |
90 | if not return_corners:
91 | return [descriptors, pos, ids]
92 | else:
93 | corners = torch.stack([
94 | torch.stack([i_top_left, j_top_left], dim=0),
95 | torch.stack([i_top_right, j_top_right], dim=0),
96 | torch.stack([i_bottom_left, j_bottom_left], dim=0),
97 | torch.stack([i_bottom_right, j_bottom_right], dim=0)
98 | ], dim=0)
99 | return [descriptors, pos, ids, corners]
100 |
101 |
102 | def upscale_positions(pos, scaling_steps=0):
103 | for _ in range(scaling_steps):
104 | pos = pos * 2 + 0.5
105 | return pos
106 |
107 |
108 | def process_multiscale(image, model, scales=[.5, 1, 2]):
109 | b, _, h_init, w_init = image.size()
110 | device = image.device
111 | assert (b == 1)
112 |
113 | all_keypoints = torch.zeros([3, 0])
114 | all_descriptors = torch.zeros([
115 | model.dense_feature_extraction.num_channels, 0
116 | ])
117 | all_scores = torch.zeros(0)
118 |
119 | previous_dense_features = None
120 | banned = None
121 | for idx, scale in enumerate(scales):
122 | current_image = F.interpolate(
123 | image, scale_factor=scale,
124 | mode='bilinear', align_corners=True
125 | )
126 | _, _, h_level, w_level = current_image.size()
127 |
128 | dense_features = model.dense_feature_extraction(current_image)
129 | del current_image
130 |
131 | _, _, h, w = dense_features.size()
132 |
133 | # Sum the feature maps.
134 | if previous_dense_features is not None:
135 | dense_features += F.interpolate(
136 | previous_dense_features, size=[h, w],
137 | mode='bilinear', align_corners=True
138 | )
139 | del previous_dense_features
140 |
141 | # Recover detections.
142 | detections = model.detection(dense_features)
143 | if banned is not None:
144 | banned = F.interpolate(banned.float(), size=[h, w]).bool()
145 | detections = torch.min(detections, ~banned)
146 | banned = torch.max(
147 | torch.max(detections, dim=1)[0].unsqueeze(1), banned
148 | )
149 | else:
150 | banned = torch.max(detections, dim=1)[0].unsqueeze(1)
151 | fmap_pos = torch.nonzero(detections[0].cpu()).t()
152 | del detections
153 |
154 | # Recover displacements.
155 | displacements = model.localization(dense_features)[0].cpu()
156 | displacements_i = displacements[
157 | 0, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
158 | ]
159 | displacements_j = displacements[
160 | 1, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
161 | ]
162 | del displacements
163 |
164 | mask = torch.min(
165 | torch.abs(displacements_i) < 0.5,
166 | torch.abs(displacements_j) < 0.5
167 | )
168 | fmap_pos = fmap_pos[:, mask]
169 | valid_displacements = torch.stack([
170 | displacements_i[mask],
171 | displacements_j[mask]
172 | ], dim=0)
173 | del mask, displacements_i, displacements_j
174 |
175 | fmap_keypoints = fmap_pos[1:, :].float() + valid_displacements
176 | del valid_displacements
177 |
178 | try:
179 | raw_descriptors, _, ids = interpolate_dense_features(
180 | fmap_keypoints.to(device),
181 | dense_features[0]
182 | )
183 | except EmptyTensorError:
184 | continue
185 | fmap_pos = fmap_pos[:, ids]
186 | fmap_keypoints = fmap_keypoints[:, ids]
187 | del ids
188 |
189 | keypoints = upscale_positions(fmap_keypoints, scaling_steps=2)
190 | del fmap_keypoints
191 |
192 | descriptors = F.normalize(raw_descriptors, dim=0).cpu()
193 | del raw_descriptors
194 |
195 | keypoints[0, :] *= h_init / h_level
196 | keypoints[1, :] *= w_init / w_level
197 |
198 | fmap_pos = fmap_pos.cpu()
199 | keypoints = keypoints.cpu()
200 |
201 | keypoints = torch.cat([
202 | keypoints,
203 | torch.ones([1, keypoints.size(1)]) * 1 / scale,
204 | ], dim=0)
205 |
206 | scores = dense_features[
207 | 0, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
208 | ].cpu() / (idx + 1)
209 | del fmap_pos
210 |
211 | all_keypoints = torch.cat([all_keypoints, keypoints], dim=1)
212 | all_descriptors = torch.cat([all_descriptors, descriptors], dim=1)
213 | all_scores = torch.cat([all_scores, scores], dim=0)
214 | del keypoints, descriptors
215 |
216 | previous_dense_features = dense_features
217 | del dense_features
218 | del previous_dense_features, banned
219 |
220 | keypoints = all_keypoints.t().numpy()
221 | del all_keypoints
222 | scores = all_scores.numpy()
223 | del all_scores
224 | descriptors = all_descriptors.t().numpy()
225 | del all_descriptors
226 | return keypoints, scores, descriptors
227 |
228 |
229 | def preprocess_image(image, preprocessing='caffe'):
230 | image = image.astype(np.float32)
231 | image = np.transpose(image, [2, 0, 1])
232 | if preprocessing is None:
233 | pass
234 | elif preprocessing == 'caffe':
235 | # RGB -> BGR
236 | image = image[:: -1, :, :]
237 | # Zero-center by mean pixel
238 | mean = np.array([103.939, 116.779, 123.68])
239 | image = image - mean.reshape([3, 1, 1])
240 | elif preprocessing == 'torch':
241 | image /= 255.0
242 | mean = np.array([0.485, 0.456, 0.406])
243 | std = np.array([0.229, 0.224, 0.225])
244 | image = (image - mean.reshape([3, 1, 1])) / std.reshape([3, 1, 1])
245 | else:
246 | raise ValueError('Unknown preprocessing parameter.')
247 | return image
248 |
249 |
250 | def extract_d2net_return(d2net, img_path, multi_scale=False, need_nms=False):
251 | image = imageio.imread(img_path)
252 | if len(image.shape) == 2:
253 | image = image[:, :, np.newaxis]
254 | image = np.repeat(image, 3, -1)
255 |
256 | # TODO: switch to PIL.Image due to deprecation of scipy.misc.imresize.
257 | resized_image = image
258 |
259 | fact_i = image.shape[0] / resized_image.shape[0]
260 | fact_j = image.shape[1] / resized_image.shape[1]
261 |
262 | input_image = preprocess_image(
263 | resized_image,
264 | preprocessing='caffe',
265 | )
266 | with torch.no_grad():
267 | if multi_scale:
268 | keypoints, scores, descriptors = process_multiscale(
269 | torch.tensor(
270 | input_image[np.newaxis, :, :, :].astype(np.float32),
271 | device=torch.device("cuda:0"),
272 | ),
273 | model=d2net,
274 | )
275 | else:
276 | keypoints, scores, descriptors = process_multiscale(
277 | torch.tensor(
278 | input_image[np.newaxis, :, :, :].astype(np.float32),
279 | device=torch.device("cuda:0"),
280 | ),
281 | d2net,
282 | scales=[1]
283 | )
284 |
285 | # Input image coordinates
286 | keypoints[:, 0] *= fact_i
287 | keypoints[:, 1] *= fact_j
288 | # i, j -> u, v
289 | keypoints = keypoints[:, [1, 0, 2]]
290 |
291 | return keypoints, descriptors, scores
292 |
--------------------------------------------------------------------------------
/localization/fine/features/extract_sift.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @Time : 2021/2/22 上午10:10
4 | @Auth : Fei Xue
5 | @File : extract_sift.py
6 | @Email: fx221@cam.ac.uk
7 | """
8 |
9 | import argparse
10 | import os
11 | import os.path as osp
12 | from tqdm import tqdm
13 | import cv2
14 | import numpy as np
15 |
16 |
17 | def plot_keypoint(img_path, pts, scores=None):
18 | img = cv2.imread(img_path)
19 | img_out = img.copy()
20 | r = 3
21 | if scores is None:
22 | for i in range(pts.shape[0]):
23 | pt = pts[i]
24 | # img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1])), int(r * s), (0, 0, 255), 4)
25 | img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1])), r, (0, 0, 255), 4)
26 | else:
27 | scores_norm = scores / np.linalg.norm(scores, ord=2)
28 | print("score median: ", np.median(scores_norm))
29 | for i in range(pts.shape[0]):
30 | pt = pts[i]
31 | s = scores_norm[i]
32 | if s < np.median(scores_norm):
33 | continue
34 | # img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1])), int(r * s), (0, 0, 255), 4)
35 | img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1])), r, (0, 0, 255), 4)
36 |
37 | cv2.imshow("img", img_out)
38 | cv2.waitKey(0)
39 |
40 |
41 | def extract_sift_return(sift, img_path, need_nms=False):
42 | img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
43 | # sift = cv2.xfeatures2d.SIFT_create(nOctaveLayers=1, contrastThreshold=0.03, edgeThreshold=8)
44 | # kpts = sift.detect(img)
45 | # descs = np.zeros((10000, 128), np.float)
46 | kpts, descs = sift.detectAndCompute(img, None)
47 |
48 | scores = np.array([kp.response for kp in kpts], np.float32)
49 | kpts = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts])
50 |
51 | return kpts, descs, scores
52 |
53 |
54 | if __name__ == '__main__':
55 | # Argument parsing
56 | parser = argparse.ArgumentParser(description='super point extraction on haptches')
57 | parser.add_argument('--hpatches_path', type=str, required=True,
58 | help='path to a file containing a list of images to process')
59 | parser.add_argument("--output_path", type=str, required=True,
60 | help='path to save descriptors')
61 | parser.add_argument("--image_list_file", type=str, default="img_list_hpatches_list.txt",
62 | help='path to save descriptors')
63 |
64 | args = parser.parse_args()
65 | hpatches_path = args.hpatches_path
66 | output_dir = args.output_path
67 | os.makedirs(output_dir, exist_ok=True)
68 |
69 | with open(args.image_list_file, 'r') as f:
70 | lines = f.readlines()
71 |
72 | sift = cv2.xfeatures2d.SIFT_create(4000)
73 | for line in tqdm(lines, total=len(lines)):
74 | path = line.strip()
75 | img_path = osp.join(args.hpatches_path, path)
76 | print("img_path: ", img_path)
77 |
78 | img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
79 | kpts = sift.detect(img)
80 | kpts = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts])
81 | print(kpts.shape)
82 | plot_keypoint(img_path=img_path, pts=kpts)
83 |
--------------------------------------------------------------------------------
/localization/fine/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | '''=================================================
3 | @Project -> File shloc -> test
4 | @IDE PyCharm
5 | @Author fx221@cam.ac.uk
6 | @Date 03/08/2021 12:12
7 | =================================================='''
8 | import os
9 | import os.path as osp
10 | import numpy as np
11 | import torch
12 | import cv2
13 | from localization.fine.extractor import get_model
14 | from localization.fine.matcher import Matcher
15 | from localization.tools import plot_keypoint, plot_matches
16 | from tools.common import resize_img
17 |
18 | match_confs = {
19 | 'superglue': {
20 | 'output': 'matches-superglue',
21 | 'model': {
22 | 'name': 'superglue',
23 | 'weights': 'outdoor',
24 | 'sinkhorn_iterations': 50,
25 | },
26 | },
27 | 'NNM': {
28 | 'output': 'NNM',
29 | 'model': {
30 | 'name': 'nnm',
31 | 'do_mutual_check': True,
32 | 'distance_threshold': None,
33 | },
34 | },
35 | 'NNML': {
36 | 'output': 'NNML',
37 | 'model': {
38 | 'name': 'nnml',
39 | 'do_mutual_check': True,
40 | 'distance_threshold': None,
41 | },
42 | },
43 |
44 | 'ONN': {
45 | 'output': 'ONN',
46 | 'model': {
47 | 'name': 'nn',
48 | 'do_mutual_check': False,
49 | 'distance_threshold': None,
50 | },
51 | },
52 | 'NNR': {
53 | 'output': 'NNR',
54 | 'model': {
55 | 'name': 'nnr',
56 | 'do_mutual_check': True,
57 | 'distance_threshold': 0.9,
58 | },
59 | }
60 | }
61 |
62 | extract_confs = {
63 | # 'superpoint_aachen': {
64 | 'superpoint-n2000-r1024-mask': {
65 | 'output': 'feats-superpoint-n2000-r1024-mask',
66 | 'model': {
67 | 'name': 'superpoint',
68 | 'nms_radius': 4,
69 | 'max_keypoints': 2000,
70 | 'model_fn': osp.join(os.getcwd(), "models/superpoint_v1.pth"),
71 | },
72 | 'preprocessing': {
73 | 'grayscale': True,
74 | 'resize_max': 1024,
75 | },
76 | },
77 |
78 | 'superpoint-n4096-r1024-mask': {
79 | 'output': 'feats-superpoint-n4096-r1024-mask',
80 | 'model': {
81 | 'name': 'superpoint',
82 | 'nms_radius': 4,
83 | 'max_keypoints': 4096,
84 | 'model_fn': osp.join(os.getcwd(), "models/superpoint_v1.pth"),
85 | },
86 | 'preprocessing': {
87 | 'grayscale': True,
88 | 'resize_max': 1024,
89 | },
90 | 'mask': True,
91 | },
92 |
93 | 'r2d2-n2000-r1024-mask': {
94 | 'output': 'feats-r2d2-n2000-r1024-mask',
95 | 'model': {
96 | 'name': 'r2d2',
97 | 'nms_radius': 4,
98 | 'max_keypoints': 2000,
99 | 'model_fn': osp.join(os.getcwd(), "models/r2d2_WASF_N16.pt"),
100 | },
101 | 'preprocessing': {
102 | 'grayscale': False,
103 | 'resize_max': 1024,
104 | },
105 | 'mask': True,
106 | },
107 |
108 | 'r2d2-n4096-r1024-mask': {
109 | 'output': 'feats-r2d2-n2000-r1024-mask',
110 | 'model': {
111 | 'name': 'r2d2',
112 | 'nms_radius': 4,
113 | 'max_keypoints': 4096,
114 | 'model_fn': osp.join(os.getcwd(), "models/r2d2_WASF_N16.pt"),
115 | },
116 | 'preprocessing': {
117 | 'grayscale': False,
118 | 'resize_max': 1024,
119 | },
120 | 'mask': True,
121 | },
122 | }
123 |
124 |
125 | def test(use_mask=True):
126 | img_dir = "/home/mifs/fx221/fx221/localization/aachen_v1_1/images/images_upright"
127 | mask_dir = "/home/mifs/fx221/fx221/localization/aachen_v1_1/global_seg_instance"
128 | # pair_fn = "datasets/aachen/pairs-db-covis20.txt"
129 | # pair_fn = '/home/mifs/fx221/fx221/exp/shloc/aachen/2021_08_05_23_29_59_aachen_pspf_resnext101_32x4d_d4_u8_ce_b8_seg_cls_aug_stylized/loc_by_seg/loc_by_sec2_top30.txt'
130 | pair_fn = '/home/mifs/fx221/fx221/exp/shloc/aachen/2021_08_05_23_29_59_aachen_pspf_resnext101_32x4d_d4_u8_ce_b8_seg_cls_aug_stylized/loc_by_seg/loc_by_sec_top30_fail_list.txt'
131 | all_pairs = []
132 | with open(pair_fn, "r") as f:
133 | lines = f.readlines()
134 | for l in lines:
135 | l = l.strip().split(" ")
136 | all_pairs.append((l[0], l[1]))
137 |
138 | if use_mask:
139 | m_conf = match_confs["NNML"]
140 | else:
141 | m_conf = match_confs["NNM"]
142 |
143 | # e_conf = extract_confs["superpoint-n4096-r1024-mask"]
144 | e_conf = extract_confs["r2d2-n4096-r1024-mask"]
145 | matcher = Matcher(conf=m_conf)
146 | matcher = matcher.eval().cuda()
147 |
148 | model, extractor = get_model(model_name=e_conf['model']['name'], weight_path=e_conf["model"]["model_fn"])
149 | model = model.cuda()
150 |
151 | # matcher = Matcher[]
152 |
153 | cv2.namedWindow("pt0", cv2.WINDOW_NORMAL)
154 | cv2.namedWindow("pt1", cv2.WINDOW_NORMAL)
155 | cv2.namedWindow("match", cv2.WINDOW_NORMAL)
156 |
157 | next_q = False
158 | for p in all_pairs:
159 | fn0 = p[0]
160 | fn1 = p[1]
161 | img0 = cv2.imread(osp.join(img_dir, fn0))
162 | img1 = cv2.imread(osp.join(img_dir, fn1))
163 | if e_conf["preprocessing"]["grayscale"]:
164 | img0_gray = cv2.cvtColor(img0, cv2.COLOR_BGR2GRAY)[None]
165 | img1_gray = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)[None]
166 | else:
167 | img0_gray = cv2.cvtColor(img0, cv2.COLOR_BGR2RGB)
168 | img1_gray = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
169 | img0_gray = img0_gray.transpose(2, 0, 1)
170 | img1_gray = img1_gray.transpose(2, 0, 1)
171 | mask0 = cv2.imread(osp.join(mask_dir, fn0.replace("jpg", "png")))
172 | mask0 = cv2.resize(mask0, dsize=(img0.shape[1], img0.shape[0]), interpolation=cv2.INTER_NEAREST)
173 | # label0 = np.int32(mask0[:, :, 2]) * 256 * 256 + np.int32(mask0[:, :, 1]) * 256 + np.int32(mask0[:, :, 0])
174 | mask1 = cv2.imread(osp.join(mask_dir, fn1.replace("jpg", "png")))
175 | mask1 = cv2.resize(mask1, dsize=(img1.shape[1], img1.shape[0]), interpolation=cv2.INTER_NEAREST)
176 | # label1 = np.int32(mask1[:, :, 2]) * 256 * 256 + np.int32(mask1[:, :, 1]) * 256 + np.int32(mask1[:, :, 0])
177 | print(img0.shape, img1.shape)
178 | print(mask0.shape, mask1.shape)
179 |
180 | pred0 = extractor(model, img=torch.from_numpy(img0_gray / 255.).unsqueeze(0).float(),
181 | topK=e_conf["model"]["max_keypoints"], mask=mask0 if use_mask else None)
182 | pred1 = extractor(model, img=torch.from_numpy(img1_gray / 255.).unsqueeze(0).float(),
183 | topK=e_conf["model"]["max_keypoints"], mask=mask1 if use_mask else None)
184 |
185 | img_pt0 = plot_keypoint(img_path=img0, pts=pred0["keypoints"])
186 | img_pt1 = plot_keypoint(img_path=img1, pts=pred1["keypoints"])
187 |
188 | if use_mask:
189 | match_data = {
190 | "descriptors0": pred0["descriptors"],
191 | "labels0": pred0["labels"],
192 | "descriptors1": pred1["descriptors"],
193 | "labels1": pred1["labels"],
194 | }
195 | else:
196 | match_data = {
197 | "descriptors0": pred0["descriptors"],
198 | # "labels0": pred0["labels"],
199 | "descriptors1": pred1["descriptors"],
200 | # "labels1": pred1["labels"],
201 | }
202 | matches = matcher(match_data)["matches0"]
203 | # matches = pred['matches0'] # [0].cpu().short().numpy()
204 | valid_matches = []
205 | for i in range(matches.shape[0]):
206 | if matches[i] > 0:
207 | valid_matches.append([i, matches[i]])
208 | valid_matches = np.array(valid_matches, np.int)
209 | img_matches = plot_matches(img1=img0, img2=img1,
210 | pts1=pred0["keypoints"],
211 | pts2=pred1["keypoints"],
212 | matches=valid_matches,
213 | )
214 |
215 | img_pt0 = resize_img(img_pt0, nh=512)
216 | mask0 = resize_img(mask0, nh=512)
217 | img_pt1 = resize_img(img_pt1, nh=512)
218 | mask1 = resize_img(mask1, nh=512)
219 | img_matches = resize_img(img_matches, nh=512)
220 |
221 | cv2.imshow("pt0", np.hstack([img_pt0, mask0]))
222 | cv2.imshow("pt1", np.hstack([img_pt1, mask1]))
223 | cv2.imshow("match", img_matches)
224 | cv2.waitKey(0)
225 |
226 |
227 | if __name__ == '__main__':
228 | test(use_mask=True)
229 |
--------------------------------------------------------------------------------
/localization/fine/triangulation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | from pathlib import Path
4 | from tqdm import tqdm
5 | import h5py
6 | import numpy as np
7 | import subprocess
8 | import pprint
9 | import shutil
10 |
11 | from localization.utils.read_write_model import (read_model, read_images_binary,
12 | CAMERA_MODEL_NAMES, read_cameras_binary,
13 | write_points3d_binary, write_images_binary, Image)
14 | from localization.utils.database import COLMAPDatabase
15 | from localization.utils.parsers import (
16 | parse_image_lists_with_intrinsics, parse_retrieval, names_to_pair)
17 |
18 |
19 | def create_empty_model(reference_model, empty_model):
20 | logging.info('Creating an empty model.')
21 | empty_model.mkdir(exist_ok=True)
22 | shutil.copy(reference_model / 'cameras.bin', empty_model)
23 | write_points3d_binary(dict(), empty_model / 'points3D.bin')
24 | images = read_images_binary(str(reference_model / 'images.bin'))
25 | images_empty = dict()
26 | for id_, image in images.items():
27 | image = image._asdict()
28 | image['xys'] = np.zeros((0, 2), float)
29 | image['point3D_ids'] = np.full(0, -1, int)
30 | images_empty[id_] = Image(**image)
31 | write_images_binary(images_empty, empty_model / 'images.bin')
32 |
33 |
34 | def create_db_from_model(empty_model, database_path):
35 | if database_path.exists():
36 | logging.warning('Database already exists.')
37 |
38 | cameras = read_cameras_binary(str(empty_model / 'cameras.bin'))
39 | images = read_images_binary(str(empty_model / 'images.bin'))
40 |
41 | db = COLMAPDatabase.connect(database_path)
42 | db.create_tables()
43 |
44 | for i, camera in cameras.items():
45 | model_id = CAMERA_MODEL_NAMES[camera.model].model_id
46 | db.add_camera(
47 | model_id, camera.width, camera.height, camera.params, camera_id=i,
48 | prior_focal_length=True)
49 |
50 | for i, image in images.items():
51 | db.add_image(image.name, image.camera_id, image_id=i)
52 |
53 | db.commit()
54 | db.close()
55 | return {image.name: i for i, image in images.items()}
56 |
57 |
58 | def import_features(image_ids, database_path, features_path):
59 | logging.info('Importing features into the database...')
60 | hfile = h5py.File(str(features_path), 'r')
61 | db = COLMAPDatabase.connect(database_path)
62 |
63 | for image_name, image_id in tqdm(image_ids.items()):
64 | keypoints = hfile[image_name]['keypoints'].__array__()
65 | keypoints += 0.5 # COLMAP origin
66 | db.add_keypoints(image_id, keypoints)
67 |
68 | hfile.close()
69 | db.commit()
70 | db.close()
71 |
72 |
73 | def import_matches(image_ids, database_path, pairs_path, matches_path,
74 | min_match_score=None, skip_geometric_verification=False):
75 | logging.info('Importing matches into the database...')
76 |
77 | with open(str(pairs_path), 'r') as f:
78 | pairs = [p.split() for p in f.readlines()]
79 |
80 | hfile = h5py.File(str(matches_path), 'r')
81 | db = COLMAPDatabase.connect(database_path)
82 |
83 | matched = set()
84 | for name0, name1 in tqdm(pairs):
85 | if name0 not in image_ids or name1 not in image_ids:
86 | continue
87 | id0, id1 = image_ids[name0], image_ids[name1]
88 | if len({(id0, id1), (id1, id0)} & matched) > 0:
89 | continue
90 | pair = names_to_pair(name0, name1)
91 | if pair not in hfile:
92 | raise ValueError(
93 | f'Could not find pair {(name0, name1)}... '
94 | 'Maybe you matched with a different list of pairs? '
95 | f'Reverse in file: {names_to_pair(name0, name1) in hfile}.')
96 |
97 | matches = hfile[pair]['matches0'].__array__()
98 | valid = matches > -1
99 | if min_match_score:
100 | scores = hfile[pair]['matching_scores0'].__array__()
101 | valid = valid & (scores > min_match_score)
102 | matches = np.stack([np.where(valid)[0], matches[valid]], -1)
103 |
104 | db.add_matches(id0, id1, matches)
105 | matched |= {(id0, id1), (id1, id0)}
106 |
107 | if skip_geometric_verification:
108 | db.add_two_view_geometry(id0, id1, matches)
109 |
110 | hfile.close()
111 | db.commit()
112 | db.close()
113 |
114 |
115 | def geometric_verification(colmap_path, database_path, pairs_path):
116 | logging.info('Performing geometric verification of the matches...')
117 | cmd = [
118 | str(colmap_path), 'matches_importer',
119 | '--database_path', str(database_path),
120 | '--match_list_path', str(pairs_path),
121 | '--match_type', 'pairs',
122 | '--SiftMatching.max_num_trials', str(20000),
123 | '--SiftMatching.min_inlier_ratio', str(0.1)]
124 | ret = subprocess.call(cmd)
125 | if ret != 0:
126 | logging.warning('Problem with matches_importer, exiting.')
127 | exit(ret)
128 |
129 |
130 | def run_triangulation(colmap_path, model_path, database_path, image_dir,
131 | empty_model):
132 | logging.info('Running the triangulation...')
133 | assert model_path.exists()
134 |
135 | cmd = [
136 | str(colmap_path), 'point_triangulator',
137 | '--database_path', str(database_path),
138 | '--image_path', str(image_dir),
139 | '--input_path', str(empty_model),
140 | '--output_path', str(model_path),
141 | '--Mapper.ba_refine_focal_length', '0',
142 | '--Mapper.ba_refine_principal_point', '0',
143 | '--Mapper.ba_refine_extra_params', '0']
144 | logging.info(' '.join(cmd))
145 | ret = subprocess.call(cmd)
146 | if ret != 0:
147 | logging.warning('Problem with point_triangulator, exiting.')
148 | exit(ret)
149 |
150 | stats_raw = subprocess.check_output(
151 | [str(colmap_path), 'model_analyzer', '--path', model_path])
152 | stats_raw = stats_raw.decode().split("\n")
153 | stats = dict()
154 | for stat in stats_raw:
155 | if stat.startswith("Registered images"):
156 | stats['num_reg_images'] = int(stat.split()[-1])
157 | elif stat.startswith("Points"):
158 | stats['num_sparse_points'] = int(stat.split()[-1])
159 | elif stat.startswith("Observations"):
160 | stats['num_observations'] = int(stat.split()[-1])
161 | elif stat.startswith("Mean track length"):
162 | stats['mean_track_length'] = float(stat.split()[-1])
163 | elif stat.startswith("Mean observations per image"):
164 | stats['num_observations_per_image'] = float(stat.split()[-1])
165 | elif stat.startswith("Mean reprojection error"):
166 | stats['mean_reproj_error'] = float(stat.split()[-1][:-2])
167 |
168 | return stats
169 |
170 |
171 | def main(sfm_dir, reference_sfm_model, image_dir, pairs, features, matches,
172 | colmap_path='colmap', skip_geometric_verification=False,
173 | min_match_score=None):
174 | assert reference_sfm_model.exists(), reference_sfm_model
175 | assert features.exists(), features
176 | assert pairs.exists(), pairs
177 | assert matches.exists(), matches
178 |
179 | sfm_dir.mkdir(parents=True, exist_ok=True)
180 | database = sfm_dir / 'database.db'
181 | model = sfm_dir / 'model'
182 | model.mkdir(exist_ok=True)
183 | empty_model = sfm_dir / 'empty'
184 |
185 | create_empty_model(reference_sfm_model, empty_model)
186 | image_ids = create_db_from_model(empty_model, database)
187 | import_features(image_ids, database, features)
188 | import_matches(image_ids, database, pairs, matches,
189 | min_match_score, skip_geometric_verification)
190 | if not skip_geometric_verification:
191 | geometric_verification(colmap_path, database, pairs)
192 | stats = run_triangulation(
193 | colmap_path, model, database, image_dir, empty_model)
194 |
195 | logging.info(f'Statistics:\n{pprint.pformat(stats)}')
196 | shutil.rmtree(empty_model)
197 |
198 |
199 | if __name__ == '__main__':
200 | parser = argparse.ArgumentParser()
201 | parser.add_argument('--sfm_dir', type=Path, required=True)
202 | parser.add_argument('--reference_sfm_model', type=Path, required=True)
203 | parser.add_argument('--image_dir', type=Path, required=True)
204 |
205 | parser.add_argument('--pairs', type=Path, required=True)
206 | parser.add_argument('--features', type=Path, required=True)
207 | parser.add_argument('--matches', type=Path, required=True)
208 |
209 | parser.add_argument('--colmap_path', type=Path, default='colmap')
210 |
211 | parser.add_argument('--skip_geometric_verification', action='store_true')
212 | parser.add_argument('--min_match_score', type=float)
213 | args = parser.parse_args()
214 |
215 | main(**args.__dict__)
216 |
--------------------------------------------------------------------------------
/localization/tools.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2021/5/6 下午4:47
3 | # @Author : Fei Xue
4 | # @Email : fx221@cam.ac.uk
5 | # @File : tools.py
6 | # @Software: PyCharm
7 |
8 | import numpy as np
9 | import cv2
10 | import torch
11 | from copy import copy
12 | from scipy.spatial.transform import Rotation as sciR
13 |
14 |
15 | def plot_keypoint(img_path, pts, scores=None, tag=None, save_path=None):
16 | if type(img_path) == str:
17 | img = cv2.imread(img_path)
18 | else:
19 | img = img_path.copy()
20 |
21 | img_out = img.copy()
22 | print(img.shape)
23 | r = 3
24 | for i in range(pts.shape[0]):
25 | pt = pts[i]
26 | # img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1])), int(r * s), (0, 0, 255), 4)
27 | img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1])), r, (0, 0, 255), 2)
28 |
29 | if save_path is not None:
30 | cv2.imwrite(save_path, img_out)
31 | return img_out
32 |
33 |
34 | def sort_dict_by_value(data, reverse=False):
35 | return sorted(data.items(), key=lambda d: d[1], reverse=reverse)
36 |
37 |
38 | def read_retrieval_results(path):
39 | output = {}
40 | with open(path, "r") as f:
41 | lines = f.readlines()
42 | for p in lines:
43 | p = p.strip("\n").split(" ")
44 |
45 | if p[1] == "no_match":
46 | continue
47 | if p[0] in output.keys():
48 | output[p[0]].append(p[1])
49 | else:
50 | output[p[0]] = [p[1]]
51 | return output
52 |
53 |
54 | def nn_k(query_gps, db_gps, k=20):
55 | q = torch.from_numpy(query_gps) # [N 2]
56 | db = torch.from_numpy(db_gps) # [M, 2]
57 | # print (q.shape, db.shape)
58 | dist = q.unsqueeze(2) - db.t().unsqueeze(0)
59 | dist = dist[:, 0, :] ** 2 + dist[:, 1, :] ** 2
60 | print("dist: ", dist.shape)
61 | topk = torch.topk(dist, dim=1, k=k, largest=False)[1]
62 | return topk
63 |
64 |
65 | def plot_matches(img1, img2, pts1, pts2, inliers, horizon=False, plot_outlier=False, confs=None, plot_match=True):
66 | rows1 = img1.shape[0]
67 | cols1 = img1.shape[1]
68 | rows2 = img2.shape[0]
69 | cols2 = img2.shape[1]
70 | r = 3
71 | if horizon:
72 | img_out = np.zeros((max([rows1, rows2]), cols1 + cols2, 3), dtype='uint8')
73 | # Place the first image to the left
74 | img_out[:rows1, :cols1] = img1
75 | # Place the next image to the right of it
76 | img_out[:rows2, cols1:] = img2 # np.dstack([img2, img2, img2])
77 |
78 | if not plot_match:
79 | return cv2.resize(img_out, None, fx=0.5, fy=0.5)
80 | # for idx, pt in enumerate(pts1):
81 | # img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1])), r, (0, 0, 255), 2)
82 | # for idx, pt in enumerate(pts2):
83 | # img_out = cv2.circle(img_out, (int(pt[0] + cols1), int(pt[1])), r, (0, 0, 255), 2)
84 | for idx in range(inliers.shape[0]):
85 | # if idx % 10 > 0:
86 | # continue
87 | if inliers[idx]:
88 | color = (0, 255, 0)
89 | else:
90 | if not plot_outlier:
91 | continue
92 | color = (0, 0, 255)
93 | pt1 = pts1[idx]
94 | pt2 = pts2[idx]
95 |
96 | if confs is not None:
97 | nr = int(r * confs[idx])
98 | else:
99 | nr = r
100 | img_out = cv2.circle(img_out, (int(pt1[0]), int(pt1[1])), nr, color, 2)
101 |
102 | img_out = cv2.circle(img_out, (int(pt2[0]) + cols1, int(pt2[1])), nr, color, 2)
103 |
104 | img_out = cv2.line(img_out, (int(pt1[0]), int(pt1[1])), (int(pt2[0]) + cols1, int(pt2[1])), color,
105 | 2)
106 | else:
107 | img_out = np.zeros((rows1 + rows2, max([cols1, cols2]), 3), dtype='uint8')
108 | # Place the first image to the left
109 | img_out[:rows1, :cols1] = img1
110 | # Place the next image to the right of it
111 | img_out[rows1:, :cols2] = img2 # np.dstack([img2, img2, img2])
112 |
113 | if not plot_match:
114 | return cv2.resize(img_out, None, fx=0.5, fy=0.5)
115 | # for idx, pt in enumerate(pts1):
116 | # img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1])), r, (0, 0, 255), 2)
117 | # for idx, pt in enumerate(pts2):
118 | # img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1] + rows1)), r, (0, 0, 255), 2)
119 | for idx in range(inliers.shape[0]):
120 | # print("idx: ", inliers[idx])
121 | # if idx % 10 > 0:
122 | # continue
123 | if inliers[idx]:
124 | color = (0, 255, 0)
125 | else:
126 | if not plot_outlier:
127 | continue
128 | color = (0, 0, 255)
129 |
130 | if confs is not None:
131 | nr = int(r * confs[idx])
132 | else:
133 | nr = r
134 |
135 | pt1 = pts1[idx]
136 | pt2 = pts2[idx]
137 | img_out = cv2.circle(img_out, (int(pt1[0]), int(pt1[1])), r, color, 2)
138 |
139 | img_out = cv2.circle(img_out, (int(pt2[0]), int(pt2[1]) + rows1), r, color, 2)
140 |
141 | img_out = cv2.line(img_out, (int(pt1[0]), int(pt1[1])), (int(pt2[0]), int(pt2[1]) + rows1), color,
142 | 2)
143 |
144 | img_rs = cv2.resize(img_out, None, fx=0.5, fy=0.5)
145 |
146 | # img_rs = cv2.putText(img_rs, 'matches: {:d}'.format(len(inliers.shape[0])), (50, 50), cv2.FONT_HERSHEY_COMPLEX, 2,
147 | # (0, 0, 255), 2)
148 |
149 | # if save_fn is not None:
150 | # cv2.imwrite(save_fn, img_rs)
151 | # cv2.imshow("match", img_rs)
152 | # cv2.waitKey(0)
153 | return img_rs
154 |
155 |
156 | def plot_reprojpoint2D(img, points2D, reproj_points2D, confs=None):
157 | img_out = copy(img)
158 | r = 5
159 | for i in range(points2D.shape[0]):
160 | p = points2D[i]
161 | rp = reproj_points2D[i]
162 |
163 | if confs is not None:
164 | nr = int(r * confs[i])
165 | else:
166 | nr = r
167 |
168 | if nr >= 50:
169 | nr = 50
170 | # img_out = cv2.circle(img_out, (int(p[0]), int(p[1])), nr, color=(0, 255, 0), thickness=2)
171 | img_out = cv2.circle(img_out, (int(rp[0]), int(rp[1])), nr, color=(0, 0, 255), thickness=3)
172 | img_out = cv2.circle(img_out, (int(rp[0]), int(rp[1])), 2, color=(0, 0, 255), thickness=3)
173 | # img_out = cv2.line(img_out, pt1=(int(p[0]), int(p[1])), pt2=(int(rp[0]), int(rp[1])), color=(0, 0, 255),
174 | # thickness=2)
175 |
176 | return img_out
177 |
178 |
179 | def reproject_fromR(points3D, rot, tvec, camera):
180 | proj_2d = rot @ points3D.transpose() + tvec.reshape(3, 1)
181 |
182 | if camera['model'] == 'SIMPLE_RADIAL':
183 | f = camera['params'][0]
184 | cx = camera['params'][1]
185 | cy = camera['params'][2]
186 | k = camera['params'][3]
187 |
188 | proj_2d = proj_2d[0:2, :] / proj_2d[2, :]
189 | r2 = proj_2d[0, :] ** 2 + proj_2d[1, :] ** 2
190 | radial = r2 * k
191 | du = proj_2d[0, :] * radial
192 | dv = proj_2d[1, :] * radial
193 |
194 | u = proj_2d[0, :] + du
195 | v = proj_2d[1, :] + dv
196 | u = u * f + cx
197 | v = v * f + cy
198 | uvs = np.vstack([u, v]).transpose()
199 |
200 | return uvs
201 |
202 |
203 | def calc_depth(points3D, rvec, tvec, camera):
204 | rot = sciR.from_quat(quat=[rvec[1], rvec[2], rvec[3], rvec[0]]).as_dcm()
205 | # print('p3d: ', points3D.shape, rot.shape, rot)
206 | proj_2d = rot @ points3D.transpose() + tvec.reshape(3, 1)
207 |
208 | return proj_2d.transpose()[:, 2]
209 |
210 |
211 | def reproject(points3D, rvec, tvec, camera):
212 | '''
213 | Args:
214 | points3D: [N, 3]
215 | rvec: [w, x, y, z]
216 | tvec: [x, y, z]
217 | Returns:
218 | '''
219 | # print('camera: ', camera)
220 | # print('p3d: ', points3D.shape)
221 | rot = sciR.from_quat(quat=[rvec[1], rvec[2], rvec[3], rvec[0]]).as_dcm()
222 | # print('p3d: ', points3D.shape, rot.shape, rot)
223 | proj_2d = rot @ points3D.transpose() + tvec.reshape(3, 1)
224 |
225 | if camera['model'] == 'SIMPLE_RADIAL':
226 | f = camera['params'][0]
227 | cx = camera['params'][1]
228 | cy = camera['params'][2]
229 | k = camera['params'][3]
230 |
231 | proj_2d = proj_2d[0:2, :] / proj_2d[2, :]
232 | r2 = proj_2d[0, :] ** 2 + proj_2d[1, :] ** 2
233 | radial = r2 * k
234 | du = proj_2d[0, :] * radial
235 | dv = proj_2d[1, :] * radial
236 |
237 | u = proj_2d[0, :] + du
238 | v = proj_2d[1, :] + dv
239 | u = u * f + cx
240 | v = v * f + cy
241 | uvs = np.vstack([u, v]).transpose()
242 |
243 | return uvs
244 |
245 |
246 | def quaternion_angular_error(q1, q2):
247 | """
248 | angular error between two quaternions
249 | :param q1: (4, )
250 | :param q2: (4, )
251 | :return:
252 | """
253 | d = abs(np.dot(q1, q2))
254 | d = min(1.0, max(-1.0, d))
255 | theta = 2 * np.arccos(d) * 180 / np.pi
256 | return theta
257 |
258 |
259 | def ColmapQ2R(qvec):
260 | rot = sciR.from_quat(quat=[qvec[1], qvec[2], qvec[3], qvec[0]]).as_dcm()
261 | return rot
262 |
263 |
264 | def compute_pose_error(pred_qcw, pred_tcw, gt_qcw, gt_tcw):
265 | pred_Rcw = sciR.from_quat(quat=[pred_qcw[1], pred_qcw[2], pred_qcw[3], pred_qcw[0]]).as_dcm()
266 | pred_tcw = np.array(pred_tcw, float).reshape(3, 1)
267 | pred_Rwc = pred_Rcw.transpose()
268 | pred_twc = -pred_Rcw.transpose() @ pred_tcw
269 |
270 | gt_Rcw = sciR.from_quat(quat=[gt_qcw[1], gt_qcw[2], gt_qcw[3], gt_qcw[0]]).as_dcm()
271 | gt_tcw = np.array(gt_tcw, float).reshape(3, 1)
272 | gt_Rwc = gt_Rcw.transpose()
273 | gt_twc = -gt_Rcw.transpose() @ gt_tcw
274 |
275 | t_error_xyz = pred_twc - gt_twc
276 | t_error = np.sqrt(np.sum(t_error_xyz ** 2))
277 |
278 | pred_qwc = sciR.from_quat(quat=[pred_qcw[1], pred_qcw[2], pred_qcw[3], pred_qcw[0]]).as_quat()
279 | gt_qwc = sciR.from_quat(quat=[gt_qcw[1], gt_qcw[2], gt_qcw[3], gt_qcw[0]]).as_quat()
280 |
281 | q_error = quaternion_angular_error(q1=pred_qwc, q2=gt_qwc)
282 |
283 | return q_error, t_error, (t_error_xyz[0, 0], t_error_xyz[1, 0], t_error_xyz[2, 0])
284 |
--------------------------------------------------------------------------------
/localization/utils/parsers.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import logging
3 | import numpy as np
4 | from collections import defaultdict
5 |
6 |
7 | def parse_image_lists_with_intrinsics(paths):
8 | results = []
9 | files = list(Path(paths.parent).glob(paths.name))
10 | assert len(files) > 0
11 |
12 | for lfile in files:
13 | with open(lfile, 'r') as f:
14 | raw_data = f.readlines()
15 |
16 | logging.info(f'Importing {len(raw_data)} queries in {lfile.name}')
17 | for data in raw_data:
18 | data = data.strip('\n').split(' ')
19 | name, camera_model, width, height = data[:4]
20 | params = np.array(data[4:], float)
21 | info = (camera_model, int(width), int(height), params)
22 | results.append((name, info))
23 |
24 | assert len(results) > 0
25 | return results
26 |
27 |
28 | def parse_img_lists_for_extended_cmu_seaons(paths):
29 | Ks = {
30 | "c0": "OPENCV 1024 768 868.993378 866.063001 525.942323 420.042529 -0.399431 0.188924 0.000153 0.000571",
31 | "c1": "OPENCV 1024 768 868.993378 866.063001 525.942323 420.042529 -0.399431 0.188924 0.000153 0.000571"
32 | }
33 |
34 | results = []
35 | files = list(Path(paths.parent).glob(paths.name))
36 | assert len(files) > 0
37 |
38 | for lfile in files:
39 | with open(lfile, 'r') as f:
40 | raw_data = f.readlines()
41 |
42 | logging.info(f'Importing {len(raw_data)} queries in {lfile.name}')
43 | for name in raw_data:
44 | name = name.strip('\n')
45 | camera = name.split('_')[2]
46 | K = Ks[camera].split(' ')
47 | camera_model, width, height = K[:3]
48 | params = np.array(K[3:], float)
49 | # print("camera: ", camera_model, width, height, params)
50 | info = (camera_model, int(width), int(height), params)
51 | results.append((name, info))
52 |
53 | assert len(results) > 0
54 | return results
55 |
56 |
57 | def parse_retrieval(path):
58 | retrieval = defaultdict(list)
59 | with open(path, 'r') as f:
60 | for p in f.read().rstrip('\n').split('\n'):
61 | q, r = p.split(' ')
62 | retrieval[q].append(r)
63 | return dict(retrieval)
64 |
65 |
66 | def names_to_pair(name0, name1):
67 | return '_'.join((name0.replace('/', '-'), name1.replace('/', '-')))
68 |
--------------------------------------------------------------------------------
/loss/__pycache__/accuracy.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feixue94/lbr/bdd126b79da9e75759a364a0d8e010d76658f1af/loss/__pycache__/accuracy.cpython-37.pyc
--------------------------------------------------------------------------------
/loss/__pycache__/aploss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feixue94/lbr/bdd126b79da9e75759a364a0d8e010d76658f1af/loss/__pycache__/aploss.cpython-36.pyc
--------------------------------------------------------------------------------
/loss/__pycache__/aploss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feixue94/lbr/bdd126b79da9e75759a364a0d8e010d76658f1af/loss/__pycache__/aploss.cpython-37.pyc
--------------------------------------------------------------------------------
/loss/accuracy.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | '''=================================================
3 | @Project -> File shloc -> accuracy
4 | @IDE PyCharm
5 | @Author fx221@cam.ac.uk
6 | @Date 14/06/2021 17:11
7 | =================================================='''
8 |
9 | import torch.nn as nn
10 |
11 |
12 | def accuracy(pred, target, topk=1, thresh=None):
13 | """Calculate accuracy according to the prediction and target.
14 |
15 | Args:
16 | pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
17 | target (torch.Tensor): The target of each prediction, shape (N, , ...)
18 | topk (int | tuple[int], optional): If the predictions in ``topk``
19 | matches the target, the predictions will be regarded as
20 | correct ones. Defaults to 1.
21 | thresh (float, optional): If not None, predictions with scores under
22 | this threshold are considered incorrect. Default to None.
23 |
24 | Returns:
25 | float | tuple[float]: If the input ``topk`` is a single integer,
26 | the function will return a single float as accuracy. If
27 | ``topk`` is a tuple containing multiple integers, the
28 | function will return a tuple containing accuracies of
29 | each ``topk`` number.
30 | """
31 | assert isinstance(topk, (int, tuple))
32 | if isinstance(topk, int):
33 | topk = (topk, )
34 | return_single = True
35 | else:
36 | return_single = False
37 |
38 | maxk = max(topk)
39 | if pred.size(0) == 0:
40 | accu = [pred.new_tensor(0.) for i in range(len(topk))]
41 | return accu[0] if return_single else accu
42 | assert pred.ndim == target.ndim + 1
43 | # print(type(pred), type(target))
44 | assert pred.size(0) == target.size(0)
45 | assert maxk <= pred.size(1), \
46 | f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
47 | pred_value, pred_label = pred.topk(maxk, dim=1)
48 | # transpose to shape (maxk, N, ...)
49 | pred_label = pred_label.transpose(0, 1)
50 | correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label))
51 | if thresh is not None:
52 | # Only prediction values larger than thresh are counted as correct
53 | correct = correct & (pred_value > thresh).t()
54 | res = []
55 | for k in topk:
56 | # correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
57 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
58 | res.append(correct_k.mul_(100.0 / target.numel()))
59 | return res[0] if return_single else res
60 |
61 |
62 | class Accuracy(nn.Module):
63 | """Accuracy calculation module."""
64 |
65 | def __init__(self, topk=(1, ), thresh=None):
66 | """Module to calculate the accuracy.
67 |
68 | Args:
69 | topk (tuple, optional): The criterion used to calculate the
70 | accuracy. Defaults to (1,).
71 | thresh (float, optional): If not None, predictions with scores
72 | under this threshold are considered incorrect. Default to None.
73 | """
74 | super().__init__()
75 | self.topk = topk
76 | self.thresh = thresh
77 |
78 | def forward(self, pred, target):
79 | """Forward function to calculate accuracy.
80 |
81 | Args:
82 | pred (torch.Tensor): Prediction of models.
83 | target (torch.Tensor): Target for each prediction.
84 |
85 | Returns:
86 | tuple[float]: The accuracies under different topk criterions.
87 | """
88 | return accuracy(pred, target, self.topk, self.thresh)
89 |
--------------------------------------------------------------------------------
/loss/seg_loss/__pycache__/crossentropy_loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feixue94/lbr/bdd126b79da9e75759a364a0d8e010d76658f1af/loss/seg_loss/__pycache__/crossentropy_loss.cpython-37.pyc
--------------------------------------------------------------------------------
/loss/seg_loss/__pycache__/segloss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feixue94/lbr/bdd126b79da9e75759a364a0d8e010d76658f1af/loss/seg_loss/__pycache__/segloss.cpython-37.pyc
--------------------------------------------------------------------------------
/loss/seg_loss/crossentropy_loss.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | '''=================================================
3 | @Project -> File shloc -> crossentropy_loss
4 | @IDE PyCharm
5 | @Author fx221@cam.ac.uk
6 | @Date 29/04/2021 20:19
7 | =================================================='''
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 |
14 | def _ohem_mask(loss, ohem_ratio):
15 | with torch.no_grad():
16 | values, _ = torch.topk(loss.reshape(-1),
17 | int(loss.nelement() * ohem_ratio))
18 | mask = loss >= values[-1]
19 | return mask.float()
20 |
21 |
22 | class BCEWithLogitsLossWithOHEM(nn.Module):
23 | def __init__(self, ohem_ratio=1.0, pos_weight=None, eps=1e-7):
24 | super(BCEWithLogitsLossWithOHEM, self).__init__()
25 | self.criterion = nn.BCEWithLogitsLoss(reduction='none',
26 | pos_weight=pos_weight)
27 | self.ohem_ratio = ohem_ratio
28 | self.eps = eps
29 |
30 | def forward(self, pred, target):
31 | loss = self.criterion(pred, target)
32 | mask = _ohem_mask(loss, self.ohem_ratio)
33 | loss = loss * mask
34 | return loss.sum() / (mask.sum() + self.eps)
35 |
36 | def set_ohem_ratio(self, ohem_ratio):
37 | self.ohem_ratio = ohem_ratio
38 |
39 |
40 | class CrossEntropyLossWithOHEM(nn.Module):
41 | def __init__(self,
42 | ohem_ratio=1.0,
43 | weight=None,
44 | ignore_index=-100,
45 | eps=1e-7):
46 | super(CrossEntropyLossWithOHEM, self).__init__()
47 | self.criterion = nn.CrossEntropyLoss(weight=weight,
48 | ignore_index=ignore_index,
49 | reduction='none')
50 | self.ohem_ratio = ohem_ratio
51 | self.eps = eps
52 |
53 | def forward(self, pred, target):
54 | loss = self.criterion(pred, target)
55 | mask = _ohem_mask(loss, self.ohem_ratio)
56 | loss = loss * mask
57 | return loss.sum() / (mask.sum() + self.eps)
58 |
59 | def set_ohem_ratio(self, ohem_ratio):
60 | self.ohem_ratio = ohem_ratio
61 |
62 |
63 | class DiceLoss(nn.Module):
64 | def __init__(self, eps=1e-7):
65 | super(DiceLoss, self).__init__()
66 | self.eps = eps
67 |
68 | def forward(self, pred, target):
69 | pred = torch.sigmoid(pred)
70 | intersection = (pred * target).sum()
71 | loss = 1 - (2. * intersection) / (pred.sum() + target.sum() + self.eps)
72 | return loss
73 |
74 |
75 | class SoftCrossEntropyLossWithOHEM(nn.Module):
76 | def __init__(self, ohem_ratio=1.0, eps=1e-7):
77 | super(SoftCrossEntropyLossWithOHEM, self).__init__()
78 | self.ohem_ratio = ohem_ratio
79 | self.eps = eps
80 |
81 | def forward(self, pred, target):
82 | pred = F.log_softmax(pred, dim=1)
83 | loss = -(pred * target).sum(1)
84 | mask = _ohem_mask(loss, self.ohem_ratio)
85 | loss = loss * mask
86 | return loss.sum() / (mask.sum() + self.eps)
87 |
88 | def set_ohem_ratio(self, ohem_ratio):
89 | self.ohem_ratio = ohem_ratio
90 |
91 |
92 | class FocalLoss(nn.Module):
93 | def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=None):
94 | super(FocalLoss, self).__init__()
95 | self.alpha = alpha
96 | self.gamma = gamma
97 | self.ignore_index = ignore_index
98 | self.size_average = size_average
99 |
100 | def forward(self, inputs, targets):
101 | ce_loss = F.cross_entropy(
102 | inputs, targets, reduction='none', ignore_index=self.ignore_index)
103 | pt = torch.exp(-ce_loss)
104 | focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
105 | if self.size_average:
106 | return focal_loss.mean()
107 | else:
108 | return focal_loss.sum()
109 |
110 |
111 | def cross_entropy2d(input, target, weight=None, size_average=True):
112 | # input: (n, c, h, w), target: (n, h, w)
113 | n, c, h, w = input.size()
114 | # log_p: (n, c, h, w)
115 |
116 | log_p = F.log_softmax(input, dim=1)
117 | # log_p: (n*h*w, c)
118 | log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous()
119 | log_p = log_p[target.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0]
120 | log_p = log_p.view(-1, c)
121 | # target: (n*h*w,)
122 | mask = target >= 0
123 | target = target[mask]
124 | loss = F.nll_loss(log_p, target, weight=weight, reduction='sum')
125 | if size_average:
126 | loss /= mask.data.sum()
127 | return loss
128 |
129 |
130 | def cross_entropy_seg(input, target):
131 | idx = target.cpu().long()
132 | one_hot_key = torch.FloatTensor(input.shape).zero_()
133 | one_hot_key = one_hot_key.scatter_(1, idx, 1)
134 | if one_hot_key.device != input.device:
135 | one_hot_key = one_hot_key.to(input.device)
136 | log_p = F.log_softmax(input=input, dim=1)
137 | loss = -(one_hot_key * log_p)
138 |
139 | loss = loss.sum(1)
140 | # print("ce_loss: ", loss.shape)
141 | # loss2 = cross_entropy2d(input=input, target=target.squeeze())
142 | # print("loss: ", loss.mean(), loss2)
143 |
144 | # loss = F.nll_loss(input=log_p, target=target.view(-1), reduction="sum")
145 | return loss.mean()
146 |
147 |
148 | class CrossEntropy(nn.Module):
149 | def __init__(self, weights=None):
150 | super(CrossEntropy, self).__init__()
151 | self.weights = weights
152 |
153 | def forward(self, input, target):
154 | idx = target.cpu().long()
155 | one_hot_key = torch.FloatTensor(input.shape).zero_()
156 | one_hot_key = one_hot_key.scatter_(1, idx, 1)
157 | if one_hot_key.device != input.device:
158 | one_hot_key = one_hot_key.to(input.device)
159 | log_p = F.log_softmax(input=input, dim=1)
160 | loss = -(one_hot_key * log_p)
161 |
162 | if self.weights is not None:
163 | loss = (loss * self.weights[None, :, None, None].expand_as(loss)).sum(1) # + self.smooth [B, C, H, W]
164 | else:
165 | loss = loss.sum(1)
166 | # print("ce_loss: ", loss.shape)
167 | # loss2 = cross_entropy2d(input=input, target=target.squeeze())
168 | # print("loss: ", loss.mean(), loss2)
169 |
170 | # loss = F.nll_loss(input=log_p, target=target.view(-1), reduction="sum")
171 | return loss.mean()
172 |
173 |
174 | if __name__ == '__main__':
175 | target = torch.randint(0, 4, (1, 1, 4, 4)).cuda()
176 | input = torch.rand((1, 4, 4, 4)).cuda()
177 | print("target: ", target.shape)
178 |
179 | net = CrossEntropy().cuda()
180 | out = net(input, target)
181 | print(out)
182 |
--------------------------------------------------------------------------------
/loss/seg_loss/focal_loss.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | '''=================================================
3 | @Project -> File shloc -> focal_loss
4 | @IDE PyCharm
5 | @Author fx221@cam.ac.uk
6 | @Date 29/04/2021 20:18
7 | =================================================='''
8 |
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 |
13 |
14 | class FocalLoss(nn.Module):
15 | """
16 | copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
17 | This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
18 | 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
19 | Focal_Loss= -1*alpha*(1-pt)*log(pt)
20 | :param num_class:
21 | :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
22 | :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
23 | focus on hard misclassified example
24 | :param smooth: (float,double) smooth value when cross entropy
25 | :param balance_index: (int) balance class index, should be specific when alpha is float
26 | :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
27 | """
28 |
29 | def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
30 | super(FocalLoss, self).__init__()
31 | self.apply_nonlin = apply_nonlin
32 | self.alpha = alpha
33 | self.gamma = gamma
34 | self.balance_index = balance_index
35 | self.smooth = smooth
36 | self.size_average = size_average
37 |
38 | if self.smooth is not None:
39 | if self.smooth < 0 or self.smooth > 1.0:
40 | raise ValueError('smooth value should be in [0,1]')
41 |
42 | def forward(self, logit, target):
43 | if self.apply_nonlin is not None:
44 | logit = self.apply_nonlin(logit)
45 | num_class = logit.shape[1]
46 |
47 | if logit.dim() > 2:
48 | # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
49 | logit = logit.view(logit.size(0), logit.size(1), -1)
50 | logit = logit.permute(0, 2, 1).contiguous()
51 | logit = logit.view(-1, logit.size(-1))
52 | target = torch.squeeze(target, 1)
53 | target = target.view(-1, 1)
54 | # print(logit.shape, target.shape)
55 | #
56 | alpha = self.alpha
57 |
58 | if alpha is None:
59 | alpha = torch.ones(num_class, 1)
60 | elif isinstance(alpha, (list, np.ndarray)):
61 | assert len(alpha) == num_class
62 | alpha = torch.FloatTensor(alpha).view(num_class, 1)
63 | alpha = alpha / alpha.sum()
64 | elif isinstance(alpha, float):
65 | alpha = torch.ones(num_class, 1)
66 | alpha = alpha * (1 - self.alpha)
67 | alpha[self.balance_index] = self.alpha
68 |
69 | else:
70 | raise TypeError('Not support alpha type')
71 |
72 | if alpha.device != logit.device:
73 | alpha = alpha.to(logit.device)
74 |
75 | idx = target.cpu().long()
76 |
77 | one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
78 | one_hot_key = one_hot_key.scatter_(1, idx, 1)
79 | if one_hot_key.device != logit.device:
80 | one_hot_key = one_hot_key.to(logit.device)
81 |
82 | if self.smooth:
83 | one_hot_key = torch.clamp(
84 | one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth)
85 | pt = (one_hot_key * logit).sum(1) + self.smooth
86 | logpt = pt.log()
87 |
88 | gamma = self.gamma
89 |
90 | alpha = alpha[idx]
91 | alpha = torch.squeeze(alpha)
92 | loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
93 |
94 | if self.size_average:
95 | loss = loss.mean()
96 | else:
97 | loss = loss.sum()
98 | return loss
99 |
100 |
101 | if __name__ == '__main__':
102 | pass
--------------------------------------------------------------------------------
/loss/seg_loss/segloss.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2021/5/11 下午2:38
3 | # @Author : Fei Xue
4 | # @Email : fx221@cam.ac.uk
5 | # @File : segloss.py
6 | # @Software: PyCharm
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from loss.seg_loss.crossentropy_loss import CrossEntropy
12 |
13 |
14 | def _ohem_mask(loss, ohem_ratio):
15 | with torch.no_grad():
16 | values, _ = torch.topk(loss.reshape(-1),
17 | int(loss.nelement() * ohem_ratio))
18 | mask = loss >= values[-1]
19 | return mask.float()
20 |
21 |
22 | class BCEWithLogitsLossWithOHEM(nn.Module):
23 | def __init__(self, ohem_ratio=1.0, pos_weight=None, eps=1e-7):
24 | super(BCEWithLogitsLossWithOHEM, self).__init__()
25 | self.criterion = nn.BCEWithLogitsLoss(reduction='none',
26 | pos_weight=pos_weight)
27 | self.ohem_ratio = ohem_ratio
28 | self.eps = eps
29 |
30 | def forward(self, pred, target):
31 | loss = self.criterion(pred, target)
32 | mask = _ohem_mask(loss, self.ohem_ratio)
33 | loss = loss * mask
34 | return loss.sum() / (mask.sum() + self.eps)
35 |
36 | def set_ohem_ratio(self, ohem_ratio):
37 | self.ohem_ratio = ohem_ratio
38 |
39 |
40 | class CrossEntropyLossWithOHEM(nn.Module):
41 | def __init__(self,
42 | ohem_ratio=1.0,
43 | weight=None,
44 | ignore_index=-100,
45 | eps=1e-7):
46 | super(CrossEntropyLossWithOHEM, self).__init__()
47 | self.criterion = nn.CrossEntropyLoss(weight=weight,
48 | ignore_index=ignore_index,
49 | reduction='none')
50 | self.ohem_ratio = ohem_ratio
51 | self.eps = eps
52 |
53 | def forward(self, pred, target):
54 | loss = self.criterion(pred, target)
55 | mask = _ohem_mask(loss, self.ohem_ratio)
56 | loss = loss * mask
57 | return loss.sum() / (mask.sum() + self.eps)
58 |
59 | def set_ohem_ratio(self, ohem_ratio):
60 | self.ohem_ratio = ohem_ratio
61 |
62 |
63 | class DiceLoss(nn.Module):
64 | def __init__(self, eps=1e-7):
65 | super(DiceLoss, self).__init__()
66 | self.eps = eps
67 |
68 | def forward(self, pred, target):
69 | pred = torch.sigmoid(pred)
70 | intersection = (pred * target).sum()
71 | loss = 1 - (2. * intersection) / (pred.sum() + target.sum() + self.eps)
72 | return loss
73 |
74 |
75 | class SoftCrossEntropyLossWithOHEM(nn.Module):
76 | def __init__(self, ohem_ratio=1.0, eps=1e-7):
77 | super(SoftCrossEntropyLossWithOHEM, self).__init__()
78 | self.ohem_ratio = ohem_ratio
79 | self.eps = eps
80 |
81 | def forward(self, pred, target):
82 | pred = F.log_softmax(pred, dim=1)
83 | loss = -(pred * target).sum(1)
84 | mask = _ohem_mask(loss, self.ohem_ratio)
85 | loss = loss * mask
86 | return loss.sum() / (mask.sum() + self.eps)
87 |
88 | def set_ohem_ratio(self, ohem_ratio):
89 | self.ohem_ratio = ohem_ratio
90 |
91 |
92 | class SegLoss(nn.Module):
93 | def __init__(self, segloss_name, use_cls=None, use_hiera=None, use_seg=None,
94 | cls_weight=1., hiera_weight=1., label_weights=None):
95 | super(SegLoss, self).__init__()
96 |
97 | if use_seg:
98 | if segloss_name == 'ce':
99 | self.seg_loss = CrossEntropy(weights=label_weights)
100 | elif segloss_name == 'ceohem':
101 | self.seg_loss = CrossEntropyLossWithOHEM(ohem_ratio=0.7, weight=label_weights)
102 | elif segloss_name == 'sceohem':
103 | self.seg_loss = SoftCrossEntropyLossWithOHEM(ohem_ratio=0.7)
104 | else:
105 | self.seg_loss = None
106 | if use_cls:
107 | self.cls_loss = nn.BCEWithLogitsLoss(weight=label_weights)
108 | self.cls_weight = cls_weight
109 | else:
110 | self.cls_loss = None
111 |
112 | if use_hiera:
113 | # self.cls_hiera = nn.BCEWithLogitsLoss(weight=label_weights)
114 | self.cls_hiera = nn.CrossEntropyLoss(weight=label_weights)
115 | self.hiera_weight = hiera_weight
116 | else:
117 | self.cls_hiera = None
118 |
119 | def forward(self, pred_seg=None, gt_seg=None, pred_cls=None, gt_cls=None, pred_hiera=None, gt_hiera=None):
120 | total_loss = 0
121 | output = {
122 |
123 | }
124 | if self.seg_loss is not None:
125 | seg_error = self.seg_loss(pred_seg, gt_seg)
126 | total_loss = total_loss + seg_error
127 | output["seg_loss"] = seg_error
128 |
129 | if self.cls_loss is not None:
130 | cls_error = self.cls_loss(pred_cls, gt_cls)
131 | total_loss = total_loss + cls_error * self.cls_weight
132 | output["cls_loss"] = cls_error
133 |
134 | if self.cls_hiera is not None:
135 | # print (pred_hiera.shape, gt_hiera.shape)
136 | hiera_error = self.cls_hiera(pred_hiera, torch.argmax(gt_hiera, 1).long())
137 | total_loss = total_loss + hiera_error * self.hiera_weight
138 | output["hiera_loss"] = hiera_error
139 |
140 | output["loss"] = total_loss
141 | return output
142 |
--------------------------------------------------------------------------------
/loss/seg_loss/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | '''=================================================
3 | @Project -> File shloc -> utils
4 | @IDE PyCharm
5 | @Author fx221@cam.ac.uk
6 | @Date 29/04/2021 20:12
7 | =================================================='''
8 |
9 | import torch
10 | import numpy as np
11 |
12 |
13 | def make_one_hot(input, num_classes=None):
14 | if num_classes is None:
15 | num_classes = input.max() + 1
16 | shape = np.array(input.shape)
17 | shape[1] = num_classes
18 | shape = tuple(shape)
19 | result = torch.zeros(shape)
20 | result = result.scatter_(1, input.cpu().long, 1)
21 | return result
22 |
--------------------------------------------------------------------------------
/net/layers.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class Flatten(nn.Module):
6 | def forward(self, input):
7 | return input.view(input.size(0), -1)
8 |
9 |
10 | class L2Norm(nn.Module):
11 | def __init__(self, dim=1):
12 | super().__init__()
13 | self.dim = dim
14 |
15 | def forward(self, input):
16 | return F.normalize(input, p=2, dim=self.dim)
17 |
--------------------------------------------------------------------------------
/net/locnets/r2d2.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | '''=================================================
3 | @Project -> File shloc -> r2d2
4 | @IDE PyCharm
5 | @Author fx221@cam.ac.uk
6 | @Date 2021-05-27 14:49
7 | =================================================='''
8 | import torch
9 | import torch.nn as nn
10 | import torchvision.transforms as tvf
11 | import torch.nn.functional as F
12 |
13 | RGB_mean = [0.485, 0.456, 0.406]
14 | RGB_std = [0.229, 0.224, 0.225]
15 |
16 | norm_RGB = tvf.Compose([tvf.ToTensor(), tvf.Normalize(mean=RGB_mean, std=RGB_std)])
17 |
18 |
19 | class BaseNet(nn.Module):
20 | """ Takes a list of images as input, and returns for each image:
21 | - a pixelwise descriptor
22 | - a pixelwise confidence
23 | """
24 |
25 | def softmax(self, ux):
26 | if ux.shape[1] == 1:
27 | x = F.softplus(ux)
28 | return x / (1 + x) # for sure in [0,1], much less plateaus than softmax
29 | elif ux.shape[1] == 2:
30 | return F.softmax(ux, dim=1)[:, 1:2]
31 |
32 | def normalize(self, x, ureliability, urepeatability):
33 | return dict(descriptors=F.normalize(x, p=2, dim=1),
34 | repeatability=self.softmax(urepeatability),
35 | reliability=self.softmax(ureliability))
36 |
37 | def forward_one(self, x):
38 | raise NotImplementedError()
39 |
40 | def forward(self, imgs, **kw):
41 | res = [self.forward_one(img) for img in imgs]
42 | # merge all dictionaries into one
43 | res = {k: [r[k] for r in res if k in r] for k in {k for r in res for k in r}}
44 | return dict(res, imgs=imgs, **kw)
45 |
46 |
47 | class PatchNet(BaseNet):
48 | """ Helper class to construct a fully-convolutional network that
49 | extract a l2-normalized patch descriptor.
50 | """
51 |
52 | def __init__(self, inchan=3, dilated=True, dilation=1, bn=True, bn_affine=False):
53 | BaseNet.__init__(self)
54 | self.inchan = inchan
55 | self.curchan = inchan
56 | self.dilated = dilated
57 | self.dilation = dilation
58 | self.bn = bn
59 | self.bn_affine = bn_affine
60 | self.ops = nn.ModuleList([])
61 |
62 | def _make_bn(self, outd):
63 | return nn.BatchNorm2d(outd, affine=self.bn_affine)
64 |
65 | def _add_conv(self, outd, k=3, stride=1, dilation=1, bn=True, relu=True):
66 | d = self.dilation * dilation
67 | if self.dilated:
68 | conv_params = dict(padding=((k - 1) * d) // 2, dilation=d, stride=1)
69 | self.dilation *= stride
70 | else:
71 | conv_params = dict(padding=((k - 1) * d) // 2, dilation=d, stride=stride)
72 | self.ops.append(nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params))
73 | if bn and self.bn: self.ops.append(self._make_bn(outd))
74 | if relu: self.ops.append(nn.ReLU(inplace=True))
75 | self.curchan = outd
76 |
77 | def forward_one(self, x):
78 | assert self.ops, "You need to add convolutions first"
79 | for n, op in enumerate(self.ops):
80 | x = op(x)
81 | return self.normalize(x)
82 |
83 |
84 | class L2_Net(PatchNet):
85 | """ Compute a 128D descriptor for all overlapping 32x32 patches.
86 | From the L2Net paper (CVPR'17).
87 | """
88 |
89 | def __init__(self, dim=128, **kw):
90 | PatchNet.__init__(self, **kw)
91 | add_conv = lambda n, **kw: self._add_conv((n * dim) // 128, **kw)
92 | add_conv(32)
93 | add_conv(32)
94 | add_conv(64, stride=2)
95 | add_conv(64)
96 | add_conv(128, stride=2)
97 | add_conv(128)
98 | add_conv(128, k=7, stride=8, bn=False, relu=False)
99 | self.out_dim = dim
100 |
101 |
102 | class Quad_L2Net(PatchNet):
103 | """ Same than L2_Net, but replace the final 8x8 conv by 3 successive 2x2 convs.
104 | """
105 |
106 | def __init__(self, dim=128, mchan=4, relu22=False, **kw):
107 | PatchNet.__init__(self, **kw)
108 | self._add_conv(8 * mchan)
109 | self._add_conv(8 * mchan)
110 | self._add_conv(16 * mchan, stride=2)
111 | self._add_conv(16 * mchan)
112 | self._add_conv(32 * mchan, stride=2)
113 | self._add_conv(32 * mchan)
114 | # replace last 8x8 convolution with 3 2x2 convolutions
115 | self._add_conv(32 * mchan, k=2, stride=2, relu=relu22)
116 | self._add_conv(32 * mchan, k=2, stride=2, relu=relu22)
117 | self._add_conv(dim, k=2, stride=2, bn=False, relu=False)
118 | self.out_dim = dim
119 |
120 |
121 | class Quad_L2Net_ConfCFS(Quad_L2Net):
122 | """ Same than Quad_L2Net, with 2 confidence maps for repeatability and reliability.
123 | """
124 |
125 | def __init__(self, **kw):
126 | Quad_L2Net.__init__(self, **kw)
127 | # reliability classifier
128 | self.clf = nn.Conv2d(self.out_dim, 2, kernel_size=1)
129 | # repeatability classifier: for some reasons it's a softplus, not a softmax!
130 | # Why? I guess it's a mistake that was left unnoticed in the code for a long time...
131 | self.sal = nn.Conv2d(self.out_dim, 1, kernel_size=1)
132 |
133 | def forward_one(self, x):
134 | assert self.ops, "You need to add convolutions first"
135 | for op in self.ops:
136 | x = op(x)
137 | # compute the confidence maps
138 | ureliability = self.clf(x ** 2)
139 | urepeatability = self.sal(x ** 2)
140 | return self.normalize(x, ureliability, urepeatability)
141 |
142 |
143 | def conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1, use_bn=False, dilation=1):
144 | if not use_bn:
145 | return nn.Sequential(
146 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
147 | padding=1, dilation=dilation),
148 | nn.ReLU(),
149 | )
150 | else:
151 | return nn.Sequential(
152 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
153 | padding=padding, dilation=dilation),
154 | nn.BatchNorm2d(out_channels, affine=False),
155 | nn.ReLU(),
156 | )
157 |
158 |
159 | class NonMaxSuppression(torch.nn.Module):
160 | def __init__(self, rel_thr=0.7, rep_thr=0.7):
161 | nn.Module.__init__(self)
162 | self.max_filter = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
163 | self.rel_thr = rel_thr
164 | self.rep_thr = rep_thr
165 |
166 | def forward(self, reliability, repeatability, **kw):
167 | assert len(reliability) == len(repeatability) == 1
168 | reliability, repeatability = reliability[0], repeatability[0]
169 |
170 | # local maxima
171 | maxima = (repeatability == self.max_filter(repeatability))
172 |
173 | # remove low peaks
174 | maxima *= (repeatability >= self.rep_thr)
175 | maxima *= (reliability >= self.rel_thr)
176 |
177 | return maxima.nonzero().t()[2:4]
178 |
179 |
180 | def extract_multiscale(net, img, detector, scale_f=2 ** 0.25,
181 | min_scale=0.0, max_scale=1,
182 | min_size=256, max_size=1024,
183 | verbose=False):
184 | old_bm = torch.backends.cudnn.benchmark
185 | torch.backends.cudnn.benchmark = False # speedup
186 |
187 | # extract keypoints at multiple scales
188 | B, three, H, W = img.shape
189 | assert B == 1 and three == 3, "should be a batch with a single RGB image"
190 |
191 | assert max_scale <= 1
192 | s = 1.0 # current scale factor
193 |
194 | X, Y, S, C, Q, D = [], [], [], [], [], []
195 | pts_list = []
196 | while s + 0.001 >= max(min_scale, min_size / max(H, W)):
197 | if s - 0.001 <= min(max_scale, max_size / max(H, W)):
198 | nh, nw = img.shape[2:]
199 | if verbose:
200 | # print(f"extracting at scale x{s:.02f} = {nw:4d}x{nh:3d}")
201 | print("extracting at scale x{:.02f} = {:4d}x{:3d}".format(s, nw, nh))
202 | # extract descriptors
203 | with torch.no_grad():
204 | res = net(imgs=[img])
205 |
206 | # get output and reliability map
207 | descriptors = res['descriptors'][0]
208 | reliability = res['reliability'][0]
209 | repeatability = res['repeatability'][0]
210 |
211 | # normalize the reliability for nms
212 | # extract maxima and descs
213 | with torch.no_grad():
214 | y, x = detector(**res) # nms
215 | c = reliability[0, 0, y, x]
216 | q = repeatability[0, 0, y, x]
217 | d = descriptors[0, :, y, x].t()
218 | n = d.shape[0]
219 |
220 | # accumulate multiple scales
221 | X.append(x.float() * W / nw)
222 | Y.append(y.float() * H / nh)
223 | S.append((32 / s) * torch.ones(n, dtype=torch.float32, device=d.device))
224 | C.append(c)
225 | Q.append(q)
226 | D.append(d)
227 |
228 | pts_list.append(torch.stack([x.float() * W / nw, y.float() * H / nh], dim=-1))
229 |
230 | s /= scale_f
231 |
232 | # down-scale the image for next iteration
233 | nh, nw = round(H * s), round(W * s)
234 | img = F.interpolate(img, (nh, nw), mode='bilinear', align_corners=False)
235 |
236 | # restore value
237 | torch.backends.cudnn.benchmark = old_bm
238 |
239 | # print("Y: ", len(Y))
240 | # print("X: ", len(X))
241 | # print("S: ", len(S))
242 | Y = torch.cat(Y)
243 | X = torch.cat(X)
244 | S = torch.cat(S) # scale
245 | scores = torch.cat(C) * torch.cat(Q) # scores = reliability * repeatability
246 | XYS = torch.stack([X, Y, S], dim=-1)
247 | D = torch.cat(D)
248 |
249 | XYS = XYS.cpu().numpy()
250 | D = D.cpu().numpy()
251 | scores = scores.cpu().numpy()
252 | return XYS, D, scores, pts_list
253 |
254 |
255 | def extract(net, img, detector):
256 | old_bm = torch.backends.cudnn.benchmark
257 | torch.backends.cudnn.benchmark = False # speedup
258 |
259 | with torch.no_grad():
260 | res = net(imgs=[img])
261 |
262 | # get output and reliability map
263 | descriptors = res['descriptors'][0]
264 | reliability = res['reliability'][0]
265 | repeatability = res['repeatability'][0]
266 |
267 | print("rel: ", torch.min(reliability), torch.max(reliability), torch.median(reliability))
268 | print("rep: ", torch.min(repeatability), torch.max(repeatability), torch.median(repeatability))
269 |
270 | # normalize the reliability for nms
271 | # extract maxima and descs
272 | with torch.no_grad():
273 | y, x = detector(**res) # nms
274 | c = reliability[0, 0, y, x]
275 | q = repeatability[0, 0, y, x]
276 | d = descriptors[0, :, y, x].t()
277 | n = d.shape[0]
278 |
279 | print("after nms: ", n)
280 |
281 | X, Y, S, C, Q, D = [], [], [], [], [], []
282 |
283 | X.append(x.float())
284 | Y.append(y.float())
285 | C.append(c)
286 | Q.append(q)
287 | D.append(d)
288 |
289 | # restore value
290 | torch.backends.cudnn.benchmark = old_bm
291 |
292 | Y = torch.cat(Y)
293 | X = torch.cat(X)
294 | scores = torch.cat(C) * torch.cat(Q)
295 |
296 | XYS = torch.stack([X, Y], dim=-1)
297 | D = torch.cat(D)
298 |
299 | return XYS, D, scores
300 |
301 |
302 | def extract_r2d2_return(r2d2, img, need_nms=False, **kwargs):
303 | # img = Image.open(img_path).convert('RGB')
304 | # H, W = img.size
305 | img = norm_RGB(img)
306 | img = img[None]
307 | img = img.cuda()
308 |
309 | rel_thr = 0.99 # 0.99
310 | rep_thr = 0.995 # 0.995
311 | min_size = 256
312 | max_size = 9999
313 | detector = NonMaxSuppression(rel_thr=rel_thr, rep_thr=rep_thr).cuda().eval()
314 |
315 | xys, desc, scores, pts_list = extract_multiscale(net=r2d2, img=img, detector=detector, min_size=min_size,
316 | max_size=max_size, scale_f=1.2) # r2d2 mode
317 |
318 | return xys, desc, scores
319 |
--------------------------------------------------------------------------------
/net/plceregnets/gem.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | '''=================================================
3 | @Project -> File shloc -> gem
4 | @IDE PyCharm
5 | @Author fx221@cam.ac.uk
6 | @Date 25/08/2021 21:00
7 | =================================================='''
8 | import torch
9 | import torch.nn as nn
10 | from torch.nn.modules import Module
11 | from torch.nn.parameter import Parameter
12 | import torch.nn.functional as F
13 |
14 |
15 | class GeneralizedMeanPooling(Module):
16 | r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes.
17 | The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)`
18 | - At p = infinity, one gets Max Pooling
19 | - At p = 1, one gets Average Pooling
20 | The output is of size H x W, for any input size.
21 | The number of output features is equal to the number of input planes.
22 | Args:
23 | output_size: the target output size of the image of the form H x W.
24 | Can be a tuple (H, W) or a single H for a square image H x H
25 | H and W can be either a ``int``, or ``None`` which means the size will
26 | be the same as that of the input.
27 | """
28 |
29 | def __init__(self, norm, output_size=1, eps=1e-6):
30 | super(GeneralizedMeanPooling, self).__init__()
31 | assert norm > 0
32 | self.p = float(norm)
33 | self.output_size = output_size
34 | self.eps = eps
35 |
36 | def forward(self, x):
37 | x = x.clamp(min=self.eps).pow(self.p)
38 | return F.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p)
39 |
40 | def __repr__(self):
41 | return self.__class__.__name__ + '(' \
42 | + str(self.p) + ', ' \
43 | + 'output_size=' + str(self.output_size) + ')'
44 |
45 |
46 | class GeneralizedMeanPoolingP(GeneralizedMeanPooling):
47 | """ Same, but norm is trainable
48 | """
49 |
50 | def __init__(self, norm=3, output_size=1, eps=1e-6):
51 | super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps)
52 | self.p = Parameter(torch.ones(1) * norm)
53 |
54 |
55 | def l2_normalize(x, axis=-1):
56 | x = F.normalize(x, p=2, dim=axis)
57 | return x
58 |
59 |
60 | class GEM(nn.Module):
61 | def __init__(self, in_dim=1024, out_dim=2048, norm_features=False, pooling='gem', gemp=3, center_bias=0,
62 | dropout_p=None, without_fc=False, projection=False, cls=False, n_classes=452):
63 | super(GEM, self).__init__()
64 | self.norm_features = norm_features
65 | self.without_fc = without_fc
66 | self.pooling = pooling
67 | self.center_bias = center_bias
68 | self.projection = projection
69 | self.cls = cls
70 | self.n_classes = n_classes
71 |
72 | if self.projection:
73 | self.proj = nn.Sequential(
74 | nn.Conv2d(in_dim, in_dim, kernel_size=3, padding=1, bias=False),
75 | nn.BatchNorm2d(in_dim),
76 | nn.ReLU(inplace=True),
77 | nn.Conv2d(in_dim, in_dim, kernel_size=3, padding=1, bias=False),
78 | nn.BatchNorm2d(in_dim),
79 | nn.ReLU(inplace=True),
80 | nn.Conv2d(in_dim, in_dim, kernel_size=1, padding=0, bias=False),
81 | # nn.BatchNorm2d(in_dim),
82 | # nn.ReLU(inplace=True),
83 | )
84 | else:
85 | self.proj = nn.Sequential(
86 | nn.Conv2d(in_dim, in_dim, kernel_size=1, padding=0, bias=False),
87 | )
88 |
89 | if self.cls:
90 | self.cls_head = nn.Linear(in_features=in_dim, out_features=self.n_classes)
91 |
92 | if pooling == 'max':
93 | self.adpool = nn.AdaptiveMaxPool2d(output_size=1)
94 | elif pooling == 'avg':
95 | self.adpool = nn.AdaptiveAvgPool2d(output_size=1)
96 | elif pooling.startswith('gem'):
97 | self.adpool = GeneralizedMeanPoolingP(norm=gemp)
98 | else:
99 | raise ValueError(pooling)
100 |
101 | self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
102 | # self.fc = nn.Linear(512 * block.expansion, out_dim)
103 | # self.fc = nn.Linear(in_dim, out_dim)
104 | # self.fc_name = 'fc'
105 | # self.feat_dim = out_dim
106 |
107 | def forward(self, feat, atten=None):
108 | if atten is not None:
109 | with torch.no_grad():
110 | if atten.shape[2] != feat.shape[2] or atten.shape[3] != feat.shape[3]:
111 | atten = F.interpolate(atten, size=(feat.shape[2], feat.shape[3]), mode='bilinear')
112 | feat = feat * atten.expand_as(feat)
113 |
114 | if self.projection:
115 | x = self.proj(feat)
116 |
117 | bs, _, H, W = x.shape
118 | if self.dropout is not None:
119 | x = self.dropout(x)
120 |
121 | if self.center_bias > 0:
122 | b = self.center_bias
123 | bias = 1 + torch.FloatTensor([[[[0, 0, 0, 0], [0, b, b, 0], [0, b, b, 0], [0, 0, 0, 0]]]]).to(x.device)
124 | bias_resize = torch.nn.functional.interpolate(bias, size=x.shape[-2:], mode='bilinear', align_corners=True)
125 | x = x * bias_resize
126 | # global pooling
127 | x = self.adpool(x)
128 |
129 | if self.norm_features:
130 | x = l2_normalize(x, axis=1)
131 |
132 | x = x.view(x.shape[0], -1)
133 | # if not self.without_fc:
134 | # x = self.fc(x)
135 | if self.cls:
136 | cls_feat = self.cls_head(x)
137 | x = l2_normalize(x, axis=-1)
138 |
139 | output = {
140 | 'feat': x,
141 | }
142 |
143 | if self.cls:
144 | output['cls'] = cls_feat
145 | return output
146 |
--------------------------------------------------------------------------------
/net/plceregnets/pregnet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | '''=================================================
3 | @Project -> File shloc -> pregnet
4 | @IDE PyCharm
5 | @Author fx221@cam.ac.uk
6 | @Date 23/08/2021 21:55
7 | =================================================='''
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from sklearn.neighbors import NearestNeighbors
12 | import numpy as np
13 |
14 |
15 | class NetVLAD(nn.Module):
16 | """NetVLAD layer implementation"""
17 |
18 | def __init__(self, num_clusters=64, dim=128, in_dim=512, proj_layer=None,
19 | normalize_input=True, vladv2=False, projection=False, cls=False, n_classes=452):
20 | """
21 | Args:
22 | num_clusters : int
23 | The number of clusters
24 | dim : int
25 | Dimension of descriptors
26 | alpha : float
27 | Parameter of initialization. Larger value is harder assignment.
28 | normalize_input : bool
29 | If true, descriptor-wise L2 normalization is applied to input.
30 | vladv2 : bool
31 | If true, use vladv2 otherwise use vladv1
32 | """
33 | super(NetVLAD, self).__init__()
34 | self.num_clusters = num_clusters
35 | self.dim = dim
36 | self.alpha = 0
37 | self.vladv2 = vladv2
38 | self.normalize_input = normalize_input
39 | self.conv = nn.Conv2d(dim, num_clusters, kernel_size=(1, 1), bias=vladv2)
40 | self.centroids = nn.Parameter(torch.rand(num_clusters, dim))
41 | self.cls = cls
42 | self.n_classes = n_classes
43 |
44 | self.projection = projection
45 | if self.projection:
46 | if proj_layer is None:
47 | self.proj = nn.Sequential(
48 | nn.Conv2d(in_dim, 512, kernel_size=3, padding=1, bias=False),
49 | nn.BatchNorm2d(512),
50 | nn.ReLU(inplace=True),
51 | nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False),
52 | # nn.BatchNorm2d(512),
53 | # nn.ReLU(inplace=True),
54 | )
55 | else:
56 | self.proj = nn.Sequential(*proj_layer)
57 |
58 | if self.cls:
59 | self.cls_head = nn.Linear(in_features=512, out_features=self.n_classes)
60 |
61 | def init_params(self, clsts, traindescs):
62 | print("Init centroids")
63 | # TODO replace numpy ops with pytorch ops
64 | if self.vladv2 == False:
65 | clstsAssign = clsts / (np.linalg.norm(clsts, axis=1, keepdims=True) + 1e-5)
66 | dots = np.dot(clstsAssign, traindescs.T)
67 | dots.sort(0)
68 | dots = dots[::-1, :] # sort, descending
69 |
70 | self.alpha = (-np.log(0.01) / np.mean(dots[0, :] - dots[1, :])).item()
71 | self.centroids = nn.Parameter(torch.from_numpy(clsts))
72 | self.conv.weight = nn.Parameter(torch.from_numpy(self.alpha * clstsAssign).unsqueeze(2).unsqueeze(3))
73 | self.conv.bias = None
74 | else:
75 | knn = NearestNeighbors(n_jobs=-1) # TODO faiss?
76 | knn.fit(traindescs)
77 | del traindescs
78 | dsSq = np.square(knn.kneighbors(clsts, 2)[1])
79 | del knn
80 | self.alpha = (-np.log(0.01) / np.mean(dsSq[:, 1] - dsSq[:, 0])).item()
81 | self.centroids = nn.Parameter(torch.from_numpy(clsts))
82 | del clsts, dsSq
83 |
84 | self.conv.weight = nn.Parameter(
85 | (2.0 * self.alpha * self.centroids).unsqueeze(-1).unsqueeze(-1)
86 | )
87 | self.conv.bias = nn.Parameter(
88 | - self.alpha * self.centroids.norm(dim=1)
89 | )
90 |
91 | def forward(self, x, atten=None):
92 | if atten is not None:
93 | if atten.shape[2] != x.shape[2] or atten.shape[3] != x.shape[3]:
94 | with torch.no_grad():
95 | atten = F.interpolate(atten, size=(x.shape[2], x.shape[3]), mode='bilinear')
96 | x = x * atten.expand_as(x)
97 |
98 | # if self.projection:
99 | x_enc = self.proj(x)
100 | # else:
101 | # x_enc = x
102 | if self.cls:
103 | x_vec = F.adaptive_avg_pool2d(x_enc, output_size=1).reshape(N, -1)
104 | cls_feat = self.cls_head(x_vec)
105 |
106 | if self.normalize_input:
107 | x = F.normalize(x_enc, p=2, dim=1) # across descriptor dim
108 |
109 | N, C, H, W = x.shape
110 | # soft-assignment
111 | soft_assign = self.conv(x).view(N, self.num_clusters, -1)
112 | soft_assign = F.softmax(soft_assign, dim=1)
113 |
114 | x_flatten = x.view(N, C, -1)
115 |
116 | # calculate residuals to each clusters
117 | vlad = torch.zeros([N, self.num_clusters, C], dtype=x.dtype, layout=x.layout, device=x.device)
118 | for C in range(self.num_clusters): # slower than non-looped, but lower memory usage
119 | residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - \
120 | self.centroids[C:C + 1, :].expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0)
121 | residual *= soft_assign[:, C:C + 1, :].unsqueeze(2)
122 | vlad[:, C:C + 1, :] = residual.sum(dim=-1)
123 |
124 | vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization
125 | vlad = vlad.view(x.size(0), -1) # flatten
126 | vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize
127 |
128 | output = {
129 | 'feat': vlad,
130 | 'img_desc': x_enc,
131 | }
132 |
133 | if self.cls:
134 | output['cls'] = cls_feat
135 | return output
136 |
--------------------------------------------------------------------------------
/net/regnets/deeplab.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | '''=================================================
3 | @Project -> File shloc -> deeplab
4 | @IDE PyCharm
5 | @Author fx221@cam.ac.uk
6 | @Date 06/09/2021 13:36
7 | =================================================='''
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from segmentation_models_pytorch.encoders import get_encoder
12 | from segmentation_models_pytorch.deeplabv3.decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder
13 | from segmentation_models_pytorch.base import SegmentationModel, SegmentationHead, ClassificationHead
14 | from segmentation_models_pytorch.base.modules import Flatten, Activation, Conv2dReLU
15 | from typing import Optional, Union
16 |
17 |
18 | class ASPPConv(nn.Sequential):
19 | def __init__(self, in_channels, out_channels, dilation):
20 | modules = [
21 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
22 | nn.BatchNorm2d(out_channels),
23 | nn.ReLU(inplace=True)
24 | ]
25 | super(ASPPConv, self).__init__(*modules)
26 |
27 |
28 | class ASPPPooling(nn.Sequential):
29 | def __init__(self, in_channels, out_channels):
30 | super(ASPPPooling, self).__init__(
31 | nn.AdaptiveAvgPool2d(1),
32 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
33 | nn.BatchNorm2d(out_channels),
34 | nn.ReLU(inplace=True))
35 |
36 | def forward(self, x):
37 | size = x.shape[-2:]
38 | x = super(ASPPPooling, self).forward(x)
39 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
40 |
41 |
42 | class ASPP(nn.Module):
43 | def __init__(self, in_channels, atrous_rates):
44 | super(ASPP, self).__init__()
45 | out_channels = 256
46 | modules = []
47 | modules.append(nn.Sequential(
48 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
49 | nn.BatchNorm2d(out_channels),
50 | nn.ReLU(inplace=True)))
51 |
52 | rate1, rate2, rate3 = tuple(atrous_rates)
53 | modules.append(ASPPConv(in_channels, out_channels, rate1))
54 | modules.append(ASPPConv(in_channels, out_channels, rate2))
55 | modules.append(ASPPConv(in_channels, out_channels, rate3))
56 | modules.append(ASPPPooling(in_channels, out_channels))
57 |
58 | self.convs = nn.ModuleList(modules)
59 |
60 | self.project = nn.Sequential(
61 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
62 | nn.BatchNorm2d(out_channels),
63 | nn.ReLU(inplace=True),
64 | nn.Dropout(0.1), )
65 |
66 | def forward(self, x):
67 | res = []
68 | for conv in self.convs:
69 | res.append(conv(x))
70 | res = torch.cat(res, dim=1)
71 | return self.project(res)
72 |
73 |
74 | class DeepLabHeadV3Plus(nn.Module):
75 | def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]):
76 | super(DeepLabHeadV3Plus, self).__init__()
77 | self.project = nn.Sequential(
78 | nn.Conv2d(low_level_channels, 48, 1, bias=False),
79 | nn.BatchNorm2d(48),
80 | nn.ReLU(inplace=True),
81 | )
82 |
83 | self.aspp = ASPP(in_channels, aspp_dilate)
84 |
85 | self.classifier = nn.Sequential(
86 | nn.Conv2d(304, 256, 3, padding=1, bias=False),
87 | nn.BatchNorm2d(256),
88 | nn.ReLU(inplace=True),
89 | nn.Conv2d(256, num_classes, 1)
90 | )
91 | self._init_weight()
92 |
93 | def forward(self, feature):
94 | low_level_feature = self.project(feature['low_level'])
95 | output_feature = self.aspp(feature['out'])
96 | output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear',
97 | align_corners=False)
98 | return self.classifier(torch.cat([low_level_feature, output_feature], dim=1))
99 |
100 | def _init_weight(self):
101 | for m in self.modules():
102 | if isinstance(m, nn.Conv2d):
103 | nn.init.kaiming_normal_(m.weight)
104 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
105 | nn.init.constant_(m.weight, 1)
106 | nn.init.constant_(m.bias, 0)
107 |
108 |
109 | class DeepLabV3Plus(nn.Module):
110 | def __init__(self,
111 | encoder_name: str = "resnet34",
112 | encoder_weights: Optional[str] = "imagenet",
113 | encoder_output_stride: int = 8,
114 | decoder_channels: int = 256,
115 | decoder_atrous_rates: tuple = (12, 24, 36),
116 | in_channels: int = 3,
117 | encoder_depth: int = 3,
118 | classes: int = 1,
119 | upsampling: int = 8,
120 | activation: Optional[str] = None,
121 | classification: bool = False,
122 | ):
123 | super(DeepLabV3Plus, self).__init__()
124 | self.classification = classification
125 |
126 | self.encoder = get_encoder(name=encoder_name,
127 | in_channels=in_channels,
128 | depth=encoder_depth,
129 | weights=encoder_weights,
130 | )
131 |
132 | if encoder_output_stride == 8:
133 | self.encoder.make_dilated(
134 | stage_list=[4, 5],
135 | dilation_list=[2, 4]
136 | )
137 |
138 | elif encoder_output_stride == 16:
139 | self.encoder.make_dilated(
140 | stage_list=[5],
141 | dilation_list=[2]
142 | )
143 | else:
144 | raise ValueError(
145 | "Encoder output stride should be 8 or 16, got {}".format(encoder_output_stride))
146 |
147 | self.decoder = DeepLabV3PlusDecoder(
148 | encoder_channels=self.encoder.out_channels,
149 | out_channels=decoder_channels,
150 | atrous_rates=decoder_atrous_rates,
151 | output_stride=encoder_output_stride,
152 | )
153 |
154 | self.seghead = SegmentationHead(
155 | in_channels=self.decoder.out_channels,
156 | out_channels=classes,
157 | activation=activation,
158 | kernel_size=1,
159 | upsampling=upsampling,
160 | )
161 |
162 | if self.classification:
163 | self.cls_head = nn.Sequential(
164 | nn.AdaptiveAvgPool2d(1),
165 | Flatten(),
166 | nn.Linear(self.encoder.out_channels[-1], classes, bias=True)
167 | )
168 |
169 | def forward(self, x):
170 | features = self.encoder(x)
171 | # print('len: ', len(features))
172 | # for i in range(5):
173 | # print(i, features[i].shape)
174 | decoder_output = self.decoder(*features)
175 |
176 | masks = self.seghead(decoder_output)
177 | # print('mask: ', masks.shape)
178 | output = {"masks": [masks]}
179 | if self.classification:
180 | cls = self.cls_head(features[-1])
181 | output["cls"] = [cls]
182 |
183 | output['feats'] = features
184 |
185 | return output
186 |
--------------------------------------------------------------------------------
/net/regnets/pspnet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | '''=================================================
3 | @Project -> File shloc -> pspnet
4 | @IDE PyCharm
5 | @Author fx221@cam.ac.uk
6 | @Date 2021-05-12 19:22
7 | =================================================='''
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from segmentation_models_pytorch.encoders import get_encoder
13 | from segmentation_models_pytorch.pspnet.decoder import PSPDecoder
14 | from segmentation_models_pytorch.base import SegmentationModel, SegmentationHead, ClassificationHead
15 | from segmentation_models_pytorch.base.modules import Flatten, Activation, Conv2dReLU
16 | import segmentation_models_pytorch.base.initialization as init
17 | from loss.seg_loss.crossentropy_loss import cross_entropy2d
18 |
19 | from typing import Optional, Union
20 |
21 |
22 | class CondLayer(nn.Module):
23 | """
24 | implementation of the element-wise linear modulation layer
25 | """
26 |
27 | def __init__(self):
28 | super(CondLayer, self).__init__()
29 | self.relu = nn.ReLU(inplace=True)
30 |
31 | def forward(self, x, gammas, betas):
32 | return self.relu((x * gammas.expand_as(x)) + betas.expand_as(x))
33 |
34 |
35 | def conv(in_planes, out_planes, kernel_size=3, stride=1, bn=False):
36 | if not bn:
37 | return nn.Sequential(
38 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size,
39 | stride=stride, padding=(kernel_size - 1) // 2),
40 | nn.ReLU(inplace=True)
41 | )
42 | else:
43 | return nn.Sequential(
44 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size,
45 | stride=stride, padding=(kernel_size - 1) // 2),
46 | nn.BatchNorm2d(out_planes, affine=True),
47 | nn.ReLU(inplace=True)
48 | )
49 |
50 |
51 | def conv1x1(in_planes, out_planes):
52 | return nn.Sequential(
53 | nn.Conv2d(in_planes, out_planes, kernel_size=1, padding=0)
54 | )
55 |
56 |
57 | class PSPUpsample(nn.Module):
58 | def __init__(self, in_channels, out_channels):
59 | super().__init__()
60 | self.conv = nn.Sequential(
61 | nn.Conv2d(in_channels, out_channels, 3, padding=1),
62 | nn.BatchNorm2d(out_channels),
63 | nn.PReLU()
64 | )
65 |
66 | def forward(self, x):
67 | h, w = 2 * x.size(2), 2 * x.size(3)
68 | p = F.upsample(input=x, size=(h, w), mode='bilinear')
69 | return self.conv(p)
70 |
71 |
72 | class PSPNetF(SegmentationModel):
73 | def __init__(self,
74 | encoder_name: str = "resnet34",
75 | encoder_weights: Optional[str] = "imagenet",
76 | encoder_depth: int = 3,
77 | psp_out_channels: int = 512,
78 | psp_use_batchnorm: bool = True,
79 | psp_dropout: float = 0.2,
80 | in_channels: int = 3,
81 | activation: Optional[Union[str, callable]] = None,
82 | upsampling: int = 8,
83 | hierarchical: bool = False,
84 | classification: bool = False,
85 | segmentation: bool = False,
86 | classes=21428,
87 | out_indices=(1, 2, 3),
88 | require_spp_feats=False
89 | ):
90 | super(PSPNetF, self).__init__()
91 | self.classification = classification
92 | self.require_spp_feats = require_spp_feats
93 |
94 | self.init_modules = []
95 |
96 | self.encoder = get_encoder(name=encoder_name,
97 | in_channels=3,
98 | depth=encoder_depth,
99 | weights=encoder_weights)
100 |
101 | self.decoder = PSPDecoder(
102 | encoder_channels=self.encoder.out_channels, # 3, 64, 256
103 | use_batchnorm=psp_use_batchnorm,
104 | out_channels=psp_out_channels,
105 | dropout=psp_dropout,
106 | )
107 |
108 | self.seghead = SegmentationHead(
109 | in_channels=psp_out_channels,
110 | out_channels=classes,
111 | kernel_size=3,
112 | activation=activation,
113 | upsampling=upsampling, # 8 for robotcar, 4 for obs9, 2 for obs6 & obs4
114 | )
115 |
116 | if classification:
117 | self.cls_head = nn.Sequential(
118 | nn.AdaptiveAvgPool2d(1),
119 | Flatten(),
120 | nn.Linear(self.encoder.out_channels[-1], classes, bias=True)
121 | )
122 |
123 | self.name = "psp-{}".format(encoder_name)
124 | self.initialize()
125 |
126 | def compute_seg_loss(self, pred_segs, gt_segs, weights=[1.0, 1.0, 1.0, 1.0]):
127 | # pred_segs = inputs["masks"]
128 | # gt_segs = outputs["label"]
129 |
130 | seg_loss = 0
131 |
132 | for pseg, gseg in zip(pred_segs, gt_segs):
133 | # print("pseg, gseg: ", pseg.shape, gseg.shape)
134 | gseg = gseg.cuda()
135 | if len(gseg.shape) == 3:
136 | gseg = gseg.unsqueeze(1)
137 | if pseg.shape[2] != gseg.shape[2] or pseg.shape[3] != gseg.shape[3]:
138 | gseg = F.interpolate(gseg.float(), size=(pseg.shape[2], pseg.shape[3]), mode="nearest")
139 |
140 | # seg_loss += cross_entropy_seg(input=pseg, target=gseg)
141 | seg_loss += cross_entropy2d(input=pseg, target=gseg.long())
142 |
143 | return seg_loss
144 |
145 | def compute_cls_loss(self, pred_cls, gt_cls, method="cel"):
146 | cls_loss = 0
147 | for pc, gc in zip(pred_cls, gt_cls):
148 | gc = gc.cuda()
149 | cls_loss += torch.nn.functional.binary_cross_entropy_with_logits(pc, gc)
150 | return cls_loss
151 |
152 | def forward(self, x):
153 | """Sequentially pass `x` trough model`s encoder, decoder and heads"""
154 | # batch = x.shape[0]
155 | # device = self.parameters().device
156 | # print('device: ', device)
157 | features = self.encoder(x)
158 | # for v in features:
159 | # print(v.shape)
160 | decode_feat = self.decoder(*features)
161 |
162 | masks = self.seghead(decode_feat)
163 |
164 | # seg_loss = self.compute_seg_loss(pred_segs=[masks], gt_segs=input['label'])
165 |
166 | output = {"masks": [masks]}
167 | if self.classification:
168 | cls = self.cls_head(features[-1])
169 | output["cls"] = [cls]
170 | output['feats'] = features
171 |
172 | return output
173 |
174 | def predict(self, x):
175 | """Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()`
176 |
177 | Args:
178 | x: 4D torch tensor with shape (batch_size, channels, height, width)
179 |
180 | Return:
181 | prediction: 4D torch tensor with shape (batch_size, classes, height, width)
182 |
183 | """
184 | if self.training:
185 | self.eval()
186 |
187 | with torch.no_grad():
188 | x = self.forward(x)
189 |
190 | return x
191 |
192 | def initialize(self):
193 | init.initialize_decoder(self.decoder)
194 |
195 | init.initialize_head(self.seghead)
196 | if self.classification:
197 | init.initialize_head(self.cls_head)
198 |
199 |
200 | if __name__ == '__main__':
201 | net = PSPNetF(
202 | encoder_name="timm-resnest50d",
203 | encoder_weights="imagenet",
204 | # classes=256,
205 | # clusters=200,
206 | encoder_depth=4,
207 | # psp_out_channels=512,
208 | ).cuda()
209 |
210 | print(net)
211 | img = torch.ones((4, 3, 256, 256)).cuda()
212 | out = net(img)
213 | if "masks" in out.keys():
214 | masks = out["masks"]
215 | print(masks[0].shape, masks[1].shape, masks[2].shape)
216 | if "cls" in out.keys():
217 | cls = out["cls"]
218 | print(cls[0].shape, cls[1].shape, cls[2].shape)
219 | # print (v.shape for v in masks)
220 | # print (v.shape for v in cls)
221 |
--------------------------------------------------------------------------------
/net/segnet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2021/5/11 上午11:36
3 | # @Author : Fei Xue
4 | # @Email : fx221@cam.ac.uk
5 | # @File : regnet.py
6 | # @Software: PyCharm
7 |
8 | from net.regnets.pspnet import PSPNetF
9 | from net.regnets.deeplab import DeepLabV3Plus
10 |
11 |
12 | def get_segnet(network,
13 | n_classes,
14 | encoder_name=None,
15 | encoder_weights=None,
16 | out_channels=512,
17 | classification=True,
18 | segmentation=True,
19 | encoder_depth=4,
20 | upsampling=8,
21 | ):
22 | if network == "pspf":
23 | net = PSPNetF(
24 | encoder_name=encoder_name,
25 | encoder_weights=encoder_weights,
26 |
27 | # aux_params=aux_params,
28 | encoder_depth=encoder_depth, # 3 for robotcar, 4 for obs 6 & 9
29 | psp_out_channels=out_channels,
30 | # hierarchical=hierarchical,
31 | classification=classification,
32 | # segmentation=segmentation,
33 | upsampling=upsampling,
34 | # classes=21428,
35 | # classes=3962,
36 | classes=n_classes,
37 | )
38 | elif network == 'deeplabv3p':
39 | net = DeepLabV3Plus(
40 | encoder_name=encoder_name,
41 | encoder_weights=encoder_weights,
42 | decoder_channels=out_channels,
43 | decoder_atrous_rates=(12, 24, 36),
44 | encoder_output_stride=8,
45 | encoder_depth=encoder_depth,
46 | classification=classification,
47 | upsampling=upsampling,
48 | classes=n_classes,
49 | )
50 |
51 | return net
52 |
--------------------------------------------------------------------------------
/robotcar.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feixue94/lbr/bdd126b79da9e75759a364a0d8e010d76658f1af/robotcar.npy
--------------------------------------------------------------------------------
/run_loc_aachen:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | root=/scratches/flyer_2
4 |
5 | dataset=$root/fx221/localization/aachen_v1_1
6 | outputs=$root/fx221/localization/outputs/aachen_v1_1
7 | image_dir=$root/fx221/localization/aachen_v1_1/images/images_upright
8 | seg_dir=$root/fx221/localization/aachen_v1_1/global_seg_instance
9 | #seg_dir=$root/fx221/localization/aachen_v1_1/global_seg_instance/remap
10 | save_root=$root/fx221/exp/shloc/aachen
11 |
12 | #feat=resnetxn-triv2-0001-n4096-r1600-mask
13 | #feat=resnetxn-triv2-0001-n1000-r1600-mask
14 | #feat=resnetxn-triv2-0001-n2000-r1600-mask
15 | feat=resnetxn-triv2-0001-n3000-r1600-mask
16 |
17 |
18 | #feat=resnetxn-triv2-ms-0001-n4096-r1600-mask
19 | #feat=resnetxn-triv2-ms-0001-n10000-r1600-mask
20 |
21 |
22 | weight_name=2021_08_29_12_49_48_aachen_pspf_resnext101_32x4d_d4_u8_b16_R256_E120_ceohem_adam_seg_cls_aug_stylized
23 | #weight_name=2021_09_06_15_13_26_aachen_deeplabv3p_resnet101_d5_u2_b8_R256_E120_ceohem_sgd_mlr_seg_cls_aug_stylized
24 | #weight_name=2021_09_08_19_27_29_aachen_deeplabv3p_resnext101_32x4d_d5_u2_b16_R256_E120_ceohem_sgd_poly_mlr_seg_cls_aug_stylized
25 | #weight_name=2021_09_11_22_12_37_aachen_deeplabv3p_resnext101_32x4d_d4_u2_b16_R256_E120_ceohem_sgd_poly_mlr_seg_cls_aug_stylized
26 | #save_dir=/data/cornucopia/fx221/exp/shloc/aachen/$weight_name/loc_by_seg
27 | save_dir=/scratches/flyer_2/fx221/exp/shloc/aachen/$weight_name/loc_by_seg
28 |
29 | #query_pair="/home/mifs/fx221/Research/Code/Hierarchical-Localization/pairs/aachen_v1.1/pairs-query-netvlad50.txt"
30 | #query_pair="/data/cornucopia/fx221/localization/outputs_hloc/aachen_v1.1/pairs-query-netvlad50.txt"
31 | query_pair=datasets/aachen/pairs-query-netvlad50.txt
32 | gt_pose_fn=/scratches/flyer_2/fx221/localization/outputs_hloc/aachen_v1_1/Aachen-v1.1_hloc_superpoint_n4096_r1600+superglue_netvlad50.txt
33 |
34 |
35 | #matcher=NNM
36 | matcher=NNML
37 | #matcher=NNR
38 | retrieval_type="lrnv"
39 | #retrieval_type="lrnv256"
40 | feature_type="feat_max"
41 | global_score_th=0.95
42 | rec_th=100
43 | nv_th=50
44 | ransac_thresh=12
45 | opt_thresh=12
46 | covisibility_frame=50
47 | init_type="clu"
48 | opt_type="clurefpos"
49 | k_seg=10
50 | k_can=5
51 | k_rec=30
52 | iters=5
53 | radius=20
54 | obs_thresh=3
55 |
56 | # with opt
57 | python3 -m localization.localizer \
58 | --image_dir $image_dir \
59 | --seg_dir $seg_dir \
60 | --save_root $save_root \
61 | --gt_pose_fn $gt_pose_fn \
62 | --dataset aachen \
63 | --map_gid_rgb_fn datasets/aachen/aachen_grgb_gid_v5.txt \
64 | --db_imglist_fn datasets/aachen/aachen_db_imglist.txt \
65 | --db_instance_fn aachen_452 \
66 | --k_seg $k_seg \
67 | --k_can $k_can \
68 | --k_rec $k_rec \
69 | --retrieval $query_pair \
70 | --retrieval_type $retrieval_type \
71 | --feature_type $feature_type \
72 | --init_type $init_type \
73 | --global_score_th $global_score_th \
74 | --weight_name $weight_name \
75 | --show_seg \
76 | --reference_sfm $outputs/sfm_$feat-$matcher/model \
77 | --queries $dataset/queries/day_night_time_queries_with_intrinsics.txt \
78 | --features $outputs/feats-$feat.h5 \
79 | --matcher_method $matcher \
80 | --ransac_thresh $ransac_thresh \
81 | --with_label \
82 | --with_match \
83 | --rec_th $rec_th \
84 | --nv_th $nv_th \
85 | --covisibility_frame $covisibility_frame \
86 | --iters $iters \
87 | --radius $radius \
88 | --obs_thresh $obs_thresh \
89 | --opt_thresh $opt_thresh \
90 | --opt_type $opt_type \
91 | --do_covisible_opt
--------------------------------------------------------------------------------
/run_loc_robotcar:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | root=/scratches/flyer_2
4 |
5 | outputs=$root/fx221/localization/outputs/robotcar
6 | dataset=$root/fx221/localization/RobotCar-Seasons
7 | save_root=$root/fx221/exp/shloc/robotcar
8 |
9 | image_dir=$root/fx221/localization/RobotCar-Seasons/images
10 | weight_name=2021_10_17_00_21_32_robotcar_pspf_resnext101_32x4d_d4_u8_b16_R256_E500_ceohem_adam_poly_mlr_seg_cls_aug_stylized
11 | seg_dir=$root/fx221/exp/shloc/robotcar/$weight_name/masks
12 |
13 |
14 | query_fn=$dataset/queries_with_intrinsics_rear.txt
15 | query_pair=datasets/robotcar/pairs-query-netvlad20-percam-perloc-rear.txt
16 | gt_pose_fn=/data/cornucopia/fx221/localization/RobotCar-Seasons/3D-models/query_poses_v2.txt
17 |
18 | #feat=resnetxn-triv2-0001-n4096-r1024-mask
19 | feat=resnetxn-triv2-ms-0001-n4096-r1024-mask
20 |
21 | #matcher=NNM
22 | matcher=NNML
23 |
24 | only_gt=0
25 | feature_type="feat_max"
26 | global_score_th=0.95
27 |
28 | rec_th=200
29 | nv_th=50
30 | ransac_thresh=12
31 | opt_thresh=8
32 | covisibility_frame=20
33 | init_type="clu"
34 | retrieval_type="lrnv"
35 | opt_type="clurefpos"
36 | k_seg=10
37 | k_can=1
38 | k_rec=20
39 | iters=5
40 | radius=20
41 | obs_thresh=3
42 |
43 | python3 -m localization.localizer \
44 | --only_gt $only_gt \
45 | --retrieval_type $retrieval_type \
46 | --gt_pose_fn $gt_pose_fn \
47 | --image_dir $image_dir \
48 | --seg_dir $seg_dir \
49 | --dataset robotcar \
50 | --map_gid_rgb_fn datasets/robotcar/robotcar_grgb_gid.txt \
51 | --db_imglist_fn datasets/robotcar/robotcar_rear_db_imglist.txt \
52 | --db_instance_fn robotcar \
53 | --save_root $save_root \
54 | --k_seg $k_seg \
55 | --k_can $k_can \
56 | --k_rec $k_rec \
57 | --feature_type $feature_type \
58 | --init_type $init_type \
59 | --global_score_th $global_score_th \
60 | --weight_name $weight_name \
61 | --show_seg \
62 | --reference_sfm $outputs/sfm_$feat-$matcher/model \
63 | --queries $query_fn \
64 | --retrieval $query_pair \
65 | --features $outputs/feats-$feat.h5 \
66 | --matcher_method $matcher \
67 | --ransac_thresh $ransac_thresh \
68 | --with_label \
69 | --with_match \
70 | --rec_th $rec_th \
71 | --nv_th $nv_th \
72 | --covisibility_frame $covisibility_frame \
73 | --iters $iters \
74 | --radius $radius \
75 | --obs_thresh $obs_thresh \
76 | --opt_thresh $opt_thresh \
77 | --opt_type $opt_type \
78 | --do_covisible_opt
--------------------------------------------------------------------------------
/run_reconstruct_aachen:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | colmap=/home/mifs/fx221/Research/Software/bin/colmap
3 |
4 | dataset=/scratches/flyer_2/fx221/localization/aachen_v1_1
5 | outputs=/scratches/flyer_2/fx221/localization/outputs/aachen_v1_1
6 |
7 | #feat=resnetxn-triv2-0001-n4096-r1600-mask
8 | #feat=resnetxn-triv2-0001-n3000-r1600-mask
9 | #feat=resnetxn-triv2-0001-n2000-r1600-mask
10 | feat=resnetxn-triv2-0001-n1000-r1600-mask
11 |
12 |
13 | image_dir=$dataset/images/images_upright
14 | mask_dir=$dataset/global_seg_instance
15 |
16 | matcher=NNML
17 | #matcher=NNM
18 |
19 | extract_feat=1
20 | match_db=1
21 | triangulate=1
22 |
23 | if [ "$extract_feat" -gt "0" ]; then
24 | python3 -m localization.fine.extractor --image_dir $image_dir --export_dir $outputs/ --conf $feat --mask_dir $mask_dir
25 | fi
26 |
27 | if [ "$match_db" -gt "0" ]; then
28 | python3 -m localization.fine.matcher --pairs datasets/aachen/pairs-db-covis20.txt --export_dir $outputs/ --conf $matcher --features feats-$feat
29 | fi
30 |
31 | if [ "$triangulate" -gt "0" ]; then
32 | python3 -m localization.fine.triangulate \
33 | --sfm_dir $outputs/sfm_$feat-$matcher \
34 | --reference_sfm_model $dataset/3D-models \
35 | --image_dir $dataset/images/images_upright \
36 | --pairs datasets/aachen/pairs-db-covis20.txt \
37 | --features $outputs/feats-$feat.h5 \
38 | --matches $outputs/feats-$feat-$matcher-pairs-db-covis20.h5 \
39 | --colmap_path $colmap
40 | fi
41 |
42 |
--------------------------------------------------------------------------------
/run_reconstruct_robotcar:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | colmap=/home/mifs/fx221/Research/Software/bin/colmap
3 |
4 | root_dir=/scratch2
5 |
6 | dataset=$root_dir/fx221/localization/RobotCar-Seasons
7 | outputs=$root_dir/fx221/localization/outputs/robotcar
8 | db_pair=$root_dir/fx221/localization/outputs/robotcar/pairs-db-covis20.txt
9 |
10 | feat=resnetxn-triv2-ms-0001-n4096-r1024-mask
11 |
12 | mask_dir=$root_dir/fx221/exp/shloc/robotcar/2021_10_17_00_21_32_robotcar_pspf_resnext101_32x4d_d4_u8_b16_R256_E500_ceohem_adam_poly_mlr_seg_cls_aug_stylized/masks
13 |
14 | matcher=NNML
15 | #matcher=NNM
16 |
17 | extract_feat_db=1
18 | match_db=1
19 | triangulate=1
20 |
21 | if [ "$extract_feat" -gt "0" ]; then
22 | python3 -m localization.fine.extractor --image_list datasets/robotcar/robotcar_db_query_imglist.txt --image_dir $dataset/images --export_dir $outputs/ --conf $feat --mask_dir $mask_dir
23 | fi
24 |
25 | if [ "$match_db" -gt "0" ]; then
26 | python3 -m localization.fine.matcher --pairs $db_pair --export_dir $outputs/ --conf $matcher --features feats-$feat
27 | fi
28 |
29 | if [ "$triangulate" -gt "0" ]; then
30 | python3 -m localization.fine.triangulate \
31 | --sfm_dir $outputs/sfm_$feat-$matcher \
32 | --reference_sfm_model $ROBOTCAR/3D-models/sfm-sift \
33 | --image_dir $ROBOTCAR/images \
34 | --pairs $outputs/pairs-db-covis20.txt \
35 | --features $outputs/feats-$feat.h5 \
36 | --matches $outputs/feats-$feat-$matcher-pairs-db-covis20.h5 \
37 | --colmap_path $colmap
38 | fi
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | '''=================================================
3 | @Project -> File shloc -> test
4 | @IDE PyCharm
5 | @Author fx221@cam.ac.uk
6 | @Date 2021-05-31 15:09
7 | =================================================='''
8 | import argparse
9 | import json
10 | import torchvision.transforms as tvf
11 | import torch
12 | import torch.nn.functional as F
13 | from tqdm import tqdm
14 | import cv2
15 | import os
16 | import numpy as np
17 | import os.path as osp
18 | from net.segnet import get_segnet
19 | from tools.common import torch_set_gpu
20 | from tools.seg_tools import read_seg_map_without_group, label_to_bgr
21 |
22 | val_transform = tvf.Compose(
23 | (
24 | # tvf.ToPILImage(),
25 | # tvf.Resize(224),
26 | tvf.ToTensor(),
27 | tvf.Normalize(mean=[0.485, 0.456, 0.406],
28 | std=[0.229, 0.224, 0.225])
29 | )
30 | )
31 |
32 |
33 | def predict(net, img):
34 | img_tensor = val_transform(img)
35 | img_tensor = img_tensor.cuda().unsqueeze(0)
36 | with torch.no_grad():
37 | prediction = net(img_tensor)
38 |
39 | return prediction
40 |
41 |
42 | def inference_rec(output, img, fn, map_gid_rgb, save_dir=None):
43 | cv2.namedWindow("out", cv2.WINDOW_NORMAL)
44 | with torch.no_grad():
45 | # output = predict(net=model, img=img)
46 | pred_mask = output["masks"][0]
47 | pred_label = torch.softmax(pred_mask, dim=1).max(1)[1].cpu().numpy()
48 | pred_conf = torch.softmax(pred_mask, dim=1).max(1)[0].cpu().numpy()
49 |
50 | last_feat = output['feats'][-1]
51 | pred_feat_max = F.adaptive_max_pool2d(last_feat, output_size=(1, 1))
52 | pred_feat_avg = F.adaptive_avg_pool2d(last_feat, output_size=(1, 1))
53 |
54 | if pred_label.shape[0] == 1:
55 | pred_label = pred_label[0]
56 | pred_conf = pred_conf[0]
57 |
58 | uids = np.unique(pred_label).tolist()
59 | pred_seg = label_to_bgr(label=pred_label, maps=map_gid_rgb)
60 | pred_conf_img = np.uint8(pred_conf * 255)
61 | pred_conf_img = cv2.applyColorMap(src=pred_conf_img, colormap=cv2.COLORMAP_PARULA)
62 |
63 | H = args.R # img.shape[0]
64 | W = args.R # img.shape[1]
65 | pred_seg = cv2.resize(pred_seg, dsize=(W, H), interpolation=cv2.INTER_NEAREST)
66 | pred_conf_img = cv2.resize(pred_conf_img, dsize=(W, H), interpolation=cv2.INTER_NEAREST)
67 | img = cv2.resize(img, dsize=(W, H))
68 |
69 | img_seg = (0.5 * img + 0.5 * pred_seg).astype(np.uint8)
70 | cat_img = np.hstack([img_seg, pred_seg, pred_conf_img])
71 |
72 | cv2.imshow("out", cat_img)
73 | key = cv2.waitKey()
74 | if key in (27, ord('q')): # exit by pressing key esc or q
75 | cv2.destroyAllWindows()
76 | exit(0)
77 | # return
78 |
79 | if save_dir is not None:
80 | conf_fn = osp.join(save_dir, "confidence", fn.split('.')[0] + ".npy")
81 | mask_fn = osp.join(save_dir, "masks", fn.replace("jpg", "png"))
82 | vis_fn = osp.join(save_dir, "vis", fn.replace("jpg", "png"))
83 | if not osp.exists(osp.dirname(conf_fn)):
84 | os.makedirs(osp.dirname(conf_fn), exist_ok=True)
85 | if not osp.exists(osp.dirname(vis_fn)):
86 | os.makedirs(osp.dirname(vis_fn), exist_ok=True)
87 | if not osp.exists(osp.dirname(mask_fn)):
88 | os.makedirs(osp.dirname(mask_fn), exist_ok=True)
89 |
90 | pred_confidence, pred_ids = torch.topk(torch.softmax(pred_mask, dim=1), k=10, largest=True, dim=1)
91 | conf_data = {"confidence": pred_confidence[0].cpu().numpy(),
92 | "ids": pred_ids[0].cpu().numpy(),
93 | 'feat_max': pred_feat_max.squeeze().cpu().numpy(),
94 | 'feat_avg': pred_feat_avg.squeeze().cpu().numpy(),
95 | }
96 |
97 | np.save(conf_fn, conf_data)
98 | cv2.imwrite(vis_fn, cat_img)
99 | cv2.imwrite(mask_fn, pred_seg)
100 |
101 |
102 | def main(args):
103 | map_gid_rgb = read_seg_map_without_group(args.grgb_gid_file)
104 |
105 | model = get_segnet(network=args.network,
106 | n_classes=args.classes,
107 | encoder_name=args.encoder_name,
108 | encoder_weights=args.encoder_weights,
109 | encoder_depth=args.encoder_depth,
110 | upsampling=args.upsampling,
111 | out_channels=args.out_channels,
112 | classification=args.classification,
113 | segmentation=args.segmentation, )
114 | print("model: ", model)
115 | if args.pretrained_weight is not None:
116 | model.load_state_dict(torch.load(args.pretrained_weight), strict=True)
117 | print("Load weight from {:s}".format(args.pretrained_weight))
118 | model.eval().cuda()
119 |
120 | img_path = args.image_path
121 | save_dir = args.save_dir
122 |
123 | print('Save results to ', save_dir)
124 |
125 | imglist = []
126 | with open(args.image_list, "r") as f:
127 | lines = f.readlines()
128 | for l in lines:
129 | imglist.append(l.strip())
130 |
131 | for fn in tqdm(imglist, total=len(imglist)):
132 | if fn.find('left') >= 0 or fn.find('right') >= 0:
133 | continue
134 | img = cv2.imread(osp.join(img_path, fn))
135 | img = cv2.resize(img, dsize=(args.R, args.R))
136 |
137 | with torch.no_grad():
138 | output = predict(net=model, img=img)
139 | inference_rec(output=output, img=img, fn=fn, map_gid_rgb=map_gid_rgb, save_dir=save_dir)
140 |
141 |
142 | if __name__ == '__main__':
143 | parser = argparse.ArgumentParser("Test Semantic localization Network")
144 | parser.add_argument("--config", type=str, required=True, help="configuration file")
145 | parser.add_argument("--pretrained_weight", type=str, default=None)
146 | parser.add_argument("--save_root", type=str, default="/home/mifs/fx221/fx221/exp/shloc/aachen")
147 | parser.add_argument("--image_path", type=str, default=None)
148 | parser.add_argument("--network", type=str, default="pspf")
149 | parser.add_argument("--save_dir", type=str, default=None)
150 | parser.add_argument("--encoder_name", type=str, default='timm-resnest50d')
151 | parser.add_argument("--encoder_weights", type=str, default='imagenet')
152 | parser.add_argument("--out_channels", type=int, default='2048')
153 | parser.add_argument("--upsampling", type=int, default='8')
154 | parser.add_argument("--gpu", type=int, nargs='+', default=[0], help='-1 for CPU')
155 | parser.add_argument("--R", type=int, default=256)
156 |
157 | args = parser.parse_args()
158 | with open(args.config, 'rt') as f:
159 | t_args = argparse.Namespace()
160 | t_args.__dict__.update(json.load(f))
161 | args = parser.parse_args(namespace=t_args)
162 |
163 | torch_set_gpu(gpus=args.gpu)
164 | main(args=args)
165 |
--------------------------------------------------------------------------------
/test_aachen:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python3 test.py --config configs/config_test_aachen.json
--------------------------------------------------------------------------------
/test_robotcar:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python3 test.py --config configs/config_test_robotcar.json
--------------------------------------------------------------------------------
/tools/common.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @Time : 2021/3/15 下午3:54
4 | @Auth : Fei Xue
5 | @File : common.py
6 | @Email: fx221@cam.ac.uk
7 | """
8 |
9 | import numpy as np
10 | import torch
11 | import json
12 | from collections import OrderedDict
13 | import cv2
14 | from torch._six import string_classes
15 | import collections.abc as collections
16 | import os
17 | import os.path as osp
18 |
19 |
20 | def get_recursive_file_list(root_dir, sub_dir="", patterns=[]):
21 | current_files = os.listdir(osp.join(root_dir, sub_dir))
22 | all_files = []
23 | for file_name in current_files:
24 | full_file_name = os.path.join(root_dir, sub_dir, file_name)
25 | print(file_name)
26 |
27 | if file_name.split('.')[-1] in patterns:
28 | all_files.append(osp.join(sub_dir, file_name))
29 |
30 | if os.path.isdir(full_file_name):
31 | next_level_files = get_recursive_file_list(root_dir, sub_dir=osp.join(sub_dir, file_name),
32 | patterns=patterns)
33 | all_files.extend(next_level_files)
34 |
35 | return all_files
36 |
37 |
38 | def sort_dict_by_value(data, reverse=False):
39 | return sorted(data.items(), key=lambda d: d[1], reverse=reverse)
40 |
41 |
42 | def mkdir_for(file_path):
43 | os.makedirs(os.path.split(file_path)[0], exist_ok=True)
44 |
45 |
46 | def model_size(model):
47 | ''' Computes the number of parameters of the model
48 | '''
49 | size = 0
50 | for weights in model.state_dict().values():
51 | size += np.prod(weights.shape)
52 | return size
53 |
54 |
55 | def torch_set_gpu(gpus):
56 | if type(gpus) is int:
57 | gpus = [gpus]
58 |
59 | cuda = all(gpu >= 0 for gpu in gpus)
60 |
61 | if cuda:
62 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu) for gpu in gpus])
63 | assert cuda and torch.cuda.is_available(), "%s has GPUs %s unavailable" % (
64 | os.environ['HOSTNAME'], os.environ['CUDA_VISIBLE_DEVICES'])
65 | torch.backends.cudnn.benchmark = True # speed-up cudnn
66 | torch.backends.cudnn.fastest = True # even more speed-up?
67 | print('Launching on GPUs ' + os.environ['CUDA_VISIBLE_DEVICES'])
68 |
69 | else:
70 | print('Launching on CPU')
71 |
72 | return cuda
73 |
74 |
75 | def save_args(args, save_path):
76 | with open(save_path, 'w') as f:
77 | json.dump(args.__dict__, f, indent=2)
78 |
79 |
80 | def load_args(args, save_path):
81 | with open(save_path, "r") as f:
82 | args.__dict__ = json.load(f)
83 |
84 |
85 | def adjust_learning_rate_poly(optimizer, epoch, num_epochs, base_lr, power):
86 | lr = base_lr * (1 - epoch / num_epochs) ** power
87 | for param_group in optimizer.param_groups:
88 | param_group['lr'] = lr
89 | return lr
90 |
91 |
92 | def read_json(fname):
93 | with fname.open('rt') as handle:
94 | return json.load(handle, object_hook=OrderedDict)
95 |
96 |
97 | def write_json(content, fname):
98 | with fname.open('wt') as handle:
99 | json.dump(content, handle, indent=4, sort_keys=False)
100 |
101 |
102 | def resize_img(img, nh=-1, nw=-1, mode=cv2.INTER_NEAREST):
103 | assert nh > 0 or nw > 0
104 | if nh == -1:
105 | return cv2.resize(img, dsize=(nw, int(img.shape[0] / img.shape[1] * nw)), interpolation=mode)
106 | if nw == -1:
107 | return cv2.resize(img, dsize=(int(img.shape[1] / img.shape[0] * nh), nh), interpolation=mode)
108 | return cv2.resize(img, dsize=(nw, nh), interpolation=mode)
109 |
110 |
111 | def map_tensor(input_, func):
112 | if isinstance(input_, torch.Tensor):
113 | return func(input_)
114 | elif isinstance(input_, string_classes):
115 | return input_
116 | elif isinstance(input_, collections.Mapping):
117 | return {k: map_tensor(sample, func) for k, sample in input_.items()}
118 | elif isinstance(input_, collections.Sequence):
119 | return [map_tensor(sample, func) for sample in input_]
120 | else:
121 | raise TypeError(
122 | f'input must be tensor, dict or list; found {type(input_)}')
123 |
124 |
125 | def imgs2video(im_dir, video_dir):
126 | img_fns = os.listdir(im_dir)
127 | # print(img_fns)
128 | img_fns = [v for v in img_fns if v.split('.')[-1] in ['jpg', 'png']]
129 | img_fns = sorted(img_fns)
130 | # print(img_fns)
131 | fps = 1
132 | img_size = (800, 492)
133 |
134 | # fourcc = cv2.cv.CV_FOURCC('M','J','P','G')#opencv2.4
135 | # fourcc = cv2.VideoWriter_fourcc('I','4','2','0')
136 |
137 | fourcc = cv2.VideoWriter_fourcc(*'MP4V')
138 | # fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V')
139 | videoWriter = cv2.VideoWriter(video_dir, fourcc, fps, img_size)
140 |
141 | for i in range(0, 500):
142 | # fn = img_fns[i].split('-')
143 | im_name = os.path.join(im_dir, img_fns[i])
144 | print(im_name)
145 | frame = cv2.imread(im_name, 1)
146 | frame = cv2.resize(frame, dsize=img_size)
147 | # print(frame.shape)
148 | # exit(0)
149 | cv2.imshow("frame", frame)
150 | cv2.waitKey(5)
151 | videoWriter.write(frame)
152 |
153 | videoWriter.release()
154 |
--------------------------------------------------------------------------------
/tools/config_parser.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | '''=================================================
3 | @Project -> File shloc -> config_parser
4 | @IDE PyCharm
5 | @Author fx221@cam.ac.uk
6 | @Date 20/07/2021 10:51
7 | =================================================='''
8 |
9 | import logging
10 | from pathlib import Path
11 | from functools import reduce
12 | from operator import getitem
13 | from datetime import datetime
14 | from tools.common import read_json, write_json, torch_set_gpu
15 |
16 |
17 | class ConfigParser:
18 | def __init__(self, args, options='', timestamp=True, test_only=False):
19 | # parse default and custom cli options
20 | for opt in options:
21 | args.add_argument(*opt.flags, default=None, type=opt.type)
22 | args = args.parse_args()
23 |
24 | # if args.device:
25 | # os.environ["CUDA_VISIBLE_DEVICES"] = args.device
26 | torch_set_gpu(gpus=args.gpu)
27 |
28 | self.cfg_fname = Path(args.config)
29 | # load config file and apply custom cli options
30 | config = read_json(self.cfg_fname)
31 | self.__config = _update_config(config, options, args)
32 |
33 | # set save_dir where trained model and log will be saved.
34 | save_dir = Path(self.config['trainer']['save_dir'])
35 | timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else ''
36 | self.timestamp = timestamp
37 |
38 | exper_name = self.config['name']
39 | self.__save_dir = save_dir / 'models' / exper_name / timestamp
40 | self.__log_dir = save_dir / 'log' / exper_name / timestamp
41 |
42 | self.save_dir.mkdir(parents=True, exist_ok=True)
43 | self.log_dir.mkdir(parents=True, exist_ok=True)
44 |
45 | if not test_only:
46 | # save updated config file to the checkpoint dir
47 | write_json(self.config, self.save_dir / 'config.json')
48 | else:
49 | write_json(self.config, self.resume.parent / 'config_test.json')
50 |
51 | def initialize(self, name, module, *args, **kwargs):
52 | """
53 | finds a function handle with the name given as 'type' in config, and returns the
54 | instance initialized with corresponding keyword args given as 'args'.
55 | """
56 | module_cfg = self[name]
57 | module_cfg["args"].update(kwargs)
58 | return getattr(module, module_cfg['type'])(*args, **module_cfg['args'])
59 |
60 | def __getitem__(self, name):
61 | return self.config[name]
62 |
63 | def get_logger(self, name, verbosity=2):
64 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity,
65 | self.log_levels.keys())
66 | assert verbosity in self.log_levels, msg_verbosity
67 | logger = logging.getLogger(name)
68 | logger.setLevel(self.log_levels[verbosity])
69 | return logger
70 |
71 | # setting read-only attributes
72 | @property
73 | def config(self):
74 | return self.__config
75 |
76 | @property
77 | def save_dir(self):
78 | return self.__save_dir
79 |
80 | @property
81 | def log_dir(self):
82 | return self.__log_dir
83 |
84 |
85 | # helper functions used to update config dict with custom cli options
86 | def _update_config(config, options, args):
87 | for opt in options:
88 | value = getattr(args, _get_opt_name(opt.flags))
89 | if value is not None:
90 | _set_by_path(config, opt.target, value)
91 | return config
92 |
93 |
94 | def _get_opt_name(flags):
95 | for flg in flags:
96 | if flg.startswith('--'):
97 | return flg.replace('--', '')
98 | return flags[0].replace('--', '')
99 |
100 |
101 | def _set_by_path(tree, keys, value):
102 | """Set a value in a nested object in tree by sequence of keys."""
103 | _get_by_path(tree, keys[:-1])[keys[-1]] = value
104 |
105 |
106 | def _get_by_path(tree, keys):
107 | """Access a nested object in tree by sequence of keys."""
108 | return reduce(getitem, keys, tree)
109 |
--------------------------------------------------------------------------------
/tools/loc_tools.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | '''=================================================
3 | @Project -> File lbr -> loc_tools
4 | @IDE PyCharm
5 | @Author fx221@cam.ac.uk
6 | @Date 06/04/2022 15:08
7 | =================================================='''
8 |
9 | import numpy as np
10 | import cv2
11 | import torch
12 | from copy import copy
13 | from scipy.spatial.transform import Rotation as sciR
14 |
15 |
16 | def plot_keypoint(img_path, pts, scores=None, tag=None, save_path=None):
17 | if type(img_path) == str:
18 | img = cv2.imread(img_path)
19 | else:
20 | img = img_path.copy()
21 |
22 | img_out = img.copy()
23 | print(img.shape)
24 | r = 3
25 | for i in range(pts.shape[0]):
26 | pt = pts[i]
27 | # img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1])), int(r * s), (0, 0, 255), 4)
28 | img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1])), r, (0, 0, 255), 2)
29 |
30 | if save_path is not None:
31 | cv2.imwrite(save_path, img_out)
32 | return img_out
33 |
34 |
35 | def sort_dict_by_value(data, reverse=False):
36 | return sorted(data.items(), key=lambda d: d[1], reverse=reverse)
37 |
38 |
39 | def read_retrieval_results(path):
40 | output = {}
41 | with open(path, "r") as f:
42 | lines = f.readlines()
43 | for p in lines:
44 | p = p.strip("\n").split(" ")
45 |
46 | if p[1] == "no_match":
47 | continue
48 | if p[0] in output.keys():
49 | output[p[0]].append(p[1])
50 | else:
51 | output[p[0]] = [p[1]]
52 | return output
53 |
54 |
55 | def nn_k(query_gps, db_gps, k=20):
56 | q = torch.from_numpy(query_gps) # [N 2]
57 | db = torch.from_numpy(db_gps) # [M, 2]
58 | # print (q.shape, db.shape)
59 | dist = q.unsqueeze(2) - db.t().unsqueeze(0)
60 | dist = dist[:, 0, :] ** 2 + dist[:, 1, :] ** 2
61 | print("dist: ", dist.shape)
62 | topk = torch.topk(dist, dim=1, k=k, largest=False)[1]
63 | return topk
64 |
65 |
66 | def plot_matches(img1, img2, pts1, pts2, inliers, horizon=False, plot_outlier=False, confs=None, plot_match=True):
67 | rows1 = img1.shape[0]
68 | cols1 = img1.shape[1]
69 | rows2 = img2.shape[0]
70 | cols2 = img2.shape[1]
71 | r = 3
72 | if horizon:
73 | img_out = np.zeros((max([rows1, rows2]), cols1 + cols2, 3), dtype='uint8')
74 | # Place the first image to the left
75 | img_out[:rows1, :cols1] = img1
76 | # Place the next image to the right of it
77 | img_out[:rows2, cols1:] = img2 # np.dstack([img2, img2, img2])
78 |
79 | if not plot_match:
80 | return cv2.resize(img_out, None, fx=0.5, fy=0.5)
81 | # for idx, pt in enumerate(pts1):
82 | # img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1])), r, (0, 0, 255), 2)
83 | # for idx, pt in enumerate(pts2):
84 | # img_out = cv2.circle(img_out, (int(pt[0] + cols1), int(pt[1])), r, (0, 0, 255), 2)
85 | for idx in range(inliers.shape[0]):
86 | # if idx % 10 > 0:
87 | # continue
88 | if inliers[idx]:
89 | color = (0, 255, 0)
90 | else:
91 | if not plot_outlier:
92 | continue
93 | color = (0, 0, 255)
94 | pt1 = pts1[idx]
95 | pt2 = pts2[idx]
96 |
97 | if confs is not None:
98 | nr = int(r * confs[idx])
99 | else:
100 | nr = r
101 | img_out = cv2.circle(img_out, (int(pt1[0]), int(pt1[1])), nr, color, 2)
102 |
103 | img_out = cv2.circle(img_out, (int(pt2[0]) + cols1, int(pt2[1])), nr, color, 2)
104 |
105 | img_out = cv2.line(img_out, (int(pt1[0]), int(pt1[1])), (int(pt2[0]) + cols1, int(pt2[1])), color,
106 | 2)
107 | else:
108 | img_out = np.zeros((rows1 + rows2, max([cols1, cols2]), 3), dtype='uint8')
109 | # Place the first image to the left
110 | img_out[:rows1, :cols1] = img1
111 | # Place the next image to the right of it
112 | img_out[rows1:, :cols2] = img2 # np.dstack([img2, img2, img2])
113 |
114 | if not plot_match:
115 | return cv2.resize(img_out, None, fx=0.5, fy=0.5)
116 | # for idx, pt in enumerate(pts1):
117 | # img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1])), r, (0, 0, 255), 2)
118 | # for idx, pt in enumerate(pts2):
119 | # img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1] + rows1)), r, (0, 0, 255), 2)
120 | for idx in range(inliers.shape[0]):
121 | # print("idx: ", inliers[idx])
122 | # if idx % 10 > 0:
123 | # continue
124 | if inliers[idx]:
125 | color = (0, 255, 0)
126 | else:
127 | if not plot_outlier:
128 | continue
129 | color = (0, 0, 255)
130 |
131 | if confs is not None:
132 | nr = int(r * confs[idx])
133 | else:
134 | nr = r
135 |
136 | pt1 = pts1[idx]
137 | pt2 = pts2[idx]
138 | img_out = cv2.circle(img_out, (int(pt1[0]), int(pt1[1])), r, color, 2)
139 |
140 | img_out = cv2.circle(img_out, (int(pt2[0]), int(pt2[1]) + rows1), r, color, 2)
141 |
142 | img_out = cv2.line(img_out, (int(pt1[0]), int(pt1[1])), (int(pt2[0]), int(pt2[1]) + rows1), color,
143 | 2)
144 |
145 | img_rs = cv2.resize(img_out, None, fx=0.5, fy=0.5)
146 |
147 | # img_rs = cv2.putText(img_rs, 'matches: {:d}'.format(len(inliers.shape[0])), (50, 50), cv2.FONT_HERSHEY_COMPLEX, 2,
148 | # (0, 0, 255), 2)
149 |
150 | # if save_fn is not None:
151 | # cv2.imwrite(save_fn, img_rs)
152 | # cv2.imshow("match", img_rs)
153 | # cv2.waitKey(0)
154 | return img_rs
155 |
156 |
157 | def plot_reprojpoint2D(img, points2D, reproj_points2D, confs=None):
158 | img_out = copy(img)
159 | r = 5
160 | for i in range(points2D.shape[0]):
161 | p = points2D[i]
162 | rp = reproj_points2D[i]
163 |
164 | if confs is not None:
165 | nr = int(r * confs[i])
166 | else:
167 | nr = r
168 |
169 | if nr >= 50:
170 | nr = 50
171 | # img_out = cv2.circle(img_out, (int(p[0]), int(p[1])), nr, color=(0, 255, 0), thickness=2)
172 | img_out = cv2.circle(img_out, (int(rp[0]), int(rp[1])), nr, color=(0, 0, 255), thickness=3)
173 | img_out = cv2.circle(img_out, (int(rp[0]), int(rp[1])), 2, color=(0, 0, 255), thickness=3)
174 | # img_out = cv2.line(img_out, pt1=(int(p[0]), int(p[1])), pt2=(int(rp[0]), int(rp[1])), color=(0, 0, 255),
175 | # thickness=2)
176 |
177 | return img_out
178 |
179 |
180 | def reproject_fromR(points3D, rot, tvec, camera):
181 | proj_2d = rot @ points3D.transpose() + tvec.reshape(3, 1)
182 |
183 | if camera['model'] == 'SIMPLE_RADIAL':
184 | f = camera['params'][0]
185 | cx = camera['params'][1]
186 | cy = camera['params'][2]
187 | k = camera['params'][3]
188 |
189 | proj_2d = proj_2d[0:2, :] / proj_2d[2, :]
190 | r2 = proj_2d[0, :] ** 2 + proj_2d[1, :] ** 2
191 | radial = r2 * k
192 | du = proj_2d[0, :] * radial
193 | dv = proj_2d[1, :] * radial
194 |
195 | u = proj_2d[0, :] + du
196 | v = proj_2d[1, :] + dv
197 | u = u * f + cx
198 | v = v * f + cy
199 | uvs = np.vstack([u, v]).transpose()
200 |
201 | return uvs
202 |
203 |
204 | def calc_depth(points3D, rvec, tvec, camera):
205 | rot = sciR.from_quat(quat=[rvec[1], rvec[2], rvec[3], rvec[0]]).as_dcm()
206 | # print('p3d: ', points3D.shape, rot.shape, rot)
207 | proj_2d = rot @ points3D.transpose() + tvec.reshape(3, 1)
208 |
209 | return proj_2d.transpose()[:, 2]
210 |
211 |
212 | def reproject(points3D, rvec, tvec, camera):
213 | '''
214 | Args:
215 | points3D: [N, 3]
216 | rvec: [w, x, y, z]
217 | tvec: [x, y, z]
218 | Returns:
219 | '''
220 | # print('camera: ', camera)
221 | # print('p3d: ', points3D.shape)
222 | rot = sciR.from_quat(quat=[rvec[1], rvec[2], rvec[3], rvec[0]]).as_dcm()
223 | # print('p3d: ', points3D.shape, rot.shape, rot)
224 | proj_2d = rot @ points3D.transpose() + tvec.reshape(3, 1)
225 |
226 | if camera['model'] == 'SIMPLE_RADIAL':
227 | f = camera['params'][0]
228 | cx = camera['params'][1]
229 | cy = camera['params'][2]
230 | k = camera['params'][3]
231 |
232 | proj_2d = proj_2d[0:2, :] / proj_2d[2, :]
233 | r2 = proj_2d[0, :] ** 2 + proj_2d[1, :] ** 2
234 | radial = r2 * k
235 | du = proj_2d[0, :] * radial
236 | dv = proj_2d[1, :] * radial
237 |
238 | u = proj_2d[0, :] + du
239 | v = proj_2d[1, :] + dv
240 | u = u * f + cx
241 | v = v * f + cy
242 | uvs = np.vstack([u, v]).transpose()
243 |
244 | return uvs
245 |
246 |
247 | def quaternion_angular_error(q1, q2):
248 | """
249 | angular error between two quaternions
250 | :param q1: (4, )
251 | :param q2: (4, )
252 | :return:
253 | """
254 | d = abs(np.dot(q1, q2))
255 | d = min(1.0, max(-1.0, d))
256 | theta = 2 * np.arccos(d) * 180 / np.pi
257 | return theta
258 |
259 |
260 | def ColmapQ2R(qvec):
261 | rot = sciR.from_quat(quat=[qvec[1], qvec[2], qvec[3], qvec[0]]).as_dcm()
262 | return rot
263 |
264 |
265 | def compute_pose_error(pred_qcw, pred_tcw, gt_qcw, gt_tcw):
266 | pred_Rcw = sciR.from_quat(quat=[pred_qcw[1], pred_qcw[2], pred_qcw[3], pred_qcw[0]]).as_dcm()
267 | pred_tcw = np.array(pred_tcw, float).reshape(3, 1)
268 | pred_Rwc = pred_Rcw.transpose()
269 | pred_twc = -pred_Rcw.transpose() @ pred_tcw
270 |
271 | gt_Rcw = sciR.from_quat(quat=[gt_qcw[1], gt_qcw[2], gt_qcw[3], gt_qcw[0]]).as_dcm()
272 | gt_tcw = np.array(gt_tcw, float).reshape(3, 1)
273 | gt_Rwc = gt_Rcw.transpose()
274 | gt_twc = -gt_Rcw.transpose() @ gt_tcw
275 |
276 | t_error_xyz = pred_twc - gt_twc
277 | t_error = np.sqrt(np.sum(t_error_xyz ** 2))
278 |
279 | pred_qwc = sciR.from_quat(quat=[pred_qcw[1], pred_qcw[2], pred_qcw[3], pred_qcw[0]]).as_quat()
280 | gt_qwc = sciR.from_quat(quat=[gt_qcw[1], gt_qcw[2], gt_qcw[3], gt_qcw[0]]).as_quat()
281 |
282 | q_error = quaternion_angular_error(q1=pred_qwc, q2=gt_qwc)
283 |
284 | return q_error, t_error, (t_error_xyz[0, 0], t_error_xyz[1, 0], t_error_xyz[2, 0])
285 |
--------------------------------------------------------------------------------
/tools/optim.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | '''=================================================
3 | @Project -> File shloc -> optim
4 | @IDE PyCharm
5 | @Author fx221@cam.ac.uk
6 | @Date 08/09/2021 12:06
7 | =================================================='''
8 | from torch.optim.lr_scheduler import _LRScheduler, StepLR
9 |
10 |
11 | # class PolyLR(_LRScheduler):
12 | # def __init__(self, optimizer, max_iters, power=0.9, last_epoch=-1, min_lr=1e-6):
13 | # self.power = power
14 | # self.max_iters = max_iters # avoid zero lr
15 | # self.min_lr = min_lr
16 | # super(PolyLR, self).__init__(optimizer, last_epoch)
17 | #
18 | # def get_lr(self):
19 | # return [max(base_lr * (1 - self.last_epoch / self.max_iters) ** self.power, self.min_lr)]
20 |
21 |
22 | class PolyLR(_LRScheduler):
23 | """Polynomial learning rate decay until step reach to max_decay_step
24 |
25 | Args:
26 | optimizer (Optimizer): Wrapped optimizer.
27 | max_decay_steps: after this step, we stop decreasing learning rate
28 | end_learning_rate: scheduler stoping learning rate decay, value of learning rate must be this value
29 | power: The power of the polynomial.
30 | """
31 |
32 | def __init__(self, optimizer, max_decay_steps, end_learning_rate=0.0001, power=1.0):
33 | if max_decay_steps <= 1.:
34 | raise ValueError('max_decay_steps should be greater than 1.')
35 | self.max_decay_steps = max_decay_steps
36 | self.end_learning_rate = end_learning_rate
37 | self.power = power
38 | self.last_step = 0
39 | super().__init__(optimizer)
40 |
41 | def get_lr(self):
42 | if self.last_step > self.max_decay_steps:
43 | return [self.end_learning_rate for _ in self.base_lrs]
44 |
45 | return [(base_lr - self.end_learning_rate) *
46 | ((1 - self.last_step / self.max_decay_steps) ** (self.power)) +
47 | self.end_learning_rate for base_lr in self.base_lrs]
48 |
49 | def step(self, step=None):
50 | if step is None:
51 | step = self.last_step + 1
52 | self.last_step = step if step != 0 else 1
53 | if self.last_step <= self.max_decay_steps:
54 | decay_lrs = [(base_lr - self.end_learning_rate) *
55 | ((1 - self.last_step / self.max_decay_steps) ** (self.power)) +
56 | self.end_learning_rate for base_lr in self.base_lrs]
57 | for param_group, lr in zip(self.optimizer.param_groups, decay_lrs):
58 | param_group['lr'] = lr
--------------------------------------------------------------------------------
/tools/seg_tools.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | '''=================================================
3 | @Project -> File shloc -> seg_tools
4 | @IDE PyCharm
5 | @Author fx221@cam.ac.uk
6 | @Date 30/04/2021 15:47
7 | =================================================='''
8 |
9 | import numpy as np
10 |
11 |
12 | def seg_to_rgb(seg, maps):
13 | """
14 | Args:
15 | seg: [C, H, W]
16 | maps: [id, R, G, B]
17 | Returns: [H, W, R/G/B]
18 | """
19 | pred_label = seg.max(0).cpu().numpy()
20 | output = np.zeros(shape=(pred_label.shape[0], pred_label.shape[1], 3), dtype=np.uint8)
21 |
22 | for label in maps.keys():
23 | rgb = maps[label]
24 | output[pred_label == label] = np.uint8(rgb)
25 | return output
26 |
27 |
28 | def label_to_rgb(label, maps):
29 | output = np.zeros(shape=(label.shape[0], label.shape[1], 3), dtype=np.uint8)
30 | for k in maps.keys():
31 | rgb = maps[k]
32 | if type(rgb) is int:
33 | b = rgb % 256
34 | r = rgb // (256 * 256)
35 | g = (rgb - r * 256 * 256) // 256
36 | rgb = np.array([r, g, b], np.uint8)
37 | # bgr = np.array([b, g, r], np.uint8)
38 | output[label == k] = np.uint8(rgb)
39 | return output
40 |
41 |
42 | def label_to_bgr(label, maps):
43 | output = np.zeros(shape=(label.shape[0], label.shape[1], 3), dtype=np.uint8)
44 | for k in maps.keys():
45 | rgb = maps[k]
46 | if type(rgb) is int:
47 | b = rgb % 256
48 | r = rgb // (256 * 256)
49 | g = (rgb - r * 256 * 256) // 256
50 | # rgb = np.array([r, g, b], np.uint8)
51 | bgr = np.array([b, g, r], np.uint8)
52 | output[label == k] = np.uint8(bgr)
53 | return output
54 |
55 |
56 | def rgb_to_bgr(img):
57 | out = np.zeros_like(img)
58 | out[:, :, 0] = img[:, :, 2]
59 | out[:, :, 1] = img[:, :, 1]
60 | out[:, :, 2] = img[:, :, 0]
61 | return out
62 |
63 |
64 | def read_seg_map(path):
65 | map = {}
66 | with open(path, "r") as f:
67 | lines = f.readlines()
68 | for l in lines:
69 | l = l.strip("\n").split(' ')
70 | map[int(l[1])] = np.array([np.uint8(l[2]), np.uint8(l[3]), np.uint8(l[4])], np.uint8)
71 |
72 | return map
73 |
74 |
75 | def read_seg_map_with_group(path):
76 | map = {}
77 | with open(path, "r") as f:
78 | lines = f.readlines()
79 | for l in lines:
80 | l = l.strip().split(" ")
81 | rid = int(l[0])
82 | grgb = int(l[1])
83 | gid = int(l[2])
84 | if rid in map.keys():
85 | map[rid][gid] = grgb
86 | else:
87 | map[rid] = {}
88 | map[rid][gid] = grgb
89 | return map
90 |
91 |
92 | def read_seg_map_without_group(path):
93 | map = {}
94 | with open(path, "r") as f:
95 | lines = f.readlines()
96 | for l in lines:
97 | l = l.strip().split(" ")
98 | grgb = int(l[0])
99 | gid = int(l[1])
100 | map[gid] = grgb
101 | return map
102 |
103 |
104 | ## code is from https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py
105 | def _fast_hist(label_true, label_pred, n_class):
106 | mask = (label_true >= 0) & (label_true < n_class)
107 | hist = np.bincount(
108 | n_class * label_true[mask].astype(int) +
109 | label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
110 | return hist
111 |
112 |
113 | def label_accuracy_score(label_trues, label_preds, n_class):
114 | """Returns accuracy score evaluation result.
115 | - overall accuracy
116 | - mean accuracy
117 | - mean IU
118 | - fwavacc
119 | """
120 | hist = np.zeros((n_class, n_class))
121 | for lt, lp in zip(label_trues, label_preds):
122 | hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
123 | acc = np.diag(hist).sum() / hist.sum()
124 | with np.errstate(divide='ignore', invalid='ignore'):
125 | acc_cls = np.diag(hist) / hist.sum(axis=1)
126 | acc_cls = np.nanmean(acc_cls)
127 | with np.errstate(divide='ignore', invalid='ignore'):
128 | iu = np.diag(hist) / (
129 | hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)
130 | )
131 | mean_iu = np.nanmean(iu)
132 | freq = hist.sum(axis=1) / hist.sum()
133 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
134 | return acc, acc_cls, mean_iu, fwavacc
135 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | '''=================================================
3 | @Project -> File shloc -> train
4 | @IDE PyCharm
5 | @Author fx221@cam.ac.uk
6 | @Date 20/07/2021 10:41
7 | =================================================='''
8 | import argparse
9 | import json
10 | from tools.seg_tools import read_seg_map_without_group
11 | from dataloader.robotcar import RobotCarSegFull
12 | from dataloader.aachen import AachenSegFull
13 | import os.path as osp
14 | import torch
15 | import torch.utils.data as Data
16 | from dataloader.augmentation import ToPILImage, RandomRotation, \
17 | RandomSizedCrop, RandomHorizontalFlip, \
18 | ToNumpy
19 | from net.segnet import get_segnet
20 | from loss.seg_loss.segloss import SegLoss
21 | from trainer_recog import RecogTrainer
22 | from tools.common import torch_set_gpu
23 |
24 | import torchvision.transforms as tvf
25 |
26 |
27 | def get_train_val_loader(args):
28 | train_transform = tvf.Compose(
29 | (
30 | tvf.ToTensor(),
31 | # tvf.ColorJitter(0.25, 0.25, 0.25, 0.15),
32 | tvf.ColorJitter(0.25, 0.25, 0.25, 0.15),
33 | tvf.Normalize(mean=[0.485, 0.456, 0.406],
34 | std=[0.229, 0.224, 0.225])
35 | )
36 | )
37 | val_transform = tvf.Compose(
38 | (
39 | tvf.ToTensor(),
40 | tvf.Normalize(mean=[0.485, 0.456, 0.406],
41 | std=[0.229, 0.224, 0.225])
42 | )
43 | )
44 |
45 | if args.dataset == "robotcar":
46 | grgb_gid_file = "./datasets/robotcar/robotcar_rear_grgb_gid.txt"
47 | map_gid_rgb = read_seg_map_without_group(grgb_gid_file)
48 | train_imglist = "./datasets/robotcar/robotcar_rear_train_file_list.txt"
49 | test_imglist = "./datasets/robotcar/robotcar_rear_test_file_list.txt"
50 |
51 | if args.aug:
52 | aug = [
53 | ToPILImage(),
54 | RandomRotation(degree=30),
55 | RandomSizedCrop(size=256),
56 | RandomHorizontalFlip(),
57 | ToNumpy(),
58 | ]
59 | else:
60 | aug = None
61 | trainset = RobotCarSegFull(image_path=osp.join(args.root, args.train_image_path),
62 | label_path=osp.join(args.root, args.train_label_path),
63 | n_classes=args.classes,
64 | transform=train_transform,
65 | grgb_gid_file=grgb_gid_file,
66 | use_cls=True,
67 | img_list=train_imglist,
68 | preload=False,
69 | aug=aug,
70 | train=True,
71 | cats=["overcast-reference", "night", "night-rain",
72 | "dusk", "dawn", "overcast-summer", "overcast-winter", "sun"]
73 | )
74 | if args.val > 0:
75 | valset = RobotCarSegFull(image_path=osp.join(args.root, args.train_image_path),
76 | label_path=osp.join(args.root, args.train_label_path),
77 | cats=["overcast-reference", "night", "night-rain",
78 | "dusk", "dawn", "overcast-summer", "overcast-winter", "sun"],
79 | n_classes=args.classes,
80 | transform=train_transform,
81 | grgb_gid_file=grgb_gid_file,
82 | use_cls=True,
83 | img_list=test_imglist,
84 | preload=False,
85 | train=False)
86 | elif args.dataset == "aachen":
87 | grgb_gid_file = args.grgb_gid_file
88 | train_imglist = args.train_imglist
89 | test_imglist = args.test_imglist
90 | map_gid_rgb = read_seg_map_without_group(grgb_gid_file)
91 |
92 | if args.aug:
93 | aug = [
94 | # RandomGaussianBlur(), # worse results, don't do it?
95 | ToPILImage(),
96 | # Resize(size=512),
97 | # RandomScale(low=0.5, high=2.0),
98 | # RandomCrop(size=256),
99 | RandomSizedCrop(size=args.R),
100 | RandomRotation(degree=45),
101 | RandomHorizontalFlip(),
102 | ToNumpy(),
103 | ]
104 | else:
105 | aug = None
106 | trainset = AachenSegFull(image_path=osp.join(args.root, args.train_image_path),
107 | label_path=osp.join(args.root, args.train_label_path),
108 | n_classes=args.classes,
109 | transform=train_transform,
110 | grgb_gid_file=grgb_gid_file,
111 | use_cls=True,
112 | img_list=train_imglist,
113 | preload=False,
114 | aug=aug,
115 | train=True,
116 | cats=args.train_cats,
117 | )
118 | if args.val > 0:
119 | valset = AachenSegFull(image_path=osp.join(args.root, args.train_image_path),
120 | label_path=osp.join(args.root, args.train_label_path),
121 | n_classes=args.classes,
122 | transform=val_transform,
123 | grgb_gid_file=grgb_gid_file,
124 | use_cls=True,
125 | img_list=test_imglist,
126 | preload=False,
127 | cats=args.val_cats,
128 | train=False)
129 | train_loader = Data.DataLoader(dataset=trainset,
130 | batch_size=args.bs,
131 | num_workers=args.workers,
132 | shuffle=True,
133 | pin_memory=True,
134 | drop_last=True,
135 | )
136 | if args.val:
137 | val_loader = Data.DataLoader(
138 | dataset=valset,
139 | batch_size=8,
140 | num_workers=args.workers,
141 | pin_memory=True,
142 | shuffle=False,
143 | drop_last=True,
144 | )
145 | else:
146 | val_loader = None
147 |
148 | return train_loader, val_loader, map_gid_rgb
149 |
150 |
151 | def main(args):
152 | model = get_segnet(network=args.network,
153 | n_classes=args.classes,
154 | encoder_name=args.encoder_name,
155 | encoder_weights=args.encoder_weights,
156 | encoder_depth=args.encoder_depth,
157 | upsampling=args.upsampling,
158 | out_channels=args.out_channels,
159 | classification=args.classification,
160 | segmentation=args.segmentation, )
161 | print(model)
162 | label_weights = torch.ones([args.classes]).cuda()
163 | label_weights[0] = 0.5
164 | loss_func = SegLoss(
165 | segloss_name=args.seg_loss,
166 | use_cls=True,
167 | use_seg=args.segmentation > 0,
168 | cls_weight=args.weight_cls,
169 | use_hiera=False,
170 | hiera_weight=0,
171 | label_weights=label_weights).cuda()
172 |
173 | train_loader, val_loader, map_gid_rgb = get_train_val_loader(args=args)
174 | trainer = RecogTrainer(model=model, train_loader=train_loader, eval_loader=val_loader if args.val else None,
175 | loss_func=loss_func, args=args, map=map_gid_rgb)
176 |
177 | if args.resume is not None:
178 | trainer.resume(checkpoint=args.resume)
179 | else:
180 | trainer.train(start_epoch=0)
181 |
182 | print("Training finished")
183 |
184 |
185 | if __name__ == '__main__':
186 | parser = argparse.ArgumentParser("Train Semantic localization Network")
187 | parser.add_argument("--config", type=str, required=True, help="configuration file")
188 | parser.add_argument("--dataset", type=str, default="small", help="small, large, robotcar")
189 | parser.add_argument("--network", type=str, default="unet")
190 | parser.add_argument("--loss", type=str, default="ce")
191 | parser.add_argument("--classes", type=int, default=400)
192 | parser.add_argument("--out_channels", type=int, default=512)
193 | parser.add_argument("--root", type=str)
194 | parser.add_argument("--train_label_path", type=str)
195 | parser.add_argument("--train_image_path", type=str)
196 | parser.add_argument("--val_label_path", type=str)
197 | parser.add_argument("--val_image_path", type=str)
198 | parser.add_argument("--bs", type=int, default=4)
199 | parser.add_argument("--R", type=int, default=256)
200 | parser.add_argument("--weight_decay", type=float, default=5e-4)
201 | parser.add_argument("--epochs", type=int, default=120)
202 | parser.add_argument("--lr", type=float, default=1e-4)
203 | parser.add_argument("--workers", type=int, default=4)
204 | parser.add_argument("--log_interval", type=int, default=50)
205 | parser.add_argument("--optimizer", type=str, default=None)
206 | parser.add_argument("--resume", type=str, default=None)
207 | parser.add_argument("--segloss", type=str, default='ce')
208 | parser.add_argument("--classification", dest="classification", action="store_true", default=True)
209 | parser.add_argument("--segmentation", dest="segmentation", action="store_true", default=True)
210 | parser.add_argument("--val", dest="val", action="store_true", default=False)
211 | parser.add_argument("--aug", dest="aug", action="store_true", default=False)
212 | parser.add_argument("--preload", dest="preload", action="store_true", default=False)
213 | parser.add_argument("--ignore_bg", dest="ignore_bg", action="store_true", default=False)
214 | parser.add_argument("--weight_cls", type=float, default=1.0)
215 | parser.add_argument("--pretrained_weight", type=str, default=None)
216 | parser.add_argument("--encoder_name", type=str, default='timm-resnest50d')
217 | parser.add_argument("--encoder_weights", type=str, default='imagenet')
218 | parser.add_argument("--save_root", type=str, default="/home/mifs/fx221/fx221/exp/shloc/aachen")
219 | parser.add_argument("--gpu", type=int, nargs='+', default=[0], help='-1 for CPU')
220 | parser.add_argument("--milestones", type=list, default=[60, 80])
221 | parser.add_argument("--grgb_gid_file", type=str)
222 | parser.add_argument("--train_imglist", type=str)
223 | parser.add_argument("--test_imglist", type=str)
224 | parser.add_argument("--lr_policy", type=str, default='plateau', help='plateau, step')
225 | parser.add_argument("--multi_lr", type=int, default=1)
226 | parser.add_argument("--train_cats", type=list, default=None)
227 | parser.add_argument("--val_cats", type=list, default=None)
228 |
229 | args = parser.parse_args()
230 | with open(args.config, 'rt') as f:
231 | t_args = argparse.Namespace()
232 | t_args.__dict__.update(json.load(f))
233 | args = parser.parse_args(namespace=t_args)
234 |
235 | print('gpu: ', args.gpu)
236 |
237 | torch_set_gpu(gpus=args.gpu)
238 | main(args=args)
239 |
--------------------------------------------------------------------------------
/train_aachen:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python3 train.py --config configs/config_train_aachen.json
--------------------------------------------------------------------------------
/train_robotcar:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python3 train.py --config configs/config_train_robotcar.json
--------------------------------------------------------------------------------