├── .gitignore ├── LICENSE ├── README.md ├── datasets ├── __init__.py ├── augment.py └── tartanair.py ├── models ├── __init__.py ├── featurenet.py ├── loss.py ├── match.py └── tool.py ├── train.py └── utils ├── __init__.py ├── evaluation.py ├── geometry.py └── visualization.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .vscode 132 | saved_models 133 | logs 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, Chen Wang 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Feature matching with FGN 2 | 3 | This repo contains the source code for the feature matching application (Sec. 7) in ["Lifelong Graph Learning." Chen Wang, Yuheng Qiu, Dasong Gao, Sebastian Scherer. *CVPR 2022*.]((https://arxiv.org/abs/2009.00647)) 4 | 5 | ## Usage 6 | ### Dependencies 7 | 8 | - Python >= 3.5 9 | - PyTorch >= 1.7 10 | - OpenCV >= 3.4 11 | - NumPy 12 | - TensorBoard 13 | - Matplotlib 14 | - ArgParse 15 | - tqdm 16 | 17 | ### Data 18 | The [TartanAir](https://theairlab.org/tartanair-dataset/) dataset is required for both training and testing. The dataset should be aranged as follows: 19 | ``` 20 | $DATASET_ROOT/ 21 | └── tartanair/ 22 | ├── abandonedfactory_night/ 23 | └── ... 24 | ``` 25 | 26 | ### Commandline 27 | Training and evaluates the method with default setting: 28 | ```sh 29 | $ python train.py --data-root --method 30 | ``` 31 | - `--method` option is used to switch between FGN-based (ours) and GAT-based (SuperGlue) graph matcher 32 | - Considering the gigantic volume of TartanAir, evaluation will happen every 5000 training steps by default (can be overriden by `--eval-freq`). Results will be logged to the console. 33 | - If `--log-dir` is specified, TensorBoard will be activated to show visualization and evaluation results instead (under "TEXT" tab). 34 | - Detailed description of settings can be viewed by `$ python train.py -h`. 35 | 36 | ## Citation 37 | ```bibtex 38 | @inproceedings{wang2022lifelong, 39 | title={Lifelong graph learning}, 40 | author={Wang, Chen and Qiu, Yuheng and Gao, Dasong and Scherer, Sebastian}, 41 | booktitle={2022 Conference on Computer Vision and Pattern Recognition (CVPR)}, 42 | year={2022} 43 | } 44 | ``` 45 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from .tartanair import TartanAir 4 | from .tartanair import AirSampler 5 | from .tartanair import TartanAirTest 6 | from .tartanair import AirAugment -------------------------------------------------------------------------------- /datasets/augment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | import numpy as np 5 | import kornia as kn 6 | from torch import nn 7 | from PIL import Image 8 | from torchvision import transforms as T 9 | from torchvision.transforms import functional as F 10 | 11 | 12 | class AirAugment(nn.Module): 13 | def __init__(self, scale=1, size=[480, 640], resize_only=False): 14 | super().__init__() 15 | self.img_size = (np.array(size) * scale).round().astype(np.int32) 16 | self.resize_totensor = T.Compose([T.Resize(self.img_size.tolist()), np.array, T.ToTensor()]) 17 | self.rand_crop = T.RandomResizedCrop(self.img_size.tolist(), scale=(0.1, 1.0)) 18 | self.rand_rotate = T.RandomRotation(45, resample=Image.BILINEAR) 19 | self.rand_color = T.ColorJitter(0.8, 0.8, 0.8) 20 | self.p = [1, 0, 0, 0] if resize_only else [0.25]*4 21 | 22 | def apply_affine(self, K, translation=[0, 0], center=[0, 0], scale=[1, 1], angle=0): 23 | """Applies transformation to K in the order: (R, S), T. All coordinates are in (h, w) order. 24 | Center is for both scale and rotation. 25 | """ 26 | translation = torch.tensor(translation[::-1].copy(), dtype=torch.float32) 27 | center = torch.tensor(center[::-1].copy(), dtype=torch.float32) 28 | scale = torch.tensor(scale[::-1].copy(), dtype=torch.float32) 29 | angle = torch.tensor([angle], dtype=torch.float32) 30 | 31 | scaled_rotation = torch.block_diag(kn.angle_to_rotation_matrix(angle)[0] @ torch.diag(scale), torch.ones(1)) 32 | scaled_rotation[:2, 2] = center - scaled_rotation[:2, :2] @ center + translation 33 | 34 | return scaled_rotation @ K 35 | 36 | def forward(self, image, K, depth=None): 37 | in_size = np.array(image.size[::-1]) 38 | center, scale, angle = in_size/2, self.img_size/in_size, 0 39 | 40 | transform = np.random.choice(np.arange(len(self.p)), p=self.p) 41 | if transform == 1: 42 | trans = self.rand_crop 43 | i, j, h, w = T.RandomResizedCrop.get_params(image, trans.scale, trans.ratio) 44 | center = np.array([i + h / 2, j + w / 2]) 45 | scale = self.img_size / np.array([h, w]) 46 | image = F.resized_crop(image, i, j, h, w, trans.size, trans.interpolation) 47 | depth = depth if depth is None else F.resized_crop(depth, i, j, h, w, trans.size, trans.interpolation) 48 | 49 | elif transform == 2: 50 | trans = self.rand_rotate 51 | angle = T.RandomRotation.get_params(trans.degrees) 52 | # fill oob pix with reflection so that model can't detect rotation with boundary 53 | image = F.pad(image, padding=tuple(in_size // 2), padding_mode='reflect') 54 | image = F.rotate(image, angle, trans.resample, trans.expand, trans.center, trans.fill) 55 | image = F.center_crop(image, tuple(in_size)) 56 | # fill oob depth with inf so that projector can mask them out 57 | depth = depth if depth is None else F.rotate(depth, angle, trans.resample, trans.expand, trans.center, float('inf')) 58 | 59 | elif transform == 3: 60 | image = self.rand_color(image) 61 | 62 | image = self.resize_totensor(image) 63 | translation = self.img_size / 2 - center 64 | K = self.apply_affine(K, translation, center, scale, angle) 65 | 66 | return (image, K) if depth is None else (image, K, self.resize_totensor(depth)) 67 | -------------------------------------------------------------------------------- /datasets/tartanair.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import re 5 | import bz2 6 | import glob 7 | import torch 8 | import pickle 9 | import numpy as np 10 | import kornia as kn 11 | from os import path 12 | from PIL import Image 13 | from copy import copy 14 | from torch.utils.data import Sampler 15 | from torch.utils.data import Dataset 16 | from torchvision import transforms as T 17 | from scipy.spatial.transform import Rotation as R 18 | from torchvision.transforms import functional as F 19 | 20 | from .augment import AirAugment 21 | 22 | 23 | class TartanAir(Dataset): 24 | def __init__(self, root, scale=1, augment=True, catalog_path=None, exclude=None, include=None): 25 | super().__init__() 26 | self.augment = AirAugment(scale, size=[480, 640], resize_only=not augment) 27 | if catalog_path is not None and os.path.exists(catalog_path): 28 | with bz2.BZ2File(catalog_path, 'rb') as f: 29 | self.sequences, self.image, self.depth, self.poses, self.sizes = pickle.load(f) 30 | else: 31 | self.sequences = glob.glob(os.path.join(root,'*','[EH]a[sr][yd]','*')) 32 | self.image, self.depth, self.poses, self.sizes = {}, {}, {}, [] 33 | ned2den = torch.FloatTensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]]) 34 | for seq in self.sequences: 35 | quaternion = np.loadtxt(path.join(seq, 'pose_left.txt'), dtype=np.float32) 36 | self.poses[seq] = ned2den @ pose2mat(quaternion) 37 | self.image[seq] = sorted(glob.glob(path.join(seq,'image_left','*.png'))) 38 | self.depth[seq] = sorted(glob.glob(path.join(seq,'depth_left','*.npy'))) 39 | assert(len(self.image[seq])==len(self.depth[seq])==self.poses[seq].shape[0]) 40 | self.sizes.append(len(self.image[seq])) 41 | os.makedirs(os.path.dirname(catalog_path), exist_ok=True) 42 | with bz2.BZ2File(catalog_path, 'wb') as f: 43 | pickle.dump((self.sequences, self.image, self.depth, self.poses, self.sizes), f) 44 | # Camera Intrinsics of TartanAir Dataset 45 | fx, fy, cx, cy = 320, 320, 320, 240 46 | self.K = torch.FloatTensor([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) 47 | 48 | # include/exclude seq with regex 49 | incl_pattern = re.compile(include) if include is not None else None 50 | excl_pattern = re.compile(exclude) if exclude is not None else None 51 | final_list = [] 52 | for seq, size in zip(self.sequences, self.sizes): 53 | if (incl_pattern and incl_pattern.search(seq) is None) or \ 54 | (excl_pattern and excl_pattern.search(seq) is not None): 55 | del self.poses[seq], self.image[seq], self.depth[seq] 56 | else: 57 | final_list.append((seq, size)) 58 | self.sequences, self.sizes = zip(*final_list) if len(final_list) > 0 else ([], []) 59 | 60 | def __len__(self): 61 | return sum(self.sizes) 62 | 63 | def __getitem__(self, ret): 64 | i, frame = ret 65 | seq, K = self.sequences[i], self.K 66 | image = Image.open(self.image[seq][frame]) 67 | depth = F.to_pil_image(np.load(self.depth[seq][frame]), mode='F') 68 | pose = self.poses[seq][frame] 69 | image, K, depth = self.augment(image, self.K, depth) 70 | return image, depth, pose, K, seq.split(os.path.sep)[-3:] 71 | 72 | def rand_split(self, ratio, seed=42): 73 | total, ratio = len(self.sequences), np.array(ratio) 74 | split_idx = np.cumsum(np.round(ratio / sum(ratio) * total), dtype=np.int)[:-1] 75 | subsets = [] 76 | for perm in np.split(np.random.default_rng(seed=seed).permutation(total), split_idx): 77 | subset = copy(self) 78 | subset.sequences = np.take(self.sequences, perm).tolist() 79 | subset.sizes = np.take(self.sizes, perm).tolist() 80 | subsets.append(subset) 81 | return subsets 82 | 83 | 84 | class TartanAirTest(Dataset): 85 | def __init__(self, root, scale=1, augment=False, catalog_path=None): 86 | super().__init__() 87 | self.augment = AirAugment(scale, size=[480, 640], resize_only=not augment) 88 | if catalog_path is not None and os.path.exists(catalog_path): 89 | with bz2.BZ2File(catalog_path, 'rb') as f: 90 | self.sequences, self.image, self.poses, self.sizes = pickle.load(f) 91 | else: 92 | self.sequences = sorted(glob.glob(os.path.join(root,'mono','*'))) 93 | self.pose_file = sorted(glob.glob(os.path.join(root,'mono_gt','*.txt'))) 94 | self.image, self.poses, self.sizes = {}, {}, [] 95 | ned2den = torch.FloatTensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]]) 96 | for seq, pose in zip(self.sequences, self.pose_file): 97 | quaternion = np.loadtxt(pose, dtype=np.float32) 98 | self.poses[seq] = ned2den @ pose2mat(quaternion) 99 | self.image[seq] = sorted(glob.glob(path.join(seq, '*.png'))) 100 | assert(len(self.image[seq])==self.poses[seq].shape[0]) 101 | self.sizes.append(len(self.image[seq])) 102 | os.makedirs(os.path.dirname(catalog_path), exist_ok=True) 103 | with bz2.BZ2File(catalog_path, 'wb') as f: 104 | pickle.dump((self.sequences, self.image, self.poses, self.sizes), f) 105 | # Camera Intrinsics of TartanAir Dataset 106 | fx, fy, cx, cy = 320, 320, 320, 240 107 | self.K = torch.FloatTensor([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) 108 | 109 | def __len__(self): 110 | return sum(self.sizes) 111 | 112 | def __getitem__(self, ret): 113 | i, frame = ret 114 | seq, K = self.sequences[i], self.K 115 | image = Image.open(self.image[seq][frame]) 116 | pose = self.poses[seq][frame] 117 | image, K = self.augment(image, self.K) 118 | return image, pose, K 119 | 120 | 121 | class AirSampler(Sampler): 122 | def __init__(self, data, batch_size, shuffle=True, overlap=True): 123 | self.data_sizes = data.sizes 124 | self.bs = batch_size 125 | self.shuffle = shuffle 126 | self.batches = [] 127 | for i, size in enumerate(self.data_sizes): 128 | b_start = np.arange(0, size - self.bs, 1 if overlap else self.bs) 129 | self.batches += [list(zip([i]*self.bs, range(st, st+self.bs))) for st in b_start] 130 | if self.shuffle: np.random.shuffle(self.batches) 131 | 132 | def __iter__(self): 133 | return iter(self.batches) 134 | 135 | def __len__(self): 136 | return len(self.batches) 137 | 138 | 139 | def pose2mat(pose): 140 | """Converts pose vectors to matrices. 141 | Args: 142 | pose: [tx, ty, tz, qx, qy, qz, qw] (N, 7). 143 | Returns: 144 | [R t] (N, 3, 4). 145 | """ 146 | t = pose[:, 0:3, None] 147 | rot = R.from_quat(pose[:, 3:7]).as_matrix().astype(np.float32).transpose(0, 2, 1) 148 | t = -rot @ t 149 | return torch.cat([torch.from_numpy(rot), torch.from_numpy(t)], dim=2) 150 | 151 | 152 | if __name__ == "__main__": 153 | from torch.utils.data import Dataset, DataLoader 154 | from torchvision import transforms as T 155 | 156 | data = TartanAir('/data/datasets/tartanair', scale=1, augment=True) 157 | sampler = AirSampler(data, batch_size=4, shuffle=True) 158 | loader = DataLoader(data, batch_sampler=sampler, num_workers=4, pin_memory=True) 159 | 160 | test_data = TartanAirTest('/data/datasets/tartanair_test', scale=1, augment=True) 161 | test_sampler = AirSampler(test_data, batch_size=4, shuffle=True) 162 | test_loader = DataLoader(test_data, batch_sampler=test_sampler, num_workers=4, pin_memory=True) 163 | 164 | for i, (image, depth, pose, K) in enumerate(loader): 165 | print(i, image.shape, depth.shape, pose.shape, K.shape) 166 | 167 | for i, (image, pose, K) in enumerate(test_loader): 168 | print(i, image.shape, pose.shape, K.shape) 169 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from .tool import Timer 4 | from .tool import count_parameters 5 | from .tool import GlobalStepCounter 6 | from .tool import EarlyStopScheduler 7 | 8 | from .match import ConsecutiveMatch 9 | 10 | from .featurenet import FeatureNet, GridSample 11 | from .loss import FeatureNetLoss, PairwiseCosine 12 | -------------------------------------------------------------------------------- /models/featurenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | from kornia.feature import nms 7 | from torchvision import models 8 | import torch.nn.functional as F 9 | import kornia.geometry.conversions as C 10 | 11 | 12 | class IndexSelect(nn.Module): 13 | def __init__(self, dim, index): 14 | super().__init__() 15 | self.dim, self.index = dim, index 16 | 17 | def forward(self, x): 18 | self.index = self.index.to(x.device) 19 | return x.index_select(self.dim, self.index) 20 | 21 | 22 | class ConstantBorder(nn.Module): 23 | ''' 24 | Set Boarders to Constant 25 | ''' 26 | def __init__(self, border=4, value=-math.inf): 27 | super().__init__() 28 | self.pad1 = nn.ConstantPad2d(-border, value=value) 29 | self.pad2 = nn.ConstantPad2d(border, value=value) 30 | 31 | def forward(self, x): 32 | return self.pad2(self.pad1(x)) 33 | 34 | 35 | class GridSample(nn.Module): 36 | def __init__(self, mode='bilinear'): 37 | super().__init__() 38 | self.mode = mode 39 | 40 | def forward(self, inputs): 41 | features, points = inputs 42 | dim = len(points.shape) 43 | points = points.view(features.size(0), 1, -1, 2) if dim == 3 else points 44 | output = F.grid_sample(features, points, self.mode, align_corners=True).permute(0, 2, 3, 1) 45 | return output.squeeze(1) if dim == 3 else output 46 | 47 | 48 | class PixelUnshuffle(nn.Module): 49 | def __init__(self, downscale_factor): 50 | super().__init__() 51 | self.factor = downscale_factor 52 | 53 | def forward(self, x): 54 | (N, C, H, W), S = x.shape, self.factor 55 | H, W = H // S, W // S 56 | x = x.reshape(N, C, H, S, W, S).permute(0, 1, 3, 5, 2, 4) 57 | x = x.reshape(N, C * S**2, H, W) 58 | return x 59 | 60 | 61 | class GraphAttn(nn.Module): 62 | def __init__(self, in_features, out_features, alpha=0.9, beta=0.2): 63 | super().__init__() 64 | self.alpha = alpha 65 | self.tran = nn.Linear(in_features, out_features) 66 | self.att1 = nn.Linear(out_features, 1) 67 | self.att2 = nn.Linear(out_features, 1) 68 | self.actv = nn.Sequential(nn.LeakyReLU(beta), nn.Softmax(dim=-1)) 69 | 70 | def forward(self, x): 71 | h = self.tran(x) 72 | att = self.att1(h) + self.att2(h).permute(0, 2, 1) 73 | adj = self.actv(att.squeeze()) 74 | return self.alpha * h + (1-self.alpha) * adj @ h 75 | 76 | 77 | class FGN(nn.Module): 78 | def __init__(self, feat_dim, feat_num, alpha=0.9, beta=0.2): 79 | super(FGN, self).__init__() 80 | # Taking advantage of parallel efficiency of nn.Conv1d 81 | self.tran = nn.Linear(feat_dim, feat_dim) 82 | self.att1 = nn.Linear(feat_dim, 1) 83 | self.att2 = nn.Linear(feat_dim, 1) 84 | self.norm = nn.Sequential(nn.LeakyReLU(beta), nn.Softmax(-1)) 85 | self.alpha = alpha 86 | 87 | def forward(self, x): 88 | adj = self.feature_adjacency(x) 89 | x_ = self.tran(torch.einsum('bde,bne->bnd', adj, x)) 90 | return self.alpha * x + (1 - self.alpha) * x_ 91 | 92 | def feature_adjacency(self, x): 93 | # w_ij = f(x_i, x_j) 94 | w = self.norm(self.att1(x) + self.att2(x).permute(0, 2, 1)) 95 | return self.row_normalize(self.sgnroot(x.transpose(-1, -2) @ w @ x)) 96 | 97 | def sgnroot(self, x): 98 | return x.sign()*(x.abs().sqrt().clamp(min=1e-7)) 99 | 100 | def row_normalize(self, x): 101 | x = x / (x.abs().sum(-1, keepdim=True) + 1e-7) 102 | x[torch.isnan(x)] = 0 103 | return x 104 | 105 | 106 | class BatchNorm2dwC(nn.Module): 107 | def __init__(self, in_features): 108 | super().__init__() 109 | self.bn = nn.BatchNorm3d(1) 110 | 111 | def forward(self, x): 112 | return self.bn(x.unsqueeze(1)).squeeze(1) 113 | 114 | 115 | class ScoreHead(nn.Module): 116 | def __init__(self, in_scale): 117 | super().__init__() 118 | self.scores_vgg = nn.Sequential(make_layer(256, 128), make_layer(128, 64, bn=BatchNorm2dwC)) 119 | self.scores_img = nn.Sequential(make_layer(3, 8), make_layer(8, 16, bn=BatchNorm2dwC), 120 | PixelUnshuffle(downscale_factor=in_scale)) 121 | self.combine = nn.Sequential( 122 | make_layer(64 + 16 * in_scale**2, in_scale**2 + 1, bn=BatchNorm2dwC, activation=nn.Softmax(dim=1)), 123 | IndexSelect(dim=1, index=torch.arange(in_scale**2)), 124 | nn.PixelShuffle(upscale_factor=in_scale), 125 | ConstantBorder(border=4, value=0)) 126 | 127 | def forward(self, images, features): 128 | scores_vgg, scores_img = self.scores_vgg(features), self.scores_img(images) 129 | return self.combine(torch.cat([scores_vgg, scores_img], dim=1)) 130 | 131 | 132 | class DescriptorHead(nn.Module): 133 | def __init__(self, feat_dim, feat_num, sample_pass): 134 | super().__init__() 135 | self.feat_dim, self.feat_num, self.sample_pass = feat_dim, feat_num, sample_pass 136 | 137 | self.descriptor = nn.Sequential( 138 | make_layer(256, self.feat_dim), 139 | make_layer(self.feat_dim, self.feat_dim, bn=None, activation=None)) 140 | self.sample = nn.Sequential(GridSample(), nn.BatchNorm1d(self.feat_num)) 141 | self.residual = nn.Sequential(make_layer(3, 128, kernel_size=9, padding=4), make_layer(128, self.feat_dim)) 142 | 143 | def forward(self, images, features, points, scores): 144 | descriptors, residual = self.descriptor(features), self.residual(images) 145 | n_group = 1 + self.sample_pass if self.training else 1 146 | descriptors, residual = _repeat_flatten(descriptors, n_group), _repeat_flatten(residual, n_group) 147 | descriptors = self.sample((descriptors, points)) + self.sample((residual, points)) 148 | return descriptors 149 | 150 | 151 | class FeatureNet(models.VGG): 152 | def __init__(self, feat_dim=256, feat_num=500, sample_pass=1, graph="FGN"): 153 | super().__init__(models.vgg13().features) 154 | self.feat_dim, self.feat_num, self.sample_pass = feat_dim, feat_num, sample_pass 155 | # Only adopt the first 15 layers of pre-trained vgg13. Feature Map: (512, H/8, W/8) 156 | self.load_state_dict(models.vgg13(pretrained=True).state_dict()) 157 | self.features = nn.Sequential(*list(self.features.children())[:15]) 158 | del self.classifier 159 | 160 | self.scores = ScoreHead(8) 161 | self.descriptors = DescriptorHead(feat_dim, feat_num, sample_pass) 162 | if graph == "GAT": 163 | self.graph = nn.Sequential( 164 | GraphAttn(self.feat_dim, self.feat_dim), 165 | nn.BatchNorm1d(feat_num), nn.LeakyReLU(0.2), 166 | GraphAttn(self.feat_dim, self.feat_dim)) 167 | elif graph == "FGN": 168 | self.graph = nn.Sequential( 169 | FGN(self.feat_dim, self.feat_num), 170 | nn.BatchNorm1d(feat_num), nn.LeakyReLU(0.2), 171 | FGN(self.feat_dim, self.feat_num)) 172 | else: 173 | raise ValueError(f"Unknown graph network structure: {graph}") 174 | self.nms = nms.NonMaximaSuppression2d((5, 5)) 175 | 176 | def forward(self, inputs): 177 | 178 | B, _, H, W = inputs.shape 179 | 180 | features = self.features(inputs) 181 | 182 | pointness = self.scores(inputs, features) 183 | 184 | scores, points = self.nms(pointness).view(B, -1, 1).topk(self.feat_num, dim=1) 185 | 186 | points = torch.cat((points % W, points // W), dim=-1) 187 | 188 | n_group = 1 189 | if self.training: 190 | n_group += self.sample_pass 191 | scores_flat_dup = _repeat_flatten(pointness.view(B, H * W), self.sample_pass) 192 | points_rand = torch.multinomial(torch.ones_like(scores_flat_dup), self.feat_num) 193 | scores_rand = torch.gather(scores_flat_dup, 1, points_rand).unsqueeze(-1) 194 | points_rand = torch.stack((points_rand % W, points_rand // W), dim=-1) 195 | points = self._append_group(points_rand, self.sample_pass, points).reshape(B * n_group, self.feat_num, 2) 196 | scores = self._append_group(scores_rand, self.sample_pass, scores).reshape(B * n_group, self.feat_num, 1) 197 | 198 | points = C.normalize_pixel_coordinates(points, H, W) 199 | 200 | descriptors = self.descriptors(inputs, features, points, scores) 201 | 202 | descriptors = self.graph(descriptors) 203 | 204 | N = n_group * self.feat_num 205 | return descriptors.reshape(B, N, self.feat_dim), points.reshape(B, N, 2), pointness, scores.reshape(B, N) 206 | 207 | @staticmethod 208 | def _append_group(grouped_samples, sample_pass, new_group): 209 | """(B*S, N, *) + (B, N, *) -> (B*(S+1), N, *)""" 210 | BS, *_shape = grouped_samples.shape 211 | raveled = grouped_samples.reshape(BS // sample_pass, sample_pass, *_shape) 212 | return torch.cat((raveled, new_group.unsqueeze(1)), dim=1) 213 | 214 | 215 | def make_layer(in_chan, out_chan, kernel_size=3, stride=1, padding=1, bn=nn.BatchNorm2d, activation=nn.ReLU()): 216 | modules = [nn.Conv2d(in_chan, out_chan, kernel_size, stride, padding)] + \ 217 | ([bn(out_chan)] if bn is not None else []) + \ 218 | ([activation] if activation is not None else []) 219 | return nn.Sequential(*modules) 220 | 221 | 222 | def _repeat_flatten(x, n): 223 | """[B0, B1, B2, ...] -> [B0, B0, ..., B1, B1, ..., B2, B2, ...]""" 224 | shape = x.shape 225 | return x.unsqueeze(1).expand(shape[0], n, *shape[1:]).reshape(shape[0] * n, *shape[1:]) 226 | 227 | 228 | if __name__ == "__main__": 229 | '''Test codes''' 230 | import argparse 231 | from tool import Timer 232 | 233 | parser = argparse.ArgumentParser(description='Test FeatureNet') 234 | parser.add_argument("--device", type=str, default='cuda', help="cuda, cuda:0, or cpu") 235 | parser.add_argument('--seed', type=int, default=0, help='Random seed.') 236 | parser.add_argument("--batch-size", type=int, default=10, help="number of minibatch size") 237 | parser.add_argument('--crop-size', nargs='+', type=int, default=[320, 320], help='image crop size') 238 | args = parser.parse_args() 239 | torch.manual_seed(args.seed) 240 | torch.cuda.manual_seed(args.seed) 241 | 242 | net = FeatureNet(512, 200).to(args.device).eval() 243 | inputs = torch.randn(args.batch_size, 3, *args.crop_size).to(args.device) 244 | 245 | timer = Timer() 246 | with torch.no_grad(): 247 | for i in range(5): 248 | descriptors, points, pointness, scores = net(inputs) 249 | print('%d D: %s, P: (%s, %s), S: %s' % (i, descriptors.shape, pointness.shape, points.shape, scores.shape)) 250 | print('time:', timer.end()) 251 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | import kornia as kn 5 | import torch.nn as nn 6 | import kornia.feature as kf 7 | import torch.nn.functional as F 8 | import kornia.geometry.conversions as C 9 | 10 | from utils import Projector 11 | from utils import Visualizer 12 | from models.featurenet import GridSample 13 | from models.match import ConsecutiveMatch 14 | from models.tool import GlobalStepCounter 15 | 16 | class FeatureNetLoss(nn.Module): 17 | def __init__(self, beta=[1, 1], K=None, writer=None, viz_start=float('inf'), viz_freq=200, counter=None): 18 | super().__init__() 19 | self.writer, self.beta, self.counter = writer, beta, counter if counter is not None else GlobalStepCounter() 20 | self.score_corner = ScoreLoss() 21 | self.desc_match = DiscriptorMatchLoss(writer=writer, counter=self.counter) 22 | self.projector = Projector() 23 | self.viz = Visualizer() if self.writer is None else Visualizer('tensorboard', writer=self.writer) 24 | self.viz_start, self.viz_freq = viz_start, viz_freq 25 | 26 | def forward(self, descriptors, points, scores, score_map, depth_map, poses, K, imgs, env): 27 | def batch_project(pts): 28 | return self.projector.cartesian(pts, depth_map, poses, K) 29 | 30 | H, W = score_map.size(2), score_map.size(3) 31 | cornerness = self.beta[0] * self.score_corner(score_map, imgs, batch_project) 32 | proj_pts, invis_idx = batch_project(points) 33 | match = self.beta[1] * self.desc_match(descriptors, scores, points.unsqueeze(0), proj_pts, invis_idx, H, W) 34 | loss = cornerness + match 35 | 36 | n_iter = self.counter.steps 37 | if self.writer is not None: 38 | self.writer.add_scalars('Loss', {'cornerness': cornerness, 39 | 'match': match, 40 | 'all': loss}, n_iter) 41 | 42 | if n_iter >= self.viz_start and n_iter % self.viz_freq == 0: 43 | self.viz.show(imgs, points, 'hot', values=scores.squeeze(-1).detach().cpu().numpy(), name='train', step=n_iter) 44 | 45 | self.viz.show(score_map, color='hot', name='score', step=n_iter) 46 | 47 | pair = torch.tensor([[0, 1], [0, 3], [0, 5], [0, 7]]) 48 | b_src, b_dst = pair[:, 0], pair[:, 1] 49 | matched, confidence = ConsecutiveMatch()(descriptors[b_src], descriptors[b_dst], points[b_dst]) 50 | top_conf, top_idx = confidence.topk(50, dim=1) 51 | top_conf, top_idx = top_conf.detach().cpu().numpy(), top_idx.unsqueeze(-1).repeat(1, 1, 2) 52 | self.viz.showmatch(imgs[b_src], points[b_src].gather(1, top_idx), imgs[b_dst], matched.gather(1, top_idx), 'hot', top_conf, 0.9, 1, name='match', step=n_iter) 53 | 54 | return loss 55 | 56 | 57 | class ScoreLoss(nn.Module): 58 | def __init__(self, radius=8, num_corners=500): 59 | super(ScoreLoss, self).__init__() 60 | self.bceloss = nn.BCELoss() 61 | self.corner_det = kf.CornerGFTT() 62 | self.num_corners = num_corners 63 | self.pool = nn.MaxPool2d(kernel_size=radius, return_indices=True) 64 | self.unpool = nn.MaxUnpool2d(kernel_size=radius) 65 | 66 | def forward(self, scores_dense, imgs, projector): 67 | corners = self.get_corners(imgs, projector) 68 | corners = kn.filters.gaussian_blur2d(corners, kernel_size=(7, 7), sigma=(1, 1)) 69 | lap = kn.filters.laplacian(scores_dense, 5) # smoothness 70 | 71 | return self.bceloss(scores_dense, corners) + (scores_dense * torch.exp(-lap)).mean() * 10 72 | 73 | def get_corners(self, imgs, projector=None): 74 | (B, _, H, W), N = imgs.shape, self.num_corners 75 | corners = kf.nms2d(self.corner_det(kn.rgb_to_grayscale(imgs)), (5, 5)) 76 | 77 | # only one in patch 78 | output, indices = self.pool(corners) 79 | corners = self.unpool(output, indices) 80 | 81 | # keep top 82 | values, idx = corners.view(B, -1).topk(N, dim=1) 83 | coords = torch.stack([idx % W, idx // W], dim=2) # (x, y), same below 84 | 85 | if not projector: 86 | # keep as-is 87 | b = torch.arange(0, B).repeat_interleave(N).to(idx) 88 | h, w = idx // W, idx % W 89 | values = values.flatten() 90 | else: 91 | # combine corners from all images 92 | coords = kn.normalize_pixel_coordinates(coords, H, W) 93 | coords, invis_idx = projector(coords) 94 | coords[tuple(invis_idx)] = -2 95 | coords_combined = coords.transpose(0, 1).reshape(B, B * N, 2) 96 | coords_combined = kn.denormalize_pixel_coordinates(coords_combined, H, W).round().to(torch.long) 97 | b = torch.arange(B).repeat_interleave(B * N).to(coords_combined) 98 | w, h = coords_combined.reshape(-1, 2).T 99 | mask = w >= 0 100 | b, h, w, values = b[mask], h[mask], w[mask], values.flatten().repeat(B)[mask] 101 | 102 | target = torch.zeros_like(corners) 103 | target[b, 0, h, w] = values 104 | target = kf.nms2d(target, (5, 5)) 105 | 106 | return (target > 0).to(target) 107 | 108 | 109 | class DiscriptorMatchLoss(nn.Module): 110 | eps = 1e-6 111 | def __init__(self, radius=1, writer=None, counter=None): 112 | super(DiscriptorMatchLoss, self).__init__() 113 | self.radius, self.writer, self.counter = radius, writer, counter if counter is not None else GlobalStepCounter() 114 | self.cosine = PairwiseCosine(inter_batch=True) 115 | 116 | def forward(self, descriptors, scores, pts_src, pts_dst, invis_idx, height, width): 117 | pts_src = C.denormalize_pixel_coordinates(pts_src.detach(), height, width) 118 | pts_dst = C.denormalize_pixel_coordinates(pts_dst.detach(), height, width) 119 | 120 | dist = torch.cdist(pts_dst, pts_src) 121 | dist[tuple(invis_idx)] = float('nan') 122 | pcos = self.cosine(descriptors, descriptors) 123 | 124 | match = (dist <= self.radius).triu(diagonal=1) 125 | miss = (dist > self.radius).triu(diagonal=1) 126 | 127 | scores = scores.detach() 128 | score_ave = (scores[:, None, :, None] + scores[None, :, None, :]).clamp(min=self.eps) / 2 129 | pcos = self.cosine(descriptors, descriptors) 130 | 131 | sig_match = -torch.log(score_ave[match]) 132 | sig_miss = -torch.log(score_ave[miss]) 133 | 134 | s_match = pcos[match] 135 | s_miss = pcos[miss] 136 | 137 | if self.writer is not None: 138 | n_iter = self.counter.steps 139 | self.writer.add_scalars('Misc/DiscriptorMatch/Count', { 140 | 'n_match': match.sum(), 141 | 'n_miss': miss.sum(), 142 | }, n_iter) 143 | 144 | if len(sig_match) > 0: 145 | self.writer.add_histogram('Misc/DiscriptorMatch/Sim/match', s_match, n_iter) 146 | self.writer.add_histogram('Misc/DiscriptorMatch/Sim/miss', s_miss[:len(s_match)], n_iter) 147 | 148 | return self.nll(sig_match, s_match) + self.nll(sig_miss, s_miss, False, match.sum() * 2) 149 | 150 | def nll(self, sig, cos, match=True, topk=None): 151 | # p(x) = exp(-l / sig) * C; l = 1 - x if match else x 152 | norm_const = torch.log(sig * (1 - torch.exp(-1 / sig))) 153 | loss = (1 - cos if match else cos) / sig + norm_const 154 | return (loss if topk is None else loss.topk(topk).values).mean() 155 | 156 | 157 | class PairwiseCosine(nn.Module): 158 | def __init__(self, inter_batch=False, dim=-1, eps=1e-8): 159 | super(PairwiseCosine, self).__init__() 160 | self.inter_batch, self.dim, self.eps = inter_batch, dim, eps 161 | self.eqn = 'amd,bnd->abmn' if inter_batch else 'bmd,bnd->bmn' 162 | 163 | def forward(self, x, y): 164 | xx = torch.sum(x**2, dim=self.dim).unsqueeze(-1) # (A, M, 1) 165 | yy = torch.sum(y**2, dim=self.dim).unsqueeze(-2) # (B, 1, N) 166 | if self.inter_batch: 167 | xx, yy = xx.unsqueeze(1), yy.unsqueeze(0) # (A, 1, M, 1), (1, B, 1, N) 168 | xy = torch.einsum(self.eqn, x, y) 169 | return xy / (xx * yy).clamp(min=self.eps**2).sqrt() 170 | -------------------------------------------------------------------------------- /models/match.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | import torch.nn as nn 5 | 6 | import models 7 | 8 | 9 | class ConsecutiveMatch(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | self.cosine = models.PairwiseCosine() 13 | 14 | def forward(self, desc_src, desc_dst, points_dst): 15 | confidence, idx = self.cosine(desc_src, desc_dst).max(dim=2) 16 | matched = points_dst.gather(1, idx.unsqueeze(2).expand(-1, -1, 2)) 17 | 18 | return matched, confidence 19 | 20 | -------------------------------------------------------------------------------- /models/tool.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import time 4 | import torch 5 | 6 | 7 | class GlobalStepCounter(): 8 | def __init__(self, initial_step=0): 9 | self._steps = initial_step 10 | 11 | @property 12 | def steps(self): 13 | return self._steps 14 | 15 | def step(self, step=1): 16 | self._steps += 1 17 | return self._steps 18 | 19 | 20 | class Timer: 21 | def __init__(self): 22 | torch.cuda.synchronize() 23 | self.start_time = time.time() 24 | 25 | def tic(self): 26 | self.start() 27 | 28 | def show(self, prefix="", output=True): 29 | torch.cuda.synchronize() 30 | duration = time.time()-self.start_time 31 | if output: 32 | print(prefix+"%fs" % duration) 33 | return duration 34 | 35 | def toc(self, prefix=""): 36 | self.end() 37 | print(prefix+"%fs = %fHz" % (self.duration, 1/self.duration)) 38 | return self.duration 39 | 40 | def start(self): 41 | torch.cuda.synchronize() 42 | self.start_time = time.time() 43 | 44 | def end(self): 45 | torch.cuda.synchronize() 46 | self.duration = time.time()-self.start_time 47 | self.start() 48 | return self.duration 49 | 50 | 51 | def count_parameters(model): 52 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 53 | 54 | 55 | class EarlyStopScheduler(torch.optim.lr_scheduler.ReduceLROnPlateau): 56 | def __init__(self, optimizer, factor=0.1, patience=10, min_lr=0, verbose=False): 57 | super().__init__(optimizer, factor=factor, patience=patience, min_lr=min_lr, verbose=verbose) 58 | self.no_decrease = 0 59 | 60 | def step(self, metrics, epoch=None): 61 | # convert `metrics` to float, in case it's a zero-dim Tensor 62 | current = float(metrics) 63 | if epoch is None: 64 | epoch = self.last_epoch = self.last_epoch + 1 65 | self.last_epoch = epoch 66 | 67 | if self.is_better(current, self.best): 68 | self.best = current 69 | self.num_bad_epochs = 0 70 | else: 71 | self.num_bad_epochs += 1 72 | 73 | if self.in_cooldown: 74 | self.cooldown_counter -= 1 75 | self.num_bad_epochs = 0 # ignore any bad epochs in cooldown 76 | 77 | if self.num_bad_epochs > self.patience: 78 | self.cooldown_counter = self.cooldown 79 | self.num_bad_epochs = 0 80 | return self._reduce_lr(epoch) 81 | 82 | def _reduce_lr(self, epoch): 83 | for i, param_group in enumerate(self.optimizer.param_groups): 84 | old_lr = float(param_group['lr']) 85 | new_lr = max(old_lr * self.factor, self.min_lrs[i]) 86 | if old_lr - new_lr > self.eps: 87 | param_group['lr'] = new_lr 88 | if self.verbose: 89 | print('Epoch {:5d}: reducing learning rate' 90 | ' of group {} to {:.4e}.'.format(epoch, i, new_lr)) 91 | return False 92 | else: 93 | return True 94 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import sys 5 | import tqdm 6 | import copy 7 | import torch 8 | import random 9 | import warnings 10 | import argparse 11 | import numpy as np 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | from collections import deque 15 | from tensorboard import program 16 | from torchvision import transforms as T 17 | from torch.utils.data import DataLoader 18 | from torch.utils.tensorboard import SummaryWriter 19 | 20 | from datasets import AirSampler 21 | from models import FeatureNet 22 | from models import FeatureNetLoss 23 | from models import ConsecutiveMatch 24 | from models import EarlyStopScheduler 25 | from utils import MatchEvaluator, Visualizer 26 | from datasets import TartanAir, TartanAirTest, AirAugment 27 | from models import Timer, count_parameters, GlobalStepCounter 28 | 29 | 30 | @torch.no_grad() 31 | def evaluate(net, evaluator, loader, args): 32 | net.eval() 33 | for images, depths, poses, K, env_seq in tqdm.tqdm(loader): 34 | images = images.to(args.device) 35 | depths = depths.to(args.device) 36 | poses = poses.to(args.device) 37 | K = K.to(args.device) 38 | descriptors, points, pointness, scores = net(images) 39 | evaluator.observe(descriptors, points, scores, pointness, depths, poses, K, images, env_seq) 40 | 41 | evaluator.report() 42 | 43 | 44 | def train(net, loader, criterion, optimizer, counter, args=None, loss_ave=50, eval_loader=None, evaluator=None): 45 | net.train() 46 | train_loss, batches = deque(), len(loader) 47 | enumerator = tqdm.tqdm(loader) 48 | for images, depths, poses, K, env_seq in enumerator: 49 | images = images.to(args.device) 50 | depths = depths.to(args.device) 51 | poses = poses.to(args.device) 52 | K = K.to(args.device) 53 | optimizer.zero_grad() 54 | descriptors, points, pointness, scores = net(images) 55 | loss = criterion(descriptors, points, scores, pointness, depths, poses, K, images, env_seq[0]) 56 | loss.backward() 57 | optimizer.step() 58 | 59 | if np.isnan(loss.item()): 60 | print('Warning: loss is nan during iteration %d. BP skipped.' % counter.steps) 61 | else: 62 | train_loss.append(loss.item()) 63 | if len(train_loss) > loss_ave: 64 | train_loss.popleft() 65 | enumerator.set_description("Loss: %.4f at"%(np.average(train_loss))) 66 | 67 | if evaluator is not None and counter.steps % args.eval_freq == 0: 68 | evaluate(net, evaluator, eval_loader, args) 69 | net.train() 70 | 71 | counter.step() 72 | 73 | return np.average(train_loss) 74 | 75 | 76 | if __name__ == "__main__": 77 | # Arguements 78 | parser = argparse.ArgumentParser(description='Feature Graph Networks') 79 | parser.add_argument("--device", type=str, default='cuda:0', help="cuda or cpu") 80 | parser.add_argument("--dataset", type=str, default='tartanair', help="TartanAir") 81 | parser.add_argument("--data-root", type=str, default='/data/datasets/tartanair', help="data location") 82 | parser.add_argument("--dataset-catalog", type=str, default='./.cache/tartanair-sequences.pbz2', help='dataset bookkeeping cache') 83 | parser.add_argument("--log-dir", type=str, default=None, help="TensorBoard log dir") 84 | parser.add_argument("--method", type=str, choices=["FGN", "GAT"], default="FGN", help="Method to train and evaluate") 85 | parser.add_argument("--load", type=str, default=None, help="load pretrained model") 86 | parser.add_argument("--save", type=str, default='./saved_models/featurenet.pth', help="model file to save") 87 | parser.add_argument("--feat-dim", type=int, default=256, help="feature dimension") 88 | parser.add_argument("--feat-num", type=int, default=300, help="feature number") 89 | parser.add_argument('--scale', type=float, default=0.5, help='image resize') 90 | parser.add_argument("--lr", type=float, default=1e-5, help="learning rate") 91 | parser.add_argument("--min-lr", type=float, default=1e-6, help="learning rate") 92 | parser.add_argument("--factor", type=float, default=0.1, help="factor of lr") 93 | parser.add_argument("--momentum", type=float, default=0.9, help="momentum of optim") 94 | parser.add_argument("--w-decay", type=float, default=0, help="weight decay of optim") 95 | parser.add_argument("--epoch", type=int, default=15, help="number of epoches") 96 | parser.add_argument("--batch-size", type=int, default=8, help="minibatch size") 97 | parser.add_argument("--patience", type=int, default=5, help="training patience") 98 | parser.add_argument("--num-workers", type=int, default=4, help="workers of dataloader") 99 | parser.add_argument("--seed", type=int, default=0, help='Random seed.') 100 | parser.add_argument("--viz_start", type=int, default=np.inf, help='Visualize starting from iteration') 101 | parser.add_argument("--viz_freq", type=int, default=1, help='Visualize every * iteration(s)') 102 | parser.add_argument("--eval-split-seed", type=int, default=42, help='Seed for splitting the dataset') 103 | parser.add_argument("--eval-percentage", type=float, default=0.2, help='Percentage of sequences for eval') 104 | parser.add_argument("--eval-freq", type=int, default=5000, help='Evaluate every * steps') 105 | parser.add_argument("--eval-topk", type=int, default=150, help='Only inspect top * matches') 106 | parser.add_argument("--eval-back", type=int, nargs='+', default=[1], help='Evaluate by matching each frame with * frames ago') 107 | args = parser.parse_args(); print(args) 108 | torch.manual_seed(args.seed) 109 | torch.cuda.manual_seed(args.seed) 110 | np.random.seed(args.seed) 111 | random.seed(args.seed) 112 | 113 | train_data, test_data = TartanAir(args.data_root, args.scale, catalog_path=args.dataset_catalog) \ 114 | .rand_split([1 - args.eval_percentage, args.eval_percentage], args.eval_split_seed) 115 | test_data.augment = AirAugment(args.scale, resize_only=True) 116 | 117 | train_sampler = AirSampler(train_data, args.batch_size, shuffle=True) 118 | test_sampler = AirSampler(test_data, args.batch_size, shuffle=False, overlap=False) 119 | 120 | train_loader = DataLoader(train_data, batch_sampler=train_sampler, pin_memory=True, num_workers=args.num_workers) 121 | eval_loader = DataLoader(test_data, batch_sampler=test_sampler, pin_memory=True, num_workers=args.num_workers) 122 | 123 | writer = None 124 | if args.log_dir is not None: 125 | from datetime import datetime 126 | current_time = datetime.now().strftime('%b%d_%H-%M-%S') 127 | writer = SummaryWriter(os.path.join(args.log_dir, current_time)) 128 | tb = program.TensorBoard() 129 | tb.configure(argv=[None, '--logdir', args.log_dir, '--bind_all']) 130 | print(('TensorBoard at %s \n' % tb.launch())) 131 | 132 | step_counter = GlobalStepCounter(initial_step=1) 133 | criterion = FeatureNetLoss(writer=writer, viz_start=args.viz_start, viz_freq=args.viz_freq, counter=step_counter) 134 | net = FeatureNet(args.feat_dim, args.feat_num, graph=args.method).to(args.device) if args.load is None else torch.load(args.load, args.device) 135 | if not isinstance(net, nn.DataParallel): 136 | net = nn.DataParallel(net) 137 | optimizer = optim.RMSprop(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.w_decay) 138 | scheduler = EarlyStopScheduler(optimizer, factor=args.factor, verbose=True, min_lr=args.min_lr, patience=args.patience) 139 | 140 | evaluator = MatchEvaluator(back=args.eval_back, viz=None, top=args.eval_topk, writer=writer, counter=step_counter) 141 | 142 | timer = Timer() 143 | for epoch in range(args.epoch): 144 | train_acc = train(net, train_loader, criterion, optimizer, step_counter, args, eval_loader=eval_loader, evaluator=evaluator) 145 | 146 | if args.save is not None: 147 | os.makedirs(os.path.dirname(args.save), exist_ok=True) 148 | save_path, save_file_dup = args.save, 0 149 | while os.path.exists(save_path): 150 | save_file_dup += 1 151 | save_path = args.save + '.%d' % save_file_dup 152 | torch.save(net, save_path) 153 | print('Saved model: %s' % save_path) 154 | 155 | if scheduler.step(1-train_acc): 156 | print('Early Stopping!') 157 | break 158 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from .visualization import Visualizer 4 | 5 | from .geometry import Projector 6 | 7 | from .evaluation import MatchEvaluator 8 | -------------------------------------------------------------------------------- /utils/evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pickle, bz2 4 | import torch.nn as nn 5 | from collections import deque 6 | from .geometry import Projector 7 | import kornia.geometry.conversions as C 8 | from prettytable import PrettyTable, MARKDOWN 9 | from models import ConsecutiveMatch, GridSample 10 | 11 | 12 | class MatchEvaluator(): 13 | def __init__(self, back=[1], viz=None, viz_dist_min=0, viz_dist_max=100, top=None, writer=None, counter=None): 14 | self.back, self.top = back, top 15 | self.viz, self.viz_min, self.viz_max = viz, viz_dist_min, viz_dist_max 16 | self.hist = [] 17 | self.grid_sample = GridSample() 18 | self.match = ConsecutiveMatch() 19 | self.error = {b: [] for b in back} 20 | self.writer = writer 21 | self.counter = counter if counter is not None else GlobalStepCounter() 22 | self.cur_env = None 23 | self.env_err_seg = {} 24 | 25 | @torch.no_grad() 26 | def observe(self, descriptors, points, scores, score_map, depth_map, poses, Ks, imgs, env_seq): 27 | B, N, _ = points.shape 28 | _, _, H, W = imgs.shape 29 | top = N if self.top is None else self.top 30 | 31 | env_seq = "_".join(zip(*env_seq).__next__()) 32 | if self.cur_env != env_seq: 33 | last_env, self.cur_env = self.cur_env, env_seq 34 | self.hist = [] 35 | n_batches = len(self.error[self.back[0]]) 36 | self.env_err_seg[self.cur_env] = [n_batches, -1] 37 | if last_env is not None: 38 | self.env_err_seg[last_env][1] = n_batches 39 | 40 | # populate hist until sufficient 41 | depths = self.grid_sample((depth_map, points)).squeeze(-1) 42 | for img, desc, pt, pose, K, depth in zip(imgs, descriptors, points, poses, Ks, depths): 43 | self.hist.insert(0, (img, desc, pt, pose, K, depth)) 44 | if len(self.hist) - B < max(self.back): 45 | return 46 | 47 | viz_img1, viz_img2, viz_pts1, viz_pts2, viz_val = [], [], [], [], [] 48 | imgs_new, descs_new, pts_new, poses_new, Ks_new, depths_new = self._unpack_hist(reversed(self.hist[:B])) 49 | for b in self.back: 50 | imgs_old, descs_old, pts_old, poses_old, Ks_old, depths_old = self._unpack_hist(reversed(self.hist[b:b+B])) 51 | matched_new_pt, conf = self.match(descs_old, descs_new, pts_new) 52 | match_conf, top_idx = conf.topk(top) 53 | match_src, match_dst = self._point_select(pts_old, top_idx), self._point_select(matched_new_pt, top_idx) 54 | depths_old = self._point_select(depths_old, top_idx) 55 | error = reproj_error(match_src, Ks_old, poses_old, depths_old, match_dst, Ks_new, poses_new, H, W) 56 | self.error[b].append(error) 57 | if self.viz is not None: 58 | viz_img1.append(imgs_old), viz_img2.append(imgs_new) 59 | viz_pts1.append(match_src), viz_pts2.append(match_dst) 60 | viz_val.append(error) 61 | 62 | if self.viz is not None: 63 | # B, back, * 64 | viz_img1, viz_img2 = torch.stack(viz_img1, dim=1), torch.stack(viz_img2, dim=1) 65 | viz_pts1, viz_pts2 = torch.stack(viz_pts1, dim=1), torch.stack(viz_pts2, dim=1) 66 | viz_val = torch.stack(viz_val, dim=1).detach().cpu().numpy() 67 | for img1, img2, pts1, pts2, val in zip(viz_img1, viz_img2, viz_pts1, viz_pts2, viz_val): 68 | self.viz.showmatch(img1, pts1, img2, pts2, 'hot', val, self.viz_min, self.viz_max, name='backs') 69 | break 70 | 71 | self.hist = self.hist[:-B] 72 | 73 | return self.error 74 | 75 | def ave_reproj_error(self, quantile=None, env=None): 76 | mean, quantiles = {}, {} 77 | for b in self.back: 78 | seg = self.env_err_seg[env] if env is not None else [0, -1] 79 | error = torch.cat(self.error[b][seg[0]:seg[1]]) 80 | mean[b] = error.mean().item() 81 | if quantile is not None: 82 | quantiles[b] = torch.quantile(error, torch.tensor(quantile).to(error)).tolist() 83 | return (mean, quantiles) if quantile is not None else mean 84 | 85 | def ave_prec(self, thresh=1, env=None): 86 | perc = {} 87 | for b in self.back: 88 | seg = self.env_err_seg[env] if env is not None else [0, -1] 89 | perc[b] = (torch.cat(self.error[b][seg[0]:seg[1]]) < 1).to(torch.float).mean().item() 90 | return perc 91 | 92 | def report(self, err_quant=[0.5, 0.9]): 93 | result = PrettyTable(['Env Name', 'n-back', 94 | 'Mean Err (%s)' % self._fmt_list(np.array(err_quant) * 100, '%d%%'), 'Ave Prec']) 95 | result.float_format['Ave Prec'] = '.2' 96 | result.align['Env Name'] = 'r' 97 | 98 | all_mean, all_quantiles = self.ave_reproj_error(quantile=err_quant) 99 | prec = self.ave_prec() 100 | result.add_rows([['All', b, '%.2f (%s)' % (all_mean[b], self._fmt_list(all_quantiles[b], '%.2f')), 101 | prec[b]] for b in self.back]) 102 | 103 | for e in self.env_err_seg: 104 | env_mean, env_quantiles = self.ave_reproj_error(quantile=err_quant, env=e) 105 | env_prec = self.ave_prec(env=e) 106 | result.add_rows([[e, b, '%.2f (%s)' % (env_mean[b], self._fmt_list(env_quantiles[b], '%.2f')), 107 | env_prec[b]] for b in self.back]) 108 | 109 | n_iter = self.counter.steps 110 | if self.writer is not None: 111 | self.writer.add_scalars('Eval/Match/MeanErr', {'%d-back' % b: v for b, v in all_mean.items()}, n_iter) 112 | self.writer.add_scalars('Eval/Match/AvePrec', {'%d-back' % b: v for b, v in prec.items()}, n_iter) 113 | for b in self.back: 114 | self.writer.add_histogram('Eval/Match/LogErr/%d-back' % b, 115 | torch.log10(torch.cat(self.error[b]).clamp(min=1e-10)), n_iter) 116 | # TensorBoard supports markdown table although column alignment is broken 117 | result.set_style(MARKDOWN) 118 | self.writer.add_text('Eval/Match/PerSeq', result.get_string(sortby='n-back'), n_iter) 119 | else: 120 | print('Evaluation: step %d' % n_iter) 121 | print(result.get_string(sortby='n-back')) 122 | 123 | # clear history 124 | self.error = {b: [] for b in self.back} 125 | self.hist = [] 126 | self.cur_env = None 127 | self.env_err_seg = {} 128 | 129 | @staticmethod 130 | def _fmt_list(elems, fmt, delim=', '): 131 | return delim.join([fmt % e for e in elems]) 132 | 133 | @staticmethod 134 | def _unpack_hist(hist): 135 | return [torch.stack(attr, dim=0) for attr in zip(*hist)] 136 | 137 | @staticmethod 138 | def _point_select(attr, idx): 139 | B, N = idx.shape 140 | return attr.gather(1, idx.view(B, N, *([1] * (len(attr.shape) - 2))).expand(B, N, *attr.shape[2:])) 141 | 142 | def save_error(self, file_path): 143 | error = {k: [v.detach().cpu().numpy() for v in vs] for k, vs in self.error.items()} 144 | with bz2.BZ2File(file_path, 'wb') as f: 145 | pickle.dump([error, self.env_err_seg], f) 146 | 147 | 148 | def reproj_error(pts_src, K_src, pose_src, depth_src, pts_dst, K_dst, pose_dst, H, W): 149 | projector = Projector() 150 | cam_src = Projector._make_camera(H, W, K_src, pose_src) 151 | cam_dst = Projector._make_camera(H, W, K_dst, pose_dst) 152 | pts_prj, _ = Projector._project_points(pts_src, depth_src, cam_src, cam_dst) 153 | diff = C.denormalize_pixel_coordinates(pts_prj, H, W) - C.denormalize_pixel_coordinates(pts_dst, H, W) 154 | return torch.hypot(diff[..., 0], diff[..., 1]) 155 | -------------------------------------------------------------------------------- /utils/geometry.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | import numpy as np 5 | import kornia as kn 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class Projector(nn.Module): 11 | def __init__(self, eps=1e-2, max_depth=1000): 12 | super().__init__() 13 | self.eps, self.max_depth = eps, max_depth 14 | 15 | def forward(self, points, depth_map, cam_src, cam_dst, depth_map_dst=None): 16 | B, N, _ = points.shape 17 | assert B == len(depth_map) == cam_src.batch_size == cam_dst.batch_size 18 | 19 | depths = self._sample_depths(depth_map, points) 20 | proj_p, proj_depth = self._project_points(points, depths, cam_src, cam_dst) 21 | 22 | # check for occlusion 23 | is_out_of_bound = torch.any((proj_p < -1) | (proj_p > 1), dim=-1) 24 | 25 | depths_dst = self._sample_depths(depth_map_dst, proj_p) if depth_map_dst is not None else self.max_depth 26 | is_occluded = proj_depth.isnan() | (proj_depth > depths_dst + self.eps) | (depths_dst > self.max_depth) 27 | 28 | return proj_p, torch.nonzero(is_out_of_bound | is_occluded, as_tuple=True) 29 | 30 | def cartesian(self, points, depth_map, poses, Ks): 31 | """Projects points from every view to every other view.""" 32 | B, N = points.shape[:2] 33 | H, W = depth_map.shape[2:] 34 | # TODO don't self-project 35 | # [0 0 0 ... 1 1 1 ... 2 2 2 ...] vs [0 1 2 ... 0 1 2 ... 0 1 2 ...] 36 | depths_rep = self._src_repeat(depth_map) 37 | points_rep = self._src_repeat(points) 38 | 39 | cam_src = self._make_camera(H, W, self._src_repeat(Ks), self._src_repeat(poses)) 40 | cam_dst = self._make_camera(H, W, self._dst_repeat(Ks), self._dst_repeat(poses)) 41 | 42 | proj_p, (pair_idx, point_idx) = self(points_rep, depths_rep, cam_src, cam_dst, depths_rep) 43 | return proj_p.reshape(B, B, N, 2), torch.stack([pair_idx // B, pair_idx % B, point_idx]) 44 | 45 | def pix2world(self, points, depth_map, poses, Ks): 46 | """Unprojects pixels to 3D coordinates.""" 47 | H, W = depth_map.shape[2:] 48 | cam = self._make_camera(H, W, Ks, poses) 49 | depths = self._sample_depths(depth_map, points) 50 | return self._pix2world(points, depths, cam) 51 | 52 | @staticmethod 53 | def _make_camera(height, width, K, pose): 54 | """Creates a PinholeCamera with specified intrinsics and extrinsics.""" 55 | intrinsics = torch.eye(4, 4).to(K).repeat(len(K), 1, 1) 56 | intrinsics[:, 0:3, 0:3] = K 57 | 58 | extrinsics = torch.eye(4, 4).to(pose).repeat(len(pose), 1, 1) 59 | extrinsics[:, 0:3, 0:4] = pose 60 | 61 | height, width = torch.tensor([height]).to(K), torch.tensor([width]).to(K) 62 | 63 | return kn.PinholeCamera(intrinsics, extrinsics, height, width) 64 | 65 | @staticmethod 66 | def _pix2world(p, depth, cam): 67 | """Projects p to world coordinate. 68 | 69 | Args: 70 | p: List of points in pixels (B, N, 2). 71 | depth: Depth of each point(B, N). 72 | cam: Camera with batch size B 73 | 74 | Returns: 75 | World coordinate of p (B, N, 3). 76 | """ 77 | p = kn.denormalize_pixel_coordinates(p, int(cam.height), int(cam.width)) 78 | p_h = kn.convert_points_to_homogeneous(p) 79 | p_cam = kn.transform_points(cam.intrinsics_inverse(), p_h) * depth.unsqueeze(-1) 80 | return kn.transform_points(kn.inverse_transformation(cam.extrinsics), p_cam) 81 | 82 | @staticmethod 83 | def _world2pix(p_w, cam): 84 | """Projects p to normalized camera coordinate. 85 | 86 | Args: 87 | p_w: List of points in world coordinate (B, N, 3). 88 | cam: Camera with batch size B 89 | 90 | Returns: 91 | Normalized coordinates of p in pose cam_dst (B, N, 2) and screen depth (B, N). 92 | """ 93 | proj = kn.compose_transformations(cam.intrinsics, cam.extrinsics) 94 | p_h = kn.transform_points(proj, p_w) 95 | p, d = kn.convert_points_from_homogeneous(p_h), p_h[..., 2] 96 | return kn.normalize_pixel_coordinates(p, int(cam.height), int(cam.width)), d 97 | 98 | @staticmethod 99 | def _project_points(p, depth_src, cam_src, cam_dst): 100 | """Projects p visible in pose T_p to pose T_q. 101 | 102 | Args: 103 | p: List of points in pixels (B, N, 2). 104 | depth: Depth of each point(B, N). 105 | cam_src, cam_dst: Source and destination cameras with batch size B 106 | 107 | Returns: 108 | Normalized coordinates of p in pose cam_dst (B, N, 2). 109 | """ 110 | return Projector._world2pix(Projector._pix2world(p, depth_src, cam_src), cam_dst) 111 | 112 | @staticmethod 113 | def _sample_depths(depths_map, points): 114 | """Samples the depth of each point in points""" 115 | assert depths_map.shape[:2] == (len(points), 1) 116 | return F.grid_sample(depths_map, points[:, None], align_corners=False)[:, 0, 0, ...] 117 | 118 | @staticmethod 119 | def _src_repeat(x): 120 | """[b0 b1 b2 ...] -> [b0 b0 ... b1 b1 ...]""" 121 | B, shape = x.shape[0], x.shape[1:] 122 | return x.unsqueeze(1).expand(B, B, *shape).reshape(B**2, *shape) 123 | 124 | @staticmethod 125 | def _dst_repeat(x): 126 | """[b0 b1 b2 ...] -> [b0 b1 ... b0 b1 ...]""" 127 | B, shape = x.shape[0], x.shape[1:] 128 | return x.unsqueeze(0).expand(B, B, *shape).reshape(B**2, *shape) 129 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import cv2 5 | import torch 6 | import numpy as np 7 | import torchvision 8 | from matplotlib import cm 9 | import matplotlib.colors as mc 10 | import kornia.geometry.conversions as C 11 | 12 | 13 | class Visualizer(): 14 | vis_id = 0 15 | 16 | def __init__(self, display='imshow', default_name=None, **kwargs): 17 | self.radius, self.thickness = 1, 1 18 | self.default_name = 'Visualizer %d' % self.vis_id if default_name is None else default_name 19 | Visualizer.vis_id += 1 20 | 21 | if display == 'imshow': 22 | self.displayer = ImshowDisplayer() 23 | elif display == 'tensorboard': 24 | self.displayer = TBDisplayer(**kwargs) 25 | elif display == 'video': 26 | self.displayer = VideoFileDisplayer(**kwargs) 27 | 28 | def show(self, images, points=None, color='red', nrow=4, values=None, vmin=None, vmax=None, name=None, step=0): 29 | b, c, h, w = images.shape 30 | if c == 3: 31 | images = torch2cv(images) 32 | elif c == 1: # show colored values 33 | images = images.detach().cpu().numpy().transpose((0, 2, 3, 1)) 34 | images = get_colors(color, images.squeeze(-1), vmin, vmax) 35 | 36 | if points is not None: 37 | points = C.denormalize_pixel_coordinates(points, h, w).to(torch.int) 38 | for i, pts in enumerate(points): 39 | colors = get_colors(color, [0]*len(pts) if values is None else values[i], vmin, vmax) 40 | images[i] = circles(images[i], pts, self.radius, colors, self.thickness) 41 | 42 | disp_name = name if name is not None else self.default_name 43 | 44 | if nrow is not None: 45 | images = torch.tensor(images.copy()).permute((0, 3, 1, 2)) 46 | grid = torchvision.utils.make_grid(images, nrow=nrow, padding=1).permute((1, 2, 0)) 47 | self.displayer.display(disp_name, grid.numpy(), step) 48 | else: 49 | for i, img in enumerate(images): 50 | self.displayer.display(disp_name + str(i), img, step) 51 | 52 | def showmatch(self, imges1, points1, images2, points2, color='blue', values=None, vmin=None, vmax=None, name=None, step=0): 53 | match_pairs = [] 54 | for i, (img1, pts1, img2, pts2) in enumerate(zip(imges1, points1, images2, points2)): 55 | assert len(pts1) == len(pts2) 56 | h, w = img1.size(-2), img1.size(-1) 57 | pts1 = C.denormalize_pixel_coordinates(pts1, h, w) 58 | pts2 = C.denormalize_pixel_coordinates(pts2, h, w) 59 | img1, img2 = torch2cv(torch.stack([img1, img2])) 60 | colors = get_colors(color, [0]*len(pts1) if values is None else values[i], vmin, vmax) 61 | match_pairs.append(torch.tensor(matches(img1, pts1, img2, pts2, colors))) 62 | 63 | images = torch.stack(match_pairs).permute((0, 3, 1, 2)) 64 | grid = torchvision.utils.make_grid(images, nrow=2, padding=1).permute((1, 2, 0)) 65 | self.displayer.display(name if name is not None else self.default_name, grid.numpy(), step) 66 | 67 | def reprojectshow(self, imgs, pts_src, pts_dst, src, dst): 68 | # TODO not adapted for change in torch2cv 69 | pts_src, pts_dst = pts_src[src], pts_src[dst] 70 | for i in range(src[0].size(0)): 71 | pts1 = pts_src[i].unsqueeze(0) 72 | pts2 = pts_dst[i].unsqueeze(0) 73 | img1 = torch2cv(imgs[src[0][i]]).copy() 74 | img2 = torch2cv(imgs[dst[0][i]]).copy() 75 | image = matches(img1,pts1,img2,pts2,self.blue,2) 76 | cv2.imshow(self.winname+'-dst', image) 77 | cv2.waitKey(1) 78 | 79 | def close(self): 80 | self.displayer.close() 81 | 82 | 83 | class VisDisplayer(): 84 | def display(self, name, frame, step=0): 85 | raise NotImplementedError() 86 | 87 | def close(self): 88 | pass 89 | 90 | 91 | class ImshowDisplayer(VisDisplayer): 92 | def display(self, name, frame, step=0): 93 | cv2.imshow(name, frame) 94 | cv2.waitKey(1) 95 | 96 | def close(self): 97 | cv2.destroyAllWindows() 98 | 99 | 100 | class TBDisplayer(VisDisplayer): 101 | def __init__(self, writer): 102 | self.writer = writer 103 | 104 | def display(self, name, frame, step=0): 105 | self.writer.add_image(name, frame[:, :, ::-1], step, dataformats='HWC') 106 | 107 | 108 | class VideoFileDisplayer(VisDisplayer): 109 | def __init__(self, save_dir=None, framerate=10): 110 | if save_dir is None: 111 | from datetime import datetime 112 | current_time = datetime.now().strftime('%b%d_%H-%M-%S') 113 | self.save_dir = os.path.join('.', 'vidout', current_time) 114 | else: 115 | self.save_dir = save_dir 116 | self.framerate = framerate 117 | self.writer = {} 118 | 119 | def display(self, name, frame, step=0): 120 | if name not in self.writer: 121 | os.makedirs(self.save_dir, exist_ok=True) 122 | self.writer[name] = cv2.VideoWriter(os.path.join(self.save_dir, '%s.avi' % name), 123 | cv2.VideoWriter_fourcc(*'avc1'), 124 | self.framerate, (frame.shape[1], frame.shape[0])) 125 | self.writer[name].write(frame) 126 | 127 | def close(self): 128 | for wn in self.writer: 129 | self.writer[wn].release() 130 | 131 | 132 | def matches(img1, pts1, img2, pts2, colors, circ_radius=3, thickness=1): 133 | ''' Assume pts1 are matched with pts2, respectively. 134 | ''' 135 | H1, W1, C = img1.shape 136 | H2, W2, _ = img2.shape 137 | new_img = np.zeros((max(H1, H2), W1 + W2, C), img1.dtype) 138 | new_img[:H1, :W1], new_img[:H2, W1:W1+W2] = img1, img2 139 | new_img = circles(new_img, pts1, circ_radius, colors, thickness) 140 | pts2[:, 0] += W1 141 | new_img = circles(new_img, pts2, circ_radius, colors, thickness) 142 | return lines(new_img, pts1, pts2, colors, thickness) 143 | 144 | 145 | def circles(image, points, radius, colors, thickness): 146 | for pt, c in zip(points, colors): 147 | if not torch.any(pt.isnan()): 148 | image = cv2.circle(image.copy(), tuple(pt), radius, tuple(c.tolist()), thickness, cv2.LINE_AA) 149 | return image 150 | 151 | 152 | def lines(image, pts1, pts2, colors, thickness): 153 | for pt1, pt2, c in zip(pts1, pts2, colors): 154 | if not torch.any(pt1.isnan() | pt2.isnan()): 155 | image = cv2.line(image.copy(), tuple(pt1), tuple(pt2), tuple(c.tolist()), thickness, cv2.LINE_AA) 156 | return image 157 | 158 | 159 | def get_colors(name, values=[0], vmin=None, vmax=None): 160 | if name in mc.get_named_colors_mapping(): 161 | rgb = mc.to_rgba_array(name)[0, :3] 162 | rgb = np.tile(rgb, (len(values), 1)) 163 | else: 164 | values = np.array(values) 165 | normalize = mc.Normalize(vmin=vmin, vmax=vmax) 166 | cmap = cm.get_cmap(name) 167 | rgb = cmap(normalize(values)) 168 | return (rgb[..., 2::-1] * 255).astype(np.uint8) 169 | 170 | 171 | def torch2cv(images): 172 | rgb = (255 * images).type(torch.uint8).cpu().numpy() 173 | bgr = rgb[:, ::-1, ...].transpose((0, 2, 3, 1)) 174 | return bgr 175 | --------------------------------------------------------------------------------