├── .gitignore ├── README.md ├── aachen_452.npy ├── assets ├── overview.png └── samples │ ├── 1116.png │ └── 1417176916116788.png ├── configs ├── config_test_aachen.json ├── config_test_robotcar.json ├── config_train_aachen.json └── config_train_robotcar.json ├── dataloader ├── aachen.py ├── augmentation.py ├── place_recog.py └── robotcar.py ├── datasets ├── aachen │ ├── aachen_db_imglist.txt │ ├── aachen_db_query_imglist.txt │ ├── aachen_grgb_gid_v5.txt │ ├── aachen_query_imglist.txt │ ├── aachen_test_file_list.txt │ ├── aachen_test_file_list_v5.txt │ ├── aachen_train_file_list_v5.txt │ └── pairs-query-netvlad50.txt └── robotcar │ ├── pairs-query-netvlad20-percam-perloc-rear.txt │ ├── robotcar_db_imglist.txt │ ├── robotcar_db_query_imglist.txt │ ├── robotcar_grgb_gid.txt │ ├── robotcar_queries_with_intrinsics_rear.txt │ ├── robotcar_query_imglist.txt │ ├── robotcar_rear_grgb_gid.txt │ ├── robotcar_rear_test_file_list.txt │ ├── robotcar_rear_test_file_list_full.txt │ └── robotcar_rear_train_file_list.txt ├── localization ├── coarse │ ├── coarselocalization.py │ └── evaluate.py ├── fine │ ├── extractor.py │ ├── features │ │ ├── extract_d2net.py │ │ ├── extract_r2d2.py │ │ ├── extract_sift.py │ │ └── extract_spp.py │ ├── localize_cv2.py │ ├── matcher.py │ ├── test.py │ └── triangulation.py ├── localizer.py ├── tools.py └── utils │ ├── database.py │ ├── parsers.py │ └── read_write_model.py ├── loss ├── __pycache__ │ ├── accuracy.cpython-37.pyc │ ├── aploss.cpython-36.pyc │ └── aploss.cpython-37.pyc ├── accuracy.py ├── aploss.py └── seg_loss │ ├── __pycache__ │ ├── crossentropy_loss.cpython-37.pyc │ └── segloss.cpython-37.pyc │ ├── crossentropy_loss.py │ ├── focal_loss.py │ ├── segloss.py │ └── utils.py ├── net ├── layers.py ├── locnets │ ├── r2d2.py │ ├── resnet.py │ └── superpoint.py ├── plceregnets │ ├── gem.py │ └── pregnet.py ├── regnets │ ├── deeplab.py │ └── pspnet.py └── segnet.py ├── robotcar.npy ├── run_loc_aachen ├── run_loc_robotcar ├── run_reconstruct_aachen ├── run_reconstruct_robotcar ├── test.py ├── test_aachen ├── test_robotcar ├── tools ├── common.py ├── config_parser.py ├── loc_tools.py ├── optim.py └── seg_tools.py ├── train.py ├── train_aachen ├── train_place_recog.py ├── train_robotcar └── trainer_recog.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | localization/coarse/__pycache__ 4 | localization/__pycache__ 5 | localization/fine/__pycache__ 6 | localization/utils/__pycache__ 7 | tools/__pycache__ 8 | weights -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Efficient Large-scale Localization by Global Instance Recognition 2 | 3 |

4 | 5 |

6 | 7 | In this work, we propose to leverage global instances, which are robust to illumination and season changes for both 8 | coarse and fine localization. For coarse localization, instead of performing global reference search directly, we search 9 | for reference images from recognized global instances progressively. The recognized instances are further utilized for 10 | instance-wise feature detection and matching to enhance the localization accuracy. 11 | 12 | * Full paper 13 | PDF: [Efficient Large-scale Localization by Global Instance Recognition](https://openaccess.thecvf.com/content/CVPR2022/papers/Xue_Efficient_Large-Scale_Localization_by_Global_Instance_Recognition_CVPR_2022_paper.pdf) 14 | . 15 | 16 | * Authors: *Fei Xue, Ignas Budvytis, Daniel Olmeda Reino, Roberto Cipolla* 17 | 18 | * Website: [lbr](https://github.com/feixue94/feixue94.github.io/lbr) for videos, slides, recent updates, and datasets. 19 | 20 | ## Dependencies 21 | 22 | * Python 3 >= 3.6 23 | * PyTorch >= 1.8 24 | * OpenCV >= 3.4 25 | * NumPy >= 1.18 26 | * segmentation-models-pytorch = 0.1.3 27 | * colmap 28 | * pycolmap = 0.0.1 29 | 30 | ## Data preparation 31 | 32 | Please follow instructions on the [VisualLocalization Benchmark](https://www.visuallocalization.net/datasets/) to 33 | download images and reference 3D models of Aachen_v1.1 and RobotCar-Seasons datasets 34 | 35 | * [Images of Aachen_v1.1 dataset](https://data.ciirc.cvut.cz/public/projects/2020VisualLocalization/Aachen-Day-Night/) 36 | * [Images of RobotCar-Seasons dataset](https://data.ciirc.cvut.cz/public/projects/2020VisualLocalization/RobotCar-Seasons/) 37 | 38 | You can download global instance labels used in this work here: 39 | 40 | * [Global instances of Aachen_v1.1](https://drive.google.com/file/d/17qerRcU8Iemjwz7tUtlX9syfN-WVSs4K/view?usp=sharing) 41 | * [Global instances of RobotCar-Seasons](https://drive.google.com/file/d/1Ns5I3YGoMCBURzWKZTxqsugeG4jUcj4a/view?usp=sharing) 42 | 43 | Since only daytime images are included in Aachen_v1.1 and RobotCar-Seasons database, which may cause huge recognition 44 | errors for the recognition of nighttime query images, we augment the training data by generating some stylized images, 45 | which can also be downloaded along with global instances. The structure of files in Aachen_v1.1 dataset should be like 46 | this: 47 | 48 | ``` 49 | - aachen_v1.1 50 | - global_seg_instance 51 | - db 52 | - query 53 | - sequences 54 | - images 55 | - 3D-models 56 | - images.bin 57 | - points3D.bin 58 | - cameras.bin 59 | - stylized 60 | - images (raw database images) 61 | - images_aug1 62 | - images_aug2 63 | ``` 64 | 65 | For RobotCar-Seasons dataset, it should be like this: 66 | 67 | ``` 68 | - RobotCar-Seasons 69 | - gloabl_seg_instance 70 | - overcast-reference 71 | - rear 72 | - images 73 | - overcast-reference 74 | - night 75 | ... 76 | - night-rain 77 | - 3D-models 78 | - sfm-sift 79 | - cameras.bin 80 | - images.bin 81 | - points3D.bin 82 | - stylized 83 | - overcast-reference (raw database images) 84 | - overcast-reference_aug_d 85 | - overcast-reference_aug_r 86 | - overcast-reference_aug_nr 87 | ``` 88 | 89 | ## Pretrained weights 90 | 91 | We provide pretrained weights for local feature detection and extraction, global instance recognition for Aachen_v1.1 92 | and RobotCar-Seasons datasets, respectively, which can be downloaded 93 | from [here](https://drive.google.com/file/d/1N4j7PkZoy2CkWhS7u6dFzMIoai3ShG9p/view?usp=sharing) 94 | 95 | ## Testing of global instance recognition 96 | 97 | you will get predicted masks of global instances, confidence maps, global features, and visualization images. 98 | 99 | * testing recognition on Aachen_v1.1 100 | 101 | ``` 102 | ./test_aachen 103 | ``` 104 | 105 |

106 | 107 |

108 | 109 | * testing recognition on RobotCar-Seasons 110 | 111 | ``` 112 | ./test_robotcar 113 | ``` 114 | 115 |

116 | 117 |

118 | 119 | ## 3D Reconstruction 120 | 121 | For fine localization, you also need a 3D map of the environment reconstructed by structure-from-motion. 122 | 123 | * feature extraction and 3D reconstruction for Aachen_v1.1 124 | 125 | ``` 126 | ./run_reconstruct_aachen 127 | ``` 128 | 129 | * feature extraction and 3d reconstruction for RobotCar-Seasons 130 | 131 | ``` 132 | ./run_reconstruct_robotcar 133 | ``` 134 | 135 | ## Localization with global instances 136 | 137 | Once you have the global instance masks of the query and database images and the 3D map of the scene, you can run the 138 | following commands for localization. 139 | 140 | * localization on Aachen_v1.1 141 | 142 | ``` 143 | ./run_loc_aachn 144 | ``` 145 | 146 | you will get results like this: 147 | 148 | | | Day | Night | 149 | | -------- | ------- | -------- | 150 | | cvpr | 89.1 / 96.1 / 99.3 | 77.0 / 90.1 / 99.5 | 151 | | post-cvpr | 88.8 / 95.8 / 99.2 | 75.4 / 91.6 / 100 | 152 | 153 | * localization on RobotCar-Seasons 154 | 155 | ``` 156 | ./run_loc_robotcar 157 | ``` 158 | 159 | you will get results like this: 160 | 161 | | | Night | Night-rain | 162 | | -------- | ----- | ------- | 163 | | cvpr | 24.9 / 62.3 / 86.1 | 47.5 / 73.4 / 90.0 | 164 | | post-cvpr | 28.1 / 66.9 / 91.8 | 46.1 / 73.6 / 92.5 | 165 | 166 | ## Training 167 | 168 | If you want to retrain the recognition network, you can run the following commands. 169 | 170 | * training recognition on Aachen_v1.1 171 | 172 | ``` 173 | ./train_aachen 174 | ``` 175 | 176 | * training recognition on RobotCar-Seasons 177 | 178 | ``` 179 | ./train_robotcar 180 | ``` 181 | 182 | ## BibTeX Citation 183 | 184 | If you use any ideas from the paper or code from this repo, please consider citing: 185 | 186 | ``` 187 | @inproceedings{xue2022efficient, 188 | author = {Fei Xue and Ignas Budvytis and Daniel Olmeda Reino and Roberto Cipolla}, 189 | title = {Efficient Large-scale Localization by Global Instance Recognition}, 190 | booktitle = {CVPR}, 191 | year = {2022} 192 | } 193 | ``` 194 | 195 | ## Acknowledgements 196 | 197 | Part of the code is from previous excellent works 198 | including [SuperPoint](https://github.com/magicleap/SuperPointPretrainedNetwork), [R2D2](https://github.com/naver/r2d2) 199 | , [HLoc](https://github.com/cvg/Hierarchical-Localization). You can find more details from their released repositories 200 | if you are interested in their works. 201 | -------------------------------------------------------------------------------- /aachen_452.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feixue94/lbr/bdd126b79da9e75759a364a0d8e010d76658f1af/aachen_452.npy -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feixue94/lbr/bdd126b79da9e75759a364a0d8e010d76658f1af/assets/overview.png -------------------------------------------------------------------------------- /assets/samples/1116.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feixue94/lbr/bdd126b79da9e75759a364a0d8e010d76658f1af/assets/samples/1116.png -------------------------------------------------------------------------------- /assets/samples/1417176916116788.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feixue94/lbr/bdd126b79da9e75759a364a0d8e010d76658f1af/assets/samples/1417176916116788.png -------------------------------------------------------------------------------- /configs/config_test_aachen.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "SHLoc", 3 | "dataset": "aachen", 4 | "gpu": 0, 5 | "image_path": "/scratches/flyer_2/fx221/localization/aachen_v1_1/images/images_upright", 6 | "image_list": "datasets/aachen/aachen_db_query_imglist.txt", 7 | "grgb_gid_file": "datasets/aachen/aachen_grgb_gid_v5.txt", 8 | "classes": 452, 9 | "segmentation": 1, 10 | "classification": 1, 11 | "augmentation": 1, 12 | "network": "pspf", 13 | "encoder_name": "resnext101_32x4d", 14 | "encoder_weights": "ssl", 15 | "out_channels": 2048, 16 | "encoder_depth": 4, 17 | "R": 256, 18 | "upsampling": 8, 19 | "bs": 8, 20 | "save_dir": "/scratches/flyer_2/fx221/exp/shloc/data_release/aachen_v1.1", 21 | "pretrained_weight": "weights/2021_08_29_12_49_48_aachen_pspf_resnext101_32x4d_d4_u8_b16_R256_E120_ceohem_adam_seg_cls_aug.pth" 22 | } -------------------------------------------------------------------------------- /configs/config_test_robotcar.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "SHLoc", 3 | "dataset": "robotcar", 4 | "gpu": 0, 5 | "image_path": "/scratches/flyer_2/fx221/localization/RobotCar-Seasons/images", 6 | "image_list": "datasets/robotcar/robotcar_rear_test_file_list_full.txt", 7 | "grgb_gid_file": "datasets/robotcar/robotcar_rear_grgb_gid.txt", 8 | "classes": 692, 9 | "segmentation": 1, 10 | "classification": 1, 11 | "network": "pspf", 12 | "encoder_name": "resnext101_32x4d", 13 | "encoder_weights": "ssl", 14 | "out_channels": 2048, 15 | "encoder_depth": 4, 16 | "upsampling": 8, 17 | "bs": 8, 18 | "save_dir": "/scratches/flyer_2/fx221/exp/shloc/data_release/RobotCar-Seasons", 19 | "pretrained_weight": "weights/2021_10_17_00_21_32_robotcar_pspf_resnext101_32x4d_d4_u8_b16_R256_E500_ceohem_adam_poly_mlr_seg_cls_aug.pth" 20 | } -------------------------------------------------------------------------------- /configs/config_train_aachen.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "SHLoc", 3 | "dataset": "aachen", 4 | "gpu": [ 5 | 0 6 | ], 7 | "root": "/scratches/flyer_2/fx221/localization/aachen_v1_1", 8 | "save_root": "/scratches/flyer_2/fx221/exp/shloc/aachen", 9 | "train_label_path": "global_seg_instance", 10 | "train_image_path": "stylized", 11 | "val_label_path": "global_seg_instance", 12 | "val_image_path": "stylized", 13 | "tag": "stylized", 14 | "grgb_gid_file": "./datasets/aachen/aachen_grgb_gid_v5.txt", 15 | "train_imglist": "./datasets/aachen/aachen_train_file_list_v5.txt", 16 | "test_imglist": "./datasets/aachen/aachen_test_file_list_v5.txt", 17 | "train_cats": [ 18 | "images/images_upright", 19 | "images_aug1/images_upright", 20 | "images_aug2/images_upright" 21 | ], 22 | "val_cats": [ 23 | "images_aug/images_upright" 24 | ], 25 | "classes": 452, 26 | "weight_cls": 2.0, 27 | "segmentation": 1, 28 | "classification": 1, 29 | "augmentation": 1, 30 | "network": "pspf", 31 | "encoder_name": "resnext101_32x4d", 32 | "encoder_weights": "ssl", 33 | "seg_loss": "ceohem", 34 | "seg_loss_sce": "sceohem", 35 | "out_channels": 2048, 36 | "encoder_depth": 4, 37 | "upsampling": 8, 38 | "loss": "ce", 39 | "bs": 16, 40 | "R": 256, 41 | "optimizer": "adam", 42 | "weight_decay": 0.0005, 43 | "lr_policy": "step", 44 | "multi_lr": 1, 45 | "epochs": 120, 46 | "milestones": [ 47 | 80, 48 | 100 49 | ], 50 | "lr": 0.0001, 51 | "workers": 4, 52 | "log_interval": 50, 53 | "save_dir": "", 54 | "val": 1, 55 | "aug": 1 56 | } -------------------------------------------------------------------------------- /configs/config_train_robotcar.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "SHLoc", 3 | "dataset": "robotcar", 4 | "gpu": [ 5 | 0 6 | ], 7 | "root": "/scratches/flyer_2/fx221/localization/RobotCar-Seasons", 8 | "save_root": "/scratches/flyer_2/fx221/exp/shloc/robotcar", 9 | "train_label_path": "global_seg_instance/mixed", 10 | "train_image_path": "stylized", 11 | "val_label_path": "global_seg_instance", 12 | "val_image_path": "stylized", 13 | "tag": "stylized", 14 | "grgb_gid_file": "./datasets/robotcar/robotcar_grgb_gid.txt", 15 | "train_imglist": "./datasets/robotcar/robotcar_train_file_list_full.txt", 16 | "test_imglist": "./datasets/robotcar/robotcar_test_file_list.txt", 17 | "train_cats": [ 18 | "overcast-reference", 19 | "overcast-reference_aug_d", 20 | "overcast-reference_aug_n" 21 | ], 22 | "val_cats": [ 23 | "overcast-reference_aug_nr" 24 | ], 25 | "classes": 692, 26 | "weight_cls": 2.0, 27 | "segmentation": 1, 28 | "classification": 1, 29 | "augmentation": 1, 30 | "network": "pspf", 31 | "encoder_name": "resnext101_32x4d", 32 | "encoder_weights": "ssl", 33 | "seg_loss": "ceohem", 34 | "seg_loss_sce": "sceohem", 35 | "out_channels": 2048, 36 | "encoder_depth": 4, 37 | "upsampling": 8, 38 | "loss": "ce", 39 | "bs": 16, 40 | "R": 256, 41 | "optimizer": "adam", 42 | "weight_decay": 0.0005, 43 | "lr_policy": "poly", 44 | "multi_lr": 1, 45 | "epochs": 500, 46 | "milestones": [ 47 | 80, 48 | 100 49 | ], 50 | "lr": 0.0001, 51 | "workers": 4, 52 | "log_interval": 50, 53 | "val": 1, 54 | "aug": 1 55 | } -------------------------------------------------------------------------------- /dataloader/aachen.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | '''================================================= 3 | @Project -> File shloc -> aachen 4 | @IDE PyCharm 5 | @Author fx221@cam.ac.uk 6 | @Date 18/07/2021 10:22 7 | ==================================================''' 8 | import torch.utils.data as data 9 | from torch.utils.data.dataloader import * 10 | import numpy as np 11 | import os.path as osp 12 | import cv2 13 | import random 14 | from tools.common import resize_img 15 | 16 | 17 | class AachenSegFull(data.Dataset): 18 | def __init__(self, 19 | image_path, 20 | label_path, 21 | grgb_gid_file, 22 | n_classes, 23 | cats, 24 | train=True, 25 | transform=None, 26 | use_cls=False, 27 | img_list=None, 28 | preload=False, 29 | aug=None, 30 | R=256, 31 | keep_ratio=False, 32 | ): 33 | super(AachenSegFull, self).__init__() 34 | self.image_path = image_path 35 | self.label_path = label_path 36 | self.transform = transform 37 | self.train = train 38 | self.cls = use_cls 39 | self.aug = aug 40 | self.cats = cats 41 | self.R = R 42 | print("augmentation: ", self.aug) 43 | self.classes = n_classes 44 | self.keep_ratio = keep_ratio 45 | 46 | self.valid_fns = [] 47 | with open(img_list, "r") as f: 48 | lines = f.readlines() 49 | for l in lines: 50 | l = l.strip() 51 | if not osp.exists(osp.join(self.label_path, l.replace("jpg", "png"))): 52 | continue 53 | 54 | keep = True 55 | for c in self.cats: 56 | if not osp.exists(osp.join(self.image_path, c, l)): 57 | keep = False 58 | if not keep: 59 | continue 60 | self.valid_fns.append(l) 61 | # 62 | self.grgb_gid = {} 63 | with open(grgb_gid_file, "r") as f: 64 | lines = f.readlines() 65 | for l in lines: 66 | l = l.strip().split(" ") 67 | self.grgb_gid[int(l[0])] = int(l[1]) 68 | 69 | print("Load {:d} valid samples.".format(len(self.valid_fns))) 70 | print("No. of gids: {:d}".format(len(self.grgb_gid.keys()))) 71 | 72 | def seg_to_gid(self, seg_img, grgb_gid_map): 73 | id_img = np.int32(seg_img[:, :, 2]) * 256 * 256 + np.int32(seg_img[:, :, 1]) * 256 + np.int32( 74 | seg_img[:, :, 0]) 75 | luids = np.unique(id_img).tolist() 76 | # print("id_img: ", id_img.shape) 77 | out_img = np.zeros_like(seg_img) 78 | gid_img = np.zeros_like(id_img) 79 | for id in luids: 80 | if id in grgb_gid_map.keys(): 81 | gid = grgb_gid_map[id] 82 | mask = (id_img == id) 83 | gid_img[mask] = gid 84 | 85 | out_img[mask] = seg_img[mask] 86 | 87 | return out_img, gid_img 88 | 89 | def get_item_seg(self, idx): 90 | fn = self.valid_fns[idx] 91 | 92 | if len(self.cats) > 1: 93 | cat_id = random.randint(1, len(self.cats)) 94 | else: 95 | cat_id = 0 96 | # print(osp.join(self.image_path, self.cats[cat_id - 1], fn)) 97 | # exit(0) 98 | # print("fn: ", fn) 99 | img = cv2.imread(osp.join(self.image_path, self.cats[cat_id - 1], fn)) 100 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 101 | # print(osp.join(self.label_path, fn.replace('jpg', 'png'))) 102 | seg_img = cv2.imread(osp.join(self.label_path, fn.replace('jpg', 'png'))) 103 | # print(img.shape, seg_img.shape) 104 | seg_img = cv2.resize(seg_img, dsize=(img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST) 105 | # print(img.shape, seg_img.shape) 106 | # exit(0) 107 | 108 | if self.aug is not None: 109 | for a in self.aug: 110 | img, seg_img = a((img, seg_img)) 111 | else: 112 | if self.keep_ratio: 113 | img = resize_img(img=img, nh=self.R, mode=cv2.INTER_NEAREST) 114 | else: 115 | img = cv2.resize(img.astype(np.uint8), dsize=(self.R, self.R)) 116 | seg_img = cv2.resize(seg_img.astype(np.uint8), dsize=(img.shape[1], img.shape[0]), 117 | interpolation=cv2.INTER_NEAREST) 118 | if self.transform is not None: 119 | img = self.transform(img) 120 | seg_img = np.array(seg_img) 121 | 122 | filtered_seg, gids = self.seg_to_gid(seg_img=seg_img, grgb_gid_map=self.grgb_gid) 123 | gids = np.asarray(gids, np.int64) 124 | gids = torch.from_numpy(gids) 125 | gids = torch.LongTensor(gids) 126 | 127 | output = { 128 | "img": img, 129 | "label": [gids], 130 | "label_img": filtered_seg, 131 | } 132 | 133 | if self.cls: 134 | cls_label = np.zeros(shape=(self.classes), dtype=np.float) 135 | uids = np.unique(gids).tolist() 136 | for id in uids: 137 | cls_label[id] = 1.0 138 | cls_label = torch.from_numpy(cls_label) 139 | 140 | output["cls"] = [cls_label] 141 | 142 | return output 143 | 144 | def __getitem__(self, idx): 145 | return self.get_item_seg(idx=idx) 146 | 147 | def __len__(self): 148 | return len(self.valid_fns) 149 | -------------------------------------------------------------------------------- /dataloader/robotcar.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | '''================================================= 3 | @Project -> File shloc -> robotcar 4 | @IDE PyCharm 5 | @Author fx221@cam.ac.uk 6 | @Date 18/07/2021 10:22 7 | ==================================================''' 8 | import torch.utils.data as data 9 | from torch.utils.data.dataloader import * 10 | import numpy as np 11 | import os.path as osp 12 | import cv2 13 | import random 14 | from tools.common import resize_img 15 | 16 | 17 | class RobotCarSegFull(data.Dataset): 18 | def __init__(self, 19 | image_path, 20 | label_path, 21 | grgb_gid_file, 22 | n_classes, 23 | cats, 24 | train=True, 25 | transform=None, 26 | use_cls=False, 27 | img_list=None, 28 | preload=False, 29 | aug=None, 30 | R=256, 31 | keep_ratio=False, 32 | ): 33 | super(RobotCarSegFull, self).__init__() 34 | self.image_path = image_path 35 | self.label_path = label_path 36 | self.transform = transform 37 | self.train = train 38 | self.cls = use_cls 39 | self.aug = aug 40 | self.cats = cats 41 | self.R = R 42 | print("augmentation: ", self.aug) 43 | self.classes = n_classes 44 | self.keep_ratio = keep_ratio 45 | 46 | self.valid_fns = [] 47 | with open(img_list, "r") as f: 48 | lines = f.readlines() 49 | for l in lines: 50 | l = l.strip() 51 | if not osp.exists(osp.join(self.label_path, l.replace("jpg", "png"))): 52 | continue 53 | 54 | keep = True 55 | for c in self.cats: 56 | if not osp.exists(osp.join(self.image_path, c, l)): 57 | keep = False 58 | if not keep: 59 | continue 60 | self.valid_fns.append(l) 61 | 62 | self.grgb_gid = {} 63 | with open(grgb_gid_file, "r") as f: 64 | lines = f.readlines() 65 | for l in lines: 66 | l = l.strip().split(" ") 67 | self.grgb_gid[int(l[0])] = int(l[1]) 68 | 69 | print("Load {:d} valid samples.".format(len(self.valid_fns))) 70 | print("No. of gids: {:d}".format(len(self.grgb_gid.keys()))) 71 | 72 | def seg_to_gid(self, seg_img, grgb_gid_map): 73 | id_img = np.int32(seg_img[:, :, 2]) * 256 * 256 + np.int32(seg_img[:, :, 1]) * 256 + np.int32( 74 | seg_img[:, :, 0]) 75 | luids = np.unique(id_img).tolist() 76 | # print("id_img: ", id_img.shape) 77 | out_img = np.zeros_like(seg_img) 78 | gid_img = np.zeros_like(id_img) 79 | for id in luids: 80 | if id in grgb_gid_map.keys(): 81 | gid = grgb_gid_map[id] 82 | mask = (id_img == id) 83 | gid_img[mask] = gid 84 | 85 | out_img[mask] = seg_img[mask] 86 | 87 | return out_img, gid_img 88 | 89 | def get_item_seg(self, idx): 90 | fn = self.valid_fns[idx] 91 | 92 | cat_id = random.randint(1, len(self.cats)) 93 | # print(osp.join(self.image_path, self.cats[cat_id - 1], fn)) 94 | img = cv2.imread(osp.join(self.image_path, self.cats[cat_id - 1], fn)) 95 | # raw_img = cv2.resize(img, dsize=(self.R, self.R)) 96 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 97 | 98 | seg_img = cv2.imread(osp.join(self.label_path, fn.replace('jpg', 'png'))) 99 | 100 | if self.aug is not None: 101 | for a in self.aug: 102 | img, seg_img = a((img, seg_img)) 103 | 104 | else: 105 | if self.keep_ratio: 106 | img = resize_img(img=img, nh=self.R, mode=cv2.INTER_NEAREST) 107 | else: 108 | img = cv2.resize(img.astype(np.uint8), dsize=(self.R, self.R)) 109 | seg_img = cv2.resize(seg_img.astype(np.uint8), dsize=(img.shape[1], img.shape[0]), 110 | interpolation=cv2.INTER_NEAREST) 111 | raw_img = img.astype(np.uint8) 112 | 113 | if self.transform is not None: 114 | img = self.transform(img) 115 | seg_img = np.array(seg_img) 116 | 117 | filtered_seg, gids = self.seg_to_gid(seg_img=seg_img, grgb_gid_map=self.grgb_gid) 118 | gids = np.asarray(gids, np.int64) 119 | gids = torch.from_numpy(gids) 120 | gids = torch.LongTensor(gids) 121 | 122 | output = { 123 | "raw_img": raw_img, 124 | "img": img, 125 | "label": [gids], 126 | "label_img": filtered_seg, 127 | } 128 | 129 | if self.cls: 130 | cls_label = np.zeros(shape=(self.classes), dtype=np.float) 131 | uids = np.unique(gids).tolist() 132 | for id in uids: 133 | cls_label[id] = 1.0 134 | cls_label = torch.from_numpy(cls_label) 135 | 136 | output["cls"] = [cls_label] 137 | 138 | return output 139 | 140 | def __getitem__(self, idx): 141 | return self.get_item_seg(idx=idx) 142 | 143 | def __len__(self): 144 | return len(self.valid_fns) 145 | -------------------------------------------------------------------------------- /datasets/aachen/aachen_grgb_gid_v5.txt: -------------------------------------------------------------------------------- 1 | 0 0 2 | 3278090 1 3 | 3278100 2 4 | 3278115 3 5 | 3278120 4 6 | 3278135 5 7 | 3278150 6 8 | 3278155 7 9 | 3278160 8 10 | 3278165 9 11 | 3278175 10 12 | 3278180 11 13 | 3278185 12 14 | 3278195 13 15 | 3278205 14 16 | 3278215 15 17 | 3278225 16 18 | 3278230 17 19 | 3278235 18 20 | 3278245 19 21 | 3278255 20 22 | 3278260 21 23 | 3278265 22 24 | 3278280 23 25 | 3279365 24 26 | 3279380 25 27 | 3279385 26 28 | 3279390 27 29 | 3279430 28 30 | 3279450 29 31 | 3279455 30 32 | 3279460 31 33 | 3279465 32 34 | 3279470 33 35 | 3279475 34 36 | 3279485 35 37 | 3279490 36 38 | 3279500 37 39 | 3279515 38 40 | 3280805 39 41 | 3280815 40 42 | 3280820 41 43 | 3280825 42 44 | 3281960 43 45 | 3281975 44 46 | 3283215 45 47 | 3285785 46 48 | 3285800 47 49 | 3285830 48 50 | 3285850 49 51 | 3285860 50 52 | 3285865 51 53 | 3285885 52 54 | 3285895 53 55 | 3285910 54 56 | 3285925 55 57 | 3287205 56 58 | 3287260 57 59 | 3288385 58 60 | 3288485 59 61 | 3288490 60 62 | 3288500 61 63 | 3288505 62 64 | 3288510 63 65 | 3288520 64 66 | 3288530 65 67 | 3289615 66 68 | 3289630 67 69 | 3289650 68 70 | 3289660 69 71 | 3289665 70 72 | 3289670 71 73 | 3289680 72 74 | 3289725 73 75 | 3289735 74 76 | 3289820 75 77 | 3289825 76 78 | 3290960 77 79 | 3290970 78 80 | 3290980 79 81 | 3291000 80 82 | 3291010 81 83 | 3291015 82 84 | 3291020 83 85 | 3291025 84 86 | 3291035 85 87 | 3291060 86 88 | 3292190 87 89 | 3292195 88 90 | 3292200 89 91 | 3292210 90 92 | 3292215 91 93 | 3292225 92 94 | 3292230 93 95 | 3292235 94 96 | 3292245 95 97 | 3292260 96 98 | 3292275 97 99 | 3292285 98 100 | 3292290 99 101 | 3292295 100 102 | 3292300 101 103 | 3292310 102 104 | 3292320 103 105 | 3292335 104 106 | 3292340 105 107 | 3292345 106 108 | 3292350 107 109 | 3292355 108 110 | 3292360 109 111 | 3292370 110 112 | 3292380 111 113 | 3292385 112 114 | 3292390 113 115 | 3293445 114 116 | 3293455 115 117 | 3293460 116 118 | 3293470 117 119 | 3293490 118 120 | 3293505 119 121 | 3293510 120 122 | 3293515 121 123 | 3293520 122 124 | 3293525 123 125 | 3293530 124 126 | 3293590 125 127 | 3293595 126 128 | 3293650 127 129 | 3293655 128 130 | 3293660 129 131 | 3293665 130 132 | 3294725 131 133 | 3294730 132 134 | 3294735 133 135 | 3294740 134 136 | 3294750 135 137 | 3294805 136 138 | 3294850 137 139 | 3294860 138 140 | 3294865 139 141 | 3294870 140 142 | 3294875 141 143 | 3294885 142 144 | 3294895 143 145 | 3294900 144 146 | 3297295 145 147 | 3297485 146 148 | 3297500 147 149 | 3297525 148 150 | 3298775 149 151 | 3298780 150 152 | 3298790 151 153 | 3300065 152 154 | 3301170 153 155 | 3301290 154 156 | 3301310 155 157 | 3301315 156 158 | 3301325 157 159 | 3301330 158 160 | 3302470 159 161 | 3302475 160 162 | 3302480 161 163 | 3302490 162 164 | 3302545 163 165 | 3302570 164 166 | 3302595 165 167 | 3302605 166 168 | 3303850 167 169 | 3303870 168 170 | 3303915 169 171 | 3303920 170 172 | 3303925 171 173 | 3304995 172 174 | 3306370 173 175 | 3307600 174 176 | 3307615 175 177 | 3307620 176 178 | 3307630 177 179 | 3307635 178 180 | 3307675 179 181 | 3307700 180 182 | 3307705 181 183 | 3307710 182 184 | 3307715 183 185 | 3307735 184 186 | 3308975 185 187 | 3311395 186 188 | 3313985 187 189 | 3316630 188 190 | 3316635 189 191 | 3316640 190 192 | 3316645 191 193 | 3316695 192 194 | 3319100 193 195 | 3319110 194 196 | 3319120 195 197 | 3321615 196 198 | 3321825 197 199 | 3324240 198 200 | 3325525 199 201 | 3325545 200 202 | 3326760 201 203 | 3326840 202 204 | 3326870 203 205 | 3326875 204 206 | 3328015 205 207 | 3330570 206 208 | 3330665 207 209 | 3330680 208 210 | 3330755 209 211 | 3330760 210 212 | 3331930 211 213 | 3331945 212 214 | 3331950 213 215 | 3332005 214 216 | 3333175 215 217 | 3333290 216 218 | 3333305 217 219 | 3334545 218 220 | 3334560 219 221 | 3334575 220 222 | 3334585 221 223 | 3334595 222 224 | 3334600 223 225 | 3334620 224 226 | 3336985 225 227 | 3338330 226 228 | 3339625 227 229 | 3339695 228 230 | 3339700 229 231 | 3474813 230 232 | 3474873 231 233 | 3477278 232 234 | 3477293 233 235 | 3477303 234 236 | 3478538 235 237 | 3479883 236 238 | 3479983 237 239 | 3481328 238 240 | 3488928 239 241 | 3491403 240 242 | 3492753 241 243 | 3492773 242 244 | 3495268 243 245 | 3496653 244 246 | 3497763 245 247 | 3500388 246 248 | 3501768 247 249 | 3502908 248 250 | 3502918 249 251 | 3502973 250 252 | 3503018 251 253 | 3504168 252 254 | 3504173 253 255 | 3504218 254 256 | 3504253 255 257 | 3504258 256 258 | 3504288 257 259 | 3504308 258 260 | 3504338 259 261 | 3505423 260 262 | 3505433 261 263 | 3505508 262 264 | 3506748 263 265 | 3506758 264 266 | 3506793 265 267 | 3506858 266 268 | 3510613 267 269 | 3510643 268 270 | 3510648 269 271 | 3511823 270 272 | 3513293 271 273 | 3515853 272 274 | 3517038 273 275 | 3517153 274 276 | 3518393 275 277 | 3518403 276 278 | 3518438 277 279 | 3519538 278 280 | 3519563 279 281 | 3519608 280 282 | 3519613 281 283 | 3519638 282 284 | 3519643 283 285 | 3519648 284 286 | 3519653 285 287 | 3519713 286 288 | 3522128 287 289 | 3522203 288 290 | 3522208 289 291 | 3523468 290 292 | 3526123 291 293 | 3527253 292 294 | 3527268 293 295 | 3527298 294 296 | 3527308 295 297 | 3527333 296 298 | 3528453 297 299 | 3529798 298 300 | 3529818 299 301 | 3529843 300 302 | 3529903 301 303 | 3529973 302 304 | 3531013 303 305 | 3531033 304 306 | 3531053 305 307 | 3531073 306 308 | 3531108 307 309 | 3531148 308 310 | 3531183 309 311 | 3531228 310 312 | 3531243 311 313 | 3533618 312 314 | 3533698 313 315 | 3533713 314 316 | 3533718 315 317 | 3533738 316 318 | 3533803 317 319 | 3671501 318 320 | 3672671 319 321 | 3675261 320 322 | 3677806 321 323 | 3677831 322 324 | 3677836 323 325 | 3677846 324 326 | 3677876 325 327 | 3677921 326 328 | 3677926 327 329 | 3677936 328 330 | 3677941 329 331 | 3678981 330 332 | 3678996 331 333 | 3679001 332 334 | 3679006 333 335 | 3679036 334 336 | 3679056 335 337 | 3679061 336 338 | 3679066 337 339 | 3679081 338 340 | 3679091 339 341 | 3679096 340 342 | 3679101 341 343 | 3679111 342 344 | 3679116 343 345 | 3679121 344 346 | 3679131 345 347 | 3679146 346 348 | 3679186 347 349 | 3679216 348 350 | 3680381 349 351 | 3680396 350 352 | 3680461 351 353 | 3680476 352 354 | 3681591 353 355 | 3681606 354 356 | 3681656 355 357 | 3681666 356 358 | 3681671 357 359 | 3682856 358 360 | 3682891 359 361 | 3682896 360 362 | 3682911 361 363 | 3682916 362 364 | 3682936 363 365 | 3682971 364 366 | 3683016 365 367 | 3683021 366 368 | 3683031 367 369 | 3683041 368 370 | 3683046 369 371 | 3683051 370 372 | 3683056 371 373 | 3683061 372 374 | 3684121 373 375 | 3684136 374 376 | 3684331 375 377 | 3685471 376 378 | 3685486 377 379 | 3685506 378 380 | 3685541 379 381 | 3685551 380 382 | 3685556 381 383 | 3685571 382 384 | 3685601 383 385 | 3685616 384 386 | 3686716 385 387 | 3686721 386 388 | 3686726 387 389 | 3686736 388 390 | 3686741 389 391 | 3686751 390 392 | 3686786 391 393 | 3686821 392 394 | 3688101 393 395 | 3688141 394 396 | 3688176 395 397 | 3689311 396 398 | 3689321 397 399 | 3689341 398 400 | 3689366 399 401 | 3689436 400 402 | 3690676 401 403 | 3690706 402 404 | 3691806 403 405 | 3691831 404 406 | 3691846 405 407 | 3691871 406 408 | 3691901 407 409 | 3691906 408 410 | 3691911 409 411 | 3691916 410 412 | 3691921 411 413 | 3693191 412 414 | 3693231 413 415 | 3693291 414 416 | 3693301 415 417 | 3694351 416 418 | 3694376 417 419 | 3694396 418 420 | 3694526 419 421 | 3694556 420 422 | 3694571 421 423 | 3695811 422 424 | 3704806 423 425 | 3705926 424 426 | 3707236 425 427 | 3707256 426 428 | 3709721 427 429 | 3709751 428 430 | 3709791 429 431 | 3709796 430 432 | 3709801 431 433 | 3709861 432 434 | 3709881 433 435 | 3709906 434 436 | 3709916 435 437 | 3712281 436 438 | 3712301 437 439 | 3712306 438 440 | 3712316 439 441 | 3713601 440 442 | 3713611 441 443 | 3713701 442 444 | 3713711 443 445 | 3713716 444 446 | 3714936 445 447 | 3725081 446 448 | 3725091 447 449 | 3727651 448 450 | 3727656 449 451 | 3727681 450 452 | 3694917 451 -------------------------------------------------------------------------------- /datasets/robotcar/robotcar_grgb_gid.txt: -------------------------------------------------------------------------------- 1 | 0 0 2 | 6364 1 3 | 23394 2 4 | 57615 3 5 | 60798 4 6 | 150085 5 7 | 175434 6 8 | 176037 7 9 | 181106 8 10 | 199936 9 11 | 210177 10 12 | 212877 11 13 | 213559 12 14 | 247058 13 15 | 270776 14 16 | 288895 15 17 | 319083 16 18 | 320757 17 19 | 325851 18 20 | 353253 19 21 | 381759 20 22 | 437887 21 23 | 453785 22 24 | 498638 23 25 | 526846 24 26 | 531348 25 27 | 543518 26 28 | 543812 27 29 | 545929 28 30 | 564003 29 31 | 576434 30 32 | 580974 31 33 | 590030 32 34 | 610805 33 35 | 702863 34 36 | 709265 35 37 | 709540 36 38 | 765311 37 39 | 782213 38 40 | 810012 39 41 | 819088 40 42 | 824305 41 43 | 850422 42 44 | 884397 43 45 | 889520 44 46 | 912819 45 47 | 978590 46 48 | 995266 47 49 | 995626 48 50 | 999483 49 51 | 1008184 50 52 | 1009542 51 53 | 1019753 52 54 | 1071053 53 55 | 1085329 54 56 | 1119230 55 57 | 1167652 56 58 | 1225038 57 59 | 1238518 58 60 | 1294830 59 61 | 1314892 60 62 | 1327411 61 63 | 1352622 62 64 | 1366495 63 65 | 1395676 64 66 | 1407940 65 67 | 1440984 66 68 | 1464343 67 69 | 1479206 68 70 | 1490275 69 71 | 1497017 70 72 | 1523442 71 73 | 1549427 72 74 | 1654677 73 75 | 1663294 74 76 | 1668363 75 77 | 1673874 76 78 | 1702843 77 79 | 1717718 78 80 | 1739889 79 81 | 1772714 80 82 | 1817813 81 83 | 1829001 82 84 | 1874620 83 85 | 1875618 84 86 | 1889053 85 87 | 1895727 86 88 | 1912346 87 89 | 1912579 88 90 | 1985512 89 91 | 2072595 90 92 | 2102710 91 93 | 2113846 92 94 | 2126589 93 95 | 2131005 94 96 | 2131153 95 97 | 2138826 96 98 | 2174275 97 99 | 2203459 98 100 | 2213499 99 101 | 2219083 100 102 | 2262220 101 103 | 2264533 102 104 | 2277579 103 105 | 2305834 104 106 | 2326657 105 107 | 2332441 106 108 | 2348870 107 109 | 2359303 108 110 | 2405604 109 111 | 2424837 110 112 | 2490602 111 113 | 2514362 112 114 | 2516488 113 115 | 2527590 114 116 | 2542277 115 117 | 2576385 116 118 | 2583196 117 119 | 2587429 118 120 | 2596827 119 121 | 2597173 120 122 | 2601229 121 123 | 2634996 122 124 | 2663514 123 125 | 2715215 124 126 | 2728917 125 127 | 2733281 126 128 | 2766150 127 129 | 2802044 128 130 | 2827180 129 131 | 2847475 130 132 | 2854690 131 133 | 2871057 132 134 | 2903521 133 135 | 2925661 134 136 | 2927162 135 137 | 2929354 136 138 | 2996188 137 139 | 2998206 138 140 | 3014581 139 141 | 3082380 140 142 | 3110452 141 143 | 3116084 142 144 | 3141376 143 145 | 3152933 144 146 | 3181625 145 147 | 3204717 146 148 | 3233591 147 149 | 3244623 148 150 | 3253176 149 151 | 3262203 150 152 | 3305564 151 153 | 3311652 152 154 | 3320315 153 155 | 3325135 154 156 | 3370285 155 157 | 3412498 156 158 | 3417898 157 159 | 3433598 158 160 | 3544405 159 161 | 3578076 160 162 | 3595339 161 163 | 3650993 162 164 | 3685436 163 165 | 3750221 164 166 | 3798912 165 167 | 3861052 166 168 | 3873159 167 169 | 3896946 168 170 | 3903459 169 171 | 3926388 170 172 | 3999718 171 173 | 4026091 172 174 | 4030735 173 175 | 4067188 174 176 | 4069818 175 177 | 4153347 176 178 | 4184294 177 179 | 4234676 178 180 | 4252117 179 181 | 4266213 180 182 | 4309693 181 183 | 4338284 182 184 | 4372118 183 185 | 4384678 184 186 | 4400102 185 187 | 4453066 186 188 | 4462363 187 189 | 4482266 188 190 | 4491681 189 191 | 4523381 190 192 | 4555476 191 193 | 4566179 192 194 | 4569433 193 195 | 4604363 194 196 | 4606003 195 197 | 4616685 196 198 | 4630582 197 199 | 4646095 198 200 | 4647883 199 201 | 4664923 200 202 | 4665749 201 203 | 4671493 202 204 | 4673808 203 205 | 4689288 204 206 | 4777859 205 207 | 4799179 206 208 | 4801898 207 209 | 4811530 208 210 | 4817074 209 211 | 4817818 210 212 | 4858907 211 213 | 4863345 212 214 | 4890813 213 215 | 4898226 214 216 | 4910943 215 217 | 4935911 216 218 | 4936408 217 219 | 4936902 218 220 | 4940944 219 221 | 4971984 220 222 | 5027601 221 223 | 5035672 222 224 | 5039127 223 225 | 5046480 224 226 | 5111995 225 227 | 5138567 226 228 | 5147419 227 229 | 5151530 228 230 | 5205192 229 231 | 5217593 230 232 | 5241512 231 233 | 5242836 232 234 | 5284001 233 235 | 5286259 234 236 | 5315600 235 237 | 5345921 236 238 | 5386746 237 239 | 5395385 238 240 | 5442999 239 241 | 5475736 240 242 | 5510360 241 243 | 5514032 242 244 | 5553568 243 245 | 5569361 244 246 | 5580445 245 247 | 5589550 246 248 | 5615807 247 249 | 5620240 248 250 | 5700847 249 251 | 5731987 250 252 | 5751326 251 253 | 5766574 252 254 | 5787866 253 255 | 6013635 254 256 | 6037937 255 257 | 6041614 256 258 | 6046315 257 259 | 6056857 258 260 | 6063872 259 261 | 6118186 260 262 | 6131942 261 263 | 6179666 262 264 | 6189326 263 265 | 6215098 264 266 | 6279671 265 267 | 6287660 266 268 | 6309752 267 269 | 6460881 268 270 | 6489020 269 271 | 6504191 270 272 | 6509113 271 273 | 6561575 272 274 | 6610995 273 275 | 6613578 274 276 | 6643602 275 277 | 6649335 276 278 | 6658223 277 279 | 6728999 278 280 | 6730265 279 281 | 6791232 280 282 | 6881234 281 283 | 6946263 282 284 | 6990915 283 285 | 6996546 284 286 | 7035810 285 287 | 7052088 286 288 | 7132222 287 289 | 7144087 288 290 | 7154880 289 291 | 7162659 290 292 | 7170321 291 293 | 7211586 292 294 | 7228156 293 295 | 7249365 294 296 | 7253543 295 297 | 7254915 296 298 | 7292702 297 299 | 7308500 298 300 | 7333343 299 301 | 7339663 300 302 | 7366751 301 303 | 7375562 302 304 | 7430941 303 305 | 7497254 304 306 | 7549429 305 307 | 7551394 306 308 | 7595620 307 309 | 7609170 308 310 | 7650576 309 311 | 7666026 310 312 | 7674197 311 313 | 7680288 312 314 | 7680575 313 315 | 7707348 314 316 | 7725442 315 317 | 7744645 316 318 | 7827654 317 319 | 7868719 318 320 | 7892243 319 321 | 7946736 320 322 | 7969433 321 323 | 7973717 322 324 | 7985090 323 325 | 7996993 324 326 | 8016201 325 327 | 8020153 326 328 | 8023008 327 329 | 8048764 328 330 | 8057055 329 331 | 8062996 330 332 | 8108598 331 333 | 8112789 332 334 | 8116651 333 335 | 8123456 334 336 | 8125414 335 337 | 8129018 336 338 | 8174750 337 339 | 8175269 338 340 | 8196328 339 341 | 8228754 340 342 | 8273920 341 343 | 8350040 342 344 | 8403142 343 345 | 8437876 344 346 | 8526921 345 347 | 8540988 346 348 | 8548323 347 349 | 8553226 348 350 | 8562994 349 351 | 8569045 350 352 | 8581435 351 353 | 8606250 352 354 | 8726120 353 355 | 8730736 354 356 | 8749543 355 357 | 8761546 356 358 | 8767900 357 359 | 8777866 358 360 | 8787062 359 361 | 8792788 360 362 | 8818625 361 363 | 8841999 362 364 | 8874784 363 365 | 8892815 364 366 | 8906037 365 367 | 8945531 366 368 | 8980418 367 369 | 9021538 368 370 | 9022943 369 371 | 9041144 370 372 | 9076416 371 373 | 9102346 372 374 | 9107106 373 375 | 9149554 374 376 | 9169103 375 377 | 9169494 376 378 | 9170693 377 379 | 9216740 378 380 | 9282799 379 381 | 9283630 380 382 | 9289720 381 383 | 9310627 382 384 | 9324749 383 385 | 9396141 384 386 | 9436438 385 387 | 9512503 386 388 | 9560712 387 389 | 9562931 388 390 | 9600695 389 391 | 9628261 390 392 | 9629210 391 393 | 9632827 392 394 | 9645554 393 395 | 9719535 394 396 | 9747429 395 397 | 9750312 396 398 | 9782122 397 399 | 9784114 398 400 | 9811084 399 401 | 9811643 400 402 | 9818447 401 403 | 9820147 402 404 | 9847195 403 405 | 9852829 404 406 | 9870536 405 407 | 9878298 406 408 | 9883630 407 409 | 9884893 408 410 | 9902591 409 411 | 9903263 410 412 | 9965108 411 413 | 10002187 412 414 | 10005604 413 415 | 10026878 414 416 | 10058738 415 417 | 10058929 416 418 | 10064207 417 419 | 10103837 418 420 | 10104710 419 421 | 10112543 420 422 | 10137990 421 423 | 10146387 422 424 | 10164310 423 425 | 10183214 424 426 | 10206610 425 427 | 10216704 426 428 | 10230878 427 429 | 10281694 428 430 | 10288524 429 431 | 10385002 430 432 | 10395596 431 433 | 10409534 432 434 | 10439224 433 435 | 10459801 434 436 | 10488652 435 437 | 10493062 436 438 | 10522066 437 439 | 10549138 438 440 | 10601801 439 441 | 10637756 440 442 | 10665783 441 443 | 10677004 442 444 | 10705015 443 445 | 10728864 444 446 | 10731818 445 447 | 10745640 446 448 | 10750859 447 449 | 10800238 448 450 | 10889446 449 451 | 10973279 450 452 | 10982006 451 453 | 10996763 452 454 | 11051513 453 455 | 11068074 454 456 | 11077388 455 457 | 11108289 456 458 | 11123316 457 459 | 11136986 458 460 | 11161653 459 461 | 11175276 460 462 | 11176137 461 463 | 11214427 462 464 | 11260563 463 465 | 11324726 464 466 | 11363293 465 467 | 11371057 466 468 | 11410265 467 469 | 11418527 468 470 | 11421580 469 471 | 11502725 470 472 | 11554693 471 473 | 11561343 472 474 | 11561455 473 475 | 11642146 474 476 | 11647519 475 477 | 11647837 476 478 | 11657865 477 479 | 11667561 478 480 | 11685376 479 481 | 11693576 480 482 | 11723261 481 483 | 11739127 482 484 | 11743117 483 485 | 11763511 484 486 | 11786771 485 487 | 11787975 486 488 | 11789999 487 489 | 11875511 488 490 | 11885186 489 491 | 11957138 490 492 | 11966246 491 493 | 11972745 492 494 | 11998374 493 495 | 12020652 494 496 | 12060872 495 497 | 12086366 496 498 | 12116689 497 499 | 12158955 498 500 | 12203366 499 501 | 12207263 500 502 | 12245620 501 503 | 12252926 502 504 | 12291624 503 505 | 12310383 504 506 | 12371102 505 507 | 12378159 506 508 | 12380007 507 509 | 12385219 508 510 | 12404930 509 511 | 12415030 510 512 | 12434276 511 513 | 12434905 512 514 | 12485852 513 515 | 12493645 514 516 | 12506905 515 517 | 12536153 516 518 | 12572058 517 519 | 12577856 518 520 | 12682672 519 521 | 12689876 520 522 | 12748234 521 523 | 12763375 522 524 | 12780903 523 525 | 12816035 524 526 | 12832517 525 527 | 12871622 526 528 | 12906613 527 529 | 12910727 528 530 | 12958419 529 531 | 12965337 530 532 | 13022185 531 533 | 13026292 532 534 | 13061705 533 535 | 13103964 534 536 | 13142750 535 537 | 13164169 536 538 | 13165124 537 539 | 13171812 538 540 | 13235093 539 541 | 13250309 540 542 | 13252825 541 543 | 13268351 542 544 | 13289735 543 545 | 13316278 544 546 | 13413520 545 547 | 13457974 546 548 | 13458949 547 549 | 13522410 548 550 | 13536838 549 551 | 13545290 550 552 | 13552287 551 553 | 13561967 552 554 | 13575036 553 555 | 13581511 554 556 | 13625816 555 557 | 13627766 556 558 | 13663787 557 559 | 13678685 558 560 | 13730783 559 561 | 13743558 560 562 | 13759183 561 563 | 13792649 562 564 | 13808428 563 565 | 13837615 564 566 | 13877006 565 567 | 13893522 566 568 | 13916954 567 569 | 13917067 568 570 | 13954961 569 571 | 13959251 570 572 | 13974664 571 573 | 13978665 572 574 | 13982792 573 575 | 14003903 574 576 | 14009246 575 577 | 14015053 576 578 | 14017347 577 579 | 14036513 578 580 | 14064916 579 581 | 14073178 580 582 | 14107123 581 583 | 14107593 582 584 | 14151535 583 585 | 14160619 584 586 | 14181127 585 587 | 14189518 586 588 | 14190848 587 589 | 14198842 588 590 | 14208832 589 591 | 14223933 590 592 | 14269518 591 593 | 14312345 592 594 | 14330348 593 595 | 14336072 594 596 | 14339301 595 597 | 14343754 596 598 | 14353177 597 599 | 14355049 598 600 | 14356993 599 601 | 14371602 600 602 | 14384955 601 603 | 14414539 602 604 | 14486186 603 605 | 14543264 604 606 | 14609565 605 607 | 14641169 606 608 | 14665141 607 609 | 14667266 608 610 | 14702863 609 611 | 14708337 610 612 | 14712810 611 613 | 14734180 612 614 | 14758763 613 615 | 14780185 614 616 | 14792991 615 617 | 14817755 616 618 | 14832907 617 619 | 14837073 618 620 | 14862318 619 621 | 14866436 620 622 | 14956803 621 623 | 14974364 622 624 | 14995550 623 625 | 15067389 624 626 | 15067545 625 627 | 15082052 626 628 | 15090153 627 629 | 15093644 628 630 | 15102510 629 631 | 15102568 630 632 | 15139030 631 633 | 15187198 632 634 | 15234898 633 635 | 15250459 634 636 | 15251280 635 637 | 15251388 636 638 | 15260258 637 639 | 15278895 638 640 | 15333752 639 641 | 15389971 640 642 | 15394374 641 643 | 15425435 642 644 | 15441351 643 645 | 15451584 644 646 | 15455566 645 647 | 15480518 646 648 | 15481639 647 649 | 15618964 648 650 | 15654742 649 651 | 15703719 650 652 | 15705344 651 653 | 15817148 652 654 | 15891947 653 655 | 15894739 654 656 | 15908759 655 657 | 15918523 656 658 | 15936857 657 659 | 15946266 658 660 | 15985673 659 661 | 16093211 660 662 | 16101439 661 663 | 16132747 662 664 | 16158626 663 665 | 16175272 664 666 | 16179482 665 667 | 16247824 666 668 | 16252359 667 669 | 16259523 668 670 | 16272330 669 671 | 16278899 670 672 | 16325425 671 673 | 16327790 672 674 | 16328129 673 675 | 16350436 674 676 | 16352135 675 677 | 16381584 676 678 | 16387858 677 679 | 16392621 678 680 | 16393228 679 681 | 16395957 680 682 | 16396044 681 683 | 16416396 682 684 | 16442592 683 685 | 16589510 684 686 | 16600464 685 687 | 16620753 686 688 | 16669188 687 689 | 16705451 688 690 | 16710611 689 691 | 16730133 690 692 | 16740108 691 693 | -------------------------------------------------------------------------------- /localization/coarse/evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | '''================================================= 3 | @Project -> File lbr -> evaluate 4 | @IDE PyCharm 5 | @Author fx221@cam.ac.uk 6 | @Date 06/04/2022 11:56 7 | ==================================================''' 8 | import numpy as np 9 | 10 | 11 | def ap_k(query_labels, k=1): 12 | ap = 0.0 13 | positives = 0 14 | for idx, i in enumerate(query_labels[:k]): 15 | if i == 1: 16 | positives += 1 17 | ap_at_count = positives / (idx + 1) 18 | ap += ap_at_count 19 | return ap / k 20 | 21 | 22 | def recall_k(query_labels, k=1): 23 | return np.sum(query_labels) / k 24 | 25 | 26 | def eval_k(preds, gts, k=1): 27 | output = {} 28 | mean_ap = [] 29 | mean_recall = [] 30 | for query in preds.keys(): 31 | pred_cands = preds[query] 32 | gt_cands = gts[query] 33 | pred = [] 34 | 35 | for c in pred_cands[:k]: 36 | if c in gt_cands: 37 | pred.append(1) 38 | else: 39 | pred.append(0) 40 | ap = ap_k(query_labels=pred, k=k) 41 | recall = recall_k(query_labels=pred, k=len(gt_cands)) 42 | mean_ap.append(ap) 43 | mean_recall.append(recall) 44 | 45 | # print("{:s} topk: {:d} ap: {:.4f} recall: {:.4f}".format(query, k, ap, recall)) 46 | 47 | output[query] = (ap, recall) 48 | 49 | return np.mean(mean_ap), np.mean(mean_recall), output 50 | 51 | 52 | def evaluate_retrieval(preds, gts, ks=[1, 10, 20, 50]): 53 | output = {} 54 | for k in ks: 55 | mean_ap, mean_recall, _ = eval_k(preds=preds, gts=gts, k=k) 56 | output[k] = { 57 | 'accuracy': mean_ap, 58 | 'recall': mean_recall, 59 | } 60 | return output 61 | 62 | 63 | def evaluate_retrieval_by_query(preds, gts, ks=[1, 10, 20, 50]): 64 | output = {} 65 | for k in ks: 66 | output[k] = 0 67 | 68 | failed_cases = [] 69 | for q in preds.keys(): 70 | gt_cans = gts[q] 71 | for k in ks: 72 | pred_cans = preds[q][:k] 73 | overlap = [v for v in pred_cans if v in gt_cans] 74 | if len(overlap) >= 1: 75 | output[k] += 1 76 | 77 | if k == 50 and len(overlap) == 0: 78 | failed_cases.append(q) 79 | 80 | for k in ks: 81 | output[k] = output[k] / len(preds.keys()) 82 | 83 | output['failed_case'] = failed_cases 84 | return output 85 | -------------------------------------------------------------------------------- /localization/fine/extractor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | '''================================================= 3 | @Project -> File shloc -> extractor 4 | @IDE PyCharm 5 | @Author fx221@cam.ac.uk 6 | @Date 26/07/2021 16:49 7 | ==================================================''' 8 | import torch 9 | import torch.utils.data as Data 10 | 11 | import argparse 12 | import os 13 | import os.path as osp 14 | import h5py 15 | from tqdm import tqdm 16 | from types import SimpleNamespace 17 | import logging 18 | from pathlib import Path 19 | import cv2 20 | import numpy as np 21 | import pprint 22 | 23 | from localization.fine.features.extract_spp import extract_spp_return 24 | from localization.fine.features.extract_r2d2 import extract_r2d2_return, load_network 25 | from net.locnets.superpoint import SuperPointNet 26 | from net.locnets.resnet import ResNetXN, extract_resnet_return 27 | 28 | confs = { 29 | 'resnetxn-triv2-0001-n4096-r1600-mask': { 30 | 'output': 'feats-resnetxn-triv2-0001-n4096-r1600-mask', 31 | 'model': { 32 | 'name': 'resnetxn', 33 | 'max_keypoints': 4096, 34 | 'conf_th': 0.001, 35 | 'multiscale': True, 36 | 'scales': [1.0], 37 | 'model_fn': osp.join(os.getcwd(), "models/resnetxn_triv2_d128_38.pth"), 38 | }, 39 | 'preprocessing': { 40 | 'grayscale': False, 41 | 'resize_max': 1600, 42 | }, 43 | 'mask': True, 44 | }, 45 | 46 | 'resnetxn-triv2-0001-n3000-r1600-mask': { 47 | 'output': 'feats-resnetxn-triv2-0001-n3000-r1600-mask', 48 | 'model': { 49 | 'name': 'resnetxn', 50 | 'max_keypoints': 3000, 51 | 'conf_th': 0.001, 52 | 'multiscale': True, 53 | 'scales': [1.0], 54 | 'model_fn': osp.join(os.getcwd(), "models/resnetxn_triv2_d128_38.pth"), 55 | }, 56 | 'preprocessing': { 57 | 'grayscale': False, 58 | 'resize_max': 1600, 59 | }, 60 | 'mask': True, 61 | }, 62 | 'resnetxn-triv2-0001-n2000-r1600-mask': { 63 | 'output': 'feats-resnetxn-triv2-0001-n2000-r1600-mask', 64 | 'model': { 65 | 'name': 'resnetxn', 66 | 'max_keypoints': 2000, 67 | 'conf_th': 0.001, 68 | 'multiscale': True, 69 | 'scales': [1.0], 70 | 'model_fn': osp.join(os.getcwd(), "models/resnetxn_triv2_d128_38.pth"), 71 | }, 72 | 'preprocessing': { 73 | 'grayscale': False, 74 | 'resize_max': 1600, 75 | }, 76 | 'mask': True, 77 | }, 78 | 79 | 'resnetxn-triv2-0001-n1000-r1600-mask': { 80 | 'output': 'feats-resnetxn-triv2-0001-n1000-r1600-mask', 81 | 'model': { 82 | 'name': 'resnetxn', 83 | 'max_keypoints': 1000, 84 | 'conf_th': 0.001, 85 | 'multiscale': True, 86 | 'scales': [1.0], 87 | 'model_fn': osp.join(os.getcwd(), "models/resnetxn_triv2_d128_38.pth"), 88 | }, 89 | 'preprocessing': { 90 | 'grayscale': False, 91 | 'resize_max': 1600, 92 | }, 93 | 'mask': True, 94 | }, 95 | 96 | 'resnetxn-triv2-0001-n4096-r1024-mask': { 97 | 'output': 'feats-resnetxn-triv2-0001-n4096-r1024-mask', 98 | 'model': { 99 | 'name': 'resnetxn', 100 | 'max_keypoints': 4096, 101 | 'conf_th': 0.001, 102 | 'multiscale': True, 103 | 'scales': [1.0], 104 | 'model_fn': osp.join(os.getcwd(), "models/resnetxn_triv2_d128_38.pth"), 105 | }, 106 | 'preprocessing': { 107 | 'grayscale': False, 108 | 'resize_max': 1024, 109 | }, 110 | 'mask': True, 111 | }, 112 | 113 | 'resnetxn-triv2-ms-0001-n4096-r1024-mask': { 114 | 'output': 'feats-resnetxn-triv2-ms-0001-n4096-r1024-mask', 115 | 'model': { 116 | 'name': 'resnetxn', 117 | 'max_keypoints': 4096, 118 | 'conf_th': 0.001, 119 | 'multiscale': True, 120 | 'scales': [1.2, 1.0, 0.8], 121 | 'model_fn': osp.join(os.getcwd(), "models/resnetxn_triv2_d128_38.pth"), 122 | }, 123 | 'preprocessing': { 124 | 'grayscale': False, 125 | 'resize_max': 1024, 126 | }, 127 | 'mask': True, 128 | }, 129 | 130 | } 131 | 132 | confs_matcher = { 133 | 'superglue': { 134 | 'output': 'matches-superglue', 135 | 'model': { 136 | 'name': 'superglue', 137 | 'weights': 'outdoor', 138 | 'sinkhorn_iterations': 20, 139 | }, 140 | }, 141 | 'NNM': { 142 | 'output': 'NNM', 143 | 'model': { 144 | 'name': 'nnm', 145 | 'do_mutual_check': True, 146 | 'distance_threshold': None, 147 | }, 148 | }, 149 | 'NNML': { 150 | 'output': 'NNML', 151 | 'model': { 152 | 'name': 'nnml', 153 | 'do_mutual_check': True, 154 | 'distance_threshold': None, 155 | }, 156 | }, 157 | 158 | 'ONN': { 159 | 'output': 'ONN', 160 | 'model': { 161 | 'name': 'nn', 162 | 'do_mutual_check': False, 163 | 'distance_threshold': None, 164 | }, 165 | }, 166 | 'NNR': { 167 | 'output': 'NNR', 168 | 'model': { 169 | 'name': 'nnr', 170 | 'do_mutual_check': True, 171 | 'distance_threshold': 0.9, 172 | }, 173 | } 174 | } 175 | 176 | 177 | class ImageDataset(Data.Dataset): 178 | default_conf = { 179 | 'globs': ['*.jpg', '*.png', '*.jpeg', '*.JPG', '*.PNG'], 180 | 'grayscale': False, 181 | 'resize_max': None, 182 | 'resize_force': False, 183 | } 184 | 185 | def __init__(self, root, conf, image_list=None, 186 | mask_root=None): 187 | self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf}) 188 | self.root = root 189 | 190 | self.paths = [] 191 | if image_list is None: 192 | for g in conf.globs: 193 | self.paths += list(Path(root).glob('**/' + g)) 194 | if len(self.paths) == 0: 195 | raise ValueError(f'Could not find any image in root: {root}.') 196 | self.paths = [i.relative_to(root) for i in self.paths] 197 | else: 198 | with open(image_list, "r") as f: 199 | lines = f.readlines() 200 | for l in lines: 201 | l = l.strip() 202 | self.paths.append(Path(l)) 203 | 204 | logging.info(f'Found {len(self.paths)} images in root {root}.') 205 | 206 | if mask_root is not None: 207 | self.mask_root = mask_root 208 | else: 209 | self.mask_root = None 210 | 211 | print("mask_root: ", self.mask_root) 212 | 213 | def __getitem__(self, idx): 214 | path = self.paths[idx] 215 | if self.conf.grayscale: 216 | mode = cv2.IMREAD_GRAYSCALE 217 | else: 218 | mode = cv2.IMREAD_COLOR 219 | image = cv2.imread(str(self.root / path), mode) 220 | if not self.conf.grayscale: 221 | image = image[:, :, ::-1] # BGR to RGB 222 | if image is None: 223 | raise ValueError(f'Cannot read image {str(path)}.') 224 | image = image.astype(np.float32) 225 | size = image.shape[:2][::-1] 226 | w, h = size 227 | 228 | if self.conf.resize_max and (self.conf.resize_force 229 | or max(w, h) > self.conf.resize_max): 230 | scale = self.conf.resize_max / max(h, w) 231 | h_new, w_new = int(round(h * scale)), int(round(w * scale)) 232 | image = cv2.resize( 233 | image, (w_new, h_new), interpolation=cv2.INTER_LINEAR) 234 | 235 | if self.conf.grayscale: 236 | image = image[None] 237 | else: 238 | image = image.transpose((2, 0, 1)) # HxWxC to CxHxW 239 | image = image / 255. 240 | 241 | data = { 242 | 'name': str(path), 243 | 'image': image, 244 | 'original_size': np.array(size), 245 | } 246 | 247 | if self.mask_root is not None: 248 | mask_path = Path(str(path).replace("jpg", "png")) 249 | if osp.exists(mask_path): 250 | mask = cv2.imread(str(self.mask_root / mask_path)) 251 | mask = cv2.resize(mask, dsize=(image.shape[2], image.shape[1]), interpolation=cv2.INTER_NEAREST) 252 | else: 253 | mask = np.zeros(shape=(image.shape[1], image.shape[2], 3), dtype=np.uint8) 254 | 255 | data['mask'] = mask 256 | 257 | return data 258 | 259 | def __len__(self): 260 | return len(self.paths) 261 | 262 | 263 | def get_model(model_name, weight_path): 264 | if model_name == "superpoint": 265 | model = SuperPointNet().eval() 266 | model.load_state_dict(torch.load(weight_path)) 267 | extractor = extract_spp_return 268 | elif model_name == "r2d2": 269 | model = load_network(model_fn=weight_path).eval() 270 | extractor = extract_r2d2_return 271 | elif model_name == 'resnetxn-ori': 272 | model = ResNetXN(encoder_depth=2, outdim=128).eval() 273 | model.load_state_dict(torch.load(weight_path)['state_dict'], strict=True) 274 | extractor = extract_resnet_return 275 | elif model_name == 'resnetxn': 276 | model = ResNetXN(encoder_depth=2, outdim=128).eval() 277 | model.load_state_dict(torch.load(weight_path)['model'], strict=True) 278 | extractor = extract_resnet_return 279 | 280 | return model, extractor 281 | 282 | 283 | @torch.no_grad() 284 | def main(conf, image_dir, export_dir, mask_dir=None, tag=None): 285 | logging.info('Extracting local features with configuration:' 286 | f'\n{pprint.pformat(conf)}') 287 | 288 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 289 | # Model = dynamic_load(extractors, conf['model']['name']) 290 | # model = Model(conf['model']).eval().to(device) 291 | model, extractor = get_model(model_name=conf['model']['name'], weight_path=conf["model"]["model_fn"]) 292 | model = model.cuda() 293 | print("model: ", model) 294 | 295 | loader = ImageDataset(image_dir, conf['preprocessing'], 296 | image_list=args.image_list, 297 | mask_root=mask_dir) 298 | loader = torch.utils.data.DataLoader(loader, num_workers=4) 299 | 300 | feature_path = Path(export_dir, conf['output'] + '.h5') 301 | feature_path.parent.mkdir(exist_ok=True, parents=True) 302 | feature_file = h5py.File(str(feature_path), 'a') 303 | 304 | with tqdm(total=len(loader)) as t: 305 | for idx, data in enumerate(loader): 306 | t.update() 307 | if tag is not None: 308 | if data['name'][0].find(tag) < 0: 309 | continue 310 | pred = extractor(model, img=data["image"], 311 | topK=conf["model"]["max_keypoints"], 312 | mask=data["mask"][0].numpy().astype(np.uint8) if "mask" in data.keys() else None, 313 | conf_th=conf["model"]["conf_th"], 314 | scales=conf["model"]["scales"], 315 | ) 316 | 317 | # pred = {k: v[0].cpu().numpy() for k, v in pred.items()} 318 | pred['descriptors'] = pred['descriptors'].transpose() 319 | 320 | t.set_postfix(npoints=pred['keypoints'].shape[0]) 321 | # print(pred['keypoints'].shape) 322 | 323 | pred['image_size'] = original_size = data['original_size'][0].numpy() 324 | # pred['descriptors'] = pred['descriptors'].T 325 | if 'keypoints' in pred.keys(): 326 | size = np.array(data['image'].shape[-2:][::-1]) 327 | scales = (original_size / size).astype(np.float32) 328 | pred['keypoints'] = (pred['keypoints'] + .5) * scales[None] - .5 329 | 330 | grp = feature_file.create_group(data['name'][0]) 331 | for k, v in pred.items(): 332 | # print(k, v.shape) 333 | grp.create_dataset(k, data=v) 334 | 335 | del pred 336 | 337 | feature_file.close() 338 | logging.info('Finished exporting features.') 339 | 340 | return feature_path 341 | 342 | 343 | if __name__ == '__main__': 344 | parser = argparse.ArgumentParser() 345 | parser.add_argument('--image_dir', type=Path, required=True) 346 | parser.add_argument('--image_list', type=str, default=None) 347 | parser.add_argument('--tag', type=str, default=None) 348 | parser.add_argument('--mask_dir', type=Path, default=None) 349 | parser.add_argument('--export_dir', type=Path, required=True) 350 | parser.add_argument('--conf', type=str, default='superpoint_aachen', 351 | choices=list(confs.keys())) 352 | args = parser.parse_args() 353 | main(confs[args.conf], args.image_dir, args.export_dir, 354 | mask_dir=args.mask_dir if confs[args.conf]["mask"] else None, tag=args.tag) 355 | -------------------------------------------------------------------------------- /localization/fine/features/extract_d2net.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2021/2/22 下午2:23 4 | @Auth : Fei Xue 5 | @File : extract_d2net.py 6 | @Email: fx221@cam.ac.uk 7 | """ 8 | 9 | import imageio 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | 15 | class EmptyTensorError(Exception): 16 | pass 17 | 18 | 19 | class NoGradientError(Exception): 20 | pass 21 | 22 | 23 | def interpolate_dense_features(pos, dense_features, return_corners=False): 24 | device = pos.device 25 | 26 | ids = torch.arange(0, pos.size(1), device=device) 27 | 28 | _, h, w = dense_features.size() 29 | 30 | i = pos[0, :] 31 | j = pos[1, :] 32 | 33 | # Valid corners 34 | i_top_left = torch.floor(i).long() 35 | j_top_left = torch.floor(j).long() 36 | valid_top_left = torch.min(i_top_left >= 0, j_top_left >= 0) 37 | 38 | i_top_right = torch.floor(i).long() 39 | j_top_right = torch.ceil(j).long() 40 | valid_top_right = torch.min(i_top_right >= 0, j_top_right < w) 41 | 42 | i_bottom_left = torch.ceil(i).long() 43 | j_bottom_left = torch.floor(j).long() 44 | valid_bottom_left = torch.min(i_bottom_left < h, j_bottom_left >= 0) 45 | 46 | i_bottom_right = torch.ceil(i).long() 47 | j_bottom_right = torch.ceil(j).long() 48 | valid_bottom_right = torch.min(i_bottom_right < h, j_bottom_right < w) 49 | 50 | valid_corners = torch.min( 51 | torch.min(valid_top_left, valid_top_right), 52 | torch.min(valid_bottom_left, valid_bottom_right) 53 | ) 54 | 55 | i_top_left = i_top_left[valid_corners] 56 | j_top_left = j_top_left[valid_corners] 57 | 58 | i_top_right = i_top_right[valid_corners] 59 | j_top_right = j_top_right[valid_corners] 60 | 61 | i_bottom_left = i_bottom_left[valid_corners] 62 | j_bottom_left = j_bottom_left[valid_corners] 63 | 64 | i_bottom_right = i_bottom_right[valid_corners] 65 | j_bottom_right = j_bottom_right[valid_corners] 66 | 67 | ids = ids[valid_corners] 68 | if ids.size(0) == 0: 69 | raise EmptyTensorError 70 | 71 | # Interpolation 72 | i = i[ids] 73 | j = j[ids] 74 | dist_i_top_left = i - i_top_left.float() 75 | dist_j_top_left = j - j_top_left.float() 76 | w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left) 77 | w_top_right = (1 - dist_i_top_left) * dist_j_top_left 78 | w_bottom_left = dist_i_top_left * (1 - dist_j_top_left) 79 | w_bottom_right = dist_i_top_left * dist_j_top_left 80 | 81 | descriptors = ( 82 | w_top_left * dense_features[:, i_top_left, j_top_left] + 83 | w_top_right * dense_features[:, i_top_right, j_top_right] + 84 | w_bottom_left * dense_features[:, i_bottom_left, j_bottom_left] + 85 | w_bottom_right * dense_features[:, i_bottom_right, j_bottom_right] 86 | ) 87 | 88 | pos = torch.cat([i.view(1, -1), j.view(1, -1)], dim=0) 89 | 90 | if not return_corners: 91 | return [descriptors, pos, ids] 92 | else: 93 | corners = torch.stack([ 94 | torch.stack([i_top_left, j_top_left], dim=0), 95 | torch.stack([i_top_right, j_top_right], dim=0), 96 | torch.stack([i_bottom_left, j_bottom_left], dim=0), 97 | torch.stack([i_bottom_right, j_bottom_right], dim=0) 98 | ], dim=0) 99 | return [descriptors, pos, ids, corners] 100 | 101 | 102 | def upscale_positions(pos, scaling_steps=0): 103 | for _ in range(scaling_steps): 104 | pos = pos * 2 + 0.5 105 | return pos 106 | 107 | 108 | def process_multiscale(image, model, scales=[.5, 1, 2]): 109 | b, _, h_init, w_init = image.size() 110 | device = image.device 111 | assert (b == 1) 112 | 113 | all_keypoints = torch.zeros([3, 0]) 114 | all_descriptors = torch.zeros([ 115 | model.dense_feature_extraction.num_channels, 0 116 | ]) 117 | all_scores = torch.zeros(0) 118 | 119 | previous_dense_features = None 120 | banned = None 121 | for idx, scale in enumerate(scales): 122 | current_image = F.interpolate( 123 | image, scale_factor=scale, 124 | mode='bilinear', align_corners=True 125 | ) 126 | _, _, h_level, w_level = current_image.size() 127 | 128 | dense_features = model.dense_feature_extraction(current_image) 129 | del current_image 130 | 131 | _, _, h, w = dense_features.size() 132 | 133 | # Sum the feature maps. 134 | if previous_dense_features is not None: 135 | dense_features += F.interpolate( 136 | previous_dense_features, size=[h, w], 137 | mode='bilinear', align_corners=True 138 | ) 139 | del previous_dense_features 140 | 141 | # Recover detections. 142 | detections = model.detection(dense_features) 143 | if banned is not None: 144 | banned = F.interpolate(banned.float(), size=[h, w]).bool() 145 | detections = torch.min(detections, ~banned) 146 | banned = torch.max( 147 | torch.max(detections, dim=1)[0].unsqueeze(1), banned 148 | ) 149 | else: 150 | banned = torch.max(detections, dim=1)[0].unsqueeze(1) 151 | fmap_pos = torch.nonzero(detections[0].cpu()).t() 152 | del detections 153 | 154 | # Recover displacements. 155 | displacements = model.localization(dense_features)[0].cpu() 156 | displacements_i = displacements[ 157 | 0, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :] 158 | ] 159 | displacements_j = displacements[ 160 | 1, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :] 161 | ] 162 | del displacements 163 | 164 | mask = torch.min( 165 | torch.abs(displacements_i) < 0.5, 166 | torch.abs(displacements_j) < 0.5 167 | ) 168 | fmap_pos = fmap_pos[:, mask] 169 | valid_displacements = torch.stack([ 170 | displacements_i[mask], 171 | displacements_j[mask] 172 | ], dim=0) 173 | del mask, displacements_i, displacements_j 174 | 175 | fmap_keypoints = fmap_pos[1:, :].float() + valid_displacements 176 | del valid_displacements 177 | 178 | try: 179 | raw_descriptors, _, ids = interpolate_dense_features( 180 | fmap_keypoints.to(device), 181 | dense_features[0] 182 | ) 183 | except EmptyTensorError: 184 | continue 185 | fmap_pos = fmap_pos[:, ids] 186 | fmap_keypoints = fmap_keypoints[:, ids] 187 | del ids 188 | 189 | keypoints = upscale_positions(fmap_keypoints, scaling_steps=2) 190 | del fmap_keypoints 191 | 192 | descriptors = F.normalize(raw_descriptors, dim=0).cpu() 193 | del raw_descriptors 194 | 195 | keypoints[0, :] *= h_init / h_level 196 | keypoints[1, :] *= w_init / w_level 197 | 198 | fmap_pos = fmap_pos.cpu() 199 | keypoints = keypoints.cpu() 200 | 201 | keypoints = torch.cat([ 202 | keypoints, 203 | torch.ones([1, keypoints.size(1)]) * 1 / scale, 204 | ], dim=0) 205 | 206 | scores = dense_features[ 207 | 0, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :] 208 | ].cpu() / (idx + 1) 209 | del fmap_pos 210 | 211 | all_keypoints = torch.cat([all_keypoints, keypoints], dim=1) 212 | all_descriptors = torch.cat([all_descriptors, descriptors], dim=1) 213 | all_scores = torch.cat([all_scores, scores], dim=0) 214 | del keypoints, descriptors 215 | 216 | previous_dense_features = dense_features 217 | del dense_features 218 | del previous_dense_features, banned 219 | 220 | keypoints = all_keypoints.t().numpy() 221 | del all_keypoints 222 | scores = all_scores.numpy() 223 | del all_scores 224 | descriptors = all_descriptors.t().numpy() 225 | del all_descriptors 226 | return keypoints, scores, descriptors 227 | 228 | 229 | def preprocess_image(image, preprocessing='caffe'): 230 | image = image.astype(np.float32) 231 | image = np.transpose(image, [2, 0, 1]) 232 | if preprocessing is None: 233 | pass 234 | elif preprocessing == 'caffe': 235 | # RGB -> BGR 236 | image = image[:: -1, :, :] 237 | # Zero-center by mean pixel 238 | mean = np.array([103.939, 116.779, 123.68]) 239 | image = image - mean.reshape([3, 1, 1]) 240 | elif preprocessing == 'torch': 241 | image /= 255.0 242 | mean = np.array([0.485, 0.456, 0.406]) 243 | std = np.array([0.229, 0.224, 0.225]) 244 | image = (image - mean.reshape([3, 1, 1])) / std.reshape([3, 1, 1]) 245 | else: 246 | raise ValueError('Unknown preprocessing parameter.') 247 | return image 248 | 249 | 250 | def extract_d2net_return(d2net, img_path, multi_scale=False, need_nms=False): 251 | image = imageio.imread(img_path) 252 | if len(image.shape) == 2: 253 | image = image[:, :, np.newaxis] 254 | image = np.repeat(image, 3, -1) 255 | 256 | # TODO: switch to PIL.Image due to deprecation of scipy.misc.imresize. 257 | resized_image = image 258 | 259 | fact_i = image.shape[0] / resized_image.shape[0] 260 | fact_j = image.shape[1] / resized_image.shape[1] 261 | 262 | input_image = preprocess_image( 263 | resized_image, 264 | preprocessing='caffe', 265 | ) 266 | with torch.no_grad(): 267 | if multi_scale: 268 | keypoints, scores, descriptors = process_multiscale( 269 | torch.tensor( 270 | input_image[np.newaxis, :, :, :].astype(np.float32), 271 | device=torch.device("cuda:0"), 272 | ), 273 | model=d2net, 274 | ) 275 | else: 276 | keypoints, scores, descriptors = process_multiscale( 277 | torch.tensor( 278 | input_image[np.newaxis, :, :, :].astype(np.float32), 279 | device=torch.device("cuda:0"), 280 | ), 281 | d2net, 282 | scales=[1] 283 | ) 284 | 285 | # Input image coordinates 286 | keypoints[:, 0] *= fact_i 287 | keypoints[:, 1] *= fact_j 288 | # i, j -> u, v 289 | keypoints = keypoints[:, [1, 0, 2]] 290 | 291 | return keypoints, descriptors, scores 292 | -------------------------------------------------------------------------------- /localization/fine/features/extract_sift.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2021/2/22 上午10:10 4 | @Auth : Fei Xue 5 | @File : extract_sift.py 6 | @Email: fx221@cam.ac.uk 7 | """ 8 | 9 | import argparse 10 | import os 11 | import os.path as osp 12 | from tqdm import tqdm 13 | import cv2 14 | import numpy as np 15 | 16 | 17 | def plot_keypoint(img_path, pts, scores=None): 18 | img = cv2.imread(img_path) 19 | img_out = img.copy() 20 | r = 3 21 | if scores is None: 22 | for i in range(pts.shape[0]): 23 | pt = pts[i] 24 | # img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1])), int(r * s), (0, 0, 255), 4) 25 | img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1])), r, (0, 0, 255), 4) 26 | else: 27 | scores_norm = scores / np.linalg.norm(scores, ord=2) 28 | print("score median: ", np.median(scores_norm)) 29 | for i in range(pts.shape[0]): 30 | pt = pts[i] 31 | s = scores_norm[i] 32 | if s < np.median(scores_norm): 33 | continue 34 | # img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1])), int(r * s), (0, 0, 255), 4) 35 | img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1])), r, (0, 0, 255), 4) 36 | 37 | cv2.imshow("img", img_out) 38 | cv2.waitKey(0) 39 | 40 | 41 | def extract_sift_return(sift, img_path, need_nms=False): 42 | img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) 43 | # sift = cv2.xfeatures2d.SIFT_create(nOctaveLayers=1, contrastThreshold=0.03, edgeThreshold=8) 44 | # kpts = sift.detect(img) 45 | # descs = np.zeros((10000, 128), np.float) 46 | kpts, descs = sift.detectAndCompute(img, None) 47 | 48 | scores = np.array([kp.response for kp in kpts], np.float32) 49 | kpts = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts]) 50 | 51 | return kpts, descs, scores 52 | 53 | 54 | if __name__ == '__main__': 55 | # Argument parsing 56 | parser = argparse.ArgumentParser(description='super point extraction on haptches') 57 | parser.add_argument('--hpatches_path', type=str, required=True, 58 | help='path to a file containing a list of images to process') 59 | parser.add_argument("--output_path", type=str, required=True, 60 | help='path to save descriptors') 61 | parser.add_argument("--image_list_file", type=str, default="img_list_hpatches_list.txt", 62 | help='path to save descriptors') 63 | 64 | args = parser.parse_args() 65 | hpatches_path = args.hpatches_path 66 | output_dir = args.output_path 67 | os.makedirs(output_dir, exist_ok=True) 68 | 69 | with open(args.image_list_file, 'r') as f: 70 | lines = f.readlines() 71 | 72 | sift = cv2.xfeatures2d.SIFT_create(4000) 73 | for line in tqdm(lines, total=len(lines)): 74 | path = line.strip() 75 | img_path = osp.join(args.hpatches_path, path) 76 | print("img_path: ", img_path) 77 | 78 | img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) 79 | kpts = sift.detect(img) 80 | kpts = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts]) 81 | print(kpts.shape) 82 | plot_keypoint(img_path=img_path, pts=kpts) 83 | -------------------------------------------------------------------------------- /localization/fine/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | '''================================================= 3 | @Project -> File shloc -> test 4 | @IDE PyCharm 5 | @Author fx221@cam.ac.uk 6 | @Date 03/08/2021 12:12 7 | ==================================================''' 8 | import os 9 | import os.path as osp 10 | import numpy as np 11 | import torch 12 | import cv2 13 | from localization.fine.extractor import get_model 14 | from localization.fine.matcher import Matcher 15 | from localization.tools import plot_keypoint, plot_matches 16 | from tools.common import resize_img 17 | 18 | match_confs = { 19 | 'superglue': { 20 | 'output': 'matches-superglue', 21 | 'model': { 22 | 'name': 'superglue', 23 | 'weights': 'outdoor', 24 | 'sinkhorn_iterations': 50, 25 | }, 26 | }, 27 | 'NNM': { 28 | 'output': 'NNM', 29 | 'model': { 30 | 'name': 'nnm', 31 | 'do_mutual_check': True, 32 | 'distance_threshold': None, 33 | }, 34 | }, 35 | 'NNML': { 36 | 'output': 'NNML', 37 | 'model': { 38 | 'name': 'nnml', 39 | 'do_mutual_check': True, 40 | 'distance_threshold': None, 41 | }, 42 | }, 43 | 44 | 'ONN': { 45 | 'output': 'ONN', 46 | 'model': { 47 | 'name': 'nn', 48 | 'do_mutual_check': False, 49 | 'distance_threshold': None, 50 | }, 51 | }, 52 | 'NNR': { 53 | 'output': 'NNR', 54 | 'model': { 55 | 'name': 'nnr', 56 | 'do_mutual_check': True, 57 | 'distance_threshold': 0.9, 58 | }, 59 | } 60 | } 61 | 62 | extract_confs = { 63 | # 'superpoint_aachen': { 64 | 'superpoint-n2000-r1024-mask': { 65 | 'output': 'feats-superpoint-n2000-r1024-mask', 66 | 'model': { 67 | 'name': 'superpoint', 68 | 'nms_radius': 4, 69 | 'max_keypoints': 2000, 70 | 'model_fn': osp.join(os.getcwd(), "models/superpoint_v1.pth"), 71 | }, 72 | 'preprocessing': { 73 | 'grayscale': True, 74 | 'resize_max': 1024, 75 | }, 76 | }, 77 | 78 | 'superpoint-n4096-r1024-mask': { 79 | 'output': 'feats-superpoint-n4096-r1024-mask', 80 | 'model': { 81 | 'name': 'superpoint', 82 | 'nms_radius': 4, 83 | 'max_keypoints': 4096, 84 | 'model_fn': osp.join(os.getcwd(), "models/superpoint_v1.pth"), 85 | }, 86 | 'preprocessing': { 87 | 'grayscale': True, 88 | 'resize_max': 1024, 89 | }, 90 | 'mask': True, 91 | }, 92 | 93 | 'r2d2-n2000-r1024-mask': { 94 | 'output': 'feats-r2d2-n2000-r1024-mask', 95 | 'model': { 96 | 'name': 'r2d2', 97 | 'nms_radius': 4, 98 | 'max_keypoints': 2000, 99 | 'model_fn': osp.join(os.getcwd(), "models/r2d2_WASF_N16.pt"), 100 | }, 101 | 'preprocessing': { 102 | 'grayscale': False, 103 | 'resize_max': 1024, 104 | }, 105 | 'mask': True, 106 | }, 107 | 108 | 'r2d2-n4096-r1024-mask': { 109 | 'output': 'feats-r2d2-n2000-r1024-mask', 110 | 'model': { 111 | 'name': 'r2d2', 112 | 'nms_radius': 4, 113 | 'max_keypoints': 4096, 114 | 'model_fn': osp.join(os.getcwd(), "models/r2d2_WASF_N16.pt"), 115 | }, 116 | 'preprocessing': { 117 | 'grayscale': False, 118 | 'resize_max': 1024, 119 | }, 120 | 'mask': True, 121 | }, 122 | } 123 | 124 | 125 | def test(use_mask=True): 126 | img_dir = "/home/mifs/fx221/fx221/localization/aachen_v1_1/images/images_upright" 127 | mask_dir = "/home/mifs/fx221/fx221/localization/aachen_v1_1/global_seg_instance" 128 | # pair_fn = "datasets/aachen/pairs-db-covis20.txt" 129 | # pair_fn = '/home/mifs/fx221/fx221/exp/shloc/aachen/2021_08_05_23_29_59_aachen_pspf_resnext101_32x4d_d4_u8_ce_b8_seg_cls_aug_stylized/loc_by_seg/loc_by_sec2_top30.txt' 130 | pair_fn = '/home/mifs/fx221/fx221/exp/shloc/aachen/2021_08_05_23_29_59_aachen_pspf_resnext101_32x4d_d4_u8_ce_b8_seg_cls_aug_stylized/loc_by_seg/loc_by_sec_top30_fail_list.txt' 131 | all_pairs = [] 132 | with open(pair_fn, "r") as f: 133 | lines = f.readlines() 134 | for l in lines: 135 | l = l.strip().split(" ") 136 | all_pairs.append((l[0], l[1])) 137 | 138 | if use_mask: 139 | m_conf = match_confs["NNML"] 140 | else: 141 | m_conf = match_confs["NNM"] 142 | 143 | # e_conf = extract_confs["superpoint-n4096-r1024-mask"] 144 | e_conf = extract_confs["r2d2-n4096-r1024-mask"] 145 | matcher = Matcher(conf=m_conf) 146 | matcher = matcher.eval().cuda() 147 | 148 | model, extractor = get_model(model_name=e_conf['model']['name'], weight_path=e_conf["model"]["model_fn"]) 149 | model = model.cuda() 150 | 151 | # matcher = Matcher[] 152 | 153 | cv2.namedWindow("pt0", cv2.WINDOW_NORMAL) 154 | cv2.namedWindow("pt1", cv2.WINDOW_NORMAL) 155 | cv2.namedWindow("match", cv2.WINDOW_NORMAL) 156 | 157 | next_q = False 158 | for p in all_pairs: 159 | fn0 = p[0] 160 | fn1 = p[1] 161 | img0 = cv2.imread(osp.join(img_dir, fn0)) 162 | img1 = cv2.imread(osp.join(img_dir, fn1)) 163 | if e_conf["preprocessing"]["grayscale"]: 164 | img0_gray = cv2.cvtColor(img0, cv2.COLOR_BGR2GRAY)[None] 165 | img1_gray = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)[None] 166 | else: 167 | img0_gray = cv2.cvtColor(img0, cv2.COLOR_BGR2RGB) 168 | img1_gray = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB) 169 | img0_gray = img0_gray.transpose(2, 0, 1) 170 | img1_gray = img1_gray.transpose(2, 0, 1) 171 | mask0 = cv2.imread(osp.join(mask_dir, fn0.replace("jpg", "png"))) 172 | mask0 = cv2.resize(mask0, dsize=(img0.shape[1], img0.shape[0]), interpolation=cv2.INTER_NEAREST) 173 | # label0 = np.int32(mask0[:, :, 2]) * 256 * 256 + np.int32(mask0[:, :, 1]) * 256 + np.int32(mask0[:, :, 0]) 174 | mask1 = cv2.imread(osp.join(mask_dir, fn1.replace("jpg", "png"))) 175 | mask1 = cv2.resize(mask1, dsize=(img1.shape[1], img1.shape[0]), interpolation=cv2.INTER_NEAREST) 176 | # label1 = np.int32(mask1[:, :, 2]) * 256 * 256 + np.int32(mask1[:, :, 1]) * 256 + np.int32(mask1[:, :, 0]) 177 | print(img0.shape, img1.shape) 178 | print(mask0.shape, mask1.shape) 179 | 180 | pred0 = extractor(model, img=torch.from_numpy(img0_gray / 255.).unsqueeze(0).float(), 181 | topK=e_conf["model"]["max_keypoints"], mask=mask0 if use_mask else None) 182 | pred1 = extractor(model, img=torch.from_numpy(img1_gray / 255.).unsqueeze(0).float(), 183 | topK=e_conf["model"]["max_keypoints"], mask=mask1 if use_mask else None) 184 | 185 | img_pt0 = plot_keypoint(img_path=img0, pts=pred0["keypoints"]) 186 | img_pt1 = plot_keypoint(img_path=img1, pts=pred1["keypoints"]) 187 | 188 | if use_mask: 189 | match_data = { 190 | "descriptors0": pred0["descriptors"], 191 | "labels0": pred0["labels"], 192 | "descriptors1": pred1["descriptors"], 193 | "labels1": pred1["labels"], 194 | } 195 | else: 196 | match_data = { 197 | "descriptors0": pred0["descriptors"], 198 | # "labels0": pred0["labels"], 199 | "descriptors1": pred1["descriptors"], 200 | # "labels1": pred1["labels"], 201 | } 202 | matches = matcher(match_data)["matches0"] 203 | # matches = pred['matches0'] # [0].cpu().short().numpy() 204 | valid_matches = [] 205 | for i in range(matches.shape[0]): 206 | if matches[i] > 0: 207 | valid_matches.append([i, matches[i]]) 208 | valid_matches = np.array(valid_matches, np.int) 209 | img_matches = plot_matches(img1=img0, img2=img1, 210 | pts1=pred0["keypoints"], 211 | pts2=pred1["keypoints"], 212 | matches=valid_matches, 213 | ) 214 | 215 | img_pt0 = resize_img(img_pt0, nh=512) 216 | mask0 = resize_img(mask0, nh=512) 217 | img_pt1 = resize_img(img_pt1, nh=512) 218 | mask1 = resize_img(mask1, nh=512) 219 | img_matches = resize_img(img_matches, nh=512) 220 | 221 | cv2.imshow("pt0", np.hstack([img_pt0, mask0])) 222 | cv2.imshow("pt1", np.hstack([img_pt1, mask1])) 223 | cv2.imshow("match", img_matches) 224 | cv2.waitKey(0) 225 | 226 | 227 | if __name__ == '__main__': 228 | test(use_mask=True) 229 | -------------------------------------------------------------------------------- /localization/fine/triangulation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from pathlib import Path 4 | from tqdm import tqdm 5 | import h5py 6 | import numpy as np 7 | import subprocess 8 | import pprint 9 | import shutil 10 | 11 | from localization.utils.read_write_model import (read_model, read_images_binary, 12 | CAMERA_MODEL_NAMES, read_cameras_binary, 13 | write_points3d_binary, write_images_binary, Image) 14 | from localization.utils.database import COLMAPDatabase 15 | from localization.utils.parsers import ( 16 | parse_image_lists_with_intrinsics, parse_retrieval, names_to_pair) 17 | 18 | 19 | def create_empty_model(reference_model, empty_model): 20 | logging.info('Creating an empty model.') 21 | empty_model.mkdir(exist_ok=True) 22 | shutil.copy(reference_model / 'cameras.bin', empty_model) 23 | write_points3d_binary(dict(), empty_model / 'points3D.bin') 24 | images = read_images_binary(str(reference_model / 'images.bin')) 25 | images_empty = dict() 26 | for id_, image in images.items(): 27 | image = image._asdict() 28 | image['xys'] = np.zeros((0, 2), float) 29 | image['point3D_ids'] = np.full(0, -1, int) 30 | images_empty[id_] = Image(**image) 31 | write_images_binary(images_empty, empty_model / 'images.bin') 32 | 33 | 34 | def create_db_from_model(empty_model, database_path): 35 | if database_path.exists(): 36 | logging.warning('Database already exists.') 37 | 38 | cameras = read_cameras_binary(str(empty_model / 'cameras.bin')) 39 | images = read_images_binary(str(empty_model / 'images.bin')) 40 | 41 | db = COLMAPDatabase.connect(database_path) 42 | db.create_tables() 43 | 44 | for i, camera in cameras.items(): 45 | model_id = CAMERA_MODEL_NAMES[camera.model].model_id 46 | db.add_camera( 47 | model_id, camera.width, camera.height, camera.params, camera_id=i, 48 | prior_focal_length=True) 49 | 50 | for i, image in images.items(): 51 | db.add_image(image.name, image.camera_id, image_id=i) 52 | 53 | db.commit() 54 | db.close() 55 | return {image.name: i for i, image in images.items()} 56 | 57 | 58 | def import_features(image_ids, database_path, features_path): 59 | logging.info('Importing features into the database...') 60 | hfile = h5py.File(str(features_path), 'r') 61 | db = COLMAPDatabase.connect(database_path) 62 | 63 | for image_name, image_id in tqdm(image_ids.items()): 64 | keypoints = hfile[image_name]['keypoints'].__array__() 65 | keypoints += 0.5 # COLMAP origin 66 | db.add_keypoints(image_id, keypoints) 67 | 68 | hfile.close() 69 | db.commit() 70 | db.close() 71 | 72 | 73 | def import_matches(image_ids, database_path, pairs_path, matches_path, 74 | min_match_score=None, skip_geometric_verification=False): 75 | logging.info('Importing matches into the database...') 76 | 77 | with open(str(pairs_path), 'r') as f: 78 | pairs = [p.split() for p in f.readlines()] 79 | 80 | hfile = h5py.File(str(matches_path), 'r') 81 | db = COLMAPDatabase.connect(database_path) 82 | 83 | matched = set() 84 | for name0, name1 in tqdm(pairs): 85 | if name0 not in image_ids or name1 not in image_ids: 86 | continue 87 | id0, id1 = image_ids[name0], image_ids[name1] 88 | if len({(id0, id1), (id1, id0)} & matched) > 0: 89 | continue 90 | pair = names_to_pair(name0, name1) 91 | if pair not in hfile: 92 | raise ValueError( 93 | f'Could not find pair {(name0, name1)}... ' 94 | 'Maybe you matched with a different list of pairs? ' 95 | f'Reverse in file: {names_to_pair(name0, name1) in hfile}.') 96 | 97 | matches = hfile[pair]['matches0'].__array__() 98 | valid = matches > -1 99 | if min_match_score: 100 | scores = hfile[pair]['matching_scores0'].__array__() 101 | valid = valid & (scores > min_match_score) 102 | matches = np.stack([np.where(valid)[0], matches[valid]], -1) 103 | 104 | db.add_matches(id0, id1, matches) 105 | matched |= {(id0, id1), (id1, id0)} 106 | 107 | if skip_geometric_verification: 108 | db.add_two_view_geometry(id0, id1, matches) 109 | 110 | hfile.close() 111 | db.commit() 112 | db.close() 113 | 114 | 115 | def geometric_verification(colmap_path, database_path, pairs_path): 116 | logging.info('Performing geometric verification of the matches...') 117 | cmd = [ 118 | str(colmap_path), 'matches_importer', 119 | '--database_path', str(database_path), 120 | '--match_list_path', str(pairs_path), 121 | '--match_type', 'pairs', 122 | '--SiftMatching.max_num_trials', str(20000), 123 | '--SiftMatching.min_inlier_ratio', str(0.1)] 124 | ret = subprocess.call(cmd) 125 | if ret != 0: 126 | logging.warning('Problem with matches_importer, exiting.') 127 | exit(ret) 128 | 129 | 130 | def run_triangulation(colmap_path, model_path, database_path, image_dir, 131 | empty_model): 132 | logging.info('Running the triangulation...') 133 | assert model_path.exists() 134 | 135 | cmd = [ 136 | str(colmap_path), 'point_triangulator', 137 | '--database_path', str(database_path), 138 | '--image_path', str(image_dir), 139 | '--input_path', str(empty_model), 140 | '--output_path', str(model_path), 141 | '--Mapper.ba_refine_focal_length', '0', 142 | '--Mapper.ba_refine_principal_point', '0', 143 | '--Mapper.ba_refine_extra_params', '0'] 144 | logging.info(' '.join(cmd)) 145 | ret = subprocess.call(cmd) 146 | if ret != 0: 147 | logging.warning('Problem with point_triangulator, exiting.') 148 | exit(ret) 149 | 150 | stats_raw = subprocess.check_output( 151 | [str(colmap_path), 'model_analyzer', '--path', model_path]) 152 | stats_raw = stats_raw.decode().split("\n") 153 | stats = dict() 154 | for stat in stats_raw: 155 | if stat.startswith("Registered images"): 156 | stats['num_reg_images'] = int(stat.split()[-1]) 157 | elif stat.startswith("Points"): 158 | stats['num_sparse_points'] = int(stat.split()[-1]) 159 | elif stat.startswith("Observations"): 160 | stats['num_observations'] = int(stat.split()[-1]) 161 | elif stat.startswith("Mean track length"): 162 | stats['mean_track_length'] = float(stat.split()[-1]) 163 | elif stat.startswith("Mean observations per image"): 164 | stats['num_observations_per_image'] = float(stat.split()[-1]) 165 | elif stat.startswith("Mean reprojection error"): 166 | stats['mean_reproj_error'] = float(stat.split()[-1][:-2]) 167 | 168 | return stats 169 | 170 | 171 | def main(sfm_dir, reference_sfm_model, image_dir, pairs, features, matches, 172 | colmap_path='colmap', skip_geometric_verification=False, 173 | min_match_score=None): 174 | assert reference_sfm_model.exists(), reference_sfm_model 175 | assert features.exists(), features 176 | assert pairs.exists(), pairs 177 | assert matches.exists(), matches 178 | 179 | sfm_dir.mkdir(parents=True, exist_ok=True) 180 | database = sfm_dir / 'database.db' 181 | model = sfm_dir / 'model' 182 | model.mkdir(exist_ok=True) 183 | empty_model = sfm_dir / 'empty' 184 | 185 | create_empty_model(reference_sfm_model, empty_model) 186 | image_ids = create_db_from_model(empty_model, database) 187 | import_features(image_ids, database, features) 188 | import_matches(image_ids, database, pairs, matches, 189 | min_match_score, skip_geometric_verification) 190 | if not skip_geometric_verification: 191 | geometric_verification(colmap_path, database, pairs) 192 | stats = run_triangulation( 193 | colmap_path, model, database, image_dir, empty_model) 194 | 195 | logging.info(f'Statistics:\n{pprint.pformat(stats)}') 196 | shutil.rmtree(empty_model) 197 | 198 | 199 | if __name__ == '__main__': 200 | parser = argparse.ArgumentParser() 201 | parser.add_argument('--sfm_dir', type=Path, required=True) 202 | parser.add_argument('--reference_sfm_model', type=Path, required=True) 203 | parser.add_argument('--image_dir', type=Path, required=True) 204 | 205 | parser.add_argument('--pairs', type=Path, required=True) 206 | parser.add_argument('--features', type=Path, required=True) 207 | parser.add_argument('--matches', type=Path, required=True) 208 | 209 | parser.add_argument('--colmap_path', type=Path, default='colmap') 210 | 211 | parser.add_argument('--skip_geometric_verification', action='store_true') 212 | parser.add_argument('--min_match_score', type=float) 213 | args = parser.parse_args() 214 | 215 | main(**args.__dict__) 216 | -------------------------------------------------------------------------------- /localization/tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/5/6 下午4:47 3 | # @Author : Fei Xue 4 | # @Email : fx221@cam.ac.uk 5 | # @File : tools.py 6 | # @Software: PyCharm 7 | 8 | import numpy as np 9 | import cv2 10 | import torch 11 | from copy import copy 12 | from scipy.spatial.transform import Rotation as sciR 13 | 14 | 15 | def plot_keypoint(img_path, pts, scores=None, tag=None, save_path=None): 16 | if type(img_path) == str: 17 | img = cv2.imread(img_path) 18 | else: 19 | img = img_path.copy() 20 | 21 | img_out = img.copy() 22 | print(img.shape) 23 | r = 3 24 | for i in range(pts.shape[0]): 25 | pt = pts[i] 26 | # img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1])), int(r * s), (0, 0, 255), 4) 27 | img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1])), r, (0, 0, 255), 2) 28 | 29 | if save_path is not None: 30 | cv2.imwrite(save_path, img_out) 31 | return img_out 32 | 33 | 34 | def sort_dict_by_value(data, reverse=False): 35 | return sorted(data.items(), key=lambda d: d[1], reverse=reverse) 36 | 37 | 38 | def read_retrieval_results(path): 39 | output = {} 40 | with open(path, "r") as f: 41 | lines = f.readlines() 42 | for p in lines: 43 | p = p.strip("\n").split(" ") 44 | 45 | if p[1] == "no_match": 46 | continue 47 | if p[0] in output.keys(): 48 | output[p[0]].append(p[1]) 49 | else: 50 | output[p[0]] = [p[1]] 51 | return output 52 | 53 | 54 | def nn_k(query_gps, db_gps, k=20): 55 | q = torch.from_numpy(query_gps) # [N 2] 56 | db = torch.from_numpy(db_gps) # [M, 2] 57 | # print (q.shape, db.shape) 58 | dist = q.unsqueeze(2) - db.t().unsqueeze(0) 59 | dist = dist[:, 0, :] ** 2 + dist[:, 1, :] ** 2 60 | print("dist: ", dist.shape) 61 | topk = torch.topk(dist, dim=1, k=k, largest=False)[1] 62 | return topk 63 | 64 | 65 | def plot_matches(img1, img2, pts1, pts2, inliers, horizon=False, plot_outlier=False, confs=None, plot_match=True): 66 | rows1 = img1.shape[0] 67 | cols1 = img1.shape[1] 68 | rows2 = img2.shape[0] 69 | cols2 = img2.shape[1] 70 | r = 3 71 | if horizon: 72 | img_out = np.zeros((max([rows1, rows2]), cols1 + cols2, 3), dtype='uint8') 73 | # Place the first image to the left 74 | img_out[:rows1, :cols1] = img1 75 | # Place the next image to the right of it 76 | img_out[:rows2, cols1:] = img2 # np.dstack([img2, img2, img2]) 77 | 78 | if not plot_match: 79 | return cv2.resize(img_out, None, fx=0.5, fy=0.5) 80 | # for idx, pt in enumerate(pts1): 81 | # img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1])), r, (0, 0, 255), 2) 82 | # for idx, pt in enumerate(pts2): 83 | # img_out = cv2.circle(img_out, (int(pt[0] + cols1), int(pt[1])), r, (0, 0, 255), 2) 84 | for idx in range(inliers.shape[0]): 85 | # if idx % 10 > 0: 86 | # continue 87 | if inliers[idx]: 88 | color = (0, 255, 0) 89 | else: 90 | if not plot_outlier: 91 | continue 92 | color = (0, 0, 255) 93 | pt1 = pts1[idx] 94 | pt2 = pts2[idx] 95 | 96 | if confs is not None: 97 | nr = int(r * confs[idx]) 98 | else: 99 | nr = r 100 | img_out = cv2.circle(img_out, (int(pt1[0]), int(pt1[1])), nr, color, 2) 101 | 102 | img_out = cv2.circle(img_out, (int(pt2[0]) + cols1, int(pt2[1])), nr, color, 2) 103 | 104 | img_out = cv2.line(img_out, (int(pt1[0]), int(pt1[1])), (int(pt2[0]) + cols1, int(pt2[1])), color, 105 | 2) 106 | else: 107 | img_out = np.zeros((rows1 + rows2, max([cols1, cols2]), 3), dtype='uint8') 108 | # Place the first image to the left 109 | img_out[:rows1, :cols1] = img1 110 | # Place the next image to the right of it 111 | img_out[rows1:, :cols2] = img2 # np.dstack([img2, img2, img2]) 112 | 113 | if not plot_match: 114 | return cv2.resize(img_out, None, fx=0.5, fy=0.5) 115 | # for idx, pt in enumerate(pts1): 116 | # img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1])), r, (0, 0, 255), 2) 117 | # for idx, pt in enumerate(pts2): 118 | # img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1] + rows1)), r, (0, 0, 255), 2) 119 | for idx in range(inliers.shape[0]): 120 | # print("idx: ", inliers[idx]) 121 | # if idx % 10 > 0: 122 | # continue 123 | if inliers[idx]: 124 | color = (0, 255, 0) 125 | else: 126 | if not plot_outlier: 127 | continue 128 | color = (0, 0, 255) 129 | 130 | if confs is not None: 131 | nr = int(r * confs[idx]) 132 | else: 133 | nr = r 134 | 135 | pt1 = pts1[idx] 136 | pt2 = pts2[idx] 137 | img_out = cv2.circle(img_out, (int(pt1[0]), int(pt1[1])), r, color, 2) 138 | 139 | img_out = cv2.circle(img_out, (int(pt2[0]), int(pt2[1]) + rows1), r, color, 2) 140 | 141 | img_out = cv2.line(img_out, (int(pt1[0]), int(pt1[1])), (int(pt2[0]), int(pt2[1]) + rows1), color, 142 | 2) 143 | 144 | img_rs = cv2.resize(img_out, None, fx=0.5, fy=0.5) 145 | 146 | # img_rs = cv2.putText(img_rs, 'matches: {:d}'.format(len(inliers.shape[0])), (50, 50), cv2.FONT_HERSHEY_COMPLEX, 2, 147 | # (0, 0, 255), 2) 148 | 149 | # if save_fn is not None: 150 | # cv2.imwrite(save_fn, img_rs) 151 | # cv2.imshow("match", img_rs) 152 | # cv2.waitKey(0) 153 | return img_rs 154 | 155 | 156 | def plot_reprojpoint2D(img, points2D, reproj_points2D, confs=None): 157 | img_out = copy(img) 158 | r = 5 159 | for i in range(points2D.shape[0]): 160 | p = points2D[i] 161 | rp = reproj_points2D[i] 162 | 163 | if confs is not None: 164 | nr = int(r * confs[i]) 165 | else: 166 | nr = r 167 | 168 | if nr >= 50: 169 | nr = 50 170 | # img_out = cv2.circle(img_out, (int(p[0]), int(p[1])), nr, color=(0, 255, 0), thickness=2) 171 | img_out = cv2.circle(img_out, (int(rp[0]), int(rp[1])), nr, color=(0, 0, 255), thickness=3) 172 | img_out = cv2.circle(img_out, (int(rp[0]), int(rp[1])), 2, color=(0, 0, 255), thickness=3) 173 | # img_out = cv2.line(img_out, pt1=(int(p[0]), int(p[1])), pt2=(int(rp[0]), int(rp[1])), color=(0, 0, 255), 174 | # thickness=2) 175 | 176 | return img_out 177 | 178 | 179 | def reproject_fromR(points3D, rot, tvec, camera): 180 | proj_2d = rot @ points3D.transpose() + tvec.reshape(3, 1) 181 | 182 | if camera['model'] == 'SIMPLE_RADIAL': 183 | f = camera['params'][0] 184 | cx = camera['params'][1] 185 | cy = camera['params'][2] 186 | k = camera['params'][3] 187 | 188 | proj_2d = proj_2d[0:2, :] / proj_2d[2, :] 189 | r2 = proj_2d[0, :] ** 2 + proj_2d[1, :] ** 2 190 | radial = r2 * k 191 | du = proj_2d[0, :] * radial 192 | dv = proj_2d[1, :] * radial 193 | 194 | u = proj_2d[0, :] + du 195 | v = proj_2d[1, :] + dv 196 | u = u * f + cx 197 | v = v * f + cy 198 | uvs = np.vstack([u, v]).transpose() 199 | 200 | return uvs 201 | 202 | 203 | def calc_depth(points3D, rvec, tvec, camera): 204 | rot = sciR.from_quat(quat=[rvec[1], rvec[2], rvec[3], rvec[0]]).as_dcm() 205 | # print('p3d: ', points3D.shape, rot.shape, rot) 206 | proj_2d = rot @ points3D.transpose() + tvec.reshape(3, 1) 207 | 208 | return proj_2d.transpose()[:, 2] 209 | 210 | 211 | def reproject(points3D, rvec, tvec, camera): 212 | ''' 213 | Args: 214 | points3D: [N, 3] 215 | rvec: [w, x, y, z] 216 | tvec: [x, y, z] 217 | Returns: 218 | ''' 219 | # print('camera: ', camera) 220 | # print('p3d: ', points3D.shape) 221 | rot = sciR.from_quat(quat=[rvec[1], rvec[2], rvec[3], rvec[0]]).as_dcm() 222 | # print('p3d: ', points3D.shape, rot.shape, rot) 223 | proj_2d = rot @ points3D.transpose() + tvec.reshape(3, 1) 224 | 225 | if camera['model'] == 'SIMPLE_RADIAL': 226 | f = camera['params'][0] 227 | cx = camera['params'][1] 228 | cy = camera['params'][2] 229 | k = camera['params'][3] 230 | 231 | proj_2d = proj_2d[0:2, :] / proj_2d[2, :] 232 | r2 = proj_2d[0, :] ** 2 + proj_2d[1, :] ** 2 233 | radial = r2 * k 234 | du = proj_2d[0, :] * radial 235 | dv = proj_2d[1, :] * radial 236 | 237 | u = proj_2d[0, :] + du 238 | v = proj_2d[1, :] + dv 239 | u = u * f + cx 240 | v = v * f + cy 241 | uvs = np.vstack([u, v]).transpose() 242 | 243 | return uvs 244 | 245 | 246 | def quaternion_angular_error(q1, q2): 247 | """ 248 | angular error between two quaternions 249 | :param q1: (4, ) 250 | :param q2: (4, ) 251 | :return: 252 | """ 253 | d = abs(np.dot(q1, q2)) 254 | d = min(1.0, max(-1.0, d)) 255 | theta = 2 * np.arccos(d) * 180 / np.pi 256 | return theta 257 | 258 | 259 | def ColmapQ2R(qvec): 260 | rot = sciR.from_quat(quat=[qvec[1], qvec[2], qvec[3], qvec[0]]).as_dcm() 261 | return rot 262 | 263 | 264 | def compute_pose_error(pred_qcw, pred_tcw, gt_qcw, gt_tcw): 265 | pred_Rcw = sciR.from_quat(quat=[pred_qcw[1], pred_qcw[2], pred_qcw[3], pred_qcw[0]]).as_dcm() 266 | pred_tcw = np.array(pred_tcw, float).reshape(3, 1) 267 | pred_Rwc = pred_Rcw.transpose() 268 | pred_twc = -pred_Rcw.transpose() @ pred_tcw 269 | 270 | gt_Rcw = sciR.from_quat(quat=[gt_qcw[1], gt_qcw[2], gt_qcw[3], gt_qcw[0]]).as_dcm() 271 | gt_tcw = np.array(gt_tcw, float).reshape(3, 1) 272 | gt_Rwc = gt_Rcw.transpose() 273 | gt_twc = -gt_Rcw.transpose() @ gt_tcw 274 | 275 | t_error_xyz = pred_twc - gt_twc 276 | t_error = np.sqrt(np.sum(t_error_xyz ** 2)) 277 | 278 | pred_qwc = sciR.from_quat(quat=[pred_qcw[1], pred_qcw[2], pred_qcw[3], pred_qcw[0]]).as_quat() 279 | gt_qwc = sciR.from_quat(quat=[gt_qcw[1], gt_qcw[2], gt_qcw[3], gt_qcw[0]]).as_quat() 280 | 281 | q_error = quaternion_angular_error(q1=pred_qwc, q2=gt_qwc) 282 | 283 | return q_error, t_error, (t_error_xyz[0, 0], t_error_xyz[1, 0], t_error_xyz[2, 0]) 284 | -------------------------------------------------------------------------------- /localization/utils/parsers.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import logging 3 | import numpy as np 4 | from collections import defaultdict 5 | 6 | 7 | def parse_image_lists_with_intrinsics(paths): 8 | results = [] 9 | files = list(Path(paths.parent).glob(paths.name)) 10 | assert len(files) > 0 11 | 12 | for lfile in files: 13 | with open(lfile, 'r') as f: 14 | raw_data = f.readlines() 15 | 16 | logging.info(f'Importing {len(raw_data)} queries in {lfile.name}') 17 | for data in raw_data: 18 | data = data.strip('\n').split(' ') 19 | name, camera_model, width, height = data[:4] 20 | params = np.array(data[4:], float) 21 | info = (camera_model, int(width), int(height), params) 22 | results.append((name, info)) 23 | 24 | assert len(results) > 0 25 | return results 26 | 27 | 28 | def parse_img_lists_for_extended_cmu_seaons(paths): 29 | Ks = { 30 | "c0": "OPENCV 1024 768 868.993378 866.063001 525.942323 420.042529 -0.399431 0.188924 0.000153 0.000571", 31 | "c1": "OPENCV 1024 768 868.993378 866.063001 525.942323 420.042529 -0.399431 0.188924 0.000153 0.000571" 32 | } 33 | 34 | results = [] 35 | files = list(Path(paths.parent).glob(paths.name)) 36 | assert len(files) > 0 37 | 38 | for lfile in files: 39 | with open(lfile, 'r') as f: 40 | raw_data = f.readlines() 41 | 42 | logging.info(f'Importing {len(raw_data)} queries in {lfile.name}') 43 | for name in raw_data: 44 | name = name.strip('\n') 45 | camera = name.split('_')[2] 46 | K = Ks[camera].split(' ') 47 | camera_model, width, height = K[:3] 48 | params = np.array(K[3:], float) 49 | # print("camera: ", camera_model, width, height, params) 50 | info = (camera_model, int(width), int(height), params) 51 | results.append((name, info)) 52 | 53 | assert len(results) > 0 54 | return results 55 | 56 | 57 | def parse_retrieval(path): 58 | retrieval = defaultdict(list) 59 | with open(path, 'r') as f: 60 | for p in f.read().rstrip('\n').split('\n'): 61 | q, r = p.split(' ') 62 | retrieval[q].append(r) 63 | return dict(retrieval) 64 | 65 | 66 | def names_to_pair(name0, name1): 67 | return '_'.join((name0.replace('/', '-'), name1.replace('/', '-'))) 68 | -------------------------------------------------------------------------------- /loss/__pycache__/accuracy.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feixue94/lbr/bdd126b79da9e75759a364a0d8e010d76658f1af/loss/__pycache__/accuracy.cpython-37.pyc -------------------------------------------------------------------------------- /loss/__pycache__/aploss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feixue94/lbr/bdd126b79da9e75759a364a0d8e010d76658f1af/loss/__pycache__/aploss.cpython-36.pyc -------------------------------------------------------------------------------- /loss/__pycache__/aploss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feixue94/lbr/bdd126b79da9e75759a364a0d8e010d76658f1af/loss/__pycache__/aploss.cpython-37.pyc -------------------------------------------------------------------------------- /loss/accuracy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | '''================================================= 3 | @Project -> File shloc -> accuracy 4 | @IDE PyCharm 5 | @Author fx221@cam.ac.uk 6 | @Date 14/06/2021 17:11 7 | ==================================================''' 8 | 9 | import torch.nn as nn 10 | 11 | 12 | def accuracy(pred, target, topk=1, thresh=None): 13 | """Calculate accuracy according to the prediction and target. 14 | 15 | Args: 16 | pred (torch.Tensor): The model prediction, shape (N, num_class, ...) 17 | target (torch.Tensor): The target of each prediction, shape (N, , ...) 18 | topk (int | tuple[int], optional): If the predictions in ``topk`` 19 | matches the target, the predictions will be regarded as 20 | correct ones. Defaults to 1. 21 | thresh (float, optional): If not None, predictions with scores under 22 | this threshold are considered incorrect. Default to None. 23 | 24 | Returns: 25 | float | tuple[float]: If the input ``topk`` is a single integer, 26 | the function will return a single float as accuracy. If 27 | ``topk`` is a tuple containing multiple integers, the 28 | function will return a tuple containing accuracies of 29 | each ``topk`` number. 30 | """ 31 | assert isinstance(topk, (int, tuple)) 32 | if isinstance(topk, int): 33 | topk = (topk, ) 34 | return_single = True 35 | else: 36 | return_single = False 37 | 38 | maxk = max(topk) 39 | if pred.size(0) == 0: 40 | accu = [pred.new_tensor(0.) for i in range(len(topk))] 41 | return accu[0] if return_single else accu 42 | assert pred.ndim == target.ndim + 1 43 | # print(type(pred), type(target)) 44 | assert pred.size(0) == target.size(0) 45 | assert maxk <= pred.size(1), \ 46 | f'maxk {maxk} exceeds pred dimension {pred.size(1)}' 47 | pred_value, pred_label = pred.topk(maxk, dim=1) 48 | # transpose to shape (maxk, N, ...) 49 | pred_label = pred_label.transpose(0, 1) 50 | correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label)) 51 | if thresh is not None: 52 | # Only prediction values larger than thresh are counted as correct 53 | correct = correct & (pred_value > thresh).t() 54 | res = [] 55 | for k in topk: 56 | # correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 57 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 58 | res.append(correct_k.mul_(100.0 / target.numel())) 59 | return res[0] if return_single else res 60 | 61 | 62 | class Accuracy(nn.Module): 63 | """Accuracy calculation module.""" 64 | 65 | def __init__(self, topk=(1, ), thresh=None): 66 | """Module to calculate the accuracy. 67 | 68 | Args: 69 | topk (tuple, optional): The criterion used to calculate the 70 | accuracy. Defaults to (1,). 71 | thresh (float, optional): If not None, predictions with scores 72 | under this threshold are considered incorrect. Default to None. 73 | """ 74 | super().__init__() 75 | self.topk = topk 76 | self.thresh = thresh 77 | 78 | def forward(self, pred, target): 79 | """Forward function to calculate accuracy. 80 | 81 | Args: 82 | pred (torch.Tensor): Prediction of models. 83 | target (torch.Tensor): Target for each prediction. 84 | 85 | Returns: 86 | tuple[float]: The accuracies under different topk criterions. 87 | """ 88 | return accuracy(pred, target, self.topk, self.thresh) 89 | -------------------------------------------------------------------------------- /loss/seg_loss/__pycache__/crossentropy_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feixue94/lbr/bdd126b79da9e75759a364a0d8e010d76658f1af/loss/seg_loss/__pycache__/crossentropy_loss.cpython-37.pyc -------------------------------------------------------------------------------- /loss/seg_loss/__pycache__/segloss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feixue94/lbr/bdd126b79da9e75759a364a0d8e010d76658f1af/loss/seg_loss/__pycache__/segloss.cpython-37.pyc -------------------------------------------------------------------------------- /loss/seg_loss/crossentropy_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | '''================================================= 3 | @Project -> File shloc -> crossentropy_loss 4 | @IDE PyCharm 5 | @Author fx221@cam.ac.uk 6 | @Date 29/04/2021 20:19 7 | ==================================================''' 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | def _ohem_mask(loss, ohem_ratio): 15 | with torch.no_grad(): 16 | values, _ = torch.topk(loss.reshape(-1), 17 | int(loss.nelement() * ohem_ratio)) 18 | mask = loss >= values[-1] 19 | return mask.float() 20 | 21 | 22 | class BCEWithLogitsLossWithOHEM(nn.Module): 23 | def __init__(self, ohem_ratio=1.0, pos_weight=None, eps=1e-7): 24 | super(BCEWithLogitsLossWithOHEM, self).__init__() 25 | self.criterion = nn.BCEWithLogitsLoss(reduction='none', 26 | pos_weight=pos_weight) 27 | self.ohem_ratio = ohem_ratio 28 | self.eps = eps 29 | 30 | def forward(self, pred, target): 31 | loss = self.criterion(pred, target) 32 | mask = _ohem_mask(loss, self.ohem_ratio) 33 | loss = loss * mask 34 | return loss.sum() / (mask.sum() + self.eps) 35 | 36 | def set_ohem_ratio(self, ohem_ratio): 37 | self.ohem_ratio = ohem_ratio 38 | 39 | 40 | class CrossEntropyLossWithOHEM(nn.Module): 41 | def __init__(self, 42 | ohem_ratio=1.0, 43 | weight=None, 44 | ignore_index=-100, 45 | eps=1e-7): 46 | super(CrossEntropyLossWithOHEM, self).__init__() 47 | self.criterion = nn.CrossEntropyLoss(weight=weight, 48 | ignore_index=ignore_index, 49 | reduction='none') 50 | self.ohem_ratio = ohem_ratio 51 | self.eps = eps 52 | 53 | def forward(self, pred, target): 54 | loss = self.criterion(pred, target) 55 | mask = _ohem_mask(loss, self.ohem_ratio) 56 | loss = loss * mask 57 | return loss.sum() / (mask.sum() + self.eps) 58 | 59 | def set_ohem_ratio(self, ohem_ratio): 60 | self.ohem_ratio = ohem_ratio 61 | 62 | 63 | class DiceLoss(nn.Module): 64 | def __init__(self, eps=1e-7): 65 | super(DiceLoss, self).__init__() 66 | self.eps = eps 67 | 68 | def forward(self, pred, target): 69 | pred = torch.sigmoid(pred) 70 | intersection = (pred * target).sum() 71 | loss = 1 - (2. * intersection) / (pred.sum() + target.sum() + self.eps) 72 | return loss 73 | 74 | 75 | class SoftCrossEntropyLossWithOHEM(nn.Module): 76 | def __init__(self, ohem_ratio=1.0, eps=1e-7): 77 | super(SoftCrossEntropyLossWithOHEM, self).__init__() 78 | self.ohem_ratio = ohem_ratio 79 | self.eps = eps 80 | 81 | def forward(self, pred, target): 82 | pred = F.log_softmax(pred, dim=1) 83 | loss = -(pred * target).sum(1) 84 | mask = _ohem_mask(loss, self.ohem_ratio) 85 | loss = loss * mask 86 | return loss.sum() / (mask.sum() + self.eps) 87 | 88 | def set_ohem_ratio(self, ohem_ratio): 89 | self.ohem_ratio = ohem_ratio 90 | 91 | 92 | class FocalLoss(nn.Module): 93 | def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=None): 94 | super(FocalLoss, self).__init__() 95 | self.alpha = alpha 96 | self.gamma = gamma 97 | self.ignore_index = ignore_index 98 | self.size_average = size_average 99 | 100 | def forward(self, inputs, targets): 101 | ce_loss = F.cross_entropy( 102 | inputs, targets, reduction='none', ignore_index=self.ignore_index) 103 | pt = torch.exp(-ce_loss) 104 | focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss 105 | if self.size_average: 106 | return focal_loss.mean() 107 | else: 108 | return focal_loss.sum() 109 | 110 | 111 | def cross_entropy2d(input, target, weight=None, size_average=True): 112 | # input: (n, c, h, w), target: (n, h, w) 113 | n, c, h, w = input.size() 114 | # log_p: (n, c, h, w) 115 | 116 | log_p = F.log_softmax(input, dim=1) 117 | # log_p: (n*h*w, c) 118 | log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous() 119 | log_p = log_p[target.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0] 120 | log_p = log_p.view(-1, c) 121 | # target: (n*h*w,) 122 | mask = target >= 0 123 | target = target[mask] 124 | loss = F.nll_loss(log_p, target, weight=weight, reduction='sum') 125 | if size_average: 126 | loss /= mask.data.sum() 127 | return loss 128 | 129 | 130 | def cross_entropy_seg(input, target): 131 | idx = target.cpu().long() 132 | one_hot_key = torch.FloatTensor(input.shape).zero_() 133 | one_hot_key = one_hot_key.scatter_(1, idx, 1) 134 | if one_hot_key.device != input.device: 135 | one_hot_key = one_hot_key.to(input.device) 136 | log_p = F.log_softmax(input=input, dim=1) 137 | loss = -(one_hot_key * log_p) 138 | 139 | loss = loss.sum(1) 140 | # print("ce_loss: ", loss.shape) 141 | # loss2 = cross_entropy2d(input=input, target=target.squeeze()) 142 | # print("loss: ", loss.mean(), loss2) 143 | 144 | # loss = F.nll_loss(input=log_p, target=target.view(-1), reduction="sum") 145 | return loss.mean() 146 | 147 | 148 | class CrossEntropy(nn.Module): 149 | def __init__(self, weights=None): 150 | super(CrossEntropy, self).__init__() 151 | self.weights = weights 152 | 153 | def forward(self, input, target): 154 | idx = target.cpu().long() 155 | one_hot_key = torch.FloatTensor(input.shape).zero_() 156 | one_hot_key = one_hot_key.scatter_(1, idx, 1) 157 | if one_hot_key.device != input.device: 158 | one_hot_key = one_hot_key.to(input.device) 159 | log_p = F.log_softmax(input=input, dim=1) 160 | loss = -(one_hot_key * log_p) 161 | 162 | if self.weights is not None: 163 | loss = (loss * self.weights[None, :, None, None].expand_as(loss)).sum(1) # + self.smooth [B, C, H, W] 164 | else: 165 | loss = loss.sum(1) 166 | # print("ce_loss: ", loss.shape) 167 | # loss2 = cross_entropy2d(input=input, target=target.squeeze()) 168 | # print("loss: ", loss.mean(), loss2) 169 | 170 | # loss = F.nll_loss(input=log_p, target=target.view(-1), reduction="sum") 171 | return loss.mean() 172 | 173 | 174 | if __name__ == '__main__': 175 | target = torch.randint(0, 4, (1, 1, 4, 4)).cuda() 176 | input = torch.rand((1, 4, 4, 4)).cuda() 177 | print("target: ", target.shape) 178 | 179 | net = CrossEntropy().cuda() 180 | out = net(input, target) 181 | print(out) 182 | -------------------------------------------------------------------------------- /loss/seg_loss/focal_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | '''================================================= 3 | @Project -> File shloc -> focal_loss 4 | @IDE PyCharm 5 | @Author fx221@cam.ac.uk 6 | @Date 29/04/2021 20:18 7 | ==================================================''' 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | 13 | 14 | class FocalLoss(nn.Module): 15 | """ 16 | copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py 17 | This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in 18 | 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)' 19 | Focal_Loss= -1*alpha*(1-pt)*log(pt) 20 | :param num_class: 21 | :param alpha: (tensor) 3D or 4D the scalar factor for this criterion 22 | :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more 23 | focus on hard misclassified example 24 | :param smooth: (float,double) smooth value when cross entropy 25 | :param balance_index: (int) balance class index, should be specific when alpha is float 26 | :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch. 27 | """ 28 | 29 | def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True): 30 | super(FocalLoss, self).__init__() 31 | self.apply_nonlin = apply_nonlin 32 | self.alpha = alpha 33 | self.gamma = gamma 34 | self.balance_index = balance_index 35 | self.smooth = smooth 36 | self.size_average = size_average 37 | 38 | if self.smooth is not None: 39 | if self.smooth < 0 or self.smooth > 1.0: 40 | raise ValueError('smooth value should be in [0,1]') 41 | 42 | def forward(self, logit, target): 43 | if self.apply_nonlin is not None: 44 | logit = self.apply_nonlin(logit) 45 | num_class = logit.shape[1] 46 | 47 | if logit.dim() > 2: 48 | # N,C,d1,d2 -> N,C,m (m=d1*d2*...) 49 | logit = logit.view(logit.size(0), logit.size(1), -1) 50 | logit = logit.permute(0, 2, 1).contiguous() 51 | logit = logit.view(-1, logit.size(-1)) 52 | target = torch.squeeze(target, 1) 53 | target = target.view(-1, 1) 54 | # print(logit.shape, target.shape) 55 | # 56 | alpha = self.alpha 57 | 58 | if alpha is None: 59 | alpha = torch.ones(num_class, 1) 60 | elif isinstance(alpha, (list, np.ndarray)): 61 | assert len(alpha) == num_class 62 | alpha = torch.FloatTensor(alpha).view(num_class, 1) 63 | alpha = alpha / alpha.sum() 64 | elif isinstance(alpha, float): 65 | alpha = torch.ones(num_class, 1) 66 | alpha = alpha * (1 - self.alpha) 67 | alpha[self.balance_index] = self.alpha 68 | 69 | else: 70 | raise TypeError('Not support alpha type') 71 | 72 | if alpha.device != logit.device: 73 | alpha = alpha.to(logit.device) 74 | 75 | idx = target.cpu().long() 76 | 77 | one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_() 78 | one_hot_key = one_hot_key.scatter_(1, idx, 1) 79 | if one_hot_key.device != logit.device: 80 | one_hot_key = one_hot_key.to(logit.device) 81 | 82 | if self.smooth: 83 | one_hot_key = torch.clamp( 84 | one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth) 85 | pt = (one_hot_key * logit).sum(1) + self.smooth 86 | logpt = pt.log() 87 | 88 | gamma = self.gamma 89 | 90 | alpha = alpha[idx] 91 | alpha = torch.squeeze(alpha) 92 | loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt 93 | 94 | if self.size_average: 95 | loss = loss.mean() 96 | else: 97 | loss = loss.sum() 98 | return loss 99 | 100 | 101 | if __name__ == '__main__': 102 | pass -------------------------------------------------------------------------------- /loss/seg_loss/segloss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/5/11 下午2:38 3 | # @Author : Fei Xue 4 | # @Email : fx221@cam.ac.uk 5 | # @File : segloss.py 6 | # @Software: PyCharm 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from loss.seg_loss.crossentropy_loss import CrossEntropy 12 | 13 | 14 | def _ohem_mask(loss, ohem_ratio): 15 | with torch.no_grad(): 16 | values, _ = torch.topk(loss.reshape(-1), 17 | int(loss.nelement() * ohem_ratio)) 18 | mask = loss >= values[-1] 19 | return mask.float() 20 | 21 | 22 | class BCEWithLogitsLossWithOHEM(nn.Module): 23 | def __init__(self, ohem_ratio=1.0, pos_weight=None, eps=1e-7): 24 | super(BCEWithLogitsLossWithOHEM, self).__init__() 25 | self.criterion = nn.BCEWithLogitsLoss(reduction='none', 26 | pos_weight=pos_weight) 27 | self.ohem_ratio = ohem_ratio 28 | self.eps = eps 29 | 30 | def forward(self, pred, target): 31 | loss = self.criterion(pred, target) 32 | mask = _ohem_mask(loss, self.ohem_ratio) 33 | loss = loss * mask 34 | return loss.sum() / (mask.sum() + self.eps) 35 | 36 | def set_ohem_ratio(self, ohem_ratio): 37 | self.ohem_ratio = ohem_ratio 38 | 39 | 40 | class CrossEntropyLossWithOHEM(nn.Module): 41 | def __init__(self, 42 | ohem_ratio=1.0, 43 | weight=None, 44 | ignore_index=-100, 45 | eps=1e-7): 46 | super(CrossEntropyLossWithOHEM, self).__init__() 47 | self.criterion = nn.CrossEntropyLoss(weight=weight, 48 | ignore_index=ignore_index, 49 | reduction='none') 50 | self.ohem_ratio = ohem_ratio 51 | self.eps = eps 52 | 53 | def forward(self, pred, target): 54 | loss = self.criterion(pred, target) 55 | mask = _ohem_mask(loss, self.ohem_ratio) 56 | loss = loss * mask 57 | return loss.sum() / (mask.sum() + self.eps) 58 | 59 | def set_ohem_ratio(self, ohem_ratio): 60 | self.ohem_ratio = ohem_ratio 61 | 62 | 63 | class DiceLoss(nn.Module): 64 | def __init__(self, eps=1e-7): 65 | super(DiceLoss, self).__init__() 66 | self.eps = eps 67 | 68 | def forward(self, pred, target): 69 | pred = torch.sigmoid(pred) 70 | intersection = (pred * target).sum() 71 | loss = 1 - (2. * intersection) / (pred.sum() + target.sum() + self.eps) 72 | return loss 73 | 74 | 75 | class SoftCrossEntropyLossWithOHEM(nn.Module): 76 | def __init__(self, ohem_ratio=1.0, eps=1e-7): 77 | super(SoftCrossEntropyLossWithOHEM, self).__init__() 78 | self.ohem_ratio = ohem_ratio 79 | self.eps = eps 80 | 81 | def forward(self, pred, target): 82 | pred = F.log_softmax(pred, dim=1) 83 | loss = -(pred * target).sum(1) 84 | mask = _ohem_mask(loss, self.ohem_ratio) 85 | loss = loss * mask 86 | return loss.sum() / (mask.sum() + self.eps) 87 | 88 | def set_ohem_ratio(self, ohem_ratio): 89 | self.ohem_ratio = ohem_ratio 90 | 91 | 92 | class SegLoss(nn.Module): 93 | def __init__(self, segloss_name, use_cls=None, use_hiera=None, use_seg=None, 94 | cls_weight=1., hiera_weight=1., label_weights=None): 95 | super(SegLoss, self).__init__() 96 | 97 | if use_seg: 98 | if segloss_name == 'ce': 99 | self.seg_loss = CrossEntropy(weights=label_weights) 100 | elif segloss_name == 'ceohem': 101 | self.seg_loss = CrossEntropyLossWithOHEM(ohem_ratio=0.7, weight=label_weights) 102 | elif segloss_name == 'sceohem': 103 | self.seg_loss = SoftCrossEntropyLossWithOHEM(ohem_ratio=0.7) 104 | else: 105 | self.seg_loss = None 106 | if use_cls: 107 | self.cls_loss = nn.BCEWithLogitsLoss(weight=label_weights) 108 | self.cls_weight = cls_weight 109 | else: 110 | self.cls_loss = None 111 | 112 | if use_hiera: 113 | # self.cls_hiera = nn.BCEWithLogitsLoss(weight=label_weights) 114 | self.cls_hiera = nn.CrossEntropyLoss(weight=label_weights) 115 | self.hiera_weight = hiera_weight 116 | else: 117 | self.cls_hiera = None 118 | 119 | def forward(self, pred_seg=None, gt_seg=None, pred_cls=None, gt_cls=None, pred_hiera=None, gt_hiera=None): 120 | total_loss = 0 121 | output = { 122 | 123 | } 124 | if self.seg_loss is not None: 125 | seg_error = self.seg_loss(pred_seg, gt_seg) 126 | total_loss = total_loss + seg_error 127 | output["seg_loss"] = seg_error 128 | 129 | if self.cls_loss is not None: 130 | cls_error = self.cls_loss(pred_cls, gt_cls) 131 | total_loss = total_loss + cls_error * self.cls_weight 132 | output["cls_loss"] = cls_error 133 | 134 | if self.cls_hiera is not None: 135 | # print (pred_hiera.shape, gt_hiera.shape) 136 | hiera_error = self.cls_hiera(pred_hiera, torch.argmax(gt_hiera, 1).long()) 137 | total_loss = total_loss + hiera_error * self.hiera_weight 138 | output["hiera_loss"] = hiera_error 139 | 140 | output["loss"] = total_loss 141 | return output 142 | -------------------------------------------------------------------------------- /loss/seg_loss/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | '''================================================= 3 | @Project -> File shloc -> utils 4 | @IDE PyCharm 5 | @Author fx221@cam.ac.uk 6 | @Date 29/04/2021 20:12 7 | ==================================================''' 8 | 9 | import torch 10 | import numpy as np 11 | 12 | 13 | def make_one_hot(input, num_classes=None): 14 | if num_classes is None: 15 | num_classes = input.max() + 1 16 | shape = np.array(input.shape) 17 | shape[1] = num_classes 18 | shape = tuple(shape) 19 | result = torch.zeros(shape) 20 | result = result.scatter_(1, input.cpu().long, 1) 21 | return result 22 | -------------------------------------------------------------------------------- /net/layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class Flatten(nn.Module): 6 | def forward(self, input): 7 | return input.view(input.size(0), -1) 8 | 9 | 10 | class L2Norm(nn.Module): 11 | def __init__(self, dim=1): 12 | super().__init__() 13 | self.dim = dim 14 | 15 | def forward(self, input): 16 | return F.normalize(input, p=2, dim=self.dim) 17 | -------------------------------------------------------------------------------- /net/locnets/r2d2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | '''================================================= 3 | @Project -> File shloc -> r2d2 4 | @IDE PyCharm 5 | @Author fx221@cam.ac.uk 6 | @Date 2021-05-27 14:49 7 | ==================================================''' 8 | import torch 9 | import torch.nn as nn 10 | import torchvision.transforms as tvf 11 | import torch.nn.functional as F 12 | 13 | RGB_mean = [0.485, 0.456, 0.406] 14 | RGB_std = [0.229, 0.224, 0.225] 15 | 16 | norm_RGB = tvf.Compose([tvf.ToTensor(), tvf.Normalize(mean=RGB_mean, std=RGB_std)]) 17 | 18 | 19 | class BaseNet(nn.Module): 20 | """ Takes a list of images as input, and returns for each image: 21 | - a pixelwise descriptor 22 | - a pixelwise confidence 23 | """ 24 | 25 | def softmax(self, ux): 26 | if ux.shape[1] == 1: 27 | x = F.softplus(ux) 28 | return x / (1 + x) # for sure in [0,1], much less plateaus than softmax 29 | elif ux.shape[1] == 2: 30 | return F.softmax(ux, dim=1)[:, 1:2] 31 | 32 | def normalize(self, x, ureliability, urepeatability): 33 | return dict(descriptors=F.normalize(x, p=2, dim=1), 34 | repeatability=self.softmax(urepeatability), 35 | reliability=self.softmax(ureliability)) 36 | 37 | def forward_one(self, x): 38 | raise NotImplementedError() 39 | 40 | def forward(self, imgs, **kw): 41 | res = [self.forward_one(img) for img in imgs] 42 | # merge all dictionaries into one 43 | res = {k: [r[k] for r in res if k in r] for k in {k for r in res for k in r}} 44 | return dict(res, imgs=imgs, **kw) 45 | 46 | 47 | class PatchNet(BaseNet): 48 | """ Helper class to construct a fully-convolutional network that 49 | extract a l2-normalized patch descriptor. 50 | """ 51 | 52 | def __init__(self, inchan=3, dilated=True, dilation=1, bn=True, bn_affine=False): 53 | BaseNet.__init__(self) 54 | self.inchan = inchan 55 | self.curchan = inchan 56 | self.dilated = dilated 57 | self.dilation = dilation 58 | self.bn = bn 59 | self.bn_affine = bn_affine 60 | self.ops = nn.ModuleList([]) 61 | 62 | def _make_bn(self, outd): 63 | return nn.BatchNorm2d(outd, affine=self.bn_affine) 64 | 65 | def _add_conv(self, outd, k=3, stride=1, dilation=1, bn=True, relu=True): 66 | d = self.dilation * dilation 67 | if self.dilated: 68 | conv_params = dict(padding=((k - 1) * d) // 2, dilation=d, stride=1) 69 | self.dilation *= stride 70 | else: 71 | conv_params = dict(padding=((k - 1) * d) // 2, dilation=d, stride=stride) 72 | self.ops.append(nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params)) 73 | if bn and self.bn: self.ops.append(self._make_bn(outd)) 74 | if relu: self.ops.append(nn.ReLU(inplace=True)) 75 | self.curchan = outd 76 | 77 | def forward_one(self, x): 78 | assert self.ops, "You need to add convolutions first" 79 | for n, op in enumerate(self.ops): 80 | x = op(x) 81 | return self.normalize(x) 82 | 83 | 84 | class L2_Net(PatchNet): 85 | """ Compute a 128D descriptor for all overlapping 32x32 patches. 86 | From the L2Net paper (CVPR'17). 87 | """ 88 | 89 | def __init__(self, dim=128, **kw): 90 | PatchNet.__init__(self, **kw) 91 | add_conv = lambda n, **kw: self._add_conv((n * dim) // 128, **kw) 92 | add_conv(32) 93 | add_conv(32) 94 | add_conv(64, stride=2) 95 | add_conv(64) 96 | add_conv(128, stride=2) 97 | add_conv(128) 98 | add_conv(128, k=7, stride=8, bn=False, relu=False) 99 | self.out_dim = dim 100 | 101 | 102 | class Quad_L2Net(PatchNet): 103 | """ Same than L2_Net, but replace the final 8x8 conv by 3 successive 2x2 convs. 104 | """ 105 | 106 | def __init__(self, dim=128, mchan=4, relu22=False, **kw): 107 | PatchNet.__init__(self, **kw) 108 | self._add_conv(8 * mchan) 109 | self._add_conv(8 * mchan) 110 | self._add_conv(16 * mchan, stride=2) 111 | self._add_conv(16 * mchan) 112 | self._add_conv(32 * mchan, stride=2) 113 | self._add_conv(32 * mchan) 114 | # replace last 8x8 convolution with 3 2x2 convolutions 115 | self._add_conv(32 * mchan, k=2, stride=2, relu=relu22) 116 | self._add_conv(32 * mchan, k=2, stride=2, relu=relu22) 117 | self._add_conv(dim, k=2, stride=2, bn=False, relu=False) 118 | self.out_dim = dim 119 | 120 | 121 | class Quad_L2Net_ConfCFS(Quad_L2Net): 122 | """ Same than Quad_L2Net, with 2 confidence maps for repeatability and reliability. 123 | """ 124 | 125 | def __init__(self, **kw): 126 | Quad_L2Net.__init__(self, **kw) 127 | # reliability classifier 128 | self.clf = nn.Conv2d(self.out_dim, 2, kernel_size=1) 129 | # repeatability classifier: for some reasons it's a softplus, not a softmax! 130 | # Why? I guess it's a mistake that was left unnoticed in the code for a long time... 131 | self.sal = nn.Conv2d(self.out_dim, 1, kernel_size=1) 132 | 133 | def forward_one(self, x): 134 | assert self.ops, "You need to add convolutions first" 135 | for op in self.ops: 136 | x = op(x) 137 | # compute the confidence maps 138 | ureliability = self.clf(x ** 2) 139 | urepeatability = self.sal(x ** 2) 140 | return self.normalize(x, ureliability, urepeatability) 141 | 142 | 143 | def conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1, use_bn=False, dilation=1): 144 | if not use_bn: 145 | return nn.Sequential( 146 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, 147 | padding=1, dilation=dilation), 148 | nn.ReLU(), 149 | ) 150 | else: 151 | return nn.Sequential( 152 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, 153 | padding=padding, dilation=dilation), 154 | nn.BatchNorm2d(out_channels, affine=False), 155 | nn.ReLU(), 156 | ) 157 | 158 | 159 | class NonMaxSuppression(torch.nn.Module): 160 | def __init__(self, rel_thr=0.7, rep_thr=0.7): 161 | nn.Module.__init__(self) 162 | self.max_filter = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 163 | self.rel_thr = rel_thr 164 | self.rep_thr = rep_thr 165 | 166 | def forward(self, reliability, repeatability, **kw): 167 | assert len(reliability) == len(repeatability) == 1 168 | reliability, repeatability = reliability[0], repeatability[0] 169 | 170 | # local maxima 171 | maxima = (repeatability == self.max_filter(repeatability)) 172 | 173 | # remove low peaks 174 | maxima *= (repeatability >= self.rep_thr) 175 | maxima *= (reliability >= self.rel_thr) 176 | 177 | return maxima.nonzero().t()[2:4] 178 | 179 | 180 | def extract_multiscale(net, img, detector, scale_f=2 ** 0.25, 181 | min_scale=0.0, max_scale=1, 182 | min_size=256, max_size=1024, 183 | verbose=False): 184 | old_bm = torch.backends.cudnn.benchmark 185 | torch.backends.cudnn.benchmark = False # speedup 186 | 187 | # extract keypoints at multiple scales 188 | B, three, H, W = img.shape 189 | assert B == 1 and three == 3, "should be a batch with a single RGB image" 190 | 191 | assert max_scale <= 1 192 | s = 1.0 # current scale factor 193 | 194 | X, Y, S, C, Q, D = [], [], [], [], [], [] 195 | pts_list = [] 196 | while s + 0.001 >= max(min_scale, min_size / max(H, W)): 197 | if s - 0.001 <= min(max_scale, max_size / max(H, W)): 198 | nh, nw = img.shape[2:] 199 | if verbose: 200 | # print(f"extracting at scale x{s:.02f} = {nw:4d}x{nh:3d}") 201 | print("extracting at scale x{:.02f} = {:4d}x{:3d}".format(s, nw, nh)) 202 | # extract descriptors 203 | with torch.no_grad(): 204 | res = net(imgs=[img]) 205 | 206 | # get output and reliability map 207 | descriptors = res['descriptors'][0] 208 | reliability = res['reliability'][0] 209 | repeatability = res['repeatability'][0] 210 | 211 | # normalize the reliability for nms 212 | # extract maxima and descs 213 | with torch.no_grad(): 214 | y, x = detector(**res) # nms 215 | c = reliability[0, 0, y, x] 216 | q = repeatability[0, 0, y, x] 217 | d = descriptors[0, :, y, x].t() 218 | n = d.shape[0] 219 | 220 | # accumulate multiple scales 221 | X.append(x.float() * W / nw) 222 | Y.append(y.float() * H / nh) 223 | S.append((32 / s) * torch.ones(n, dtype=torch.float32, device=d.device)) 224 | C.append(c) 225 | Q.append(q) 226 | D.append(d) 227 | 228 | pts_list.append(torch.stack([x.float() * W / nw, y.float() * H / nh], dim=-1)) 229 | 230 | s /= scale_f 231 | 232 | # down-scale the image for next iteration 233 | nh, nw = round(H * s), round(W * s) 234 | img = F.interpolate(img, (nh, nw), mode='bilinear', align_corners=False) 235 | 236 | # restore value 237 | torch.backends.cudnn.benchmark = old_bm 238 | 239 | # print("Y: ", len(Y)) 240 | # print("X: ", len(X)) 241 | # print("S: ", len(S)) 242 | Y = torch.cat(Y) 243 | X = torch.cat(X) 244 | S = torch.cat(S) # scale 245 | scores = torch.cat(C) * torch.cat(Q) # scores = reliability * repeatability 246 | XYS = torch.stack([X, Y, S], dim=-1) 247 | D = torch.cat(D) 248 | 249 | XYS = XYS.cpu().numpy() 250 | D = D.cpu().numpy() 251 | scores = scores.cpu().numpy() 252 | return XYS, D, scores, pts_list 253 | 254 | 255 | def extract(net, img, detector): 256 | old_bm = torch.backends.cudnn.benchmark 257 | torch.backends.cudnn.benchmark = False # speedup 258 | 259 | with torch.no_grad(): 260 | res = net(imgs=[img]) 261 | 262 | # get output and reliability map 263 | descriptors = res['descriptors'][0] 264 | reliability = res['reliability'][0] 265 | repeatability = res['repeatability'][0] 266 | 267 | print("rel: ", torch.min(reliability), torch.max(reliability), torch.median(reliability)) 268 | print("rep: ", torch.min(repeatability), torch.max(repeatability), torch.median(repeatability)) 269 | 270 | # normalize the reliability for nms 271 | # extract maxima and descs 272 | with torch.no_grad(): 273 | y, x = detector(**res) # nms 274 | c = reliability[0, 0, y, x] 275 | q = repeatability[0, 0, y, x] 276 | d = descriptors[0, :, y, x].t() 277 | n = d.shape[0] 278 | 279 | print("after nms: ", n) 280 | 281 | X, Y, S, C, Q, D = [], [], [], [], [], [] 282 | 283 | X.append(x.float()) 284 | Y.append(y.float()) 285 | C.append(c) 286 | Q.append(q) 287 | D.append(d) 288 | 289 | # restore value 290 | torch.backends.cudnn.benchmark = old_bm 291 | 292 | Y = torch.cat(Y) 293 | X = torch.cat(X) 294 | scores = torch.cat(C) * torch.cat(Q) 295 | 296 | XYS = torch.stack([X, Y], dim=-1) 297 | D = torch.cat(D) 298 | 299 | return XYS, D, scores 300 | 301 | 302 | def extract_r2d2_return(r2d2, img, need_nms=False, **kwargs): 303 | # img = Image.open(img_path).convert('RGB') 304 | # H, W = img.size 305 | img = norm_RGB(img) 306 | img = img[None] 307 | img = img.cuda() 308 | 309 | rel_thr = 0.99 # 0.99 310 | rep_thr = 0.995 # 0.995 311 | min_size = 256 312 | max_size = 9999 313 | detector = NonMaxSuppression(rel_thr=rel_thr, rep_thr=rep_thr).cuda().eval() 314 | 315 | xys, desc, scores, pts_list = extract_multiscale(net=r2d2, img=img, detector=detector, min_size=min_size, 316 | max_size=max_size, scale_f=1.2) # r2d2 mode 317 | 318 | return xys, desc, scores 319 | -------------------------------------------------------------------------------- /net/plceregnets/gem.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | '''================================================= 3 | @Project -> File shloc -> gem 4 | @IDE PyCharm 5 | @Author fx221@cam.ac.uk 6 | @Date 25/08/2021 21:00 7 | ==================================================''' 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn.modules import Module 11 | from torch.nn.parameter import Parameter 12 | import torch.nn.functional as F 13 | 14 | 15 | class GeneralizedMeanPooling(Module): 16 | r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes. 17 | The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)` 18 | - At p = infinity, one gets Max Pooling 19 | - At p = 1, one gets Average Pooling 20 | The output is of size H x W, for any input size. 21 | The number of output features is equal to the number of input planes. 22 | Args: 23 | output_size: the target output size of the image of the form H x W. 24 | Can be a tuple (H, W) or a single H for a square image H x H 25 | H and W can be either a ``int``, or ``None`` which means the size will 26 | be the same as that of the input. 27 | """ 28 | 29 | def __init__(self, norm, output_size=1, eps=1e-6): 30 | super(GeneralizedMeanPooling, self).__init__() 31 | assert norm > 0 32 | self.p = float(norm) 33 | self.output_size = output_size 34 | self.eps = eps 35 | 36 | def forward(self, x): 37 | x = x.clamp(min=self.eps).pow(self.p) 38 | return F.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p) 39 | 40 | def __repr__(self): 41 | return self.__class__.__name__ + '(' \ 42 | + str(self.p) + ', ' \ 43 | + 'output_size=' + str(self.output_size) + ')' 44 | 45 | 46 | class GeneralizedMeanPoolingP(GeneralizedMeanPooling): 47 | """ Same, but norm is trainable 48 | """ 49 | 50 | def __init__(self, norm=3, output_size=1, eps=1e-6): 51 | super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps) 52 | self.p = Parameter(torch.ones(1) * norm) 53 | 54 | 55 | def l2_normalize(x, axis=-1): 56 | x = F.normalize(x, p=2, dim=axis) 57 | return x 58 | 59 | 60 | class GEM(nn.Module): 61 | def __init__(self, in_dim=1024, out_dim=2048, norm_features=False, pooling='gem', gemp=3, center_bias=0, 62 | dropout_p=None, without_fc=False, projection=False, cls=False, n_classes=452): 63 | super(GEM, self).__init__() 64 | self.norm_features = norm_features 65 | self.without_fc = without_fc 66 | self.pooling = pooling 67 | self.center_bias = center_bias 68 | self.projection = projection 69 | self.cls = cls 70 | self.n_classes = n_classes 71 | 72 | if self.projection: 73 | self.proj = nn.Sequential( 74 | nn.Conv2d(in_dim, in_dim, kernel_size=3, padding=1, bias=False), 75 | nn.BatchNorm2d(in_dim), 76 | nn.ReLU(inplace=True), 77 | nn.Conv2d(in_dim, in_dim, kernel_size=3, padding=1, bias=False), 78 | nn.BatchNorm2d(in_dim), 79 | nn.ReLU(inplace=True), 80 | nn.Conv2d(in_dim, in_dim, kernel_size=1, padding=0, bias=False), 81 | # nn.BatchNorm2d(in_dim), 82 | # nn.ReLU(inplace=True), 83 | ) 84 | else: 85 | self.proj = nn.Sequential( 86 | nn.Conv2d(in_dim, in_dim, kernel_size=1, padding=0, bias=False), 87 | ) 88 | 89 | if self.cls: 90 | self.cls_head = nn.Linear(in_features=in_dim, out_features=self.n_classes) 91 | 92 | if pooling == 'max': 93 | self.adpool = nn.AdaptiveMaxPool2d(output_size=1) 94 | elif pooling == 'avg': 95 | self.adpool = nn.AdaptiveAvgPool2d(output_size=1) 96 | elif pooling.startswith('gem'): 97 | self.adpool = GeneralizedMeanPoolingP(norm=gemp) 98 | else: 99 | raise ValueError(pooling) 100 | 101 | self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None 102 | # self.fc = nn.Linear(512 * block.expansion, out_dim) 103 | # self.fc = nn.Linear(in_dim, out_dim) 104 | # self.fc_name = 'fc' 105 | # self.feat_dim = out_dim 106 | 107 | def forward(self, feat, atten=None): 108 | if atten is not None: 109 | with torch.no_grad(): 110 | if atten.shape[2] != feat.shape[2] or atten.shape[3] != feat.shape[3]: 111 | atten = F.interpolate(atten, size=(feat.shape[2], feat.shape[3]), mode='bilinear') 112 | feat = feat * atten.expand_as(feat) 113 | 114 | if self.projection: 115 | x = self.proj(feat) 116 | 117 | bs, _, H, W = x.shape 118 | if self.dropout is not None: 119 | x = self.dropout(x) 120 | 121 | if self.center_bias > 0: 122 | b = self.center_bias 123 | bias = 1 + torch.FloatTensor([[[[0, 0, 0, 0], [0, b, b, 0], [0, b, b, 0], [0, 0, 0, 0]]]]).to(x.device) 124 | bias_resize = torch.nn.functional.interpolate(bias, size=x.shape[-2:], mode='bilinear', align_corners=True) 125 | x = x * bias_resize 126 | # global pooling 127 | x = self.adpool(x) 128 | 129 | if self.norm_features: 130 | x = l2_normalize(x, axis=1) 131 | 132 | x = x.view(x.shape[0], -1) 133 | # if not self.without_fc: 134 | # x = self.fc(x) 135 | if self.cls: 136 | cls_feat = self.cls_head(x) 137 | x = l2_normalize(x, axis=-1) 138 | 139 | output = { 140 | 'feat': x, 141 | } 142 | 143 | if self.cls: 144 | output['cls'] = cls_feat 145 | return output 146 | -------------------------------------------------------------------------------- /net/plceregnets/pregnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | '''================================================= 3 | @Project -> File shloc -> pregnet 4 | @IDE PyCharm 5 | @Author fx221@cam.ac.uk 6 | @Date 23/08/2021 21:55 7 | ==================================================''' 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from sklearn.neighbors import NearestNeighbors 12 | import numpy as np 13 | 14 | 15 | class NetVLAD(nn.Module): 16 | """NetVLAD layer implementation""" 17 | 18 | def __init__(self, num_clusters=64, dim=128, in_dim=512, proj_layer=None, 19 | normalize_input=True, vladv2=False, projection=False, cls=False, n_classes=452): 20 | """ 21 | Args: 22 | num_clusters : int 23 | The number of clusters 24 | dim : int 25 | Dimension of descriptors 26 | alpha : float 27 | Parameter of initialization. Larger value is harder assignment. 28 | normalize_input : bool 29 | If true, descriptor-wise L2 normalization is applied to input. 30 | vladv2 : bool 31 | If true, use vladv2 otherwise use vladv1 32 | """ 33 | super(NetVLAD, self).__init__() 34 | self.num_clusters = num_clusters 35 | self.dim = dim 36 | self.alpha = 0 37 | self.vladv2 = vladv2 38 | self.normalize_input = normalize_input 39 | self.conv = nn.Conv2d(dim, num_clusters, kernel_size=(1, 1), bias=vladv2) 40 | self.centroids = nn.Parameter(torch.rand(num_clusters, dim)) 41 | self.cls = cls 42 | self.n_classes = n_classes 43 | 44 | self.projection = projection 45 | if self.projection: 46 | if proj_layer is None: 47 | self.proj = nn.Sequential( 48 | nn.Conv2d(in_dim, 512, kernel_size=3, padding=1, bias=False), 49 | nn.BatchNorm2d(512), 50 | nn.ReLU(inplace=True), 51 | nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False), 52 | # nn.BatchNorm2d(512), 53 | # nn.ReLU(inplace=True), 54 | ) 55 | else: 56 | self.proj = nn.Sequential(*proj_layer) 57 | 58 | if self.cls: 59 | self.cls_head = nn.Linear(in_features=512, out_features=self.n_classes) 60 | 61 | def init_params(self, clsts, traindescs): 62 | print("Init centroids") 63 | # TODO replace numpy ops with pytorch ops 64 | if self.vladv2 == False: 65 | clstsAssign = clsts / (np.linalg.norm(clsts, axis=1, keepdims=True) + 1e-5) 66 | dots = np.dot(clstsAssign, traindescs.T) 67 | dots.sort(0) 68 | dots = dots[::-1, :] # sort, descending 69 | 70 | self.alpha = (-np.log(0.01) / np.mean(dots[0, :] - dots[1, :])).item() 71 | self.centroids = nn.Parameter(torch.from_numpy(clsts)) 72 | self.conv.weight = nn.Parameter(torch.from_numpy(self.alpha * clstsAssign).unsqueeze(2).unsqueeze(3)) 73 | self.conv.bias = None 74 | else: 75 | knn = NearestNeighbors(n_jobs=-1) # TODO faiss? 76 | knn.fit(traindescs) 77 | del traindescs 78 | dsSq = np.square(knn.kneighbors(clsts, 2)[1]) 79 | del knn 80 | self.alpha = (-np.log(0.01) / np.mean(dsSq[:, 1] - dsSq[:, 0])).item() 81 | self.centroids = nn.Parameter(torch.from_numpy(clsts)) 82 | del clsts, dsSq 83 | 84 | self.conv.weight = nn.Parameter( 85 | (2.0 * self.alpha * self.centroids).unsqueeze(-1).unsqueeze(-1) 86 | ) 87 | self.conv.bias = nn.Parameter( 88 | - self.alpha * self.centroids.norm(dim=1) 89 | ) 90 | 91 | def forward(self, x, atten=None): 92 | if atten is not None: 93 | if atten.shape[2] != x.shape[2] or atten.shape[3] != x.shape[3]: 94 | with torch.no_grad(): 95 | atten = F.interpolate(atten, size=(x.shape[2], x.shape[3]), mode='bilinear') 96 | x = x * atten.expand_as(x) 97 | 98 | # if self.projection: 99 | x_enc = self.proj(x) 100 | # else: 101 | # x_enc = x 102 | if self.cls: 103 | x_vec = F.adaptive_avg_pool2d(x_enc, output_size=1).reshape(N, -1) 104 | cls_feat = self.cls_head(x_vec) 105 | 106 | if self.normalize_input: 107 | x = F.normalize(x_enc, p=2, dim=1) # across descriptor dim 108 | 109 | N, C, H, W = x.shape 110 | # soft-assignment 111 | soft_assign = self.conv(x).view(N, self.num_clusters, -1) 112 | soft_assign = F.softmax(soft_assign, dim=1) 113 | 114 | x_flatten = x.view(N, C, -1) 115 | 116 | # calculate residuals to each clusters 117 | vlad = torch.zeros([N, self.num_clusters, C], dtype=x.dtype, layout=x.layout, device=x.device) 118 | for C in range(self.num_clusters): # slower than non-looped, but lower memory usage 119 | residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - \ 120 | self.centroids[C:C + 1, :].expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0) 121 | residual *= soft_assign[:, C:C + 1, :].unsqueeze(2) 122 | vlad[:, C:C + 1, :] = residual.sum(dim=-1) 123 | 124 | vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization 125 | vlad = vlad.view(x.size(0), -1) # flatten 126 | vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize 127 | 128 | output = { 129 | 'feat': vlad, 130 | 'img_desc': x_enc, 131 | } 132 | 133 | if self.cls: 134 | output['cls'] = cls_feat 135 | return output 136 | -------------------------------------------------------------------------------- /net/regnets/deeplab.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | '''================================================= 3 | @Project -> File shloc -> deeplab 4 | @IDE PyCharm 5 | @Author fx221@cam.ac.uk 6 | @Date 06/09/2021 13:36 7 | ==================================================''' 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from segmentation_models_pytorch.encoders import get_encoder 12 | from segmentation_models_pytorch.deeplabv3.decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder 13 | from segmentation_models_pytorch.base import SegmentationModel, SegmentationHead, ClassificationHead 14 | from segmentation_models_pytorch.base.modules import Flatten, Activation, Conv2dReLU 15 | from typing import Optional, Union 16 | 17 | 18 | class ASPPConv(nn.Sequential): 19 | def __init__(self, in_channels, out_channels, dilation): 20 | modules = [ 21 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), 22 | nn.BatchNorm2d(out_channels), 23 | nn.ReLU(inplace=True) 24 | ] 25 | super(ASPPConv, self).__init__(*modules) 26 | 27 | 28 | class ASPPPooling(nn.Sequential): 29 | def __init__(self, in_channels, out_channels): 30 | super(ASPPPooling, self).__init__( 31 | nn.AdaptiveAvgPool2d(1), 32 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 33 | nn.BatchNorm2d(out_channels), 34 | nn.ReLU(inplace=True)) 35 | 36 | def forward(self, x): 37 | size = x.shape[-2:] 38 | x = super(ASPPPooling, self).forward(x) 39 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False) 40 | 41 | 42 | class ASPP(nn.Module): 43 | def __init__(self, in_channels, atrous_rates): 44 | super(ASPP, self).__init__() 45 | out_channels = 256 46 | modules = [] 47 | modules.append(nn.Sequential( 48 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 49 | nn.BatchNorm2d(out_channels), 50 | nn.ReLU(inplace=True))) 51 | 52 | rate1, rate2, rate3 = tuple(atrous_rates) 53 | modules.append(ASPPConv(in_channels, out_channels, rate1)) 54 | modules.append(ASPPConv(in_channels, out_channels, rate2)) 55 | modules.append(ASPPConv(in_channels, out_channels, rate3)) 56 | modules.append(ASPPPooling(in_channels, out_channels)) 57 | 58 | self.convs = nn.ModuleList(modules) 59 | 60 | self.project = nn.Sequential( 61 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), 62 | nn.BatchNorm2d(out_channels), 63 | nn.ReLU(inplace=True), 64 | nn.Dropout(0.1), ) 65 | 66 | def forward(self, x): 67 | res = [] 68 | for conv in self.convs: 69 | res.append(conv(x)) 70 | res = torch.cat(res, dim=1) 71 | return self.project(res) 72 | 73 | 74 | class DeepLabHeadV3Plus(nn.Module): 75 | def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]): 76 | super(DeepLabHeadV3Plus, self).__init__() 77 | self.project = nn.Sequential( 78 | nn.Conv2d(low_level_channels, 48, 1, bias=False), 79 | nn.BatchNorm2d(48), 80 | nn.ReLU(inplace=True), 81 | ) 82 | 83 | self.aspp = ASPP(in_channels, aspp_dilate) 84 | 85 | self.classifier = nn.Sequential( 86 | nn.Conv2d(304, 256, 3, padding=1, bias=False), 87 | nn.BatchNorm2d(256), 88 | nn.ReLU(inplace=True), 89 | nn.Conv2d(256, num_classes, 1) 90 | ) 91 | self._init_weight() 92 | 93 | def forward(self, feature): 94 | low_level_feature = self.project(feature['low_level']) 95 | output_feature = self.aspp(feature['out']) 96 | output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear', 97 | align_corners=False) 98 | return self.classifier(torch.cat([low_level_feature, output_feature], dim=1)) 99 | 100 | def _init_weight(self): 101 | for m in self.modules(): 102 | if isinstance(m, nn.Conv2d): 103 | nn.init.kaiming_normal_(m.weight) 104 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 105 | nn.init.constant_(m.weight, 1) 106 | nn.init.constant_(m.bias, 0) 107 | 108 | 109 | class DeepLabV3Plus(nn.Module): 110 | def __init__(self, 111 | encoder_name: str = "resnet34", 112 | encoder_weights: Optional[str] = "imagenet", 113 | encoder_output_stride: int = 8, 114 | decoder_channels: int = 256, 115 | decoder_atrous_rates: tuple = (12, 24, 36), 116 | in_channels: int = 3, 117 | encoder_depth: int = 3, 118 | classes: int = 1, 119 | upsampling: int = 8, 120 | activation: Optional[str] = None, 121 | classification: bool = False, 122 | ): 123 | super(DeepLabV3Plus, self).__init__() 124 | self.classification = classification 125 | 126 | self.encoder = get_encoder(name=encoder_name, 127 | in_channels=in_channels, 128 | depth=encoder_depth, 129 | weights=encoder_weights, 130 | ) 131 | 132 | if encoder_output_stride == 8: 133 | self.encoder.make_dilated( 134 | stage_list=[4, 5], 135 | dilation_list=[2, 4] 136 | ) 137 | 138 | elif encoder_output_stride == 16: 139 | self.encoder.make_dilated( 140 | stage_list=[5], 141 | dilation_list=[2] 142 | ) 143 | else: 144 | raise ValueError( 145 | "Encoder output stride should be 8 or 16, got {}".format(encoder_output_stride)) 146 | 147 | self.decoder = DeepLabV3PlusDecoder( 148 | encoder_channels=self.encoder.out_channels, 149 | out_channels=decoder_channels, 150 | atrous_rates=decoder_atrous_rates, 151 | output_stride=encoder_output_stride, 152 | ) 153 | 154 | self.seghead = SegmentationHead( 155 | in_channels=self.decoder.out_channels, 156 | out_channels=classes, 157 | activation=activation, 158 | kernel_size=1, 159 | upsampling=upsampling, 160 | ) 161 | 162 | if self.classification: 163 | self.cls_head = nn.Sequential( 164 | nn.AdaptiveAvgPool2d(1), 165 | Flatten(), 166 | nn.Linear(self.encoder.out_channels[-1], classes, bias=True) 167 | ) 168 | 169 | def forward(self, x): 170 | features = self.encoder(x) 171 | # print('len: ', len(features)) 172 | # for i in range(5): 173 | # print(i, features[i].shape) 174 | decoder_output = self.decoder(*features) 175 | 176 | masks = self.seghead(decoder_output) 177 | # print('mask: ', masks.shape) 178 | output = {"masks": [masks]} 179 | if self.classification: 180 | cls = self.cls_head(features[-1]) 181 | output["cls"] = [cls] 182 | 183 | output['feats'] = features 184 | 185 | return output 186 | -------------------------------------------------------------------------------- /net/regnets/pspnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | '''================================================= 3 | @Project -> File shloc -> pspnet 4 | @IDE PyCharm 5 | @Author fx221@cam.ac.uk 6 | @Date 2021-05-12 19:22 7 | ==================================================''' 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from segmentation_models_pytorch.encoders import get_encoder 13 | from segmentation_models_pytorch.pspnet.decoder import PSPDecoder 14 | from segmentation_models_pytorch.base import SegmentationModel, SegmentationHead, ClassificationHead 15 | from segmentation_models_pytorch.base.modules import Flatten, Activation, Conv2dReLU 16 | import segmentation_models_pytorch.base.initialization as init 17 | from loss.seg_loss.crossentropy_loss import cross_entropy2d 18 | 19 | from typing import Optional, Union 20 | 21 | 22 | class CondLayer(nn.Module): 23 | """ 24 | implementation of the element-wise linear modulation layer 25 | """ 26 | 27 | def __init__(self): 28 | super(CondLayer, self).__init__() 29 | self.relu = nn.ReLU(inplace=True) 30 | 31 | def forward(self, x, gammas, betas): 32 | return self.relu((x * gammas.expand_as(x)) + betas.expand_as(x)) 33 | 34 | 35 | def conv(in_planes, out_planes, kernel_size=3, stride=1, bn=False): 36 | if not bn: 37 | return nn.Sequential( 38 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, 39 | stride=stride, padding=(kernel_size - 1) // 2), 40 | nn.ReLU(inplace=True) 41 | ) 42 | else: 43 | return nn.Sequential( 44 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, 45 | stride=stride, padding=(kernel_size - 1) // 2), 46 | nn.BatchNorm2d(out_planes, affine=True), 47 | nn.ReLU(inplace=True) 48 | ) 49 | 50 | 51 | def conv1x1(in_planes, out_planes): 52 | return nn.Sequential( 53 | nn.Conv2d(in_planes, out_planes, kernel_size=1, padding=0) 54 | ) 55 | 56 | 57 | class PSPUpsample(nn.Module): 58 | def __init__(self, in_channels, out_channels): 59 | super().__init__() 60 | self.conv = nn.Sequential( 61 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 62 | nn.BatchNorm2d(out_channels), 63 | nn.PReLU() 64 | ) 65 | 66 | def forward(self, x): 67 | h, w = 2 * x.size(2), 2 * x.size(3) 68 | p = F.upsample(input=x, size=(h, w), mode='bilinear') 69 | return self.conv(p) 70 | 71 | 72 | class PSPNetF(SegmentationModel): 73 | def __init__(self, 74 | encoder_name: str = "resnet34", 75 | encoder_weights: Optional[str] = "imagenet", 76 | encoder_depth: int = 3, 77 | psp_out_channels: int = 512, 78 | psp_use_batchnorm: bool = True, 79 | psp_dropout: float = 0.2, 80 | in_channels: int = 3, 81 | activation: Optional[Union[str, callable]] = None, 82 | upsampling: int = 8, 83 | hierarchical: bool = False, 84 | classification: bool = False, 85 | segmentation: bool = False, 86 | classes=21428, 87 | out_indices=(1, 2, 3), 88 | require_spp_feats=False 89 | ): 90 | super(PSPNetF, self).__init__() 91 | self.classification = classification 92 | self.require_spp_feats = require_spp_feats 93 | 94 | self.init_modules = [] 95 | 96 | self.encoder = get_encoder(name=encoder_name, 97 | in_channels=3, 98 | depth=encoder_depth, 99 | weights=encoder_weights) 100 | 101 | self.decoder = PSPDecoder( 102 | encoder_channels=self.encoder.out_channels, # 3, 64, 256 103 | use_batchnorm=psp_use_batchnorm, 104 | out_channels=psp_out_channels, 105 | dropout=psp_dropout, 106 | ) 107 | 108 | self.seghead = SegmentationHead( 109 | in_channels=psp_out_channels, 110 | out_channels=classes, 111 | kernel_size=3, 112 | activation=activation, 113 | upsampling=upsampling, # 8 for robotcar, 4 for obs9, 2 for obs6 & obs4 114 | ) 115 | 116 | if classification: 117 | self.cls_head = nn.Sequential( 118 | nn.AdaptiveAvgPool2d(1), 119 | Flatten(), 120 | nn.Linear(self.encoder.out_channels[-1], classes, bias=True) 121 | ) 122 | 123 | self.name = "psp-{}".format(encoder_name) 124 | self.initialize() 125 | 126 | def compute_seg_loss(self, pred_segs, gt_segs, weights=[1.0, 1.0, 1.0, 1.0]): 127 | # pred_segs = inputs["masks"] 128 | # gt_segs = outputs["label"] 129 | 130 | seg_loss = 0 131 | 132 | for pseg, gseg in zip(pred_segs, gt_segs): 133 | # print("pseg, gseg: ", pseg.shape, gseg.shape) 134 | gseg = gseg.cuda() 135 | if len(gseg.shape) == 3: 136 | gseg = gseg.unsqueeze(1) 137 | if pseg.shape[2] != gseg.shape[2] or pseg.shape[3] != gseg.shape[3]: 138 | gseg = F.interpolate(gseg.float(), size=(pseg.shape[2], pseg.shape[3]), mode="nearest") 139 | 140 | # seg_loss += cross_entropy_seg(input=pseg, target=gseg) 141 | seg_loss += cross_entropy2d(input=pseg, target=gseg.long()) 142 | 143 | return seg_loss 144 | 145 | def compute_cls_loss(self, pred_cls, gt_cls, method="cel"): 146 | cls_loss = 0 147 | for pc, gc in zip(pred_cls, gt_cls): 148 | gc = gc.cuda() 149 | cls_loss += torch.nn.functional.binary_cross_entropy_with_logits(pc, gc) 150 | return cls_loss 151 | 152 | def forward(self, x): 153 | """Sequentially pass `x` trough model`s encoder, decoder and heads""" 154 | # batch = x.shape[0] 155 | # device = self.parameters().device 156 | # print('device: ', device) 157 | features = self.encoder(x) 158 | # for v in features: 159 | # print(v.shape) 160 | decode_feat = self.decoder(*features) 161 | 162 | masks = self.seghead(decode_feat) 163 | 164 | # seg_loss = self.compute_seg_loss(pred_segs=[masks], gt_segs=input['label']) 165 | 166 | output = {"masks": [masks]} 167 | if self.classification: 168 | cls = self.cls_head(features[-1]) 169 | output["cls"] = [cls] 170 | output['feats'] = features 171 | 172 | return output 173 | 174 | def predict(self, x): 175 | """Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()` 176 | 177 | Args: 178 | x: 4D torch tensor with shape (batch_size, channels, height, width) 179 | 180 | Return: 181 | prediction: 4D torch tensor with shape (batch_size, classes, height, width) 182 | 183 | """ 184 | if self.training: 185 | self.eval() 186 | 187 | with torch.no_grad(): 188 | x = self.forward(x) 189 | 190 | return x 191 | 192 | def initialize(self): 193 | init.initialize_decoder(self.decoder) 194 | 195 | init.initialize_head(self.seghead) 196 | if self.classification: 197 | init.initialize_head(self.cls_head) 198 | 199 | 200 | if __name__ == '__main__': 201 | net = PSPNetF( 202 | encoder_name="timm-resnest50d", 203 | encoder_weights="imagenet", 204 | # classes=256, 205 | # clusters=200, 206 | encoder_depth=4, 207 | # psp_out_channels=512, 208 | ).cuda() 209 | 210 | print(net) 211 | img = torch.ones((4, 3, 256, 256)).cuda() 212 | out = net(img) 213 | if "masks" in out.keys(): 214 | masks = out["masks"] 215 | print(masks[0].shape, masks[1].shape, masks[2].shape) 216 | if "cls" in out.keys(): 217 | cls = out["cls"] 218 | print(cls[0].shape, cls[1].shape, cls[2].shape) 219 | # print (v.shape for v in masks) 220 | # print (v.shape for v in cls) 221 | -------------------------------------------------------------------------------- /net/segnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/5/11 上午11:36 3 | # @Author : Fei Xue 4 | # @Email : fx221@cam.ac.uk 5 | # @File : regnet.py 6 | # @Software: PyCharm 7 | 8 | from net.regnets.pspnet import PSPNetF 9 | from net.regnets.deeplab import DeepLabV3Plus 10 | 11 | 12 | def get_segnet(network, 13 | n_classes, 14 | encoder_name=None, 15 | encoder_weights=None, 16 | out_channels=512, 17 | classification=True, 18 | segmentation=True, 19 | encoder_depth=4, 20 | upsampling=8, 21 | ): 22 | if network == "pspf": 23 | net = PSPNetF( 24 | encoder_name=encoder_name, 25 | encoder_weights=encoder_weights, 26 | 27 | # aux_params=aux_params, 28 | encoder_depth=encoder_depth, # 3 for robotcar, 4 for obs 6 & 9 29 | psp_out_channels=out_channels, 30 | # hierarchical=hierarchical, 31 | classification=classification, 32 | # segmentation=segmentation, 33 | upsampling=upsampling, 34 | # classes=21428, 35 | # classes=3962, 36 | classes=n_classes, 37 | ) 38 | elif network == 'deeplabv3p': 39 | net = DeepLabV3Plus( 40 | encoder_name=encoder_name, 41 | encoder_weights=encoder_weights, 42 | decoder_channels=out_channels, 43 | decoder_atrous_rates=(12, 24, 36), 44 | encoder_output_stride=8, 45 | encoder_depth=encoder_depth, 46 | classification=classification, 47 | upsampling=upsampling, 48 | classes=n_classes, 49 | ) 50 | 51 | return net 52 | -------------------------------------------------------------------------------- /robotcar.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feixue94/lbr/bdd126b79da9e75759a364a0d8e010d76658f1af/robotcar.npy -------------------------------------------------------------------------------- /run_loc_aachen: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | root=/scratches/flyer_2 4 | 5 | dataset=$root/fx221/localization/aachen_v1_1 6 | outputs=$root/fx221/localization/outputs/aachen_v1_1 7 | image_dir=$root/fx221/localization/aachen_v1_1/images/images_upright 8 | seg_dir=$root/fx221/localization/aachen_v1_1/global_seg_instance 9 | #seg_dir=$root/fx221/localization/aachen_v1_1/global_seg_instance/remap 10 | save_root=$root/fx221/exp/shloc/aachen 11 | 12 | #feat=resnetxn-triv2-0001-n4096-r1600-mask 13 | #feat=resnetxn-triv2-0001-n1000-r1600-mask 14 | #feat=resnetxn-triv2-0001-n2000-r1600-mask 15 | feat=resnetxn-triv2-0001-n3000-r1600-mask 16 | 17 | 18 | #feat=resnetxn-triv2-ms-0001-n4096-r1600-mask 19 | #feat=resnetxn-triv2-ms-0001-n10000-r1600-mask 20 | 21 | 22 | weight_name=2021_08_29_12_49_48_aachen_pspf_resnext101_32x4d_d4_u8_b16_R256_E120_ceohem_adam_seg_cls_aug_stylized 23 | #weight_name=2021_09_06_15_13_26_aachen_deeplabv3p_resnet101_d5_u2_b8_R256_E120_ceohem_sgd_mlr_seg_cls_aug_stylized 24 | #weight_name=2021_09_08_19_27_29_aachen_deeplabv3p_resnext101_32x4d_d5_u2_b16_R256_E120_ceohem_sgd_poly_mlr_seg_cls_aug_stylized 25 | #weight_name=2021_09_11_22_12_37_aachen_deeplabv3p_resnext101_32x4d_d4_u2_b16_R256_E120_ceohem_sgd_poly_mlr_seg_cls_aug_stylized 26 | #save_dir=/data/cornucopia/fx221/exp/shloc/aachen/$weight_name/loc_by_seg 27 | save_dir=/scratches/flyer_2/fx221/exp/shloc/aachen/$weight_name/loc_by_seg 28 | 29 | #query_pair="/home/mifs/fx221/Research/Code/Hierarchical-Localization/pairs/aachen_v1.1/pairs-query-netvlad50.txt" 30 | #query_pair="/data/cornucopia/fx221/localization/outputs_hloc/aachen_v1.1/pairs-query-netvlad50.txt" 31 | query_pair=datasets/aachen/pairs-query-netvlad50.txt 32 | gt_pose_fn=/scratches/flyer_2/fx221/localization/outputs_hloc/aachen_v1_1/Aachen-v1.1_hloc_superpoint_n4096_r1600+superglue_netvlad50.txt 33 | 34 | 35 | #matcher=NNM 36 | matcher=NNML 37 | #matcher=NNR 38 | retrieval_type="lrnv" 39 | #retrieval_type="lrnv256" 40 | feature_type="feat_max" 41 | global_score_th=0.95 42 | rec_th=100 43 | nv_th=50 44 | ransac_thresh=12 45 | opt_thresh=12 46 | covisibility_frame=50 47 | init_type="clu" 48 | opt_type="clurefpos" 49 | k_seg=10 50 | k_can=5 51 | k_rec=30 52 | iters=5 53 | radius=20 54 | obs_thresh=3 55 | 56 | # with opt 57 | python3 -m localization.localizer \ 58 | --image_dir $image_dir \ 59 | --seg_dir $seg_dir \ 60 | --save_root $save_root \ 61 | --gt_pose_fn $gt_pose_fn \ 62 | --dataset aachen \ 63 | --map_gid_rgb_fn datasets/aachen/aachen_grgb_gid_v5.txt \ 64 | --db_imglist_fn datasets/aachen/aachen_db_imglist.txt \ 65 | --db_instance_fn aachen_452 \ 66 | --k_seg $k_seg \ 67 | --k_can $k_can \ 68 | --k_rec $k_rec \ 69 | --retrieval $query_pair \ 70 | --retrieval_type $retrieval_type \ 71 | --feature_type $feature_type \ 72 | --init_type $init_type \ 73 | --global_score_th $global_score_th \ 74 | --weight_name $weight_name \ 75 | --show_seg \ 76 | --reference_sfm $outputs/sfm_$feat-$matcher/model \ 77 | --queries $dataset/queries/day_night_time_queries_with_intrinsics.txt \ 78 | --features $outputs/feats-$feat.h5 \ 79 | --matcher_method $matcher \ 80 | --ransac_thresh $ransac_thresh \ 81 | --with_label \ 82 | --with_match \ 83 | --rec_th $rec_th \ 84 | --nv_th $nv_th \ 85 | --covisibility_frame $covisibility_frame \ 86 | --iters $iters \ 87 | --radius $radius \ 88 | --obs_thresh $obs_thresh \ 89 | --opt_thresh $opt_thresh \ 90 | --opt_type $opt_type \ 91 | --do_covisible_opt -------------------------------------------------------------------------------- /run_loc_robotcar: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | root=/scratches/flyer_2 4 | 5 | outputs=$root/fx221/localization/outputs/robotcar 6 | dataset=$root/fx221/localization/RobotCar-Seasons 7 | save_root=$root/fx221/exp/shloc/robotcar 8 | 9 | image_dir=$root/fx221/localization/RobotCar-Seasons/images 10 | weight_name=2021_10_17_00_21_32_robotcar_pspf_resnext101_32x4d_d4_u8_b16_R256_E500_ceohem_adam_poly_mlr_seg_cls_aug_stylized 11 | seg_dir=$root/fx221/exp/shloc/robotcar/$weight_name/masks 12 | 13 | 14 | query_fn=$dataset/queries_with_intrinsics_rear.txt 15 | query_pair=datasets/robotcar/pairs-query-netvlad20-percam-perloc-rear.txt 16 | gt_pose_fn=/data/cornucopia/fx221/localization/RobotCar-Seasons/3D-models/query_poses_v2.txt 17 | 18 | #feat=resnetxn-triv2-0001-n4096-r1024-mask 19 | feat=resnetxn-triv2-ms-0001-n4096-r1024-mask 20 | 21 | #matcher=NNM 22 | matcher=NNML 23 | 24 | only_gt=0 25 | feature_type="feat_max" 26 | global_score_th=0.95 27 | 28 | rec_th=200 29 | nv_th=50 30 | ransac_thresh=12 31 | opt_thresh=8 32 | covisibility_frame=20 33 | init_type="clu" 34 | retrieval_type="lrnv" 35 | opt_type="clurefpos" 36 | k_seg=10 37 | k_can=1 38 | k_rec=20 39 | iters=5 40 | radius=20 41 | obs_thresh=3 42 | 43 | python3 -m localization.localizer \ 44 | --only_gt $only_gt \ 45 | --retrieval_type $retrieval_type \ 46 | --gt_pose_fn $gt_pose_fn \ 47 | --image_dir $image_dir \ 48 | --seg_dir $seg_dir \ 49 | --dataset robotcar \ 50 | --map_gid_rgb_fn datasets/robotcar/robotcar_grgb_gid.txt \ 51 | --db_imglist_fn datasets/robotcar/robotcar_rear_db_imglist.txt \ 52 | --db_instance_fn robotcar \ 53 | --save_root $save_root \ 54 | --k_seg $k_seg \ 55 | --k_can $k_can \ 56 | --k_rec $k_rec \ 57 | --feature_type $feature_type \ 58 | --init_type $init_type \ 59 | --global_score_th $global_score_th \ 60 | --weight_name $weight_name \ 61 | --show_seg \ 62 | --reference_sfm $outputs/sfm_$feat-$matcher/model \ 63 | --queries $query_fn \ 64 | --retrieval $query_pair \ 65 | --features $outputs/feats-$feat.h5 \ 66 | --matcher_method $matcher \ 67 | --ransac_thresh $ransac_thresh \ 68 | --with_label \ 69 | --with_match \ 70 | --rec_th $rec_th \ 71 | --nv_th $nv_th \ 72 | --covisibility_frame $covisibility_frame \ 73 | --iters $iters \ 74 | --radius $radius \ 75 | --obs_thresh $obs_thresh \ 76 | --opt_thresh $opt_thresh \ 77 | --opt_type $opt_type \ 78 | --do_covisible_opt -------------------------------------------------------------------------------- /run_reconstruct_aachen: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | colmap=/home/mifs/fx221/Research/Software/bin/colmap 3 | 4 | dataset=/scratches/flyer_2/fx221/localization/aachen_v1_1 5 | outputs=/scratches/flyer_2/fx221/localization/outputs/aachen_v1_1 6 | 7 | #feat=resnetxn-triv2-0001-n4096-r1600-mask 8 | #feat=resnetxn-triv2-0001-n3000-r1600-mask 9 | #feat=resnetxn-triv2-0001-n2000-r1600-mask 10 | feat=resnetxn-triv2-0001-n1000-r1600-mask 11 | 12 | 13 | image_dir=$dataset/images/images_upright 14 | mask_dir=$dataset/global_seg_instance 15 | 16 | matcher=NNML 17 | #matcher=NNM 18 | 19 | extract_feat=1 20 | match_db=1 21 | triangulate=1 22 | 23 | if [ "$extract_feat" -gt "0" ]; then 24 | python3 -m localization.fine.extractor --image_dir $image_dir --export_dir $outputs/ --conf $feat --mask_dir $mask_dir 25 | fi 26 | 27 | if [ "$match_db" -gt "0" ]; then 28 | python3 -m localization.fine.matcher --pairs datasets/aachen/pairs-db-covis20.txt --export_dir $outputs/ --conf $matcher --features feats-$feat 29 | fi 30 | 31 | if [ "$triangulate" -gt "0" ]; then 32 | python3 -m localization.fine.triangulate \ 33 | --sfm_dir $outputs/sfm_$feat-$matcher \ 34 | --reference_sfm_model $dataset/3D-models \ 35 | --image_dir $dataset/images/images_upright \ 36 | --pairs datasets/aachen/pairs-db-covis20.txt \ 37 | --features $outputs/feats-$feat.h5 \ 38 | --matches $outputs/feats-$feat-$matcher-pairs-db-covis20.h5 \ 39 | --colmap_path $colmap 40 | fi 41 | 42 | -------------------------------------------------------------------------------- /run_reconstruct_robotcar: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | colmap=/home/mifs/fx221/Research/Software/bin/colmap 3 | 4 | root_dir=/scratch2 5 | 6 | dataset=$root_dir/fx221/localization/RobotCar-Seasons 7 | outputs=$root_dir/fx221/localization/outputs/robotcar 8 | db_pair=$root_dir/fx221/localization/outputs/robotcar/pairs-db-covis20.txt 9 | 10 | feat=resnetxn-triv2-ms-0001-n4096-r1024-mask 11 | 12 | mask_dir=$root_dir/fx221/exp/shloc/robotcar/2021_10_17_00_21_32_robotcar_pspf_resnext101_32x4d_d4_u8_b16_R256_E500_ceohem_adam_poly_mlr_seg_cls_aug_stylized/masks 13 | 14 | matcher=NNML 15 | #matcher=NNM 16 | 17 | extract_feat_db=1 18 | match_db=1 19 | triangulate=1 20 | 21 | if [ "$extract_feat" -gt "0" ]; then 22 | python3 -m localization.fine.extractor --image_list datasets/robotcar/robotcar_db_query_imglist.txt --image_dir $dataset/images --export_dir $outputs/ --conf $feat --mask_dir $mask_dir 23 | fi 24 | 25 | if [ "$match_db" -gt "0" ]; then 26 | python3 -m localization.fine.matcher --pairs $db_pair --export_dir $outputs/ --conf $matcher --features feats-$feat 27 | fi 28 | 29 | if [ "$triangulate" -gt "0" ]; then 30 | python3 -m localization.fine.triangulate \ 31 | --sfm_dir $outputs/sfm_$feat-$matcher \ 32 | --reference_sfm_model $ROBOTCAR/3D-models/sfm-sift \ 33 | --image_dir $ROBOTCAR/images \ 34 | --pairs $outputs/pairs-db-covis20.txt \ 35 | --features $outputs/feats-$feat.h5 \ 36 | --matches $outputs/feats-$feat-$matcher-pairs-db-covis20.h5 \ 37 | --colmap_path $colmap 38 | fi -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | '''================================================= 3 | @Project -> File shloc -> test 4 | @IDE PyCharm 5 | @Author fx221@cam.ac.uk 6 | @Date 2021-05-31 15:09 7 | ==================================================''' 8 | import argparse 9 | import json 10 | import torchvision.transforms as tvf 11 | import torch 12 | import torch.nn.functional as F 13 | from tqdm import tqdm 14 | import cv2 15 | import os 16 | import numpy as np 17 | import os.path as osp 18 | from net.segnet import get_segnet 19 | from tools.common import torch_set_gpu 20 | from tools.seg_tools import read_seg_map_without_group, label_to_bgr 21 | 22 | val_transform = tvf.Compose( 23 | ( 24 | # tvf.ToPILImage(), 25 | # tvf.Resize(224), 26 | tvf.ToTensor(), 27 | tvf.Normalize(mean=[0.485, 0.456, 0.406], 28 | std=[0.229, 0.224, 0.225]) 29 | ) 30 | ) 31 | 32 | 33 | def predict(net, img): 34 | img_tensor = val_transform(img) 35 | img_tensor = img_tensor.cuda().unsqueeze(0) 36 | with torch.no_grad(): 37 | prediction = net(img_tensor) 38 | 39 | return prediction 40 | 41 | 42 | def inference_rec(output, img, fn, map_gid_rgb, save_dir=None): 43 | cv2.namedWindow("out", cv2.WINDOW_NORMAL) 44 | with torch.no_grad(): 45 | # output = predict(net=model, img=img) 46 | pred_mask = output["masks"][0] 47 | pred_label = torch.softmax(pred_mask, dim=1).max(1)[1].cpu().numpy() 48 | pred_conf = torch.softmax(pred_mask, dim=1).max(1)[0].cpu().numpy() 49 | 50 | last_feat = output['feats'][-1] 51 | pred_feat_max = F.adaptive_max_pool2d(last_feat, output_size=(1, 1)) 52 | pred_feat_avg = F.adaptive_avg_pool2d(last_feat, output_size=(1, 1)) 53 | 54 | if pred_label.shape[0] == 1: 55 | pred_label = pred_label[0] 56 | pred_conf = pred_conf[0] 57 | 58 | uids = np.unique(pred_label).tolist() 59 | pred_seg = label_to_bgr(label=pred_label, maps=map_gid_rgb) 60 | pred_conf_img = np.uint8(pred_conf * 255) 61 | pred_conf_img = cv2.applyColorMap(src=pred_conf_img, colormap=cv2.COLORMAP_PARULA) 62 | 63 | H = args.R # img.shape[0] 64 | W = args.R # img.shape[1] 65 | pred_seg = cv2.resize(pred_seg, dsize=(W, H), interpolation=cv2.INTER_NEAREST) 66 | pred_conf_img = cv2.resize(pred_conf_img, dsize=(W, H), interpolation=cv2.INTER_NEAREST) 67 | img = cv2.resize(img, dsize=(W, H)) 68 | 69 | img_seg = (0.5 * img + 0.5 * pred_seg).astype(np.uint8) 70 | cat_img = np.hstack([img_seg, pred_seg, pred_conf_img]) 71 | 72 | cv2.imshow("out", cat_img) 73 | key = cv2.waitKey() 74 | if key in (27, ord('q')): # exit by pressing key esc or q 75 | cv2.destroyAllWindows() 76 | exit(0) 77 | # return 78 | 79 | if save_dir is not None: 80 | conf_fn = osp.join(save_dir, "confidence", fn.split('.')[0] + ".npy") 81 | mask_fn = osp.join(save_dir, "masks", fn.replace("jpg", "png")) 82 | vis_fn = osp.join(save_dir, "vis", fn.replace("jpg", "png")) 83 | if not osp.exists(osp.dirname(conf_fn)): 84 | os.makedirs(osp.dirname(conf_fn), exist_ok=True) 85 | if not osp.exists(osp.dirname(vis_fn)): 86 | os.makedirs(osp.dirname(vis_fn), exist_ok=True) 87 | if not osp.exists(osp.dirname(mask_fn)): 88 | os.makedirs(osp.dirname(mask_fn), exist_ok=True) 89 | 90 | pred_confidence, pred_ids = torch.topk(torch.softmax(pred_mask, dim=1), k=10, largest=True, dim=1) 91 | conf_data = {"confidence": pred_confidence[0].cpu().numpy(), 92 | "ids": pred_ids[0].cpu().numpy(), 93 | 'feat_max': pred_feat_max.squeeze().cpu().numpy(), 94 | 'feat_avg': pred_feat_avg.squeeze().cpu().numpy(), 95 | } 96 | 97 | np.save(conf_fn, conf_data) 98 | cv2.imwrite(vis_fn, cat_img) 99 | cv2.imwrite(mask_fn, pred_seg) 100 | 101 | 102 | def main(args): 103 | map_gid_rgb = read_seg_map_without_group(args.grgb_gid_file) 104 | 105 | model = get_segnet(network=args.network, 106 | n_classes=args.classes, 107 | encoder_name=args.encoder_name, 108 | encoder_weights=args.encoder_weights, 109 | encoder_depth=args.encoder_depth, 110 | upsampling=args.upsampling, 111 | out_channels=args.out_channels, 112 | classification=args.classification, 113 | segmentation=args.segmentation, ) 114 | print("model: ", model) 115 | if args.pretrained_weight is not None: 116 | model.load_state_dict(torch.load(args.pretrained_weight), strict=True) 117 | print("Load weight from {:s}".format(args.pretrained_weight)) 118 | model.eval().cuda() 119 | 120 | img_path = args.image_path 121 | save_dir = args.save_dir 122 | 123 | print('Save results to ', save_dir) 124 | 125 | imglist = [] 126 | with open(args.image_list, "r") as f: 127 | lines = f.readlines() 128 | for l in lines: 129 | imglist.append(l.strip()) 130 | 131 | for fn in tqdm(imglist, total=len(imglist)): 132 | if fn.find('left') >= 0 or fn.find('right') >= 0: 133 | continue 134 | img = cv2.imread(osp.join(img_path, fn)) 135 | img = cv2.resize(img, dsize=(args.R, args.R)) 136 | 137 | with torch.no_grad(): 138 | output = predict(net=model, img=img) 139 | inference_rec(output=output, img=img, fn=fn, map_gid_rgb=map_gid_rgb, save_dir=save_dir) 140 | 141 | 142 | if __name__ == '__main__': 143 | parser = argparse.ArgumentParser("Test Semantic localization Network") 144 | parser.add_argument("--config", type=str, required=True, help="configuration file") 145 | parser.add_argument("--pretrained_weight", type=str, default=None) 146 | parser.add_argument("--save_root", type=str, default="/home/mifs/fx221/fx221/exp/shloc/aachen") 147 | parser.add_argument("--image_path", type=str, default=None) 148 | parser.add_argument("--network", type=str, default="pspf") 149 | parser.add_argument("--save_dir", type=str, default=None) 150 | parser.add_argument("--encoder_name", type=str, default='timm-resnest50d') 151 | parser.add_argument("--encoder_weights", type=str, default='imagenet') 152 | parser.add_argument("--out_channels", type=int, default='2048') 153 | parser.add_argument("--upsampling", type=int, default='8') 154 | parser.add_argument("--gpu", type=int, nargs='+', default=[0], help='-1 for CPU') 155 | parser.add_argument("--R", type=int, default=256) 156 | 157 | args = parser.parse_args() 158 | with open(args.config, 'rt') as f: 159 | t_args = argparse.Namespace() 160 | t_args.__dict__.update(json.load(f)) 161 | args = parser.parse_args(namespace=t_args) 162 | 163 | torch_set_gpu(gpus=args.gpu) 164 | main(args=args) 165 | -------------------------------------------------------------------------------- /test_aachen: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python3 test.py --config configs/config_test_aachen.json -------------------------------------------------------------------------------- /test_robotcar: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python3 test.py --config configs/config_test_robotcar.json -------------------------------------------------------------------------------- /tools/common.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2021/3/15 下午3:54 4 | @Auth : Fei Xue 5 | @File : common.py 6 | @Email: fx221@cam.ac.uk 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | import json 12 | from collections import OrderedDict 13 | import cv2 14 | from torch._six import string_classes 15 | import collections.abc as collections 16 | import os 17 | import os.path as osp 18 | 19 | 20 | def get_recursive_file_list(root_dir, sub_dir="", patterns=[]): 21 | current_files = os.listdir(osp.join(root_dir, sub_dir)) 22 | all_files = [] 23 | for file_name in current_files: 24 | full_file_name = os.path.join(root_dir, sub_dir, file_name) 25 | print(file_name) 26 | 27 | if file_name.split('.')[-1] in patterns: 28 | all_files.append(osp.join(sub_dir, file_name)) 29 | 30 | if os.path.isdir(full_file_name): 31 | next_level_files = get_recursive_file_list(root_dir, sub_dir=osp.join(sub_dir, file_name), 32 | patterns=patterns) 33 | all_files.extend(next_level_files) 34 | 35 | return all_files 36 | 37 | 38 | def sort_dict_by_value(data, reverse=False): 39 | return sorted(data.items(), key=lambda d: d[1], reverse=reverse) 40 | 41 | 42 | def mkdir_for(file_path): 43 | os.makedirs(os.path.split(file_path)[0], exist_ok=True) 44 | 45 | 46 | def model_size(model): 47 | ''' Computes the number of parameters of the model 48 | ''' 49 | size = 0 50 | for weights in model.state_dict().values(): 51 | size += np.prod(weights.shape) 52 | return size 53 | 54 | 55 | def torch_set_gpu(gpus): 56 | if type(gpus) is int: 57 | gpus = [gpus] 58 | 59 | cuda = all(gpu >= 0 for gpu in gpus) 60 | 61 | if cuda: 62 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu) for gpu in gpus]) 63 | assert cuda and torch.cuda.is_available(), "%s has GPUs %s unavailable" % ( 64 | os.environ['HOSTNAME'], os.environ['CUDA_VISIBLE_DEVICES']) 65 | torch.backends.cudnn.benchmark = True # speed-up cudnn 66 | torch.backends.cudnn.fastest = True # even more speed-up? 67 | print('Launching on GPUs ' + os.environ['CUDA_VISIBLE_DEVICES']) 68 | 69 | else: 70 | print('Launching on CPU') 71 | 72 | return cuda 73 | 74 | 75 | def save_args(args, save_path): 76 | with open(save_path, 'w') as f: 77 | json.dump(args.__dict__, f, indent=2) 78 | 79 | 80 | def load_args(args, save_path): 81 | with open(save_path, "r") as f: 82 | args.__dict__ = json.load(f) 83 | 84 | 85 | def adjust_learning_rate_poly(optimizer, epoch, num_epochs, base_lr, power): 86 | lr = base_lr * (1 - epoch / num_epochs) ** power 87 | for param_group in optimizer.param_groups: 88 | param_group['lr'] = lr 89 | return lr 90 | 91 | 92 | def read_json(fname): 93 | with fname.open('rt') as handle: 94 | return json.load(handle, object_hook=OrderedDict) 95 | 96 | 97 | def write_json(content, fname): 98 | with fname.open('wt') as handle: 99 | json.dump(content, handle, indent=4, sort_keys=False) 100 | 101 | 102 | def resize_img(img, nh=-1, nw=-1, mode=cv2.INTER_NEAREST): 103 | assert nh > 0 or nw > 0 104 | if nh == -1: 105 | return cv2.resize(img, dsize=(nw, int(img.shape[0] / img.shape[1] * nw)), interpolation=mode) 106 | if nw == -1: 107 | return cv2.resize(img, dsize=(int(img.shape[1] / img.shape[0] * nh), nh), interpolation=mode) 108 | return cv2.resize(img, dsize=(nw, nh), interpolation=mode) 109 | 110 | 111 | def map_tensor(input_, func): 112 | if isinstance(input_, torch.Tensor): 113 | return func(input_) 114 | elif isinstance(input_, string_classes): 115 | return input_ 116 | elif isinstance(input_, collections.Mapping): 117 | return {k: map_tensor(sample, func) for k, sample in input_.items()} 118 | elif isinstance(input_, collections.Sequence): 119 | return [map_tensor(sample, func) for sample in input_] 120 | else: 121 | raise TypeError( 122 | f'input must be tensor, dict or list; found {type(input_)}') 123 | 124 | 125 | def imgs2video(im_dir, video_dir): 126 | img_fns = os.listdir(im_dir) 127 | # print(img_fns) 128 | img_fns = [v for v in img_fns if v.split('.')[-1] in ['jpg', 'png']] 129 | img_fns = sorted(img_fns) 130 | # print(img_fns) 131 | fps = 1 132 | img_size = (800, 492) 133 | 134 | # fourcc = cv2.cv.CV_FOURCC('M','J','P','G')#opencv2.4 135 | # fourcc = cv2.VideoWriter_fourcc('I','4','2','0') 136 | 137 | fourcc = cv2.VideoWriter_fourcc(*'MP4V') 138 | # fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V') 139 | videoWriter = cv2.VideoWriter(video_dir, fourcc, fps, img_size) 140 | 141 | for i in range(0, 500): 142 | # fn = img_fns[i].split('-') 143 | im_name = os.path.join(im_dir, img_fns[i]) 144 | print(im_name) 145 | frame = cv2.imread(im_name, 1) 146 | frame = cv2.resize(frame, dsize=img_size) 147 | # print(frame.shape) 148 | # exit(0) 149 | cv2.imshow("frame", frame) 150 | cv2.waitKey(5) 151 | videoWriter.write(frame) 152 | 153 | videoWriter.release() 154 | -------------------------------------------------------------------------------- /tools/config_parser.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | '''================================================= 3 | @Project -> File shloc -> config_parser 4 | @IDE PyCharm 5 | @Author fx221@cam.ac.uk 6 | @Date 20/07/2021 10:51 7 | ==================================================''' 8 | 9 | import logging 10 | from pathlib import Path 11 | from functools import reduce 12 | from operator import getitem 13 | from datetime import datetime 14 | from tools.common import read_json, write_json, torch_set_gpu 15 | 16 | 17 | class ConfigParser: 18 | def __init__(self, args, options='', timestamp=True, test_only=False): 19 | # parse default and custom cli options 20 | for opt in options: 21 | args.add_argument(*opt.flags, default=None, type=opt.type) 22 | args = args.parse_args() 23 | 24 | # if args.device: 25 | # os.environ["CUDA_VISIBLE_DEVICES"] = args.device 26 | torch_set_gpu(gpus=args.gpu) 27 | 28 | self.cfg_fname = Path(args.config) 29 | # load config file and apply custom cli options 30 | config = read_json(self.cfg_fname) 31 | self.__config = _update_config(config, options, args) 32 | 33 | # set save_dir where trained model and log will be saved. 34 | save_dir = Path(self.config['trainer']['save_dir']) 35 | timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else '' 36 | self.timestamp = timestamp 37 | 38 | exper_name = self.config['name'] 39 | self.__save_dir = save_dir / 'models' / exper_name / timestamp 40 | self.__log_dir = save_dir / 'log' / exper_name / timestamp 41 | 42 | self.save_dir.mkdir(parents=True, exist_ok=True) 43 | self.log_dir.mkdir(parents=True, exist_ok=True) 44 | 45 | if not test_only: 46 | # save updated config file to the checkpoint dir 47 | write_json(self.config, self.save_dir / 'config.json') 48 | else: 49 | write_json(self.config, self.resume.parent / 'config_test.json') 50 | 51 | def initialize(self, name, module, *args, **kwargs): 52 | """ 53 | finds a function handle with the name given as 'type' in config, and returns the 54 | instance initialized with corresponding keyword args given as 'args'. 55 | """ 56 | module_cfg = self[name] 57 | module_cfg["args"].update(kwargs) 58 | return getattr(module, module_cfg['type'])(*args, **module_cfg['args']) 59 | 60 | def __getitem__(self, name): 61 | return self.config[name] 62 | 63 | def get_logger(self, name, verbosity=2): 64 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, 65 | self.log_levels.keys()) 66 | assert verbosity in self.log_levels, msg_verbosity 67 | logger = logging.getLogger(name) 68 | logger.setLevel(self.log_levels[verbosity]) 69 | return logger 70 | 71 | # setting read-only attributes 72 | @property 73 | def config(self): 74 | return self.__config 75 | 76 | @property 77 | def save_dir(self): 78 | return self.__save_dir 79 | 80 | @property 81 | def log_dir(self): 82 | return self.__log_dir 83 | 84 | 85 | # helper functions used to update config dict with custom cli options 86 | def _update_config(config, options, args): 87 | for opt in options: 88 | value = getattr(args, _get_opt_name(opt.flags)) 89 | if value is not None: 90 | _set_by_path(config, opt.target, value) 91 | return config 92 | 93 | 94 | def _get_opt_name(flags): 95 | for flg in flags: 96 | if flg.startswith('--'): 97 | return flg.replace('--', '') 98 | return flags[0].replace('--', '') 99 | 100 | 101 | def _set_by_path(tree, keys, value): 102 | """Set a value in a nested object in tree by sequence of keys.""" 103 | _get_by_path(tree, keys[:-1])[keys[-1]] = value 104 | 105 | 106 | def _get_by_path(tree, keys): 107 | """Access a nested object in tree by sequence of keys.""" 108 | return reduce(getitem, keys, tree) 109 | -------------------------------------------------------------------------------- /tools/loc_tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | '''================================================= 3 | @Project -> File lbr -> loc_tools 4 | @IDE PyCharm 5 | @Author fx221@cam.ac.uk 6 | @Date 06/04/2022 15:08 7 | ==================================================''' 8 | 9 | import numpy as np 10 | import cv2 11 | import torch 12 | from copy import copy 13 | from scipy.spatial.transform import Rotation as sciR 14 | 15 | 16 | def plot_keypoint(img_path, pts, scores=None, tag=None, save_path=None): 17 | if type(img_path) == str: 18 | img = cv2.imread(img_path) 19 | else: 20 | img = img_path.copy() 21 | 22 | img_out = img.copy() 23 | print(img.shape) 24 | r = 3 25 | for i in range(pts.shape[0]): 26 | pt = pts[i] 27 | # img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1])), int(r * s), (0, 0, 255), 4) 28 | img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1])), r, (0, 0, 255), 2) 29 | 30 | if save_path is not None: 31 | cv2.imwrite(save_path, img_out) 32 | return img_out 33 | 34 | 35 | def sort_dict_by_value(data, reverse=False): 36 | return sorted(data.items(), key=lambda d: d[1], reverse=reverse) 37 | 38 | 39 | def read_retrieval_results(path): 40 | output = {} 41 | with open(path, "r") as f: 42 | lines = f.readlines() 43 | for p in lines: 44 | p = p.strip("\n").split(" ") 45 | 46 | if p[1] == "no_match": 47 | continue 48 | if p[0] in output.keys(): 49 | output[p[0]].append(p[1]) 50 | else: 51 | output[p[0]] = [p[1]] 52 | return output 53 | 54 | 55 | def nn_k(query_gps, db_gps, k=20): 56 | q = torch.from_numpy(query_gps) # [N 2] 57 | db = torch.from_numpy(db_gps) # [M, 2] 58 | # print (q.shape, db.shape) 59 | dist = q.unsqueeze(2) - db.t().unsqueeze(0) 60 | dist = dist[:, 0, :] ** 2 + dist[:, 1, :] ** 2 61 | print("dist: ", dist.shape) 62 | topk = torch.topk(dist, dim=1, k=k, largest=False)[1] 63 | return topk 64 | 65 | 66 | def plot_matches(img1, img2, pts1, pts2, inliers, horizon=False, plot_outlier=False, confs=None, plot_match=True): 67 | rows1 = img1.shape[0] 68 | cols1 = img1.shape[1] 69 | rows2 = img2.shape[0] 70 | cols2 = img2.shape[1] 71 | r = 3 72 | if horizon: 73 | img_out = np.zeros((max([rows1, rows2]), cols1 + cols2, 3), dtype='uint8') 74 | # Place the first image to the left 75 | img_out[:rows1, :cols1] = img1 76 | # Place the next image to the right of it 77 | img_out[:rows2, cols1:] = img2 # np.dstack([img2, img2, img2]) 78 | 79 | if not plot_match: 80 | return cv2.resize(img_out, None, fx=0.5, fy=0.5) 81 | # for idx, pt in enumerate(pts1): 82 | # img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1])), r, (0, 0, 255), 2) 83 | # for idx, pt in enumerate(pts2): 84 | # img_out = cv2.circle(img_out, (int(pt[0] + cols1), int(pt[1])), r, (0, 0, 255), 2) 85 | for idx in range(inliers.shape[0]): 86 | # if idx % 10 > 0: 87 | # continue 88 | if inliers[idx]: 89 | color = (0, 255, 0) 90 | else: 91 | if not plot_outlier: 92 | continue 93 | color = (0, 0, 255) 94 | pt1 = pts1[idx] 95 | pt2 = pts2[idx] 96 | 97 | if confs is not None: 98 | nr = int(r * confs[idx]) 99 | else: 100 | nr = r 101 | img_out = cv2.circle(img_out, (int(pt1[0]), int(pt1[1])), nr, color, 2) 102 | 103 | img_out = cv2.circle(img_out, (int(pt2[0]) + cols1, int(pt2[1])), nr, color, 2) 104 | 105 | img_out = cv2.line(img_out, (int(pt1[0]), int(pt1[1])), (int(pt2[0]) + cols1, int(pt2[1])), color, 106 | 2) 107 | else: 108 | img_out = np.zeros((rows1 + rows2, max([cols1, cols2]), 3), dtype='uint8') 109 | # Place the first image to the left 110 | img_out[:rows1, :cols1] = img1 111 | # Place the next image to the right of it 112 | img_out[rows1:, :cols2] = img2 # np.dstack([img2, img2, img2]) 113 | 114 | if not plot_match: 115 | return cv2.resize(img_out, None, fx=0.5, fy=0.5) 116 | # for idx, pt in enumerate(pts1): 117 | # img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1])), r, (0, 0, 255), 2) 118 | # for idx, pt in enumerate(pts2): 119 | # img_out = cv2.circle(img_out, (int(pt[0]), int(pt[1] + rows1)), r, (0, 0, 255), 2) 120 | for idx in range(inliers.shape[0]): 121 | # print("idx: ", inliers[idx]) 122 | # if idx % 10 > 0: 123 | # continue 124 | if inliers[idx]: 125 | color = (0, 255, 0) 126 | else: 127 | if not plot_outlier: 128 | continue 129 | color = (0, 0, 255) 130 | 131 | if confs is not None: 132 | nr = int(r * confs[idx]) 133 | else: 134 | nr = r 135 | 136 | pt1 = pts1[idx] 137 | pt2 = pts2[idx] 138 | img_out = cv2.circle(img_out, (int(pt1[0]), int(pt1[1])), r, color, 2) 139 | 140 | img_out = cv2.circle(img_out, (int(pt2[0]), int(pt2[1]) + rows1), r, color, 2) 141 | 142 | img_out = cv2.line(img_out, (int(pt1[0]), int(pt1[1])), (int(pt2[0]), int(pt2[1]) + rows1), color, 143 | 2) 144 | 145 | img_rs = cv2.resize(img_out, None, fx=0.5, fy=0.5) 146 | 147 | # img_rs = cv2.putText(img_rs, 'matches: {:d}'.format(len(inliers.shape[0])), (50, 50), cv2.FONT_HERSHEY_COMPLEX, 2, 148 | # (0, 0, 255), 2) 149 | 150 | # if save_fn is not None: 151 | # cv2.imwrite(save_fn, img_rs) 152 | # cv2.imshow("match", img_rs) 153 | # cv2.waitKey(0) 154 | return img_rs 155 | 156 | 157 | def plot_reprojpoint2D(img, points2D, reproj_points2D, confs=None): 158 | img_out = copy(img) 159 | r = 5 160 | for i in range(points2D.shape[0]): 161 | p = points2D[i] 162 | rp = reproj_points2D[i] 163 | 164 | if confs is not None: 165 | nr = int(r * confs[i]) 166 | else: 167 | nr = r 168 | 169 | if nr >= 50: 170 | nr = 50 171 | # img_out = cv2.circle(img_out, (int(p[0]), int(p[1])), nr, color=(0, 255, 0), thickness=2) 172 | img_out = cv2.circle(img_out, (int(rp[0]), int(rp[1])), nr, color=(0, 0, 255), thickness=3) 173 | img_out = cv2.circle(img_out, (int(rp[0]), int(rp[1])), 2, color=(0, 0, 255), thickness=3) 174 | # img_out = cv2.line(img_out, pt1=(int(p[0]), int(p[1])), pt2=(int(rp[0]), int(rp[1])), color=(0, 0, 255), 175 | # thickness=2) 176 | 177 | return img_out 178 | 179 | 180 | def reproject_fromR(points3D, rot, tvec, camera): 181 | proj_2d = rot @ points3D.transpose() + tvec.reshape(3, 1) 182 | 183 | if camera['model'] == 'SIMPLE_RADIAL': 184 | f = camera['params'][0] 185 | cx = camera['params'][1] 186 | cy = camera['params'][2] 187 | k = camera['params'][3] 188 | 189 | proj_2d = proj_2d[0:2, :] / proj_2d[2, :] 190 | r2 = proj_2d[0, :] ** 2 + proj_2d[1, :] ** 2 191 | radial = r2 * k 192 | du = proj_2d[0, :] * radial 193 | dv = proj_2d[1, :] * radial 194 | 195 | u = proj_2d[0, :] + du 196 | v = proj_2d[1, :] + dv 197 | u = u * f + cx 198 | v = v * f + cy 199 | uvs = np.vstack([u, v]).transpose() 200 | 201 | return uvs 202 | 203 | 204 | def calc_depth(points3D, rvec, tvec, camera): 205 | rot = sciR.from_quat(quat=[rvec[1], rvec[2], rvec[3], rvec[0]]).as_dcm() 206 | # print('p3d: ', points3D.shape, rot.shape, rot) 207 | proj_2d = rot @ points3D.transpose() + tvec.reshape(3, 1) 208 | 209 | return proj_2d.transpose()[:, 2] 210 | 211 | 212 | def reproject(points3D, rvec, tvec, camera): 213 | ''' 214 | Args: 215 | points3D: [N, 3] 216 | rvec: [w, x, y, z] 217 | tvec: [x, y, z] 218 | Returns: 219 | ''' 220 | # print('camera: ', camera) 221 | # print('p3d: ', points3D.shape) 222 | rot = sciR.from_quat(quat=[rvec[1], rvec[2], rvec[3], rvec[0]]).as_dcm() 223 | # print('p3d: ', points3D.shape, rot.shape, rot) 224 | proj_2d = rot @ points3D.transpose() + tvec.reshape(3, 1) 225 | 226 | if camera['model'] == 'SIMPLE_RADIAL': 227 | f = camera['params'][0] 228 | cx = camera['params'][1] 229 | cy = camera['params'][2] 230 | k = camera['params'][3] 231 | 232 | proj_2d = proj_2d[0:2, :] / proj_2d[2, :] 233 | r2 = proj_2d[0, :] ** 2 + proj_2d[1, :] ** 2 234 | radial = r2 * k 235 | du = proj_2d[0, :] * radial 236 | dv = proj_2d[1, :] * radial 237 | 238 | u = proj_2d[0, :] + du 239 | v = proj_2d[1, :] + dv 240 | u = u * f + cx 241 | v = v * f + cy 242 | uvs = np.vstack([u, v]).transpose() 243 | 244 | return uvs 245 | 246 | 247 | def quaternion_angular_error(q1, q2): 248 | """ 249 | angular error between two quaternions 250 | :param q1: (4, ) 251 | :param q2: (4, ) 252 | :return: 253 | """ 254 | d = abs(np.dot(q1, q2)) 255 | d = min(1.0, max(-1.0, d)) 256 | theta = 2 * np.arccos(d) * 180 / np.pi 257 | return theta 258 | 259 | 260 | def ColmapQ2R(qvec): 261 | rot = sciR.from_quat(quat=[qvec[1], qvec[2], qvec[3], qvec[0]]).as_dcm() 262 | return rot 263 | 264 | 265 | def compute_pose_error(pred_qcw, pred_tcw, gt_qcw, gt_tcw): 266 | pred_Rcw = sciR.from_quat(quat=[pred_qcw[1], pred_qcw[2], pred_qcw[3], pred_qcw[0]]).as_dcm() 267 | pred_tcw = np.array(pred_tcw, float).reshape(3, 1) 268 | pred_Rwc = pred_Rcw.transpose() 269 | pred_twc = -pred_Rcw.transpose() @ pred_tcw 270 | 271 | gt_Rcw = sciR.from_quat(quat=[gt_qcw[1], gt_qcw[2], gt_qcw[3], gt_qcw[0]]).as_dcm() 272 | gt_tcw = np.array(gt_tcw, float).reshape(3, 1) 273 | gt_Rwc = gt_Rcw.transpose() 274 | gt_twc = -gt_Rcw.transpose() @ gt_tcw 275 | 276 | t_error_xyz = pred_twc - gt_twc 277 | t_error = np.sqrt(np.sum(t_error_xyz ** 2)) 278 | 279 | pred_qwc = sciR.from_quat(quat=[pred_qcw[1], pred_qcw[2], pred_qcw[3], pred_qcw[0]]).as_quat() 280 | gt_qwc = sciR.from_quat(quat=[gt_qcw[1], gt_qcw[2], gt_qcw[3], gt_qcw[0]]).as_quat() 281 | 282 | q_error = quaternion_angular_error(q1=pred_qwc, q2=gt_qwc) 283 | 284 | return q_error, t_error, (t_error_xyz[0, 0], t_error_xyz[1, 0], t_error_xyz[2, 0]) 285 | -------------------------------------------------------------------------------- /tools/optim.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | '''================================================= 3 | @Project -> File shloc -> optim 4 | @IDE PyCharm 5 | @Author fx221@cam.ac.uk 6 | @Date 08/09/2021 12:06 7 | ==================================================''' 8 | from torch.optim.lr_scheduler import _LRScheduler, StepLR 9 | 10 | 11 | # class PolyLR(_LRScheduler): 12 | # def __init__(self, optimizer, max_iters, power=0.9, last_epoch=-1, min_lr=1e-6): 13 | # self.power = power 14 | # self.max_iters = max_iters # avoid zero lr 15 | # self.min_lr = min_lr 16 | # super(PolyLR, self).__init__(optimizer, last_epoch) 17 | # 18 | # def get_lr(self): 19 | # return [max(base_lr * (1 - self.last_epoch / self.max_iters) ** self.power, self.min_lr)] 20 | 21 | 22 | class PolyLR(_LRScheduler): 23 | """Polynomial learning rate decay until step reach to max_decay_step 24 | 25 | Args: 26 | optimizer (Optimizer): Wrapped optimizer. 27 | max_decay_steps: after this step, we stop decreasing learning rate 28 | end_learning_rate: scheduler stoping learning rate decay, value of learning rate must be this value 29 | power: The power of the polynomial. 30 | """ 31 | 32 | def __init__(self, optimizer, max_decay_steps, end_learning_rate=0.0001, power=1.0): 33 | if max_decay_steps <= 1.: 34 | raise ValueError('max_decay_steps should be greater than 1.') 35 | self.max_decay_steps = max_decay_steps 36 | self.end_learning_rate = end_learning_rate 37 | self.power = power 38 | self.last_step = 0 39 | super().__init__(optimizer) 40 | 41 | def get_lr(self): 42 | if self.last_step > self.max_decay_steps: 43 | return [self.end_learning_rate for _ in self.base_lrs] 44 | 45 | return [(base_lr - self.end_learning_rate) * 46 | ((1 - self.last_step / self.max_decay_steps) ** (self.power)) + 47 | self.end_learning_rate for base_lr in self.base_lrs] 48 | 49 | def step(self, step=None): 50 | if step is None: 51 | step = self.last_step + 1 52 | self.last_step = step if step != 0 else 1 53 | if self.last_step <= self.max_decay_steps: 54 | decay_lrs = [(base_lr - self.end_learning_rate) * 55 | ((1 - self.last_step / self.max_decay_steps) ** (self.power)) + 56 | self.end_learning_rate for base_lr in self.base_lrs] 57 | for param_group, lr in zip(self.optimizer.param_groups, decay_lrs): 58 | param_group['lr'] = lr -------------------------------------------------------------------------------- /tools/seg_tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | '''================================================= 3 | @Project -> File shloc -> seg_tools 4 | @IDE PyCharm 5 | @Author fx221@cam.ac.uk 6 | @Date 30/04/2021 15:47 7 | ==================================================''' 8 | 9 | import numpy as np 10 | 11 | 12 | def seg_to_rgb(seg, maps): 13 | """ 14 | Args: 15 | seg: [C, H, W] 16 | maps: [id, R, G, B] 17 | Returns: [H, W, R/G/B] 18 | """ 19 | pred_label = seg.max(0).cpu().numpy() 20 | output = np.zeros(shape=(pred_label.shape[0], pred_label.shape[1], 3), dtype=np.uint8) 21 | 22 | for label in maps.keys(): 23 | rgb = maps[label] 24 | output[pred_label == label] = np.uint8(rgb) 25 | return output 26 | 27 | 28 | def label_to_rgb(label, maps): 29 | output = np.zeros(shape=(label.shape[0], label.shape[1], 3), dtype=np.uint8) 30 | for k in maps.keys(): 31 | rgb = maps[k] 32 | if type(rgb) is int: 33 | b = rgb % 256 34 | r = rgb // (256 * 256) 35 | g = (rgb - r * 256 * 256) // 256 36 | rgb = np.array([r, g, b], np.uint8) 37 | # bgr = np.array([b, g, r], np.uint8) 38 | output[label == k] = np.uint8(rgb) 39 | return output 40 | 41 | 42 | def label_to_bgr(label, maps): 43 | output = np.zeros(shape=(label.shape[0], label.shape[1], 3), dtype=np.uint8) 44 | for k in maps.keys(): 45 | rgb = maps[k] 46 | if type(rgb) is int: 47 | b = rgb % 256 48 | r = rgb // (256 * 256) 49 | g = (rgb - r * 256 * 256) // 256 50 | # rgb = np.array([r, g, b], np.uint8) 51 | bgr = np.array([b, g, r], np.uint8) 52 | output[label == k] = np.uint8(bgr) 53 | return output 54 | 55 | 56 | def rgb_to_bgr(img): 57 | out = np.zeros_like(img) 58 | out[:, :, 0] = img[:, :, 2] 59 | out[:, :, 1] = img[:, :, 1] 60 | out[:, :, 2] = img[:, :, 0] 61 | return out 62 | 63 | 64 | def read_seg_map(path): 65 | map = {} 66 | with open(path, "r") as f: 67 | lines = f.readlines() 68 | for l in lines: 69 | l = l.strip("\n").split(' ') 70 | map[int(l[1])] = np.array([np.uint8(l[2]), np.uint8(l[3]), np.uint8(l[4])], np.uint8) 71 | 72 | return map 73 | 74 | 75 | def read_seg_map_with_group(path): 76 | map = {} 77 | with open(path, "r") as f: 78 | lines = f.readlines() 79 | for l in lines: 80 | l = l.strip().split(" ") 81 | rid = int(l[0]) 82 | grgb = int(l[1]) 83 | gid = int(l[2]) 84 | if rid in map.keys(): 85 | map[rid][gid] = grgb 86 | else: 87 | map[rid] = {} 88 | map[rid][gid] = grgb 89 | return map 90 | 91 | 92 | def read_seg_map_without_group(path): 93 | map = {} 94 | with open(path, "r") as f: 95 | lines = f.readlines() 96 | for l in lines: 97 | l = l.strip().split(" ") 98 | grgb = int(l[0]) 99 | gid = int(l[1]) 100 | map[gid] = grgb 101 | return map 102 | 103 | 104 | ## code is from https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py 105 | def _fast_hist(label_true, label_pred, n_class): 106 | mask = (label_true >= 0) & (label_true < n_class) 107 | hist = np.bincount( 108 | n_class * label_true[mask].astype(int) + 109 | label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class) 110 | return hist 111 | 112 | 113 | def label_accuracy_score(label_trues, label_preds, n_class): 114 | """Returns accuracy score evaluation result. 115 | - overall accuracy 116 | - mean accuracy 117 | - mean IU 118 | - fwavacc 119 | """ 120 | hist = np.zeros((n_class, n_class)) 121 | for lt, lp in zip(label_trues, label_preds): 122 | hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) 123 | acc = np.diag(hist).sum() / hist.sum() 124 | with np.errstate(divide='ignore', invalid='ignore'): 125 | acc_cls = np.diag(hist) / hist.sum(axis=1) 126 | acc_cls = np.nanmean(acc_cls) 127 | with np.errstate(divide='ignore', invalid='ignore'): 128 | iu = np.diag(hist) / ( 129 | hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist) 130 | ) 131 | mean_iu = np.nanmean(iu) 132 | freq = hist.sum(axis=1) / hist.sum() 133 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 134 | return acc, acc_cls, mean_iu, fwavacc 135 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | '''================================================= 3 | @Project -> File shloc -> train 4 | @IDE PyCharm 5 | @Author fx221@cam.ac.uk 6 | @Date 20/07/2021 10:41 7 | ==================================================''' 8 | import argparse 9 | import json 10 | from tools.seg_tools import read_seg_map_without_group 11 | from dataloader.robotcar import RobotCarSegFull 12 | from dataloader.aachen import AachenSegFull 13 | import os.path as osp 14 | import torch 15 | import torch.utils.data as Data 16 | from dataloader.augmentation import ToPILImage, RandomRotation, \ 17 | RandomSizedCrop, RandomHorizontalFlip, \ 18 | ToNumpy 19 | from net.segnet import get_segnet 20 | from loss.seg_loss.segloss import SegLoss 21 | from trainer_recog import RecogTrainer 22 | from tools.common import torch_set_gpu 23 | 24 | import torchvision.transforms as tvf 25 | 26 | 27 | def get_train_val_loader(args): 28 | train_transform = tvf.Compose( 29 | ( 30 | tvf.ToTensor(), 31 | # tvf.ColorJitter(0.25, 0.25, 0.25, 0.15), 32 | tvf.ColorJitter(0.25, 0.25, 0.25, 0.15), 33 | tvf.Normalize(mean=[0.485, 0.456, 0.406], 34 | std=[0.229, 0.224, 0.225]) 35 | ) 36 | ) 37 | val_transform = tvf.Compose( 38 | ( 39 | tvf.ToTensor(), 40 | tvf.Normalize(mean=[0.485, 0.456, 0.406], 41 | std=[0.229, 0.224, 0.225]) 42 | ) 43 | ) 44 | 45 | if args.dataset == "robotcar": 46 | grgb_gid_file = "./datasets/robotcar/robotcar_rear_grgb_gid.txt" 47 | map_gid_rgb = read_seg_map_without_group(grgb_gid_file) 48 | train_imglist = "./datasets/robotcar/robotcar_rear_train_file_list.txt" 49 | test_imglist = "./datasets/robotcar/robotcar_rear_test_file_list.txt" 50 | 51 | if args.aug: 52 | aug = [ 53 | ToPILImage(), 54 | RandomRotation(degree=30), 55 | RandomSizedCrop(size=256), 56 | RandomHorizontalFlip(), 57 | ToNumpy(), 58 | ] 59 | else: 60 | aug = None 61 | trainset = RobotCarSegFull(image_path=osp.join(args.root, args.train_image_path), 62 | label_path=osp.join(args.root, args.train_label_path), 63 | n_classes=args.classes, 64 | transform=train_transform, 65 | grgb_gid_file=grgb_gid_file, 66 | use_cls=True, 67 | img_list=train_imglist, 68 | preload=False, 69 | aug=aug, 70 | train=True, 71 | cats=["overcast-reference", "night", "night-rain", 72 | "dusk", "dawn", "overcast-summer", "overcast-winter", "sun"] 73 | ) 74 | if args.val > 0: 75 | valset = RobotCarSegFull(image_path=osp.join(args.root, args.train_image_path), 76 | label_path=osp.join(args.root, args.train_label_path), 77 | cats=["overcast-reference", "night", "night-rain", 78 | "dusk", "dawn", "overcast-summer", "overcast-winter", "sun"], 79 | n_classes=args.classes, 80 | transform=train_transform, 81 | grgb_gid_file=grgb_gid_file, 82 | use_cls=True, 83 | img_list=test_imglist, 84 | preload=False, 85 | train=False) 86 | elif args.dataset == "aachen": 87 | grgb_gid_file = args.grgb_gid_file 88 | train_imglist = args.train_imglist 89 | test_imglist = args.test_imglist 90 | map_gid_rgb = read_seg_map_without_group(grgb_gid_file) 91 | 92 | if args.aug: 93 | aug = [ 94 | # RandomGaussianBlur(), # worse results, don't do it? 95 | ToPILImage(), 96 | # Resize(size=512), 97 | # RandomScale(low=0.5, high=2.0), 98 | # RandomCrop(size=256), 99 | RandomSizedCrop(size=args.R), 100 | RandomRotation(degree=45), 101 | RandomHorizontalFlip(), 102 | ToNumpy(), 103 | ] 104 | else: 105 | aug = None 106 | trainset = AachenSegFull(image_path=osp.join(args.root, args.train_image_path), 107 | label_path=osp.join(args.root, args.train_label_path), 108 | n_classes=args.classes, 109 | transform=train_transform, 110 | grgb_gid_file=grgb_gid_file, 111 | use_cls=True, 112 | img_list=train_imglist, 113 | preload=False, 114 | aug=aug, 115 | train=True, 116 | cats=args.train_cats, 117 | ) 118 | if args.val > 0: 119 | valset = AachenSegFull(image_path=osp.join(args.root, args.train_image_path), 120 | label_path=osp.join(args.root, args.train_label_path), 121 | n_classes=args.classes, 122 | transform=val_transform, 123 | grgb_gid_file=grgb_gid_file, 124 | use_cls=True, 125 | img_list=test_imglist, 126 | preload=False, 127 | cats=args.val_cats, 128 | train=False) 129 | train_loader = Data.DataLoader(dataset=trainset, 130 | batch_size=args.bs, 131 | num_workers=args.workers, 132 | shuffle=True, 133 | pin_memory=True, 134 | drop_last=True, 135 | ) 136 | if args.val: 137 | val_loader = Data.DataLoader( 138 | dataset=valset, 139 | batch_size=8, 140 | num_workers=args.workers, 141 | pin_memory=True, 142 | shuffle=False, 143 | drop_last=True, 144 | ) 145 | else: 146 | val_loader = None 147 | 148 | return train_loader, val_loader, map_gid_rgb 149 | 150 | 151 | def main(args): 152 | model = get_segnet(network=args.network, 153 | n_classes=args.classes, 154 | encoder_name=args.encoder_name, 155 | encoder_weights=args.encoder_weights, 156 | encoder_depth=args.encoder_depth, 157 | upsampling=args.upsampling, 158 | out_channels=args.out_channels, 159 | classification=args.classification, 160 | segmentation=args.segmentation, ) 161 | print(model) 162 | label_weights = torch.ones([args.classes]).cuda() 163 | label_weights[0] = 0.5 164 | loss_func = SegLoss( 165 | segloss_name=args.seg_loss, 166 | use_cls=True, 167 | use_seg=args.segmentation > 0, 168 | cls_weight=args.weight_cls, 169 | use_hiera=False, 170 | hiera_weight=0, 171 | label_weights=label_weights).cuda() 172 | 173 | train_loader, val_loader, map_gid_rgb = get_train_val_loader(args=args) 174 | trainer = RecogTrainer(model=model, train_loader=train_loader, eval_loader=val_loader if args.val else None, 175 | loss_func=loss_func, args=args, map=map_gid_rgb) 176 | 177 | if args.resume is not None: 178 | trainer.resume(checkpoint=args.resume) 179 | else: 180 | trainer.train(start_epoch=0) 181 | 182 | print("Training finished") 183 | 184 | 185 | if __name__ == '__main__': 186 | parser = argparse.ArgumentParser("Train Semantic localization Network") 187 | parser.add_argument("--config", type=str, required=True, help="configuration file") 188 | parser.add_argument("--dataset", type=str, default="small", help="small, large, robotcar") 189 | parser.add_argument("--network", type=str, default="unet") 190 | parser.add_argument("--loss", type=str, default="ce") 191 | parser.add_argument("--classes", type=int, default=400) 192 | parser.add_argument("--out_channels", type=int, default=512) 193 | parser.add_argument("--root", type=str) 194 | parser.add_argument("--train_label_path", type=str) 195 | parser.add_argument("--train_image_path", type=str) 196 | parser.add_argument("--val_label_path", type=str) 197 | parser.add_argument("--val_image_path", type=str) 198 | parser.add_argument("--bs", type=int, default=4) 199 | parser.add_argument("--R", type=int, default=256) 200 | parser.add_argument("--weight_decay", type=float, default=5e-4) 201 | parser.add_argument("--epochs", type=int, default=120) 202 | parser.add_argument("--lr", type=float, default=1e-4) 203 | parser.add_argument("--workers", type=int, default=4) 204 | parser.add_argument("--log_interval", type=int, default=50) 205 | parser.add_argument("--optimizer", type=str, default=None) 206 | parser.add_argument("--resume", type=str, default=None) 207 | parser.add_argument("--segloss", type=str, default='ce') 208 | parser.add_argument("--classification", dest="classification", action="store_true", default=True) 209 | parser.add_argument("--segmentation", dest="segmentation", action="store_true", default=True) 210 | parser.add_argument("--val", dest="val", action="store_true", default=False) 211 | parser.add_argument("--aug", dest="aug", action="store_true", default=False) 212 | parser.add_argument("--preload", dest="preload", action="store_true", default=False) 213 | parser.add_argument("--ignore_bg", dest="ignore_bg", action="store_true", default=False) 214 | parser.add_argument("--weight_cls", type=float, default=1.0) 215 | parser.add_argument("--pretrained_weight", type=str, default=None) 216 | parser.add_argument("--encoder_name", type=str, default='timm-resnest50d') 217 | parser.add_argument("--encoder_weights", type=str, default='imagenet') 218 | parser.add_argument("--save_root", type=str, default="/home/mifs/fx221/fx221/exp/shloc/aachen") 219 | parser.add_argument("--gpu", type=int, nargs='+', default=[0], help='-1 for CPU') 220 | parser.add_argument("--milestones", type=list, default=[60, 80]) 221 | parser.add_argument("--grgb_gid_file", type=str) 222 | parser.add_argument("--train_imglist", type=str) 223 | parser.add_argument("--test_imglist", type=str) 224 | parser.add_argument("--lr_policy", type=str, default='plateau', help='plateau, step') 225 | parser.add_argument("--multi_lr", type=int, default=1) 226 | parser.add_argument("--train_cats", type=list, default=None) 227 | parser.add_argument("--val_cats", type=list, default=None) 228 | 229 | args = parser.parse_args() 230 | with open(args.config, 'rt') as f: 231 | t_args = argparse.Namespace() 232 | t_args.__dict__.update(json.load(f)) 233 | args = parser.parse_args(namespace=t_args) 234 | 235 | print('gpu: ', args.gpu) 236 | 237 | torch_set_gpu(gpus=args.gpu) 238 | main(args=args) 239 | -------------------------------------------------------------------------------- /train_aachen: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python3 train.py --config configs/config_train_aachen.json -------------------------------------------------------------------------------- /train_robotcar: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python3 train.py --config configs/config_train_robotcar.json --------------------------------------------------------------------------------