├── .gitignore ├── .style.yapf ├── README.md ├── baseline └── model │ ├── __init__.py │ ├── lfgc.py │ └── oanet.py ├── config.py ├── config ├── test_yfcc.txt ├── train_yfcc.txt └── val_yfcc.txt ├── demo_2d.py ├── imgs ├── 54990444_8865247484.jpg ├── 57895226_4857581382.jpg ├── 68833924_5994205213.jpg ├── calibration_000002.h5 ├── calibration_000344.h5 ├── calibration_000489.h5 ├── demo_0.png ├── demo_1.png └── demo_2.png ├── lib ├── __init__.py ├── all_data_loaders.py ├── eval.py ├── lfgc_trainer.py ├── oa_trainer.py ├── timer.py ├── trainer.py ├── twodim_data_loaders.py ├── twodim_trainer.py ├── util.py ├── util_2d.py └── util_data.py ├── model ├── __init__.py ├── common.py ├── pyramidnet.py ├── residual_block.py ├── resnetsc.py ├── resunet.py └── simpleunet.py ├── requirements.txt ├── scripts ├── benchmark_yfcc.py ├── download_yfcc.sh ├── gen_2d.py ├── plot_yfcc.py ├── train_2d.sh ├── train_2d_onpaper.sh ├── train_lfgc.sh ├── train_oa.sh └── train_oa_onpaper.sh ├── test.py ├── train.py ├── ucn ├── blocks.py └── resunet.py └── util ├── __init__.py └── file.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Temp files 2 | __pycache__ 3 | *.swp 4 | *.swo 5 | *.orig 6 | .idea 7 | outputs/ 8 | exp/ 9 | *.pyc 10 | .vscode 11 | .ipynb_checkpoints 12 | visualize/ 13 | local_scripts/ 14 | figures/ 15 | .vim/ 16 | *.pth 17 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = pep8 3 | column_limit = 88 4 | indent_width = 2 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # High-dimensional Convolutional Networks for Geometric Pattern Recognition 2 | 3 | ### Image Correspondences 4 | 5 | #### Download YFCC100M Dataset 6 | 7 | ``` 8 | bash scripts/download_yfcc.sh /path/to/yfcc100m 9 | ``` 10 | 11 | #### Preprocess YFCC100M Dataset 12 | 13 | SIFT 14 | 15 | ``` 16 | python -m scripts.gen_2d \ 17 | --source /path/to/yfcc100m \ 18 | --target /path/to/save/processed/dataset \ 19 | 20 | ``` 21 | 22 | UCN 23 | 24 | ``` 25 | python -m scripts.gen_2d \ 26 | --source /path/to/yfcc100m \ 27 | --target /path/to/save/processed/dataset \ 28 | --feature ucn \ 29 | --onthefly \ 30 | --ucn_weight /path/to/pretrained/ucn/weight 31 | ``` 32 | 33 | #### Training Network 34 | 35 | Train an image correspondence network. 36 | 37 | ``` 38 | bash scripts/train_2d.sh "-experiment1" \ 39 | "--data_dir_raw /path/to/raw/yfcc \ 40 | --data_dir_processed /path/to/processed/yfcc" 41 | ``` 42 | 43 | #### Testing on YFCC100M Dataset 44 | 45 | ``` 46 | python -m scripts.benchmark_yfcc \ 47 | --data_dir_raw /path/to/yfcc100m \ 48 | --data_dir_processed /path/to/processed/dataset \ 49 | --weights /path/to/checkpoint \ 50 | --out_dir /path/to/save/outputs \ 51 | --do_extract 52 | ``` 53 | 54 | #### Demo on YFCC100M Dataset 55 | 56 | Following demo_2d script will download UCN and our best model(PyramidNetSCNoBlock) weights and test it on few pairs of images. The visualization output will be saved on './visualize' directory. 57 | 58 | ``` 59 | python demo_2d.py 60 | ``` 61 | 62 | ![demo0](imgs/demo_0.png) 63 | ![demo1](imgs/demo_1.png) 64 | 65 | #### Model Zoo 66 | 67 | | Model | Dataset | Link | 68 | | ----- | ------- | ---- | 69 | | PyramidNetSCNoBlock | YFCC100MDatasetUCN | [download](http://cvlab.postech.ac.kr/research/hcngpr/data/2d_pyramid_ucn.pth) | 70 | | ResNetSC | YFCC100MDatasetExtracted | [download](http://cvlab.postech.ac.kr/research/hcngpr/data/2d_resnetsc.pth) 71 | | ResUNetINBN2G | YFCC100MDatasetExtracted | [download](http://cvlab.postech.ac.kr/research/hcngpr/data/2d_resunet.pth) | 72 | | OANet | YFCC100MDatasetExtracted | [download](http://cvlab.postech.ac.kr/research/hcngpr/data/2d_oa.pth) | 73 | | LFGCNet | YFCC100MDatasetExtracted | [download](http://cvlab.postech.ac.kr/research/hcngpr/data/2d_lfgc.pth) | 74 | 75 | #### Raw data for Fig# 76 | 77 | [Prec-Recall](http://cvlab.postech.ac.kr/research/hcngpr/data/prec_recall_raw.txt) 78 | 79 | --- 80 | -------------------------------------------------------------------------------- /baseline/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/HighDimConvNets/bd8f03150b4d639db61109a93c37f3be0dcaec38/baseline/model/__init__.py -------------------------------------------------------------------------------- /baseline/model/lfgc.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from lib.util_2d import weighted_8points 9 | 10 | 11 | class ClassificationLoss(nn.Module): 12 | 13 | def forward(self, logits, labels): 14 | is_pos = labels.to(torch.bool) 15 | is_neg = (~is_pos) 16 | is_pos = is_pos.to(torch.float) 17 | is_neg = is_neg.to(torch.float) 18 | c = is_pos - is_neg 19 | 20 | loss = -F.logsigmoid(c * logits) 21 | num_pos = F.relu(torch.sum(is_pos, dim=1) - 1) + 1 22 | num_neg = F.relu(torch.sum(is_neg, dim=1) - 1) + 1 23 | 24 | loss_pos = torch.sum(loss * is_pos, dim=1) 25 | loss_neg = torch.sum(loss * is_neg, dim=1) 26 | 27 | balanced_loss = torch.mean(loss_pos * 0.5 / num_pos + loss_neg * 0.5 / num_neg) 28 | return balanced_loss 29 | 30 | 31 | class RegressionLoss(nn.Module): 32 | 33 | def forward(self, logits, coords, e_gt): 34 | e = weighted_8points(coords, logits) 35 | e_gt = torch.reshape(e_gt, (logits.shape[0], 9)) 36 | e_gt = e_gt / torch.norm(e_gt, dim=1, keepdim=True) 37 | 38 | loss = torch.mean( 39 | torch.min(torch.sum((e - e_gt)**2, dim=1), torch.sum((e + e_gt)**2, dim=1))) 40 | return loss 41 | 42 | 43 | class LFGCLoss(nn.Module): 44 | 45 | def __init__(self, alpha, beta, regression_iter): 46 | super(LFGCLoss, self).__init__() 47 | self.alpha = alpha 48 | self.beta = beta 49 | self.regression_iter = regression_iter 50 | 51 | def forward(self, logits, coords, labels, e_gt, iteration): 52 | ClsLoss = ClassificationLoss() 53 | RegLoss = RegressionLoss() 54 | 55 | cls_loss = ClsLoss(logits, labels) 56 | 57 | if iteration > self.regression_iter: 58 | reg_loss = RegLoss(logits, coords, e_gt) 59 | loss = cls_loss * self.alpha + reg_loss * self.beta 60 | else: 61 | loss = cls_loss * self.alpha 62 | 63 | return loss 64 | 65 | 66 | class ContextNorm(nn.Module): 67 | 68 | def __init__(self, eps): 69 | super(ContextNorm, self).__init__() 70 | self.eps = eps 71 | 72 | def forward(self, x): 73 | variance, mean = torch.var_mean(x, dim=2, keepdim=True) 74 | std = torch.sqrt(variance) 75 | return (x - mean) / (std + self.eps) 76 | 77 | 78 | class ResNetBlock(nn.Module): 79 | 80 | def __init__(self, in_channel, out_channel, kernel_size, stride): 81 | super(ResNetBlock, self).__init__() 82 | self.conv1 = nn.Conv1d(in_channel, out_channel, kernel_size, stride) 83 | self.cn1 = ContextNorm(eps=1e-3) 84 | self.bn1 = nn.BatchNorm1d(out_channel, eps=1e-3, momentum=0.99) 85 | self.conv2 = nn.Conv1d(out_channel, out_channel, kernel_size, stride) 86 | self.cn2 = ContextNorm(eps=1e-3) 87 | self.bn2 = nn.BatchNorm1d(out_channel, eps=1e-3, momentum=0.99) 88 | 89 | def forward(self, x): 90 | residual = x 91 | out = self.conv1(x) 92 | out = self.cn1(out) 93 | out = self.bn1(out) 94 | out = F.relu(out) 95 | 96 | out = self.conv2(out) 97 | out = self.cn2(out) 98 | out = self.bn2(out) 99 | out = F.relu(out) 100 | 101 | return out + residual 102 | 103 | 104 | class LFGCNet(nn.Module): 105 | """LFGCNet 106 | 107 | This model need normalized correspondences(4D) as input 108 | Input shape should be (batch_size, 4, num_point) 109 | 110 | """ 111 | 112 | def __init__(self, in_channel=4, out_channel=128, depth=12, config=None): 113 | super(LFGCNet, self).__init__() 114 | self.input = nn.Conv1d(in_channel, out_channel, kernel_size=1, stride=1) 115 | 116 | blocks = [ 117 | ResNetBlock(out_channel, out_channel, kernel_size=1, stride=1) 118 | for _ in range(depth) 119 | ] 120 | self.blocks = nn.Sequential(*blocks) 121 | 122 | self.output = nn.Conv1d(out_channel, 1, kernel_size=1, stride=1) 123 | 124 | self.config = config 125 | 126 | def forward(self, x): 127 | out = self.input(x) 128 | out = self.blocks(out) 129 | out = self.output(out) 130 | 131 | return out -------------------------------------------------------------------------------- /baseline/model/oanet.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from lib.util_2d import batch_episym, torch_skew_symmetric, weighted_8points 9 | 10 | 11 | class OALoss(object): 12 | 13 | def __init__(self, config): 14 | self.loss_essential = config.oa_loss_essential 15 | self.loss_classif = config.oa_loss_classif 16 | self.use_fundamental = config.oa_use_fundamental 17 | self.obj_geod_th = config.oa_obj_geod_th 18 | self.geo_loss_margin = config.oa_geo_loss_margin 19 | self.loss_essential_init_iter = config.oa_loss_essential_init_iter 20 | 21 | def run(self, global_step, data, logits, e_hat): 22 | e_gt_unnorm, labels, pts_virt = data['E'], data['labels'], data['virtPts'] 23 | e_gt_unnorm = torch.reshape(e_gt_unnorm, (-1, 9)) 24 | 25 | # Essential/Fundamental matrix loss 26 | pts1_virts, pts2_virts = pts_virt[:, :, :2], pts_virt[:, :, 2:] 27 | geod = batch_episym(pts1_virts, pts2_virts, e_hat) 28 | essential_loss = torch.min(geod, self.geo_loss_margin * geod.new_ones(geod.shape)) 29 | essential_loss = essential_loss.mean() 30 | 31 | # Classification loss 32 | is_pos = labels.to(torch.bool) 33 | is_neg = ~is_pos 34 | is_pos = is_pos.to(logits.dtype) 35 | is_neg = is_neg.to(logits.dtype) 36 | c = is_pos - is_neg 37 | classif_losses = -torch.log(torch.sigmoid(c * logits) + np.finfo(float).eps.item()) 38 | # balance 39 | num_pos = torch.relu(torch.sum(is_pos, dim=1) - 1.0) + 1.0 40 | num_neg = torch.relu(torch.sum(is_neg, dim=1) - 1.0) + 1.0 41 | classif_loss_p = torch.sum(classif_losses * is_pos, dim=1) 42 | classif_loss_n = torch.sum(classif_losses * is_neg, dim=1) 43 | classif_loss = torch.mean(classif_loss_p * 0.5 / num_pos + 44 | classif_loss_n * 0.5 / num_neg) 45 | 46 | loss = 0 47 | # Check global_step and add essential loss 48 | if self.loss_essential > 0 and global_step >= self.loss_essential_init_iter: 49 | loss += self.loss_essential * essential_loss 50 | if self.loss_classif > 0: 51 | loss += self.loss_classif * classif_loss 52 | 53 | return loss 54 | 55 | 56 | class PointCN(nn.Module): 57 | 58 | def __init__(self, channels, out_channels=None): 59 | nn.Module.__init__(self) 60 | if not out_channels: 61 | out_channels = channels 62 | self.shot_cut = None 63 | if out_channels != channels: 64 | self.shot_cut = nn.Conv2d(channels, out_channels, kernel_size=1) 65 | self.conv = nn.Sequential( 66 | nn.InstanceNorm2d(channels, eps=1e-3), nn.BatchNorm2d(channels), nn.ReLU(), 67 | nn.Conv2d(channels, out_channels, kernel_size=1), 68 | nn.InstanceNorm2d(out_channels, eps=1e-3), nn.BatchNorm2d(out_channels), 69 | nn.ReLU(), nn.Conv2d(out_channels, out_channels, kernel_size=1)) 70 | 71 | def forward(self, x): 72 | out = self.conv(x) 73 | if self.shot_cut: 74 | out = out + self.shot_cut(x) 75 | else: 76 | out = out + x 77 | return out 78 | 79 | 80 | class trans(nn.Module): 81 | 82 | def __init__(self, dim1, dim2): 83 | nn.Module.__init__(self) 84 | self.dim1 = dim1 85 | self.dim2 = dim2 86 | 87 | def forward(self, x): 88 | return x.transpose(self.dim1, self.dim2) 89 | 90 | 91 | class OAFilter(nn.Module): 92 | 93 | def __init__(self, channels, points, out_channels=None): 94 | nn.Module.__init__(self) 95 | if not out_channels: 96 | out_channels = channels 97 | self.shot_cut = None 98 | if out_channels != channels: 99 | self.shot_cut = nn.Conv2d(channels, out_channels, kernel_size=1) 100 | self.conv1 = nn.Sequential( 101 | nn.InstanceNorm2d(channels, eps=1e-3), 102 | nn.BatchNorm2d(channels), 103 | nn.ReLU(), 104 | nn.Conv2d(channels, out_channels, kernel_size=1), #b*c*n*1 105 | trans(1, 2)) 106 | # Spatial Correlation Layer 107 | self.conv2 = nn.Sequential( 108 | nn.BatchNorm2d(points), nn.ReLU(), nn.Conv2d(points, points, kernel_size=1)) 109 | self.conv3 = nn.Sequential( 110 | trans(1, 2), nn.InstanceNorm2d(out_channels, eps=1e-3), 111 | nn.BatchNorm2d(out_channels), nn.ReLU(), 112 | nn.Conv2d(out_channels, out_channels, kernel_size=1)) 113 | 114 | def forward(self, x): 115 | out = self.conv1(x) 116 | out = out + self.conv2(out) 117 | out = self.conv3(out) 118 | if self.shot_cut: 119 | out = out + self.shot_cut(x) 120 | else: 121 | out = out + x 122 | return out 123 | 124 | 125 | # you can use this bottleneck block to prevent from overfiting when your dataset is small 126 | class OAFilterBottleneck(nn.Module): 127 | 128 | def __init__(self, channels, points1, points2, out_channels=None): 129 | nn.Module.__init__(self) 130 | if not out_channels: 131 | out_channels = channels 132 | self.shot_cut = None 133 | if out_channels != channels: 134 | self.shot_cut = nn.Conv2d(channels, out_channels, kernel_size=1) 135 | self.conv1 = nn.Sequential( 136 | nn.InstanceNorm2d(channels, eps=1e-3), 137 | nn.BatchNorm2d(channels), 138 | nn.ReLU(), 139 | nn.Conv2d(channels, out_channels, kernel_size=1), #b*c*n*1 140 | trans(1, 2)) 141 | self.conv2 = nn.Sequential( 142 | nn.BatchNorm2d(points1), nn.ReLU(), nn.Conv2d(points1, points2, kernel_size=1), 143 | nn.BatchNorm2d(points2), nn.ReLU(), nn.Conv2d(points2, points1, kernel_size=1)) 144 | self.conv3 = nn.Sequential( 145 | trans(1, 2), nn.InstanceNorm2d(out_channels, eps=1e-3), 146 | nn.BatchNorm2d(out_channels), nn.ReLU(), 147 | nn.Conv2d(out_channels, out_channels, kernel_size=1)) 148 | 149 | def forward(self, x): 150 | out = self.conv1(x) 151 | out = out + self.conv2(out) 152 | out = self.conv3(out) 153 | if self.shot_cut: 154 | out = out + self.shot_cut(x) 155 | else: 156 | out = out + x 157 | return out 158 | 159 | 160 | class diff_pool(nn.Module): 161 | 162 | def __init__(self, in_channel, output_points): 163 | nn.Module.__init__(self) 164 | self.output_points = output_points 165 | self.conv = nn.Sequential( 166 | nn.InstanceNorm2d(in_channel, eps=1e-3), nn.BatchNorm2d(in_channel), nn.ReLU(), 167 | nn.Conv2d(in_channel, output_points, kernel_size=1)) 168 | 169 | def forward(self, x): 170 | embed = self.conv(x) # b*k*n*1 171 | S = torch.softmax(embed, dim=2).squeeze(3) 172 | out = torch.matmul(x.squeeze(3), S.transpose(1, 2)).unsqueeze(3) 173 | return out 174 | 175 | 176 | class diff_unpool(nn.Module): 177 | 178 | def __init__(self, in_channel, output_points): 179 | nn.Module.__init__(self) 180 | self.output_points = output_points 181 | self.conv = nn.Sequential( 182 | nn.InstanceNorm2d(in_channel, eps=1e-3), nn.BatchNorm2d(in_channel), nn.ReLU(), 183 | nn.Conv2d(in_channel, output_points, kernel_size=1)) 184 | 185 | def forward(self, x_up, x_down): 186 | #x_up: b*c*n*1 187 | #x_down: b*c*k*1 188 | embed = self.conv(x_up) # b*k*n*1 189 | S = torch.softmax(embed, dim=1).squeeze(3) # b*k*n 190 | out = torch.matmul(x_down.squeeze(3), S).unsqueeze(3) 191 | return out # b*c*n*1 192 | 193 | 194 | class OANBlock(nn.Module): 195 | 196 | def __init__(self, net_channels, input_channel, depth, clusters): 197 | nn.Module.__init__(self) 198 | channels = net_channels 199 | self.layer_num = depth 200 | print('channels:' + str(channels) + ', layer_num:' + str(self.layer_num)) 201 | self.conv1 = nn.Conv2d(input_channel, channels, kernel_size=1) 202 | 203 | l2_nums = clusters 204 | 205 | self.l1_1 = [] 206 | for _ in range(self.layer_num // 2): 207 | self.l1_1.append(PointCN(channels)) 208 | 209 | self.down1 = diff_pool(channels, l2_nums) 210 | 211 | self.l2 = [] 212 | for _ in range(self.layer_num // 2): 213 | self.l2.append(OAFilter(channels, l2_nums)) 214 | 215 | self.up1 = diff_unpool(channels, l2_nums) 216 | 217 | self.l1_2 = [] 218 | self.l1_2.append(PointCN(2 * channels, channels)) 219 | for _ in range(self.layer_num // 2 - 1): 220 | self.l1_2.append(PointCN(channels)) 221 | 222 | self.l1_1 = nn.Sequential(*self.l1_1) 223 | self.l1_2 = nn.Sequential(*self.l1_2) 224 | self.l2 = nn.Sequential(*self.l2) 225 | 226 | self.output = nn.Conv2d(channels, 1, kernel_size=1) 227 | 228 | def forward(self, data, xs): 229 | #data: b*c*n*1 230 | batch_size, num_pts = data.shape[0], data.shape[2] 231 | x1_1 = self.conv1(data) 232 | x1_1 = self.l1_1(x1_1) 233 | x_down = self.down1(x1_1) 234 | x2 = self.l2(x_down) 235 | x_up = self.up1(x1_1, x2) 236 | out = self.l1_2(torch.cat([x1_1, x_up], dim=1)) 237 | 238 | logits = torch.squeeze(torch.squeeze(self.output(out), 3), 1) 239 | e_hat = weighted_8points(xs.squeeze(1).permute(0, 2, 1), logits) 240 | 241 | x1, x2 = xs[:, 0, :, :2], xs[:, 0, :, 2:4] 242 | e_hat_norm = e_hat 243 | residual = batch_episym(x1, x2, e_hat_norm).reshape(batch_size, 1, num_pts, 1) 244 | 245 | return logits, e_hat, residual 246 | 247 | 248 | class OANet(nn.Module): 249 | 250 | def __init__(self, config): 251 | nn.Module.__init__(self) 252 | self.iter_num = config.oa_iter_num 253 | depth_each_stage = config.oa_net_depth // (config.oa_iter_num + 1) 254 | self.side_channel = (config.oa_use_ratio == 2) + (config.oa_use_mutual == 2) 255 | self.weights_init = OANBlock(config.oa_net_channels, 4 + self.side_channel, 256 | depth_each_stage, config.oa_clusters) 257 | self.weights_iter = [ 258 | OANBlock(config.oa_net_channels, 6 + self.side_channel, depth_each_stage, 259 | config.oa_clusters) for _ in range(config.oa_iter_num) 260 | ] 261 | self.weights_iter = nn.Sequential(*self.weights_iter) 262 | self.config = config 263 | 264 | def forward(self, x): 265 | assert x.dim() == 4 and x.shape[1] == 1 266 | 267 | input = x.transpose(1, 3) 268 | 269 | res_logits, res_e_hat = [], [] 270 | logits, e_hat, residual = self.weights_init(input, x) 271 | res_logits.append(logits), res_e_hat.append(e_hat) 272 | 273 | # For iterative network 274 | for i in range(self.iter_num): 275 | logits, e_hat, residual = self.weights_iter[i](torch.cat([ 276 | input, 277 | residual.detach(), 278 | F.relu(torch.tanh(logits)).reshape(residual.shape).detach() 279 | ], 280 | dim=1), x) 281 | res_logits.append(logits), res_e_hat.append(e_hat) 282 | return res_logits, res_e_hat -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | 4 | arg_lists = [] 5 | parser = argparse.ArgumentParser() 6 | 7 | 8 | def add_argument_group(name): 9 | arg = parser.add_argument_group(name) 10 | arg_lists.append(arg) 11 | return arg 12 | 13 | 14 | def str2bool(v): 15 | return v.lower() in ('true', '1') 16 | 17 | 18 | logging_arg = add_argument_group('Logging') 19 | logging_arg.add_argument('--out_dir', type=str, default='outputs') 20 | 21 | trainer_arg = add_argument_group('Trainer') 22 | trainer_arg.add_argument('--trainer', type=str, default='ContrastiveLossTrainer') 23 | trainer_arg.add_argument('--save_freq_epoch', type=int, default=1) 24 | trainer_arg.add_argument('--batch_size', type=int, default=4) 25 | trainer_arg.add_argument('--val_batch_size', type=int, default=1) 26 | trainer_arg.add_argument('--use_hard_negative', type=str2bool, default=True) 27 | trainer_arg.add_argument('--hard_negative_sample_ratio', type=int, default=0.05) 28 | trainer_arg.add_argument('--hard_negative_max_num', type=int, default=3000) 29 | 30 | trainer_arg.add_argument('--num_pos_per_batch', type=int, default=1024) 31 | trainer_arg.add_argument('--num_hn_samples_per_batch', type=int, default=256) 32 | 33 | trainer_arg.add_argument('--neg_thresh', type=float, default=1.4) 34 | trainer_arg.add_argument('--pos_thresh', type=float, default=0.1) 35 | trainer_arg.add_argument('--pos_weight', type=float, default=1) 36 | trainer_arg.add_argument('--neg_weight', type=float, default=1) 37 | trainer_arg.add_argument('--use_random_scale', type=str2bool, default=False) 38 | trainer_arg.add_argument('--min_scale', type=float, default=0.8) 39 | trainer_arg.add_argument('--max_scale', type=float, default=1.2) 40 | trainer_arg.add_argument('--use_random_rotation', type=str2bool, default=True) 41 | trainer_arg.add_argument('--rotation_range', type=float, default=360) 42 | trainer_arg.add_argument('--train_phase', type=str, default="train") 43 | trainer_arg.add_argument('--val_phase', type=str, default="val") 44 | trainer_arg.add_argument('--test_phase', type=str, default="test") 45 | trainer_arg.add_argument('--stat_freq', type=int, default=40) 46 | trainer_arg.add_argument('--final_test', type=str2bool, default=False) 47 | trainer_arg.add_argument('--test_valid', type=str2bool, default=True) 48 | trainer_arg.add_argument('--nn_max_n', type=int, default=250) 49 | trainer_arg.add_argument('--val_max_iter', type=int, default=400) 50 | trainer_arg.add_argument('--train_max_iter', type=int, default=2000) 51 | trainer_arg.add_argument('--val_epoch_freq', type=int, default=1) 52 | trainer_arg.add_argument( 53 | '--positive_pair_search_voxel_size_multiplier', type=float, default=1.5) 54 | 55 | trainer_arg.add_argument('--hit_ratio_thresh', type=float, default=0.1) 56 | 57 | # Triplets 58 | trainer_arg.add_argument('--triplet_num_pos', type=int, default=256) 59 | trainer_arg.add_argument('--triplet_num_hn', type=int, default=512) 60 | trainer_arg.add_argument('--triplet_num_rand', type=int, default=1024) 61 | 62 | # Inlier detection trainer 63 | trainer_arg.add_argument('--inlier_model', type=str, default='ResUNetBN2C') 64 | trainer_arg.add_argument('--inlier_training_start_epoch', type=int, default=-1) 65 | trainer_arg.add_argument('--inlier_feature_type', type=str, default='ones') 66 | trainer_arg.add_argument('--inlier_conv1_kernel_size', type=int, default=3) 67 | trainer_arg.add_argument('--inlier_use_balanced_loss', type=str2bool, default=True) 68 | trainer_arg.add_argument('--registration_min_pairs', type=int, default=100) 69 | trainer_arg.add_argument('--inlier_bin_size', type=int, default=1) 70 | trainer_arg.add_argument('--inlier_logit_thresh', type=float, default=-3) 71 | trainer_arg.add_argument( 72 | '--inlier_threshold_pixel', 73 | type=float, 74 | default=8, 75 | help='ThreeDMatch inlier threshold in pixel') 76 | trainer_arg.add_argument( 77 | '--inlier_threshold_type', 78 | type=str, 79 | default='hard', 80 | help='Inlier threshold type: hard, soft') 81 | trainer_arg.add_argument( 82 | '--inlier_label_type', 83 | type=str, 84 | choices=['epipolar', 'projection'], 85 | default='epipolar', 86 | help='Inlier label type') 87 | trainer_arg.add_argument( 88 | '--ucn_inlier_threshold_pixel', 89 | type=float, 90 | default=4, 91 | help='UCN hardest contrastive threshold in pixel') 92 | trainer_arg.add_argument( 93 | '--ucn_use_sift_kp', 94 | type=str2bool, 95 | default=True, 96 | help='UCN use SIFT keypoints for hardest negative mining') 97 | trainer_arg.add_argument( 98 | '--threed_feature', 99 | type=str, 100 | default='fpfh', 101 | choices=['fpfh', 'fcgf'], 102 | help='Features used for training inlier detection') 103 | trainer_arg.add_argument( 104 | '--ucn_weights', 105 | type=str, 106 | help='path to pretrained UCN weights' 107 | ) 108 | # Network specific configurations 109 | net_arg = add_argument_group('Network') 110 | net_arg.add_argument('--model', type=str, default='SimpleNetBN2C') 111 | net_arg.add_argument('--model_n_out', type=int, default=32) 112 | net_arg.add_argument('--conv1_kernel_size', type=int, default=3) 113 | net_arg.add_argument('--use_color', type=str2bool, default=False) 114 | net_arg.add_argument('--use_normal', type=str2bool, default=False) 115 | net_arg.add_argument('--normalize_feature', type=str2bool, default=False) 116 | net_arg.add_argument('--dist_type', type=str, default='L2') 117 | net_arg.add_argument('--best_val_metric', type=str, default='feat_match_ratio') 118 | net_arg.add_argument( 119 | '--best_val_comparator', 120 | type=str, 121 | choices=['smaller', 'larger'], 122 | default='larger', 123 | help='X the better') 124 | 125 | # Optimizer arguments 126 | opt_arg = add_argument_group('Optimizer') 127 | opt_arg.add_argument('--optimizer', type=str, default='SGD') 128 | opt_arg.add_argument('--max_epoch', type=int, default=100) 129 | opt_arg.add_argument('--lr', type=float, default=1e-1) 130 | opt_arg.add_argument('--momentum', type=float, default=0.8) 131 | opt_arg.add_argument('--sgd_momentum', type=float, default=0.9) 132 | opt_arg.add_argument('--sgd_dampening', type=float, default=0.1) 133 | opt_arg.add_argument('--adam_beta1', type=float, default=0.9) 134 | opt_arg.add_argument('--adam_beta2', type=float, default=0.999) 135 | opt_arg.add_argument('--weight_decay', type=float, default=1e-4) 136 | opt_arg.add_argument('--iter_size', type=int, default=1, help='accumulate gradient') 137 | opt_arg.add_argument('--bn_momentum', type=float, default=0.05) 138 | opt_arg.add_argument('--exp_gamma', type=float, default=0.99) 139 | opt_arg.add_argument('--scheduler', type=str, default='ExpLR') 140 | opt_arg.add_argument( 141 | '--icp_cache_path', type=str, default="/home/chrischoy/datasets/FCGF/kitti/icp/") 142 | 143 | misc_arg = add_argument_group('Misc') 144 | misc_arg.add_argument('--use_gpu', type=str2bool, default=True) 145 | misc_arg.add_argument( 146 | '--search_method', type=str, default='gpu', choices=['cpu', 'gpu']) 147 | misc_arg.add_argument( 148 | '--data_loader_search_method', type=str, default='cpu', choices=['cpu', 'gpu']) 149 | misc_arg.add_argument( 150 | '--eval_registration', 151 | type=str2bool, 152 | default=True, 153 | help='Evaluate RANSAC registration for a registration network') 154 | 155 | misc_arg.add_argument('--weights', type=str, default=None) 156 | misc_arg.add_argument('--weights_dir', type=str, default=None) 157 | misc_arg.add_argument('--resume', type=str, default=None) 158 | misc_arg.add_argument('--resume_dir', type=str, default=None) 159 | misc_arg.add_argument('--train_num_workers', type=int, default=2) 160 | misc_arg.add_argument('--val_num_workers', type=int, default=1) 161 | misc_arg.add_argument('--test_num_workers', type=int, default=2) 162 | misc_arg.add_argument('--fast_validation', type=str2bool, default=False) 163 | misc_arg.add_argument( 164 | '--preselect', 165 | type=str2bool, 166 | default=False, 167 | help='preselect voxelized points to compute normals. Use of voxel_size < 5cm with preselect False is discouraged.' 168 | ) 169 | 170 | data_arg = add_argument_group('Data') 171 | data_arg.add_argument('--dataset', type=str, default='ThreeDMatchPairDataset03') 172 | data_arg.add_argument('--voxel_size', type=float, default=0.05) 173 | data_arg.add_argument( 174 | '--data_dir_25mm', 175 | type=str, 176 | default="/home/chrischoy/datasets/FCGF/dataset_full_25") 177 | data_arg.add_argument( 178 | '--data_dir_10mm', type=str, default="/home/chrischoy/datasets/FCGF/dataset_full") 179 | data_arg.add_argument( 180 | '--kitti_root', type=str, default="/home/chrischoy/datasets/FCGF/kitti/") 181 | data_arg.add_argument('--use_10mm', type=str2bool, default=False) 182 | data_arg.add_argument( 183 | '--kitti_max_time_diff', 184 | type=int, 185 | default=3, 186 | help='max time difference between pairs (non inclusive)') 187 | data_arg.add_argument('--kitti_date', type=str, default='2011_09_26') 188 | data_arg.add_argument( 189 | '--data_dir_3dmatch', type=str, default="/home/chrischoy/datasets/FCGF/3DMatch") 190 | data_arg.add_argument( 191 | '--collation_3d', type=str, default='collate_pair', help="Collation function type") 192 | # 2D 193 | twod_arg = add_argument_group('2D') 194 | twod_arg.add_argument('--data_dir_2d', type=str, help="path to 2d dataset") 195 | twod_arg.add_argument( 196 | '--collation_2d', type=str, default='default', help="Collation function type") 197 | twod_arg.add_argument( 198 | '--obj_num_kp', 199 | type=int, 200 | default=2000, 201 | help="number of keypoints to sample per image") 202 | twod_arg.add_argument( 203 | '--obj_num_nn', type=int, default=1, help="number of nearest neighbor(s)") 204 | twod_arg.add_argument( 205 | '--feature_extractor', 206 | type=str, 207 | default="sift", 208 | help="select feature extractor to use") 209 | twod_arg.add_argument( 210 | '--quantization_size', 211 | type=float, 212 | default=0.01, 213 | help="quantization size to discretize image coordinates") 214 | twod_arg.add_argument( 215 | '--sample_minimum_coords', type=str2bool, default=False 216 | ) 217 | twod_arg.add_argument( 218 | '--use_ratio_test', 219 | type=str2bool, 220 | default=False, 221 | help='Use ratio test when matching features') 222 | twod_arg.add_argument( 223 | '--use_8p', type=str2bool, default=False, help="Use eight-point coordinates") 224 | twod_arg.add_argument('--use_gray', type=str2bool, default=False) 225 | twod_arg.add_argument('--resize_ratio', type=float, default=1.0) 226 | twod_arg.add_argument( 227 | '--frames_per_one_submap', 228 | type=int, 229 | default=200, 230 | help="Number of frames used to create one fragment") 231 | twod_arg.add_argument( 232 | '--regression_loss_iter', 233 | type=int, 234 | default=20000, 235 | help="start calculating regression loss after this amount of iteration") 236 | twod_arg.add_argument( 237 | '--data_dir_raw', 238 | type=str, 239 | help="path to raw dataset sources. e.g) the folder that contains ['7-scenes-chess', '7-scenes-fire', ...]" 240 | ) 241 | twod_arg.add_argument( 242 | '--data_dir_processed', 243 | type=str, 244 | help="path to preprocessed dataset. e.g) the folder that contains ['7-scenes-chess@seq-01', '7-scenes-fire@seq-01', ...]" 245 | ) 246 | twod_arg.add_argument( 247 | '--pred_threshold', 248 | type=float, 249 | default=0.0, 250 | help="Threshold for inlier prediction confidence") 251 | twod_arg.add_argument( 252 | '--use_balance_loss', 253 | type=str2bool, 254 | default=True, 255 | help="use balanced classification loss") 256 | twod_arg.add_argument( 257 | '--post_ransac', 258 | type=str2bool, 259 | default=False, 260 | help="use post ransac on evaluation" 261 | ) 262 | # Baseline 263 | baseline_arg = add_argument_group('Baseline') 264 | baseline_arg.add_argument('--baseline_model', type=str, default='Mlesac') 265 | baseline_arg.add_argument('--baseline_num_iter', type=int, default=1000) 266 | baseline_arg.add_argument('--baseline_num_sample', type=int, default=8) 267 | baseline_arg.add_argument('--baseline_inlier_threshold', type=float, default=8) 268 | baseline_arg.add_argument( 269 | '--baseline_use_normalized_coords', type=str2bool, default=False) 270 | baseline_arg.add_argument('--mlesac_sigma', type=float, default=1.0) 271 | baseline_arg.add_argument('--mlesac_em_iter', type=int, default=10) 272 | baseline_arg.add_argument('--use_parallel', type=str2bool, default=True) 273 | baseline_arg.add_argument( 274 | '--success_rte_thresh', 275 | type=float, 276 | default=0.3, 277 | help='Success if the RTE below this (m)') 278 | baseline_arg.add_argument( 279 | '--success_rre_thresh', 280 | type=float, 281 | default=15, 282 | help='Success if the RTE below this (degree)') 283 | 284 | # OANet 285 | oanet_arg = add_argument_group('OANet') 286 | oanet_arg.add_argument('--oa_loss_essential', type=float, default=0.1) 287 | oanet_arg.add_argument('--oa_loss_classif', type=float, default=1.0) 288 | oanet_arg.add_argument('--oa_use_fundamental', type=str2bool, default=False) 289 | oanet_arg.add_argument('--oa_obj_geod_th', type=float, default=1e-4) 290 | oanet_arg.add_argument('--oa_geo_loss_margin', type=float, default=0.1) 291 | oanet_arg.add_argument('--oa_loss_essential_init_iter', type=int, default=20000) 292 | oanet_arg.add_argument('--oa_iter_num', type=int, default=1) 293 | oanet_arg.add_argument('--oa_net_depth', type=int, default=12) 294 | oanet_arg.add_argument('--oa_net_channels', type=int, default=128) 295 | oanet_arg.add_argument('--oa_clusters', type=int, default=500) 296 | oanet_arg.add_argument('--oa_use_ratio', type=int, default=0) 297 | oanet_arg.add_argument('--oa_use_mutual', type=int, default=0) 298 | 299 | # ND Experiment 300 | nd_arg = add_argument_group('ND') 301 | nd_arg.add_argument('--nd_dataset', type=str, default='HyperLineDataset') 302 | nd_arg.add_argument('--nd_dimension', type=int, default=4) 303 | nd_arg.add_argument('--nd_use_coords_as_feats', type=str2bool, default=True) 304 | 305 | 306 | def get_config(): 307 | config = parser.parse_args() 308 | vars(config)['root_dir'] = osp.dirname(osp.abspath(__file__)) 309 | return config 310 | 311 | 312 | def get_parser(): 313 | return parser 314 | -------------------------------------------------------------------------------- /config/test_yfcc.txt: -------------------------------------------------------------------------------- 1 | buckingham_palace 2 | notre_dame_front_facade 3 | reichstag 4 | sacre_coeur -------------------------------------------------------------------------------- /config/train_yfcc.txt: -------------------------------------------------------------------------------- 1 | grand_central_terminal_new_york 2 | florence_cathedral_dome_interior_1 3 | mount_rushmore 4 | st_peters_square 5 | colosseum_interior 6 | london_bridge_1 7 | palace_of_versailles_chapel 8 | brandenburg_gate 9 | st_peters_basilica_interior_2 10 | sagrada_familia_1 11 | louvre 12 | pieta_michelangelo 13 | lincoln_memorial_statue 14 | blue_mosque_interior_2 15 | big_ben_2 16 | st_vitus_cathedral 17 | old_town_square_prague_clock 18 | westminster_abbey_1 19 | pantheon_exterior 20 | taj_mahal_entrance 21 | piazza_san_marco 22 | ruins_of_st_pauls 23 | united_states_capitol_rotunda 24 | grand_place_brussels_1 25 | paris_opera_1 26 | taj_mahal 27 | temple_nara_japan 28 | statue_of_liberty_1 29 | lincoln_memorial 30 | sistine_chapel_ceiling_1 31 | trevi_fountain_2 32 | petra_jordan 33 | florence_cathedral_dome_interior_2 34 | milan_cathedral 35 | united_states_capitol 36 | temple_kyoto_japan 37 | london_bridge_2 38 | old_town_square_prague 39 | florence_cathedral_side 40 | st_peters_basilica_interior_1 41 | sagrada_familia_2 42 | blue_mosque_interior_1 43 | sagrada_familia_3 44 | london_bridge_3 45 | vatican_museum_ceiling 46 | big_ben_1 47 | colosseum_exterior 48 | st_pauls_cathedral 49 | grand_place_brussels_3 50 | westminster_abbey_2 51 | piazza_della_signoria 52 | via_condotti 53 | natural_history_museum_london 54 | grand_place_brussels_2 55 | paris_opera_2 56 | palace_of_westminster 57 | palazzo_pubblico 58 | piazza_dei_miracoli 59 | pantheon_interior 60 | statue_of_liberty_2 61 | western_wall_jerusalem 62 | national_gallery_london 63 | british_museum 64 | sistine_chapel_ceiling_2 65 | hagia_sophia_interior 66 | pike_place_market 67 | trevi_fountain_1 -------------------------------------------------------------------------------- /config/val_yfcc.txt: -------------------------------------------------------------------------------- 1 | grand_central_terminal_new_york 2 | florence_cathedral_dome_interior_1 3 | mount_rushmore 4 | st_peters_square 5 | colosseum_interior 6 | london_bridge_1 7 | palace_of_versailles_chapel 8 | brandenburg_gate 9 | st_peters_basilica_interior_2 10 | sagrada_familia_1 11 | louvre 12 | pieta_michelangelo 13 | lincoln_memorial_statue 14 | blue_mosque_interior_2 15 | big_ben_2 16 | st_vitus_cathedral 17 | old_town_square_prague_clock 18 | westminster_abbey_1 19 | pantheon_exterior 20 | taj_mahal_entrance 21 | piazza_san_marco 22 | ruins_of_st_pauls 23 | united_states_capitol_rotunda 24 | grand_place_brussels_1 25 | paris_opera_1 26 | taj_mahal 27 | temple_nara_japan 28 | statue_of_liberty_1 29 | lincoln_memorial 30 | sistine_chapel_ceiling_1 31 | trevi_fountain_2 32 | petra_jordan 33 | florence_cathedral_dome_interior_2 34 | milan_cathedral 35 | united_states_capitol 36 | temple_kyoto_japan 37 | london_bridge_2 38 | old_town_square_prague 39 | florence_cathedral_side 40 | st_peters_basilica_interior_1 41 | sagrada_familia_2 42 | blue_mosque_interior_1 43 | sagrada_familia_3 44 | london_bridge_3 45 | vatican_museum_ceiling 46 | big_ben_1 47 | colosseum_exterior 48 | st_pauls_cathedral 49 | grand_place_brussels_3 50 | westminster_abbey_2 51 | piazza_della_signoria 52 | via_condotti 53 | natural_history_museum_london 54 | grand_place_brussels_2 55 | paris_opera_2 56 | palace_of_westminster 57 | palazzo_pubblico 58 | piazza_dei_miracoli 59 | pantheon_interior 60 | statue_of_liberty_2 61 | western_wall_jerusalem 62 | national_gallery_london 63 | british_museum 64 | sistine_chapel_ceiling_2 65 | hagia_sophia_interior 66 | pike_place_market 67 | trevi_fountain_1 68 | -------------------------------------------------------------------------------- /demo_2d.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os.path as osp 3 | from urllib.request import urlretrieve 4 | 5 | import open3d as o3d 6 | import cv2 7 | import matplotlib.gridspec as grid 8 | import matplotlib.pyplot as plt 9 | from matplotlib.patches import ConnectionPatch 10 | import numpy as np 11 | import torch 12 | import MinkowskiEngine as ME 13 | 14 | from model import load_model 15 | from lib.eval import find_nn_gpu 16 | from lib.util import read_txt, ensure_dir 17 | import lib.util_2d as util_2d 18 | from ucn.resunet import ResUNetBN2D2 19 | from util.file import loadh5 20 | 21 | imgs = [ 22 | '68833924_5994205213.jpg', 23 | '54990444_8865247484.jpg', 24 | '57895226_4857581382.jpg', 25 | ] 26 | calibs = [ 27 | 'calibration_000002.h5', 28 | 'calibration_000344.h5', 29 | 'calibration_000489.h5', 30 | ] 31 | output_dir = './visualize' 32 | 33 | # downaload weights 34 | if not osp.isfile('ResUNetBN2D2-YFCC100train.pth'): 35 | print("Downloading UCN weights...") 36 | urlretrieve( 37 | "https://node1.chrischoy.org/data/publications/ucn/ResUNetBN2D2-YFCC100train-100epoch.pth", 38 | 'ResUNetBN2D2-YFCC100train.pth') 39 | 40 | if not osp.isfile('2d_pyramid_ucn.pth'): 41 | print("Downloading PyramidSCNoBlock weights...") 42 | urlretrieve("http://cvlab.postech.ac.kr/research/hcngpr/data/2d_pyramid_ucn.pth", 43 | "2d_pyramid_ucn.pth") 44 | 45 | 46 | def prep_image(full_path): 47 | assert osp.exists(full_path), f"File {full_path} does not exist." 48 | img = cv2.imread(full_path) 49 | img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 50 | img_color = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 51 | # return cv2.imread(full_path, cv2.IMREAD_GRAYSCALE) 52 | return img_gray, img_color 53 | 54 | 55 | def to_normalized_torch(img, device): 56 | """ 57 | Normalize the image to [-0.5, 0.5] range and augment batch and channel dimensions. 58 | """ 59 | img = img.astype(np.float32) / 255 - 0.5 60 | return torch.from_numpy(img).to(device)[None, None, :, :] 61 | 62 | 63 | def dump_correspondence(img0, img1, F0, F1, mode='gpu-all', pixel_ths=4): 64 | use_stability_test = True 65 | use_cyclic_test = False 66 | keypoint = 'sift' 67 | if keypoint == 'sift': 68 | sift = cv2.xfeatures2d.SIFT_create( 69 | 0, 70 | 9, 71 | 0.01, # Smaller more keypoints, default 0.04 72 | 100 # larger more keypoints, default 10 73 | ) 74 | kp0 = sift.detect(img0, None) 75 | kp1 = sift.detect(img1, None) 76 | xy_kp0 = np.floor(np.array([k.pt for k in kp0]).T) 77 | xy_kp1 = np.floor(np.array([k.pt for k in kp1]).T) 78 | x0, y0 = xy_kp0[0], xy_kp0[1] 79 | x1, y1 = xy_kp1[0], xy_kp1[1] 80 | elif keypoint == 'all': 81 | x0, y0 = None, None 82 | x1, y1 = None, None 83 | 84 | H0, W0 = img0.shape 85 | H1, W1 = img1.shape 86 | 87 | if mode == 'gpu-all': 88 | nn_inds1 = find_nn_gpu(F0[:, y0, x0], 89 | F1.view(F1.shape[0], -1), 90 | nn_max_n=50, 91 | transposed=True).numpy() 92 | 93 | # Convert the index to coordinate: BxCxHxW 94 | xs1 = nn_inds1 % W1 95 | ys1 = nn_inds1 // W1 96 | 97 | if use_stability_test: 98 | # Stability test: check stable under perturbation 99 | noisex = 2 * (np.random.rand(len(xs1)) < 0.5) - 1 100 | noisey = 2 * (np.random.rand(len(ys1)) < 0.5) - 1 101 | xs1n = np.clip(xs1 + noisex, 0, W1 - 1) 102 | ys1n = np.clip(ys1 + noisey, 0, H1 - 1) 103 | else: 104 | xs1n = xs1 105 | ys1n = ys1 106 | 107 | if use_cyclic_test: 108 | # Test reciprocity 109 | nn_inds0 = find_nn_gpu(F1[:, ys1n, xs1n], 110 | F0.view(F0.shape[0], -1), 111 | nn_max_n=50, 112 | transposed=True).numpy() 113 | 114 | # Convert the index to coordinate: BxCxHxW 115 | xs0 = (nn_inds0 % W0) 116 | ys0 = (nn_inds0 // W0) 117 | 118 | # Test cyclic consistency 119 | dist_sq_nn = (x0 - xs0)**2 + (y0 - ys0)**2 120 | mask = dist_sq_nn < (pixel_ths**2) 121 | 122 | else: 123 | xs0 = x0 124 | ys0 = y0 125 | mask = np.ones(len(x0)).astype(bool) 126 | 127 | elif mode == 'gpu-all-all': 128 | nn_inds1 = find_nn_gpu( 129 | F0.view(F0.shape[0], -1), 130 | F1.view(F1.shape[0], -1), 131 | nn_max_n=50, 132 | transposed=True, 133 | ).numpy() 134 | 135 | inds0 = np.arange(len(nn_inds1)) 136 | x0 = inds0 % W0 137 | y0 = inds0 // W0 138 | 139 | xs1 = nn_inds1 % W1 140 | ys1 = nn_inds1 // W1 141 | 142 | if use_stability_test: 143 | # Stability test: check stable under perturbation 144 | noisex = 2 * (np.random.rand(len(xs1)) < 0.5) - 1 145 | noisey = 2 * (np.random.rand(len(ys1)) < 0.5) - 1 146 | xs1n = np.clip(xs1 + noisex, 0, W1 - 1) 147 | ys1n = np.clip(ys1 + noisey, 0, H1 - 1) 148 | else: 149 | xs1n = xs1 150 | ys1n = ys1 151 | 152 | # Test reciprocity 153 | nn_inds0 = find_nn_gpu(F1[:, ys1n, xs1n], 154 | F0.view(F0.shape[0], -1), 155 | nn_max_n=50, 156 | transposed=True).numpy() 157 | 158 | # Convert the index to coordinate: BxCxHxW 159 | xs0 = nn_inds0 % W0 160 | ys0 = nn_inds0 // W0 161 | 162 | # Filter out the points that fail the cycle consistency 163 | dist_sq_nn = (x0 - xs0)**2 + (y0 - ys0)**2 164 | mask = dist_sq_nn < (pixel_ths**2) 165 | 166 | return x0[mask], y0[mask], xs1[mask], ys1[mask] 167 | 168 | 169 | def draw_figure(img0, img1, coords, labels, preds): 170 | plt.clf() 171 | fig = plt.figure() 172 | ratios = ratios = [img0.shape[1] * img1.shape[0], img1.shape[1] * img0.shape[0]] 173 | gs = grid.GridSpec(nrows=2, ncols=1, height_ratios=ratios) 174 | ax1 = fig.add_subplot(gs[0]) 175 | ax2 = fig.add_subplot(gs[1]) 176 | ax1.axis('off') 177 | ax2.axis('off') 178 | preds = preds > 0.5 179 | coords = coords[preds] 180 | labels = labels[preds] 181 | 182 | for coord, is_inlier in zip(coords, labels): 183 | con = ConnectionPatch(xyA=coord[:2], 184 | xyB=coord[2:], 185 | coordsA="data", 186 | coordsB="data", 187 | axesA=ax2, 188 | axesB=ax1, 189 | color="green" if is_inlier else "red") 190 | ax2.add_artist(con) 191 | 192 | ax1.imshow(img1) 193 | ax2.imshow(img0) 194 | plt.subplots_adjust(left=0, bottom=0, right=1, top=1, hspace=0, wspace=0) 195 | return fig 196 | 197 | 198 | def demo(): 199 | root = './imgs' 200 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 201 | 202 | # load UCN model 203 | print(f"loading UCN model...") 204 | checkpoint = torch.load('ResUNetBN2D2-YFCC100train.pth') 205 | ucn = ResUNetBN2D2(1, 64, normalize_feature=True) 206 | ucn.load_state_dict(checkpoint['state_dict']) 207 | ucn.eval() 208 | ucn = ucn.to(device) 209 | 210 | # load HighDimConvNet 211 | print(f"loading HighDimConvNet model...") 212 | checkpoint = torch.load('2d_pyramid_ucn.pth') 213 | opts = checkpoint['config'] 214 | Model = load_model(opts.inlier_model) 215 | model = Model(in_channels=4, out_channels=1, clusters=opts.oa_clusters, D=4) 216 | model.load_state_dict(checkpoint['state_dict']) 217 | model.eval() 218 | model = model.to(device) 219 | 220 | idx_list = itertools.combinations(range(len(imgs)), 2) 221 | with torch.no_grad(): 222 | for i, (idx0, idx1) in enumerate(idx_list): 223 | # extract UCN features 224 | img0, img0_color = prep_image(osp.join(root, imgs[idx0])) 225 | img1, img1_color = prep_image(osp.join(root, imgs[idx1])) 226 | F0 = ucn(to_normalized_torch(img0, device)) 227 | F1 = ucn(to_normalized_torch(img1, device)) 228 | 229 | # load calibration data 230 | calib0 = loadh5(osp.join(root, calibs[idx0])) 231 | calib1 = loadh5(osp.join(root, calibs[idx1])) 232 | K0, K1 = calib0['K'], calib1['K'] 233 | imsize0, imsize1 = calib0['imsize'], calib1['imsize'] 234 | T0 = util_2d.build_extrinsic_matrix(calib0['R'], calib0['T'][0]) 235 | T1 = util_2d.build_extrinsic_matrix(calib1['R'], calib1['T'][0]) 236 | E, *_ = util_2d.compute_essential_matrix(T0, T1) 237 | E = E / np.linalg.norm(E) 238 | 239 | # dump correspondences 240 | x0, y0, x1, y1 = dump_correspondence(img0, 241 | img1, 242 | F0[0], 243 | F1[0], 244 | mode='gpu-all', 245 | pixel_ths=4) 246 | kp0 = np.stack((x0, y0), 1).astype(np.float) 247 | kp1 = np.stack((x1, y1), 1).astype(np.float) 248 | # normalize correspondence 249 | norm_kp0 = util_2d.normalize_keypoint(kp0, K0, imsize0 * 0.5)[:, :2] 250 | norm_kp1 = util_2d.normalize_keypoint(kp1, K1, imsize1 * 0.5)[:, :2] 251 | coords = np.concatenate((kp0, kp1), axis=1) 252 | norm_coords = np.concatenate((norm_kp0, norm_kp1), axis=1) 253 | 254 | # build HighDimConvNet input 255 | quan_coords = np.floor(norm_coords / opts.quantization_size) 256 | idx = ME.utils.sparse_quantize(quan_coords, return_index=True) 257 | C = quan_coords[idx] 258 | F = torch.from_numpy(norm_coords[idx]).float() 259 | C = ME.utils.batched_coordinates([C]) 260 | sinput = ME.SparseTensor(coords=C, feats=F).to(device) 261 | input_dict = dict(xyz=F, len_batch=[len(F)]) 262 | 263 | # feed forward 264 | logits, _ = model(sinput, input_dict) 265 | logits = logits[-1].squeeze().cpu() 266 | preds = np.hstack(torch.sigmoid(logits)) 267 | residuals = util_2d.compute_symmetric_epipolar_residual( 268 | E, norm_coords[:, :2], norm_coords[:, 2:]) 269 | labels = residuals < 0.01 270 | 271 | # draw figure 272 | fig = draw_figure(img0_color, img1_color, coords[idx], labels[idx], preds) 273 | filename = osp.join(output_dir, f"demo_{i}.png") 274 | fig.savefig(filename, bbox_inches='tight') 275 | print(f"save {filename}") 276 | 277 | 278 | if __name__ == "__main__": 279 | ensure_dir(output_dir) 280 | demo() 281 | -------------------------------------------------------------------------------- /imgs/54990444_8865247484.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/HighDimConvNets/bd8f03150b4d639db61109a93c37f3be0dcaec38/imgs/54990444_8865247484.jpg -------------------------------------------------------------------------------- /imgs/57895226_4857581382.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/HighDimConvNets/bd8f03150b4d639db61109a93c37f3be0dcaec38/imgs/57895226_4857581382.jpg -------------------------------------------------------------------------------- /imgs/68833924_5994205213.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/HighDimConvNets/bd8f03150b4d639db61109a93c37f3be0dcaec38/imgs/68833924_5994205213.jpg -------------------------------------------------------------------------------- /imgs/calibration_000002.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/HighDimConvNets/bd8f03150b4d639db61109a93c37f3be0dcaec38/imgs/calibration_000002.h5 -------------------------------------------------------------------------------- /imgs/calibration_000344.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/HighDimConvNets/bd8f03150b4d639db61109a93c37f3be0dcaec38/imgs/calibration_000344.h5 -------------------------------------------------------------------------------- /imgs/calibration_000489.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/HighDimConvNets/bd8f03150b4d639db61109a93c37f3be0dcaec38/imgs/calibration_000489.h5 -------------------------------------------------------------------------------- /imgs/demo_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/HighDimConvNets/bd8f03150b4d639db61109a93c37f3be0dcaec38/imgs/demo_0.png -------------------------------------------------------------------------------- /imgs/demo_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/HighDimConvNets/bd8f03150b4d639db61109a93c37f3be0dcaec38/imgs/demo_1.png -------------------------------------------------------------------------------- /imgs/demo_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/HighDimConvNets/bd8f03150b4d639db61109a93c37f3be0dcaec38/imgs/demo_2.png -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/HighDimConvNets/bd8f03150b4d639db61109a93c37f3be0dcaec38/lib/__init__.py -------------------------------------------------------------------------------- /lib/all_data_loaders.py: -------------------------------------------------------------------------------- 1 | import lib.twodim_data_loaders 2 | # import lib.threedim_data_loaders 3 | 4 | 5 | def make_data_loader(config, phase, batch_size, num_workers, shuffle=None, repeat=False): 6 | if config.dataset in lib.twodim_data_loaders.dataset_str_mapping: 7 | return lib.twodim_data_loaders.make_data_loader( 8 | config, phase, batch_size, num_workers, shuffle=shuffle, repeat=repeat) 9 | # elif config.dataset in lib.threedim_data_loaders.dataset_str_mapping: 10 | # return lib.threedim_data_loaders.make_data_loader( 11 | # config, phase, batch_size, num_workers, shuffle=shuffle, repeat=repeat) 12 | else: 13 | raise ValueError(f'{config.dataset} not defined.') 14 | -------------------------------------------------------------------------------- /lib/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def pdist(A, B, dist_type='L2', transposed=False): 5 | """ 6 | transposed: if True, F0, F1 have D x N. False by default N x D. 7 | """ 8 | if 'L2' in dist_type: 9 | if transposed: 10 | D2 = torch.sum((A.unsqueeze(2) - B.unsqueeze(1)).pow(2), 0) 11 | else: 12 | D2 = torch.sum((A.unsqueeze(1) - B.unsqueeze(0)).pow(2), 2) 13 | if dist_type == 'L2': 14 | return torch.sqrt(D2 + np.finfo(np.float32).eps) 15 | elif dist_type == 'SquareL2': 16 | return D2 17 | else: 18 | raise NotImplementedError('Not implemented') 19 | 20 | 21 | def find_nn_gpu(F0, F1, nn_max_n=-1, return_distance=False, dist_type='SquareL2', transposed=False): 22 | """ 23 | transposed: if True, F0, F1 have D x N. False by default N x D. 24 | """ 25 | # Too much memory if F0 or F1 large. Divide the F0 26 | if nn_max_n > 1: 27 | if transposed: 28 | N = F0.shape[1] 29 | else: 30 | N = len(F0) 31 | C = int(np.ceil(N / nn_max_n)) 32 | stride = nn_max_n 33 | dists, inds = [], [] 34 | for i in range(C): 35 | if transposed: 36 | dist = pdist(F0[:, i * stride:(i + 1) * stride], F1, dist_type=dist_type, transposed=transposed) 37 | else: 38 | dist = pdist(F0[i * stride:(i + 1) * stride], F1, dist_type=dist_type, transposed=transposed) 39 | min_dist, ind = dist.min(dim=1) 40 | dists.append(min_dist.detach().unsqueeze(1).cpu()) 41 | inds.append(ind.cpu()) 42 | 43 | if C * stride < N: 44 | if transposed: 45 | dist = pdist(F0[:, C * stride:], F1, dist_type=dist_type, transposed=transposed) 46 | else: 47 | dist = pdist(F0[C * stride:], F1, dist_type=dist_type, transposed=transposed) 48 | min_dist, ind = dist.min(dim=1) 49 | dists.append(min_dist.detach().unsqueeze(1).cpu()) 50 | inds.append(ind.cpu()) 51 | 52 | dists = torch.cat(dists) 53 | inds = torch.cat(inds) 54 | assert len(inds) == N 55 | else: 56 | dist = pdist(F0, F1, dist_type=dist_type, transposed=transposed) 57 | min_dist, inds = dist.min(dim=1) 58 | dists = min_dist.detach().unsqueeze(1).cpu() 59 | inds = inds.cpu() 60 | if return_distance: 61 | return inds, dists 62 | else: 63 | return inds 64 | -------------------------------------------------------------------------------- /lib/lfgc_trainer.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import logging 3 | 4 | import numpy as np 5 | import sklearn.metrics as metrics 6 | import torch 7 | 8 | from baseline.model.lfgc import LFGCLoss, LFGCNet 9 | from lib.timer import AverageMeter, Timer 10 | from lib.trainer import Trainer 11 | from lib.util_2d import (compute_angular_error, compute_symmetric_epipolar_residual, 12 | weighted_8points) 13 | 14 | 15 | class LFGCTrainer(Trainer): 16 | """LFGC trainer""" 17 | 18 | def __init__(self, config, data_loader, val_data_loader=None): 19 | Trainer.__init__(self, config, data_loader, val_data_loader) 20 | self.loss = LFGCLoss( 21 | alpha=1.0, beta=0.1, regression_iter=config.regression_loss_iter) 22 | 23 | def _initialize_model(self): 24 | model = LFGCNet() 25 | return model 26 | 27 | def _train_epoch(self, epoch): 28 | gc.collect() 29 | 30 | data_loader_iter = self.data_loader.__iter__() 31 | loss_meter, prec_meter, recall_meter, f1_meter, ap_meter = AverageMeter( 32 | ), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() 33 | data_timer, inlier_timer, total_timer = Timer(), Timer(), Timer() 34 | 35 | tot_num_data = len(data_loader_iter) 36 | if self.train_max_iter > 0: 37 | tot_num_data = min(self.train_max_iter, tot_num_data) 38 | start_iter = (epoch - 1) * tot_num_data 39 | 40 | self.model.train() 41 | for curr_iter in range(tot_num_data): 42 | self.optimizer.zero_grad() 43 | total_timer.tic() 44 | 45 | # Load data 46 | data_timer.tic() 47 | input_dict = data_loader_iter.next() 48 | data_timer.toc() 49 | 50 | # Feature extraction 51 | inlier_timer.tic() 52 | norm_coords = input_dict['norm_coords'].transpose(2, 1).to(self.device) 53 | logits = self.model(norm_coords).squeeze(1) 54 | inlier_timer.toc() 55 | 56 | # Calculate loss 57 | labels = input_dict['labels'].to(self.device) 58 | e = input_dict['E'].to(self.device) 59 | loss = self.loss(logits, norm_coords, labels, e, start_iter + curr_iter) 60 | loss.backward() 61 | 62 | # Check gradient explode 63 | explode = False 64 | for _, param in self.model.named_parameters(): 65 | if torch.any(torch.isnan(param.grad)): 66 | explode = True 67 | 68 | if explode: 69 | total_timer.toc() 70 | continue 71 | 72 | self.optimizer.step() 73 | total_timer.toc() 74 | 75 | # Accumulate metrics 76 | pred = np.hstack(torch.sigmoid(logits).squeeze().detach().cpu().numpy()) 77 | target = np.hstack(labels.cpu().numpy()).astype(np.int) 78 | prec, recall, f1, _ = metrics.precision_recall_fscore_support( 79 | target, (pred > 0.5).astype(np.int), average='binary') 80 | ap = metrics.average_precision_score(target, pred) 81 | 82 | prec_meter.update(prec) 83 | recall_meter.update(recall) 84 | f1_meter.update(f1) 85 | ap_meter.update(ap) 86 | loss_meter.update(loss.item()) 87 | 88 | torch.cuda.empty_cache() 89 | 90 | if curr_iter % self.config.stat_freq == 0: 91 | stat = { 92 | 'prec': prec_meter.avg, 93 | 'recall': recall_meter.avg, 94 | 'f1': f1_meter.avg, 95 | 'ap': ap_meter.avg, 96 | 'loss': loss_meter.avg 97 | } 98 | 99 | for k, v in stat.items(): 100 | self.writer.add_scalar(f'train/{k}', v, start_iter + curr_iter) 101 | 102 | logging.info( 103 | ', '.join([f"Train Epoch: {epoch} [{curr_iter}/{tot_num_data}]"] + 104 | [f"{k.capitalize()}: {v:.4f}" for k, v in stat.items()] + [ 105 | f"Data time: {data_timer.avg:.4f}", 106 | f"Train time: {total_timer.avg - data_timer.avg:.4f}", 107 | f"Total time: {total_timer.avg:.4f}" 108 | ])) 109 | 110 | def _valid_epoch(self): 111 | gc.collect() 112 | 113 | data_loader_iter = self.val_data_loader.__iter__() 114 | loss_meter, prec_meter, recall_meter, f1_meter, ap_meter = AverageMeter( 115 | ), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() 116 | data_timer, inlier_timer, total_timer = Timer(), Timer(), Timer() 117 | 118 | tot_num_data = len(self.val_data_loader.dataset) 119 | if self.val_max_iter > 0: 120 | tot_num_data = min(self.val_max_iter, tot_num_data) 121 | 122 | self.model.eval() 123 | with torch.no_grad(): 124 | for curr_iter in range(tot_num_data): 125 | total_timer.tic() 126 | 127 | data_timer.tic() 128 | input_dict = data_loader_iter.next() 129 | data_timer.toc() 130 | 131 | # Feature extraction 132 | inlier_timer.tic() 133 | norm_coords = input_dict['norm_coords'].transpose(2, 1).to(self.device) 134 | logits = self.model(norm_coords) 135 | logits = logits.squeeze(1) 136 | inlier_timer.toc() 137 | 138 | # Calculate loss 139 | labels = input_dict['labels'].to(self.device) 140 | e = input_dict['E'].to(self.device) 141 | loss = self.loss(logits, norm_coords, labels, e, curr_iter) 142 | total_timer.toc() 143 | 144 | # Accumulate metrics 145 | pred = np.hstack(torch.sigmoid(logits).squeeze().cpu().numpy()) 146 | target = np.hstack(labels.cpu().numpy()).astype(np.int) 147 | prec, recall, f1, _ = metrics.precision_recall_fscore_support( 148 | target, (pred > 0.5).astype(np.int), average='binary') 149 | ap = metrics.average_precision_score(target, pred) 150 | 151 | prec_meter.update(prec) 152 | recall_meter.update(recall) 153 | f1_meter.update(f1) 154 | ap_meter.update(ap) 155 | loss_meter.update(loss.item()) 156 | 157 | torch.cuda.empty_cache() 158 | 159 | if curr_iter % self.config.stat_freq == 0: 160 | stat = { 161 | 'prec': prec_meter.avg, 162 | 'recall': recall_meter.avg, 163 | 'f1': f1_meter.avg, 164 | 'ap': ap_meter.avg, 165 | 'loss': loss_meter.avg 166 | } 167 | 168 | logging.info( 169 | ', '.join([f"Validation [{curr_iter}/{tot_num_data}]"] + 170 | [f"{k.capitalize()}: {v:.4f}" for k, v in stat.items()] + [ 171 | f"Data time: {data_timer.avg:.4f}", 172 | f"Train time: {total_timer.avg - data_timer.avg:.4f}", 173 | f"Total time: {total_timer.avg:.4f}" 174 | ])) 175 | 176 | stat = { 177 | 'prec': prec_meter.avg, 178 | 'recall': recall_meter.avg, 179 | 'f1': f1_meter.avg, 180 | 'ap': ap_meter.avg, 181 | 'loss': loss_meter.avg 182 | } 183 | logging.info(', '.join([f"Validation"] + 184 | [f"{k.capitalize()}: {v:.4f}" for k, v in stat.items()])) 185 | 186 | return stat 187 | 188 | def test(self, test_loader): 189 | test_iter = test_loader.__iter__() 190 | logging.info(f"Evaluating on {test_loader.dataset.scene}") 191 | 192 | self.model.eval() 193 | labels, preds, residuals, err_qs, err_ts = [], [], [], [], [] 194 | with torch.no_grad(): 195 | for _ in range(len(test_iter)): 196 | input_dict = test_iter.next() 197 | 198 | norm_coords = input_dict['norm_coords'] 199 | 200 | coords_input = norm_coords.transpose(2, 1).to(self.device) 201 | logit = self.model(coords_input) 202 | logit = logit.squeeze(1) 203 | e_hat = weighted_8points(coords_input, logit) 204 | logit = logit.cpu() 205 | e_hat = e_hat.cpu().numpy() 206 | 207 | label = np.hstack(input_dict['labels']) 208 | pred = np.hstack(torch.sigmoid(logit)) 209 | norm_coords = np.hstack(norm_coords) 210 | R = np.hstack(input_dict['R']) 211 | t = np.hstack(input_dict['t']) 212 | 213 | residual = compute_symmetric_epipolar_residual( 214 | e_hat.reshape(3, 3).T, 215 | norm_coords[label.astype(bool), :2], 216 | norm_coords[label.astype(bool), 2:], 217 | ) 218 | 219 | err_q, err_t = compute_angular_error( 220 | R, 221 | t, 222 | e_hat.reshape(3, 3), 223 | norm_coords, 224 | pred, 225 | ) 226 | 227 | labels.append(label.astype(np.int)) 228 | preds.append(pred) 229 | residuals.append(residual) 230 | err_qs.append(err_q) 231 | err_ts.append(err_t) 232 | return labels, preds, residuals, err_qs, err_ts 233 | -------------------------------------------------------------------------------- /lib/oa_trainer.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import logging 3 | 4 | import numpy as np 5 | import sklearn.metrics as metrics 6 | import torch 7 | 8 | from baseline.model.oanet import OALoss, OANet 9 | from lib.timer import AverageMeter, Timer 10 | from lib.trainer import Trainer 11 | from lib.util_2d import compute_symmetric_epipolar_residual, compute_angular_error 12 | 13 | 14 | class OATrainer(Trainer): 15 | """OANet trainer""" 16 | 17 | def __init__(self, config, data_loader, val_data_loader=None): 18 | Trainer.__init__(self, config, data_loader, val_data_loader) 19 | self.loss = OALoss(config) 20 | 21 | def _initialize_model(self): 22 | model = OANet(self.config) 23 | return model 24 | 25 | def _train_epoch(self, epoch): 26 | gc.collect() 27 | 28 | data_loader_iter = self.data_loader.__iter__() 29 | loss_meter, prec_meter, recall_meter, f1_meter, ap_meter = AverageMeter( 30 | ), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() 31 | data_timer, inlier_timer, total_timer = Timer(), Timer(), Timer() 32 | 33 | tot_num_data = len(data_loader_iter) 34 | if self.train_max_iter > 0: 35 | tot_num_data = min(self.train_max_iter, tot_num_data) 36 | start_iter = (epoch - 1) * tot_num_data 37 | 38 | self.model.train() 39 | for curr_iter in range(tot_num_data): 40 | self.optimizer.zero_grad() 41 | total_timer.tic() 42 | 43 | # Load data 44 | data_timer.tic() 45 | input_dict = data_loader_iter.next() 46 | data_timer.toc() 47 | 48 | # To Cuda 49 | for key in input_dict.keys(): 50 | if type(input_dict[key]) == torch.Tensor: 51 | input_dict[key] = input_dict[key].to(self.device) 52 | 53 | # Feature extraction 54 | inlier_timer.tic() 55 | norm_coords = input_dict['norm_coords'].unsqueeze(1) 56 | logits, e_hat = self.model(norm_coords) 57 | inlier_timer.toc() 58 | 59 | # Calculate loss 60 | labels = input_dict['labels'] 61 | loss = 0 62 | for i in range(len(logits)): 63 | loss_i = self.loss.run(start_iter + curr_iter, input_dict, logits[i], e_hat[i]) 64 | loss += loss_i 65 | loss.backward() 66 | self.optimizer.step() 67 | total_timer.toc() 68 | 69 | # Accumulate metrics 70 | pred = np.hstack(torch.sigmoid(logits[-1]).squeeze().detach().cpu().numpy()) 71 | target = np.hstack(labels.squeeze().detach().cpu().numpy()).astype(np.int) 72 | prec, recall, f1, _ = metrics.precision_recall_fscore_support( 73 | target, (pred > 0.5).astype(np.int), average='binary') 74 | ap = metrics.average_precision_score(target, pred) 75 | 76 | prec_meter.update(prec) 77 | recall_meter.update(recall) 78 | f1_meter.update(f1) 79 | ap_meter.update(ap) 80 | loss_meter.update(loss.item()) 81 | 82 | torch.cuda.empty_cache() 83 | 84 | if curr_iter % self.config.stat_freq == 0: 85 | # Use the current value to see how stochastic the metrics are 86 | stat = { 87 | 'prec': prec_meter.val, 88 | 'recall': recall_meter.val, 89 | 'f1': f1_meter.val, 90 | 'ap': ap_meter.val, 91 | 'loss': loss_meter.val 92 | } 93 | 94 | for k, v in stat.items(): 95 | self.writer.add_scalar(f'train/{k}', v, start_iter + curr_iter) 96 | 97 | logging.info( 98 | ', '.join([f"Train Epoch: {epoch} [{curr_iter}/{tot_num_data}]"] + 99 | [f"{k.capitalize()}: {v:.4f}" for k, v in stat.items()] + [ 100 | f"Data time: {data_timer.avg:.4f}", 101 | f"Train time: {total_timer.avg - data_timer.avg:.4f}", 102 | f"Total time: {total_timer.avg:.4f}" 103 | ])) 104 | 105 | def _valid_epoch(self): 106 | gc.collect() 107 | 108 | data_loader_iter = self.val_data_loader.__iter__() 109 | loss_meter, prec_meter, recall_meter, f1_meter, ap_meter = AverageMeter( 110 | ), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() 111 | data_timer, inlier_timer, total_timer = Timer(), Timer(), Timer() 112 | 113 | tot_num_data = len(self.val_data_loader.dataset) 114 | if self.val_max_iter > 0: 115 | tot_num_data = min(self.val_max_iter, tot_num_data) 116 | 117 | self.model.eval() 118 | with torch.no_grad(): 119 | for curr_iter in range(tot_num_data): 120 | total_timer.tic() 121 | 122 | data_timer.tic() 123 | input_dict = data_loader_iter.next() 124 | data_timer.toc() 125 | 126 | # To Cuda 127 | for key in input_dict.keys(): 128 | if type(input_dict[key]) == torch.Tensor: 129 | input_dict[key] = input_dict[key].to(self.device) 130 | 131 | # Feature extraction 132 | inlier_timer.tic() 133 | norm_coords = input_dict['norm_coords'].unsqueeze(1) 134 | logits, e_hats = self.model(norm_coords) 135 | inlier_timer.toc() 136 | 137 | # Calculate loss 138 | labels = input_dict['labels'] 139 | loss = 0 140 | for i in range(len(logits)): 141 | loss_i = self.loss.run(curr_iter, input_dict, logits[i], e_hats[i]) 142 | loss += loss_i 143 | total_timer.toc() 144 | 145 | # Accumulate metrics 146 | pred = np.hstack(torch.sigmoid(logits[-1]).squeeze().cpu().numpy()) 147 | target = np.hstack(labels.squeeze().cpu().numpy()).astype(np.int) 148 | prec, recall, f1, _ = metrics.precision_recall_fscore_support( 149 | target, (pred > 0.5).astype(np.int), average='binary') 150 | ap = metrics.average_precision_score(target, pred) 151 | 152 | prec_meter.update(prec) 153 | recall_meter.update(recall) 154 | f1_meter.update(f1) 155 | ap_meter.update(ap) 156 | loss_meter.update(loss.item()) 157 | 158 | # Clean 159 | torch.cuda.empty_cache() 160 | 161 | if curr_iter % self.config.stat_freq == 0: 162 | stat = { 163 | 'prec': prec_meter.avg, 164 | 'recall': recall_meter.avg, 165 | 'f1': f1_meter.avg, 166 | 'ap': ap_meter.avg, 167 | 'loss': loss_meter.avg 168 | } 169 | 170 | logging.info( 171 | ', '.join([f"Validation [{curr_iter}/{tot_num_data}]"] + 172 | [f"{k.capitalize()}: {v:.4f}" for k, v in stat.items()] + [ 173 | f"Data time: {data_timer.avg:.4f}", 174 | f"Train time: {total_timer.avg - data_timer.avg:.4f}", 175 | f"Total time: {total_timer.avg:.4f}" 176 | ])) 177 | 178 | stat = { 179 | 'prec': prec_meter.avg, 180 | 'recall': recall_meter.avg, 181 | 'f1': f1_meter.avg, 182 | 'ap': ap_meter.avg, 183 | 'loss': loss_meter.avg 184 | } 185 | logging.info(', '.join([f"Validation"] + 186 | [f"{k.capitalize()}: {v:.4f}" for k, v in stat.items()])) 187 | 188 | return stat 189 | 190 | def test(self, test_loader): 191 | test_iter = test_loader.__iter__() 192 | logging.info(f"Evaluating on {test_loader.dataset.scene}") 193 | 194 | self.model.eval() 195 | labels, preds, residuals, err_qs, err_ts = [], [], [], [], [] 196 | with torch.no_grad(): 197 | for _ in range(len(test_iter)): 198 | input_dict = test_iter.next() 199 | 200 | norm_coords = input_dict['norm_coords'] 201 | logits, e_hats = self.model(norm_coords.unsqueeze(1).to(self.device)) 202 | logit = logits[-1].squeeze().cpu() 203 | e_hat = e_hats[-1].cpu().numpy() 204 | 205 | label = np.hstack(input_dict['labels']) 206 | pred = np.hstack(torch.sigmoid(logit)) 207 | norm_coords = np.hstack(norm_coords) 208 | R = np.hstack(input_dict['R']) 209 | t = np.hstack(input_dict['t']) 210 | 211 | residual = compute_symmetric_epipolar_residual( 212 | e_hat.reshape(3, 3).T, 213 | norm_coords[label.astype(bool), :2], 214 | norm_coords[label.astype(bool), 2:], 215 | ) 216 | 217 | err_q, err_t = compute_angular_error( 218 | R, 219 | t, 220 | e_hat.reshape(3, 3), 221 | norm_coords, 222 | pred, 223 | ) 224 | 225 | labels.append(label.astype(np.int)) 226 | preds.append(pred) 227 | residuals.append(residual) 228 | err_qs.append(err_q) 229 | err_ts.append(err_t) 230 | return labels, preds, residuals, err_qs, err_ts 231 | -------------------------------------------------------------------------------- /lib/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | 4 | 5 | class ConfusionMatrix(object): 6 | 7 | def __init__(self): 8 | self.eps = np.finfo(float).eps 9 | self.inlier_meter = AverageMeter() 10 | self.correspondence_meter = AverageMeter() 11 | self.reset() 12 | 13 | def update(self, pred, target): 14 | target = target.astype(np.bool) 15 | pred_on_pos = pred[target] 16 | pred_on_neg = pred[~target] 17 | 18 | tp = np.sum(pred_on_pos) 19 | fn = np.sum(~pred_on_pos) 20 | fp = np.sum(pred_on_neg) 21 | tn = np.sum(~pred_on_neg) 22 | 23 | self.tp += tp 24 | self.fn += fn 25 | self.fp += fp 26 | self.tn += tn 27 | 28 | inlier_ratio = np.sum(target) / target.size 29 | correspondence_accuracy = np.sum(pred) / pred.size 30 | 31 | self.inlier_meter.update(inlier_ratio) 32 | self.correspondence_meter.update(correspondence_accuracy) 33 | 34 | def eval(self): 35 | tp, tn, fp, fn, eps = self.tp, self.tn, self.fp, self.fn, self.eps 36 | 37 | accuracy = (tp + tn) / (tp + fp + tn + fn + eps) 38 | precision = tp / (tp + fp + eps) 39 | recall = tp / (tp + fn + eps) 40 | f1 = 2 * (precision * recall) / (precision + recall + eps) 41 | tpr = tp / (tp + fn + eps) 42 | tnr = tn / (tn + fp + eps) 43 | balanced_accuracy = (tpr + tnr) / 2 44 | 45 | return { 46 | 'inlier_ratio': self.inlier_meter.avg, 47 | 'correspondence_accuracy': self.correspondence_meter.avg, 48 | 'accuracy': accuracy, 49 | 'precision': precision, 50 | 'recall': recall, 51 | 'f1': f1, 52 | 'tpr': tpr, 53 | 'tnr': tnr, 54 | 'balanced_accuracy': balanced_accuracy, 55 | } 56 | 57 | def reset(self): 58 | self.tp = 0 59 | self.tn = 0 60 | self.fp = 0 61 | self.fn = 0 62 | self.inlier_meter.reset() 63 | self.correspondence_meter.reset() 64 | 65 | 66 | class GroupMeter(object): 67 | 68 | def __init__(self, keys): 69 | self.keys = keys 70 | for k in keys: 71 | setattr(self, k, AverageMeter()) 72 | 73 | def update(self, key, value): 74 | if hasattr(self, key): 75 | meter = getattr(self, key) 76 | meter.update(value) 77 | else: 78 | raise ValueError(f"{key} is not registered") 79 | 80 | def get(self, key, average=True): 81 | if hasattr(self, key): 82 | meter = getattr(self, key) 83 | if average: 84 | return meter.avg 85 | else: 86 | return meter.val 87 | else: 88 | raise ValueError(f"{key} is not registerd") 89 | 90 | def get_dict(self): 91 | return {k: self.get(k) for k in self.keys} 92 | 93 | 94 | class AverageMeter(object): 95 | """Computes and stores the average and current value""" 96 | 97 | def __init__(self): 98 | self.reset() 99 | 100 | def reset(self): 101 | self.val = 0 102 | self.avg = 0 103 | self.sum = 0.0 104 | self.sq_sum = 0.0 105 | self.count = 0 106 | 107 | def update(self, val, n=1): 108 | if not np.isnan(val): 109 | self.val = val 110 | self.sum += val * n 111 | self.count += n 112 | self.avg = self.sum / self.count 113 | self.sq_sum += val**2 * n 114 | self.var = self.sq_sum / self.count - self.avg**2 115 | 116 | 117 | class Timer(object): 118 | """A simple timer.""" 119 | 120 | def __init__(self): 121 | self.total_time = 0. 122 | self.calls = 0 123 | self.start_time = 0. 124 | self.diff = 0. 125 | self.avg = 0. 126 | 127 | def reset(self): 128 | self.total_time = 0 129 | self.calls = 0 130 | self.start_time = 0 131 | self.diff = 0 132 | self.avg = 0 133 | 134 | def tic(self): 135 | # using time.time instead of time.clock because time time.clock 136 | # does not normalize for multithreading 137 | self.start_time = time.time() 138 | 139 | def toc(self, average=True): 140 | self.diff = time.time() - self.start_time 141 | self.total_time += self.diff 142 | self.calls += 1 143 | self.avg = self.total_time / self.calls 144 | if average: 145 | return self.avg 146 | else: 147 | return self.diff 148 | -------------------------------------------------------------------------------- /lib/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import logging 4 | import numpy as np 5 | import json 6 | from abc import abstractmethod, ABC 7 | 8 | import torch 9 | import torch.optim as optim 10 | from tensorboardX import SummaryWriter 11 | 12 | from lib.util import ensure_dir, count_parameters 13 | 14 | 15 | class Trainer(ABC): 16 | 17 | def __init__(self, config, data_loader, val_data_loader=None): 18 | if config.use_gpu and not torch.cuda.is_available(): 19 | logging.warning('Warning: There\'s no CUDA support on this machine, ' 20 | 'training is performed on CPU.') 21 | raise ValueError('GPU not available, but cuda flag set') 22 | self.config = config 23 | 24 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 25 | 26 | # Prepare training config 27 | self.start_epoch = 1 28 | self.max_epoch = config.max_epoch 29 | self.save_freq = config.save_freq_epoch 30 | self.train_max_iter = config.train_max_iter 31 | self.val_max_iter = config.val_max_iter 32 | self.val_epoch_freq = config.val_epoch_freq 33 | self.iter_size = config.iter_size 34 | self.batch_size = config.batch_size 35 | self.data_loader = data_loader 36 | self.val_data_loader = val_data_loader 37 | self.test_valid = True if self.val_data_loader is not None else False 38 | self.pos_weight = config.pos_weight 39 | self.neg_weight = config.neg_weight 40 | 41 | # Prepare validation config 42 | self.best_val_comparator = config.best_val_comparator 43 | self.best_val_metric = config.best_val_metric 44 | self.best_val_epoch = -np.inf 45 | self.best_val = -np.inf 46 | 47 | # Initialize model, optimiser and scheduler 48 | model = self._initialize_model() 49 | logging.info(model) 50 | num_params = count_parameters(model) 51 | logging.info(f"=> Number of parameters = {num_params}") 52 | self.model = model.to(self.device) 53 | self.initialize_optimiser_and_scheduler() 54 | self.resume() 55 | 56 | # Prepare output directory 57 | self.checkpoint_dir = config.out_dir 58 | ensure_dir(self.checkpoint_dir) 59 | json.dump( 60 | config, 61 | open(os.path.join(self.checkpoint_dir, 'config.json'), 'w'), 62 | indent=4, 63 | sort_keys=False) 64 | 65 | # Intialize tensorboard summary writer 66 | self.writer = SummaryWriter(logdir=config.out_dir) 67 | 68 | @abstractmethod 69 | def _initialize_model(self): 70 | pass 71 | 72 | @abstractmethod 73 | def _train_epoch(self, epoch): 74 | pass 75 | 76 | @abstractmethod 77 | def _valid_epoch(self): 78 | pass 79 | 80 | def initialize_optimiser_and_scheduler(self): 81 | config = self.config 82 | if config.optimizer == 'Adam': 83 | self.optimizer = getattr(optim, config.optimizer)( 84 | self.model.parameters(), lr=config.lr, weight_decay=config.weight_decay) 85 | else: 86 | self.optimizer = getattr(optim, config.optimizer)( 87 | self.model.parameters(), 88 | lr=config.lr, 89 | momentum=config.momentum, 90 | weight_decay=config.weight_decay) 91 | 92 | self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, config.exp_gamma) 93 | 94 | def resume(self): 95 | config = self.config 96 | if config.resume is None and config.weights: 97 | logging.info("=> Loading checkpoint '{}'".format(config.weights)) 98 | checkpoint = torch.load(config.weights) 99 | if 'state_dict' in checkpoint: 100 | self.model.load_state_dict(checkpoint['state_dict']) 101 | logging.info("=> Loaded inlier weights from '{}'".format(config.weights)) 102 | else: 103 | logging.warn("=> Inlier weight not found in '{}'".format(config.weights)) 104 | 105 | if config.resume is not None: 106 | if osp.isfile(config.resume): 107 | logging.info(f"=> Resuming training from the checkpoint '{config.resume}'") 108 | state = torch.load(config.resume) 109 | 110 | self.start_epoch = state['epoch'] + 1 111 | logging.info(f"=> Training from {self.start_epoch} to {self.max_epoch}'") 112 | self.model.load_state_dict(state['state_dict']) 113 | self.scheduler.load_state_dict(state['scheduler']) 114 | self.optimizer.load_state_dict(state['optimizer']) 115 | logging.info(f"=> Loaded weights, scheduler, optimizer from '{config.resume}'") 116 | 117 | if 'best_val' in state.keys(): 118 | self.best_val = state['best_val'] 119 | self.best_val_epoch = state['best_val_epoch'] 120 | self.best_val_metric = state['best_val_metric'] 121 | else: 122 | raise ValueError(f"=> no checkpoint found at '{config.resume}'") 123 | 124 | def save_checkpoint(self, epoch, filename='checkpoint'): 125 | state = { 126 | 'epoch': epoch, 127 | 'state_dict': self.model.state_dict(), 128 | 'optimizer': self.optimizer.state_dict(), 129 | 'scheduler': self.scheduler.state_dict(), 130 | 'config': self.config, 131 | 'best_val': self.best_val, 132 | 'best_val_epoch': self.best_val_epoch, 133 | 'best_val_metric': self.best_val_metric 134 | } 135 | filename = os.path.join(self.checkpoint_dir, f'{filename}.pth') 136 | logging.info("Saving checkpoint: {} ...".format(filename)) 137 | torch.save(state, filename) 138 | 139 | def train(self): 140 | """Full training logic""" 141 | 142 | # Baseline random feature performance 143 | if self.test_valid: 144 | val_dict = self._valid_epoch() 145 | for k, v in val_dict.items(): 146 | self.writer.add_scalar(f'val/{k}', v, 0) 147 | 148 | for epoch in range(self.start_epoch, self.max_epoch + 1): 149 | lr = self.scheduler.get_lr() 150 | logging.info(f" Epoch: {epoch}, LR: {lr}") 151 | self._train_epoch(epoch) 152 | self.save_checkpoint(epoch) 153 | self.scheduler.step() 154 | 155 | if self.test_valid and epoch % self.val_epoch_freq == 0: 156 | val_dict = self._valid_epoch() 157 | for k, v in val_dict.items(): 158 | self.writer.add_scalar(f'val/{k}', v, epoch) 159 | if (self.best_val_comparator == 'larger' and self.best_val < val_dict[self.best_val_metric]) or \ 160 | (self.best_val_comparator == 'smaller' and self.best_val > val_dict[self.best_val_metric]): 161 | logging.info( 162 | f'Saving the best val model with {self.best_val_metric}: {val_dict[self.best_val_metric]}' 163 | ) 164 | self.best_val = val_dict[self.best_val_metric] 165 | self.best_val_epoch = epoch 166 | self.save_checkpoint(epoch, 'best_val_checkpoint') 167 | else: 168 | logging.info( 169 | f'Current best val model with {self.best_val_metric}: {self.best_val} at epoch {self.best_val_epoch}' 170 | ) 171 | -------------------------------------------------------------------------------- /lib/twodim_data_loaders.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os.path as osp 3 | 4 | import cv2 5 | import h5py 6 | import MinkowskiEngine as ME 7 | import numpy as np 8 | import torch 9 | import torch.utils.data 10 | 11 | from lib.util_data import InfSampler 12 | from lib.util_2d import compute_symmetric_epipolar_residual 13 | 14 | 15 | class CollationFunctionFactory: 16 | def __init__(self, config): 17 | self.config = config 18 | if config.collation_2d == 'default': 19 | self.fn = self.collate 20 | elif config.collation_2d == 'collate_correspondence': 21 | self.fn = self.collate_correspondence 22 | elif config.collation_2d == 'collate_lfgc': 23 | self.fn = self.collate_oa 24 | elif config.collation_2d == 'collate_oa': 25 | self.fn = self.collate_oa 26 | else: 27 | raise ValueError(f'Invalid collation_2d {config.collation_2d}') 28 | 29 | def collate(self, data): 30 | if isinstance(data[0], dict): 31 | return { 32 | # "img0": [b["img0"] for b in data], 33 | # "img1": [b["img1"] for b in data], 34 | "coords": [b["coords"] for b in data], 35 | "labels": [b["labels"] for b in data], 36 | "norm_coords": [b["norm_coords"] for b in data], 37 | "E": [b["E"] for b in data] 38 | } 39 | return list(zip(*data)) 40 | 41 | def collate_oa(self, batch): 42 | assert isinstance(batch[0], dict) 43 | 44 | img0_batch = [b['img0'] for b in batch] 45 | img1_batch = [b['img1'] for b in batch] 46 | coords_batch = [b['coords'] for b in batch] 47 | norm_coords_batch = [b['norm_coords'] for b in batch] 48 | labels_batch = [b['labels'] for b in batch] 49 | e_batch = [b['E'] for b in batch] 50 | virt_batch = [b['virtPts'] for b in batch] 51 | R_batch = [b['R'] for b in batch] 52 | t_batch = [b['t'] for b in batch] 53 | 54 | numkps = [coords.shape[0] for coords in norm_coords_batch] 55 | minkps = np.min(numkps) 56 | 57 | norm_coords_batch = [coords[:minkps, :] for coords in norm_coords_batch] 58 | labels_batch = [labels[:minkps] for labels in labels_batch] 59 | 60 | return { 61 | 'coords': coords_batch, 62 | 'norm_coords': torch.from_numpy(np.asarray(norm_coords_batch)), 63 | 'labels': torch.from_numpy(np.asarray(labels_batch)), 64 | 'E': torch.from_numpy(np.asarray(e_batch)), 65 | 'R': torch.from_numpy(np.asarray(R_batch)), 66 | 't': torch.from_numpy(np.asarray(t_batch)), 67 | 'virtPts': torch.from_numpy(np.asarray(virt_batch)), 68 | 'img0': img0_batch, 69 | 'img1': img1_batch 70 | } 71 | 72 | def collate_correspondence(self, batch): 73 | assert isinstance(batch[0], dict) 74 | 75 | img0_batch = [b['img0'] for b in batch] 76 | img1_batch = [b['img1'] for b in batch] 77 | coords_batch = [b['coords'] for b in batch] 78 | norm_coords_batch = [b['norm_coords'] for b in batch] 79 | labels_batch = [b['labels'] for b in batch] 80 | E_batch = [b['E'] for b in batch] 81 | R_batch = [b['R'] for b in batch] 82 | t_batch = [b['t'] for b in batch] 83 | virt_batch = [b['virtPts'] for b in batch] 84 | 85 | sinput_C, sinput_F, sinput_L, idx_batch = [], [], [], [] 86 | for norm_coords, labels in zip(norm_coords_batch, labels_batch): 87 | # quantize 88 | quan_coords = np.floor(norm_coords / self.config.quantization_size) 89 | idx, idx_reverse = ME.utils.sparse_quantize(quan_coords, 90 | return_index=True, 91 | return_inverse=True) 92 | C = quan_coords[idx] 93 | F = norm_coords[idx] 94 | L = labels[idx] 95 | 96 | sinput_C.append(C) 97 | sinput_F.append(F) 98 | sinput_L.append(L) 99 | idx_batch.append(idx) 100 | 101 | # Sample minimum number of coordinates for each batch 102 | if self.config.sample_minimum_coords: 103 | npts = [C.shape[0] for C in sinput_C] 104 | min_pts = np.min(npts) 105 | for i, (C, F, L) in enumerate(zip(sinput_C, sinput_F, sinput_L)): 106 | if C.shape[0] > min_pts: 107 | rand_idx = np.random.choice(C.shape[0], min_pts, replace=False) 108 | sinput_C[i] = C[rand_idx] 109 | sinput_F[i] = F[rand_idx] 110 | sinput_L[i] = L[rand_idx] 111 | idx_batch[i] = idx_batch[i][rand_idx] 112 | 113 | # Collate 114 | len_batch = [C.shape[0] for C in sinput_C] 115 | E_batch = torch.from_numpy(np.asarray(E_batch)) 116 | R_batch = torch.from_numpy(np.asarray(R_batch)) 117 | t_batch = torch.from_numpy(np.asarray(t_batch)) 118 | virt_batch = torch.from_numpy(np.asarray(virt_batch)) 119 | norm_coords_batch = sinput_F 120 | 121 | sinput_C_th = ME.utils.batched_coordinates(sinput_C) 122 | sinput_F_th = torch.from_numpy(np.vstack(sinput_F)) 123 | sinput_L_th = torch.from_numpy(np.hstack(sinput_L)) 124 | 125 | xyz = sinput_F_th 126 | 127 | # if inlier_feature_type is not coords, use ones as feature 128 | if self.config.inlier_feature_type != 'coords': 129 | sinput_F_th = torch.ones((len(sinput_C), 1)) 130 | 131 | return { 132 | 'sinput_C': sinput_C_th.int(), 133 | 'sinput_F': sinput_F_th.float(), 134 | 'sinput_L': sinput_L_th.int(), 135 | 'virtPts': virt_batch.float(), 136 | 'E': E_batch.float(), 137 | 'R': R_batch.float(), 138 | 't': t_batch.float(), 139 | 'len_batch': len_batch, 140 | 'xyz': xyz, 141 | 'coords': coords_batch, 142 | 'norm_coords': norm_coords_batch, 143 | 'labels': labels_batch, 144 | 'img0': img0_batch, 145 | 'img1': img1_batch, 146 | } 147 | 148 | def __call__(self, data): 149 | return self.fn(data) 150 | 151 | 152 | class YFCC100MDatasetExtracted(torch.utils.data.Dataset): 153 | DATA_FILES = { 154 | 'train': 'yfcc-sift-2000-train.h5', 155 | 'val': 'yfcc-sift-2000-val.h5', 156 | 'test': 'yfcc-sift-2000-test.h5' 157 | } 158 | 159 | def __init__(self, phase, manual_seed, config, scene=None): 160 | self.phase = phase 161 | self.manual_seed = manual_seed 162 | self.config = config 163 | self.scene = scene 164 | 165 | # self.source_dir = config.data_dir_raw 166 | self.target_dir = config.data_dir_processed 167 | self.inlier_threshold_pixel = config.inlier_threshold_pixel 168 | 169 | self.data = None 170 | config_name = self.DATA_FILES[phase] 171 | if phase == 'test' and scene is not None: 172 | splits = config_name.split('.') 173 | config_name = splits[0] + f'-{scene}.' + splits[1] 174 | self.filename = osp.join(self.target_dir, config_name) 175 | logging.info( 176 | f"Loading {self.__class__.__name__} subset {phase} from {self.filename} with {self.__len__()} pairs." 177 | ) 178 | 179 | def __len__(self): 180 | if self.data is None: 181 | self.data = h5py.File(self.filename, 'r') 182 | _len = len(self.data['coords']) 183 | self.data.close() 184 | self.data = None 185 | else: 186 | _len = len(self.data['coords']) 187 | return _len 188 | 189 | def __del__(self): 190 | if self.data is not None: 191 | self.data.close() 192 | 193 | def __getitem__(self, idx): 194 | if self.data is None: 195 | self.data = h5py.File(self.filename, 'r') 196 | 197 | idx = str(idx) 198 | coords = self.data['coords'][idx] 199 | # img_path0 = coords.attrs['img0'] 200 | # img_path1 = coords.attrs['img1'] 201 | coords = np.asarray(coords) 202 | norm_coords = np.asarray(self.data['n_coords'][idx]) 203 | E = np.asarray(self.data['E'][idx]) 204 | R = np.asarray(self.data['R'][idx]) 205 | t = np.asarray(self.data['t'][idx]) 206 | res = np.asarray(self.data['res'][idx]) 207 | E = E / np.linalg.norm(E) 208 | 209 | # img0 = cv2.imread(osp.join(self.source_dir, img_path0)) 210 | # img1 = cv2.imread(osp.join(self.source_dir, img_path1)) 211 | img0 = 1 212 | img1 = 1 213 | 214 | labels = res < self.inlier_threshold_pixel 215 | virtPts = self.correctMatches(E) 216 | 217 | return { 218 | 'img0': img0, 219 | 'img1': img1, 220 | 'coords': coords, 221 | 'norm_coords': norm_coords, 222 | 'labels': labels, 223 | 'E': E, 224 | 'R': R, 225 | 't': t, 226 | 'virtPts': virtPts, 227 | } 228 | 229 | def reset(self): 230 | if self.data is not None: 231 | self.data.close() 232 | self.data = None 233 | 234 | def correctMatches(self, E): 235 | step = 0.1 236 | xx, yy = np.meshgrid(np.arange(-1, 1, step), np.arange(-1, 1, step)) 237 | # Points in first image before projection 238 | pts1_virt_b = np.float32(np.vstack((xx.flatten(), yy.flatten())).T) 239 | # Points in second image before projection 240 | pts2_virt_b = np.float32(pts1_virt_b) 241 | pts1_virt_b, pts2_virt_b = pts1_virt_b.reshape(1, -1, 242 | 2), pts2_virt_b.reshape(1, -1, 2) 243 | 244 | pts1_virt_b, pts2_virt_b = cv2.correctMatches(E.reshape(3, 3), pts1_virt_b, 245 | pts2_virt_b) 246 | pts1_virt_b = pts1_virt_b.squeeze() 247 | pts2_virt_b = pts2_virt_b.squeeze() 248 | pts_virt = np.concatenate([pts1_virt_b, pts2_virt_b], axis=1) 249 | return pts_virt 250 | 251 | 252 | class YFCC100MDatasetUCN(YFCC100MDatasetExtracted): 253 | DATA_FILES = { 254 | 'train': 'yfcc-ucn-0-train.h5', 255 | 'val': 'yfcc-ucn-0-val.h5', 256 | 'test': 'yfcc-ucn-0-test.h5' 257 | } 258 | 259 | def __getitem__(self, idx): 260 | if self.data is None: 261 | self.data = h5py.File(self.filename, 'r') 262 | 263 | idx = str(idx) 264 | E = np.asarray(self.data['E'][idx]) 265 | E = E / np.linalg.norm(E) 266 | R = np.asarray(self.data['R'][idx]) 267 | t = np.asarray(self.data['t'][idx]) 268 | img0 = 1 269 | img1 = 1 270 | 271 | # if self.phase != 'train': 272 | coords = np.asarray(self.data['ucn_coords'][idx]) 273 | norm_coords = np.asarray(self.data['ucn_n_coords'][idx]) 274 | res = compute_symmetric_epipolar_residual(E, norm_coords[:, :2], norm_coords[:, 2:]) 275 | # else: 276 | # coords = np.asarray(self.data['coords'][idx]) 277 | # norm_coords = np.asarray(self.data['n_coords'][idx]) 278 | # res = np.asarray(self.data['res'][idx]) 279 | 280 | labels = res < self.inlier_threshold_pixel 281 | virtPts = self.correctMatches(E) 282 | 283 | return { 284 | 'img0': img0, 285 | 'img1': img1, 286 | 'coords': coords, 287 | 'norm_coords': norm_coords, 288 | 'labels': labels, 289 | 'E': E, 290 | 'R': R, 291 | 't': t, 292 | 'virtPts': virtPts, 293 | } 294 | 295 | 296 | ALL_DATASETS = [YFCC100MDatasetExtracted, YFCC100MDatasetUCN] 297 | dataset_str_mapping = {d.__name__: d for d in ALL_DATASETS} 298 | 299 | 300 | def make_data_loader(config, 301 | phase, 302 | batch_size, 303 | num_workers=0, 304 | shuffle=None, 305 | repeat=False, 306 | scene=None): 307 | if config.dataset not in dataset_str_mapping.keys(): 308 | logging.error(f'Dataset {config.dataset}, does not exists in ' + 309 | ', '.join(dataset_str_mapping.keys())) 310 | 311 | Dataset = dataset_str_mapping[config.dataset] 312 | 313 | dataset = Dataset(phase=phase, manual_seed=None, config=config, scene=scene) 314 | 315 | collate_fn = CollationFunctionFactory(config) 316 | 317 | if repeat: 318 | data_loader = torch.utils.data.DataLoader(dataset=dataset, 319 | batch_size=batch_size, 320 | collate_fn=collate_fn, 321 | num_workers=num_workers, 322 | sampler=InfSampler(dataset, shuffle)) 323 | else: 324 | data_loader = torch.utils.data.DataLoader(dataset=dataset, 325 | batch_size=batch_size, 326 | collate_fn=collate_fn, 327 | num_workers=num_workers, 328 | shuffle=shuffle) 329 | 330 | return data_loader 331 | -------------------------------------------------------------------------------- /lib/twodim_trainer.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import gc 3 | import logging 4 | 5 | import MinkowskiEngine as ME 6 | import numpy as np 7 | import sklearn.metrics as metrics 8 | import torch 9 | 10 | from lib.timer import AverageMeter, Timer 11 | from lib.trainer import Trainer 12 | from lib.util_2d import (batch_episym, compute_e_hat, 13 | compute_symmetric_epipolar_residual, compute_angular_error) 14 | from model import load_model 15 | 16 | 17 | class Loss(object): 18 | 19 | def __init__(self, config, device): 20 | self.config = config 21 | self.device = device 22 | self.loss_essential = config.oa_loss_essential 23 | self.loss_classif = config.oa_loss_classif 24 | self.geo_loss_margin = config.oa_geo_loss_margin 25 | self.loss_essential_init_iter = config.oa_loss_essential_init_iter 26 | self.use_balance_loss = config.use_balance_loss 27 | self.bce = torch.nn.BCEWithLogitsLoss() 28 | 29 | def run(self, step, data, logits, e_hats): 30 | labels, pts_virt = data['sinput_L'], data['virtPts'] 31 | 32 | loss_e = self.essential_loss(logits, e_hats, pts_virt) 33 | loss_c = self.classif_loss(labels, logits) 34 | loss = 0 35 | 36 | # Check global_step and add essential loss 37 | if self.loss_essential > 0 and step >= self.loss_essential_init_iter: 38 | loss += self.loss_essential * loss_e 39 | if self.loss_classif > 0: 40 | loss += self.loss_classif * loss_c 41 | 42 | return loss 43 | 44 | def essential_loss(self, logits, e_hat, pts_virt): 45 | e_hat = e_hat.to(self.device) 46 | pts_virt = pts_virt.to(self.device) 47 | 48 | p1 = pts_virt[:, :, :2] 49 | p2 = pts_virt[:, :, 2:] 50 | geod = batch_episym(p1, p2, e_hat) 51 | loss = torch.min(geod, self.geo_loss_margin * geod.new_ones(geod.shape)) 52 | loss = loss.mean() 53 | return loss 54 | 55 | def classif_loss(self, labels, logits): 56 | is_pos = labels.to(device=self.device, dtype=torch.bool) 57 | is_neg = ~is_pos 58 | is_pos = is_pos.to(logits.dtype) 59 | is_neg = is_neg.to(logits.dtype) 60 | 61 | if self.use_balance_loss: 62 | c = is_pos - is_neg 63 | loss = -torch.log(torch.sigmoid(c * logits) + np.finfo(float).eps.item()) 64 | num_pos = torch.relu(torch.sum(is_pos, dim=0) - 1.0) + 1.0 65 | num_neg = torch.relu(torch.sum(is_neg, dim=0) - 1.0) + 1.0 66 | loss_p = torch.sum(loss * is_pos, dim=0) 67 | loss_n = torch.sum(loss * is_neg, dim=0) 68 | loss = torch.mean(loss_p * 0.5 / num_pos + loss_n * 0.5 / num_neg) 69 | else: 70 | loss = self.bce(logits, is_pos) 71 | return loss 72 | 73 | 74 | class ImageCorrespondenceTrainer(Trainer): 75 | 76 | def __init__(self, config, data_loader, val_data_loader=None): 77 | self.is_netsc = 'NetSC' in config.inlier_model 78 | self.requires_e_hat = not self.is_netsc 79 | if 'PyramidIteration' in config.inlier_model: 80 | self.requires_e_hat = False 81 | Trainer.__init__(self, config, data_loader, val_data_loader) 82 | self.loss = Loss(config, self.device) 83 | 84 | def _initialize_model(self): 85 | config = self.config 86 | 87 | num_feats = 0 88 | if 'feats' in config.inlier_feature_type: 89 | num_feats += config.model_n_out * 2 90 | elif 'coords' in config.inlier_feature_type: 91 | num_feats += 4 92 | elif 'count' in config.inlier_feature_type: 93 | num_feats += 1 94 | elif 'ones' == config.inlier_feature_type: 95 | num_feats = 1 96 | 97 | Model = load_model(config.inlier_model) 98 | if self.is_netsc: 99 | model = Model( 100 | in_channels=num_feats, 101 | out_channels=1, 102 | clusters=config.oa_clusters, 103 | D=4, 104 | ) 105 | else: 106 | model = Model( 107 | num_feats, 108 | 1, 109 | bn_momentum=config.bn_momentum, 110 | conv1_kernel_size=config.inlier_conv1_kernel_size, 111 | normalize_feature=False, 112 | D=4) 113 | return model 114 | 115 | def forward(self, input_dict): 116 | reg_sinput = ME.SparseTensor( 117 | feats=input_dict['sinput_F'], 118 | coords=input_dict['sinput_C'], 119 | ).to(self.device) 120 | 121 | if self.requires_e_hat: 122 | logit = self.model(reg_sinput).F.squeeze() 123 | e_hat, _ = compute_e_hat(input_dict['xyz'], logit, input_dict['len_batch']) 124 | return ([logit], [e_hat]) 125 | else: 126 | return self.model(reg_sinput, input_dict) 127 | 128 | def _train_epoch(self, epoch): 129 | gc.collect() 130 | 131 | data_loader_iter = self.data_loader.__iter__() 132 | iter_size = self.config.iter_size 133 | 134 | loss_meter, prec_meter, recall_meter, f1_meter, ap_meter = AverageMeter( 135 | ), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() 136 | data_timer, inlier_timer, total_timer = Timer(), Timer(), Timer() 137 | 138 | tot_num_data = len(data_loader_iter) // iter_size 139 | if self.train_max_iter > 0: 140 | tot_num_data = min(self.train_max_iter, tot_num_data) 141 | start_iter = (epoch - 1) * (tot_num_data) 142 | 143 | self.model.train() 144 | for curr_iter in range(tot_num_data): 145 | self.optimizer.zero_grad() 146 | total_timer.tic() 147 | 148 | batch_loss = 0 149 | for iter_idx in range(iter_size): 150 | data_timer.tic() 151 | input_dict = data_loader_iter.next() 152 | data_timer.toc() 153 | 154 | # Feature extraction 155 | inlier_timer.tic() 156 | try: 157 | logits, e_hats = self.forward(input_dict) 158 | except RuntimeError: 159 | print("Runtime error") 160 | pass 161 | inlier_timer.toc() 162 | 163 | # Calculate loss 164 | loss = 0 165 | for logit, e_hat in zip(logits, e_hats): 166 | loss_i = self.loss.run(start_iter + curr_iter, input_dict, logit, e_hat) 167 | loss += loss_i 168 | loss = loss / iter_size 169 | loss.backward() 170 | 171 | # Accumulate metrics 172 | pred = np.hstack(torch.sigmoid(logits[-1].detach()).cpu().numpy()) 173 | target = np.hstack(input_dict['sinput_L'].numpy()) 174 | prec, recall, f1, _ = metrics.precision_recall_fscore_support( 175 | target, (pred > 0.5).astype(np.int), average='binary') 176 | ap = metrics.average_precision_score(target, pred) 177 | 178 | prec_meter.update(prec) 179 | recall_meter.update(recall) 180 | f1_meter.update(f1) 181 | ap_meter.update(ap) 182 | batch_loss += loss.item() 183 | 184 | total_timer.toc() 185 | self.optimizer.step() 186 | loss_meter.update(batch_loss) 187 | # Clear 188 | torch.cuda.empty_cache() 189 | 190 | if curr_iter % self.config.stat_freq == 0: 191 | # Use the current value to see how stochastic the metrics are 192 | stat = { 193 | 'prec': prec_meter.avg, 194 | 'recall': recall_meter.avg, 195 | 'f1': f1_meter.avg, 196 | 'ap': ap_meter.avg, 197 | 'loss': loss_meter.avg 198 | } 199 | for k, v in stat.items(): 200 | self.writer.add_scalar(f'train/{k}', v, start_iter + curr_iter) 201 | 202 | logging.info( 203 | ', '.join([f"Train Epoch: {epoch} [{curr_iter}/{tot_num_data}]"] + 204 | [f"{k.capitalize()}: {v:.4f}" for k, v in stat.items()] + [ 205 | f"Data time: {data_timer.avg:.4f}", 206 | f"Train time: {total_timer.avg - data_timer.avg:.4f}", 207 | f"Total time: {total_timer.avg:.4f}" 208 | ])) 209 | 210 | prec_meter.reset() 211 | recall_meter.reset() 212 | f1_meter.reset() 213 | ap_meter.reset() 214 | loss_meter.reset() 215 | total_timer.reset() 216 | data_timer.reset() 217 | 218 | def _valid_epoch(self): 219 | gc.collect() 220 | 221 | data_loader_iter = self.val_data_loader.__iter__() 222 | loss_meter, prec_meter, recall_meter, f1_meter, ap_meter = AverageMeter( 223 | ), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() 224 | ang_errs = [] 225 | ths = np.arange(7) * 5 226 | data_timer, inlier_timer, total_timer = Timer(), Timer(), Timer() 227 | 228 | tot_num_data = len(self.val_data_loader.dataset) 229 | if self.val_max_iter > 0: 230 | tot_num_data = min(self.val_max_iter, tot_num_data) 231 | 232 | self.model.eval() 233 | with torch.no_grad(): 234 | for curr_iter in range(tot_num_data): 235 | total_timer.tic() 236 | 237 | data_timer.tic() 238 | input_dict = data_loader_iter.next() 239 | data_timer.toc() 240 | 241 | # Feature extraction 242 | inlier_timer.tic() 243 | logits, e_hats = self.forward(input_dict) 244 | inlier_timer.toc() 245 | 246 | # Calculate loss 247 | loss = 0 248 | for i in range(len(logits)): 249 | loss_i = self.loss.run(curr_iter, input_dict, logits[i], e_hats[i]) 250 | loss += loss_i 251 | total_timer.toc() 252 | 253 | # Accumulate metrics 254 | pred = np.hstack(torch.sigmoid(logits[-1].detach()).cpu().numpy()) 255 | target = np.hstack(input_dict['sinput_L'].numpy()) 256 | prec, recall, f1, _ = metrics.precision_recall_fscore_support( 257 | target, (pred > 0.5).astype(np.int), average='binary') 258 | ap = metrics.average_precision_score(target, pred) 259 | 260 | prec_meter.update(prec) 261 | recall_meter.update(recall) 262 | f1_meter.update(f1) 263 | ap_meter.update(ap) 264 | loss_meter.update(loss.item()) 265 | 266 | # calcute angular error 267 | norm_coords = input_dict['norm_coords'] 268 | len_batch = input_dict['len_batch'] 269 | R = input_dict['R'].numpy() 270 | t = input_dict['t'].numpy() 271 | e_hat = e_hats[-1].cpu().numpy() 272 | cursor = 0 273 | for i, n in enumerate(len_batch): 274 | _pred = pred[cursor:cursor + n] 275 | err_q, err_t = compute_angular_error(R[i], t[i], e_hat[i].reshape(3, 3), 276 | norm_coords[i], _pred) 277 | ang_errs.append(np.maximum(err_q, err_t)) 278 | 279 | torch.cuda.empty_cache() 280 | 281 | if curr_iter % self.config.stat_freq == 0: 282 | hist, _ = np.histogram(ang_errs, ths) 283 | hist = hist.astype(np.float) / len(ang_errs) 284 | acc = np.cumsum(hist) 285 | stat = { 286 | 'prec': prec_meter.avg, 287 | 'recall': recall_meter.avg, 288 | 'f1': f1_meter.avg, 289 | 'ap': ap_meter.avg, 290 | 'mAP5': np.mean(acc[:1]), 291 | 'mAP20': np.mean(acc[:4]), 292 | 'loss': loss_meter.avg 293 | } 294 | logging.info( 295 | ', '.join([f"Validation [{curr_iter}/{tot_num_data}]"] + 296 | [f"{k.capitalize()}: {v:.4f}" for k, v in stat.items()] + [ 297 | f"Data time: {data_timer.avg:.4f}", 298 | f"Train time: {total_timer.avg - data_timer.avg:.4f}", 299 | f"Total time: {total_timer.avg:.4f}" 300 | ])) 301 | 302 | hist, _ = np.histogram(ang_errs, ths) 303 | hist = hist.astype(np.float) / len(ang_errs) 304 | acc = np.cumsum(hist) 305 | stat = { 306 | 'prec': prec_meter.avg, 307 | 'recall': recall_meter.avg, 308 | 'f1': f1_meter.avg, 309 | 'ap': ap_meter.avg, 310 | 'mAP5': np.mean(acc[:1]), 311 | 'mAP20': np.mean(acc[:4]), 312 | 'loss': loss_meter.avg 313 | } 314 | logging.info(', '.join([f"Validation"] + 315 | [f"{k.capitalize()}: {v:.4f}" for k, v in stat.items()])) 316 | 317 | return stat 318 | 319 | def test(self, test_loader): 320 | test_iter = test_loader.__iter__() 321 | logging.info(f"Evaluating on {test_loader.dataset.scene}") 322 | 323 | self.model.eval() 324 | targets, preds, residuals, err_qs, err_ts = [], [], [], [], [] 325 | with torch.no_grad(): 326 | for _ in range(len(test_iter)): 327 | input_dict = test_iter.next() 328 | 329 | logits, e_hats = self.forward(input_dict) 330 | logit = logits[-1].squeeze().cpu() 331 | e_hat = e_hats[-1].cpu().numpy() 332 | 333 | target = np.hstack(input_dict['sinput_L'].numpy()) 334 | pred = np.hstack(torch.sigmoid(logit)) 335 | norm_coords = np.hstack(input_dict['norm_coords']) 336 | R = np.hstack(input_dict['R']) 337 | t = np.hstack(input_dict['t']) 338 | 339 | residual = compute_symmetric_epipolar_residual( 340 | e_hat.reshape(3, 3).T, 341 | norm_coords[target.astype(bool), :2], 342 | norm_coords[target.astype(bool), 2:], 343 | ) 344 | 345 | err_q, err_t = compute_angular_error(R, t, e_hat.reshape(3, 3), norm_coords, 346 | pred) 347 | 348 | targets.append(target) 349 | preds.append(pred) 350 | residuals.append(residual) 351 | err_qs.append(err_q) 352 | err_ts.append(err_t) 353 | return targets, preds, residuals, err_qs, err_ts 354 | -------------------------------------------------------------------------------- /lib/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import copy 4 | import numpy as np 5 | import open3d as o3d 6 | import MinkowskiEngine as ME 7 | 8 | 9 | def read_txt(path): 10 | """Read txt file into lines. 11 | """ 12 | with open(path) as f: 13 | lines = f.readlines() 14 | lines = [x.strip() for x in lines] 15 | return lines 16 | 17 | 18 | def ensure_dir(path): 19 | if not os.path.exists(path): 20 | os.makedirs(path, mode=0o755) 21 | 22 | 23 | def paint_overlap_label(pcd, overlap): 24 | npcd = np.asarray(pcd.points).shape[0] 25 | 26 | for i in range(npcd): 27 | if overlap[i] >= 1: 28 | pcd.colors[i] = [1.0, 0.0, 0.0] 29 | return pcd 30 | 31 | 32 | def visualize_overlap_label(source, target, source_overlap, target_overlap, trans): 33 | source_temp = copy.deepcopy(source) 34 | target_temp = copy.deepcopy(target) 35 | source_temp.transform(trans) 36 | source_temp.paint_uniform_color([1, 0.706, 0]) 37 | target_temp.paint_uniform_color([0, 0.651, 0.929]) 38 | paint_overlap_label(source_temp, source_overlap) 39 | paint_overlap_label(target_temp, target_overlap) 40 | o3d.draw_geometries([source_temp, target_temp]) 41 | 42 | 43 | def extract_graph_features_from_batch(batch, features, i): 44 | g = torch.masked_select(features, (batch.batch == i).byte().unsqueeze(1).expand( 45 | -1, features.size(1))) 46 | g = g.view(-1, features.size(1)) 47 | return g 48 | 49 | 50 | def get_pointcloud_from_pytorch(batch, idx, R=None, T=None): 51 | p0 = o3d.PointCloud() 52 | pts = extract_graph_features_from_batch(batch, batch.x, 0).data.cpu().numpy() 53 | 54 | if R is not None: 55 | pts = (R.data.cpu().numpy() @ pts.T).T + T.data.cpu().numpy() 56 | 57 | p0.points = o3d.Vector3dVector(pts) 58 | 59 | return p0 60 | 61 | 62 | def R_to_quad(R): 63 | q = torch.zeros(4) 64 | 65 | q[0] = 0.5 * ((1 + R[0, 0] + R[1, 1] + R[2, 2]).sqrt()) 66 | q[1] = (R[2, 1] - R[1, 2]) / (4 * q[0]) 67 | q[2] = (R[0, 2] - R[2, 0]) / (4 * q[0]) 68 | q[3] = (R[1, 0] - R[0, 1]) / (4 * q[0]) 69 | 70 | return q 71 | 72 | 73 | def extract_features(model, 74 | xyz, 75 | rgb=None, 76 | normal=None, 77 | voxel_size=0.05, 78 | device=None, 79 | skip_check=False): 80 | ''' 81 | xyz is a N x 3 matrix 82 | rgb is a N x 3 matrix and all color must range from [0, 1] or None 83 | normal is a N x 3 matrix and all normal range from [-1, 1] or None 84 | 85 | if both rgb and normal are None, we use Nx1 one vector as an input 86 | 87 | if device is None, it tries to use gpu by default 88 | 89 | if skip_check is True, skip rigorous checks to speed up 90 | 91 | model = model.to(device) 92 | xyz, feats = extract_features(model, xyz) 93 | ''' 94 | 95 | if not skip_check: 96 | assert xyz.shape[1] == 3 97 | 98 | N = xyz.shape[0] 99 | if rgb is not None: 100 | assert N == len(rgb) 101 | assert rgb.shape[1] == 3 102 | if np.any(rgb > 1): 103 | raise ValueError('Invalid color. Color must range from [0, 1]') 104 | 105 | if normal is not None: 106 | assert N == len(normal) 107 | assert normal.shape[1] == 3 108 | if np.any(normal > 1): 109 | raise ValueError('Invalid normal. Normal must range from [-1, 1]') 110 | 111 | if device is None: 112 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 113 | 114 | feats = [] 115 | if rgb is not None: 116 | # [0, 1] 117 | feats.append(rgb - 0.5) 118 | 119 | if normal is not None: 120 | # [-1, 1] 121 | feats.append(normal / 2) 122 | 123 | if rgb is None and normal is None: 124 | feats.append(np.ones((len(xyz), 1))) 125 | 126 | feats = np.hstack(feats) 127 | 128 | # Voxelize xyz and feats 129 | coords = np.floor(xyz / voxel_size) 130 | inds = ME.utils.sparse_quantize(coords, return_index=True) 131 | coords = coords[inds] 132 | # Append the batch index 133 | coords = np.hstack([coords, np.zeros((len(coords), 1))]) 134 | return_coords = xyz[inds] 135 | 136 | feats = feats[inds] 137 | 138 | feats = torch.tensor(feats, dtype=torch.float32) 139 | coords = torch.tensor(coords, dtype=torch.int32) 140 | 141 | stensor = ME.SparseTensor(feats, coords=coords).to(device) 142 | 143 | return return_coords, model(stensor).F 144 | 145 | 146 | def concat_pos_pairs(pos_pairs, len_batch): 147 | cat_pos_pairs = [] 148 | start_inds = torch.zeros((1, 2)).long() 149 | assert len(pos_pairs) == len(len_batch) 150 | for pos_pair, lens in zip(pos_pairs, len_batch): 151 | cat_pos_pairs.append(pos_pair + start_inds) 152 | start_inds += torch.LongTensor(lens) 153 | return torch.cat(cat_pos_pairs, 0) 154 | 155 | 156 | def random_sample(arr, num_sample, fix=True): 157 | """Sample elements from array 158 | 159 | Args: 160 | arr (array): array to sample 161 | num_sample (int): maximum number of elements to sample 162 | 163 | Returns: 164 | array: sampled array 165 | 166 | """ 167 | # Fix seed 168 | if fix: 169 | np.random.seed(0) 170 | 171 | total = len(arr) 172 | num_sample = min(total, num_sample) 173 | idx = sorted(np.random.choice(range(total), num_sample, replace=False)) 174 | return np.asarray(arr)[idx] 175 | 176 | 177 | def count_parameters(model): 178 | return sum(p.numel() for p in model.parameters() if p.requires_grad) -------------------------------------------------------------------------------- /lib/util_2d.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from scipy.spatial.transform import Rotation 7 | 8 | from util.file import loadh5 9 | 10 | 11 | def serialize_calibration(path): 12 | """Load calibration file and serialize 13 | 14 | Args: 15 | path (str): path to calibration file 16 | 17 | Returns: 18 | array: serialized 1-d calibration array 19 | 20 | """ 21 | calib_dict = loadh5(path) 22 | 23 | calib_list = [] 24 | calibration_keys = ["K", "R", "T", "imsize"] 25 | 26 | # Flatten calibration data 27 | for _key in calibration_keys: 28 | calib_list += [calib_dict[_key].flatten()] 29 | 30 | calib_list += [np.linalg.inv(calib_dict["K"]).flatten()] 31 | 32 | # Serialize calibration data into 1-d array 33 | calib = np.concatenate(calib_list) 34 | return calib 35 | 36 | 37 | def parse_calibration(calib): 38 | """Parse serialiazed calibration 39 | 40 | Args: 41 | calib (np.ndarray): serialized calibration 42 | 43 | Returns: 44 | dict: parsed calibration 45 | 46 | """ 47 | 48 | parsed_calib = {} 49 | parsed_calib["K"] = calib[:9].reshape((3, 3)) 50 | parsed_calib["R"] = calib[9:18].reshape((3, 3)) 51 | parsed_calib["t"] = calib[18:21].reshape(3) 52 | parsed_calib["imsize"] = calib[21:23].reshape(2) 53 | parsed_calib["K_inv"] = calib[23:32].reshape((3, 3)) 54 | return parsed_calib 55 | 56 | 57 | def computeNN(desc0, desc1): 58 | desc0, desc1 = torch.from_numpy(desc0).cuda(), torch.from_numpy(desc1).cuda() 59 | d1 = (desc0**2).sum(1) 60 | d2 = (desc1**2).sum(1) 61 | distmat = (d1.unsqueeze(1) + d2.unsqueeze(0) - 62 | 2 * torch.matmul(desc0, desc1.transpose(0, 1))).sqrt() 63 | distVals, nnIdx1 = torch.topk(distmat, k=2, dim=1, largest=False) 64 | nnIdx1 = nnIdx1[:, 0] 65 | idx_sort = [np.arange(nnIdx1.shape[0]), nnIdx1.cpu().numpy()] 66 | return idx_sort 67 | 68 | 69 | def normalize_keypoint(kp, K, center=None): 70 | """Normalize keypoint coordinate 71 | 72 | Convert pixel image coordinates into normalized image coordinates 73 | 74 | Args: 75 | kp (array): list of keypoints 76 | K (array): intrinsic matrix 77 | center (array, optional): principal point offset, for LFGC dataset because intrinsic matrix doensn't include principal offset 78 | Returns: 79 | array: normalized keypoints as homogenous coordinates 80 | 81 | """ 82 | kp = kp.copy() 83 | if center is not None: 84 | kp -= center 85 | 86 | kp = get_homogeneous_coords(kp) 87 | K_inv = np.linalg.inv(K) 88 | kp = np.dot(kp, K_inv.T) 89 | 90 | return kp 91 | 92 | 93 | def build_extrinsic_matrix(R, t): 94 | """Build extrinsic matrix 95 | 96 | Args: 97 | R (array): Rotation matrix of shape (3,3) 98 | t (array): Translation vector of shape (3,) 99 | 100 | Returns: 101 | array: extrinsic matrix 102 | 103 | """ 104 | return np.vstack((np.hstack((R, t[:, None])), [0, 0, 0, 1])) 105 | 106 | 107 | def compute_essential_matrix(T0, T1): 108 | """Compute essential matrix 109 | 110 | Args: 111 | T0 (array): extrinsic matrix 112 | T1 (array): extrinsic matrix 113 | 114 | Returns: 115 | array: essential matrix 116 | 117 | """ 118 | 119 | dT = T1 @ np.linalg.inv(T0) 120 | dR = dT[:3, :3] 121 | dt = dT[:3, 3] 122 | 123 | skew = skew_symmetric(dt) 124 | return dR.T @ skew, dR, dt 125 | 126 | 127 | def skew_symmetric(t): 128 | """Compute skew symmetric matrix of vector t 129 | 130 | Args: 131 | t (np.ndarray): vector of shape (3,) 132 | 133 | Returns: 134 | M (np.ndarray): skew-symmetrix matrix of shape (3, 3) 135 | 136 | """ 137 | M = np.array([[0, -t[2], t[1]], [t[2], 0, -t[0]], [-t[1], t[0], 0]]) 138 | return M 139 | 140 | 141 | def get_homogeneous_coords(coords, D=2): 142 | """Convert coordinates to homogeneous coordinates 143 | 144 | Args: 145 | coords (array): coordinates 146 | D (int): dimension. default to 2 147 | 148 | Returns: 149 | array: homogeneous coordinates 150 | 151 | """ 152 | 153 | assert len(coords.shape) == 2, "coords should be 2D array" 154 | 155 | if coords.shape[1] == D + 1: 156 | return coords 157 | elif coords.shape[1] == D: 158 | ones = np.ones((coords.shape[0], 1)) 159 | return np.hstack((coords, ones)) 160 | else: 161 | raise ValueError("Invalid coordinate dimension") 162 | 163 | 164 | def compute_symmetric_epipolar_residual(E, coords0, coords1): 165 | """Compute symmetric epipolar residual 166 | 167 | Symmetric epipolar distance 168 | 169 | Args: 170 | E (np.ndarray): essential matrix 171 | coord0 (np.ndarray): homogenous coordinates 172 | coord1 (np.ndarray): homogenous coordinates 173 | 174 | Returns: 175 | array: residuals 176 | 177 | """ 178 | with warnings.catch_warnings(): 179 | warnings.simplefilter("error", category=RuntimeWarning) 180 | coords0 = get_homogeneous_coords(coords0) 181 | coords1 = get_homogeneous_coords(coords1) 182 | 183 | line_2 = np.dot(E.T, coords0.T) 184 | line_1 = np.dot(E, coords1.T) 185 | 186 | dd = np.sum(line_2.T * coords1, 1) 187 | dd = np.abs(dd) 188 | 189 | d = dd * (1.0 / np.sqrt(line_1[0, :]**2 + line_1[1, :]**2 + 1e-7) + 190 | 1.0 / np.sqrt(line_2[0, :]**2 + line_2[1, :]**2 + 1e-7)) 191 | 192 | return d 193 | 194 | 195 | def compute_e_hat(coords, logits, len_batch): 196 | e_hats = [] 197 | logits_ = [] 198 | residuals = [] 199 | start_idx = 0 200 | 201 | if isinstance(coords, np.ndarray): 202 | coords = torch.from_numpy(coords).float() 203 | 204 | coords = coords.to(logits.device) 205 | for npts in len_batch: 206 | end_idx = start_idx + npts 207 | coord = coords[start_idx:end_idx] 208 | logit = logits[start_idx:end_idx] 209 | e_hat = weighted_8points( 210 | coord.unsqueeze(0).transpose(2, 1), 211 | logit.unsqueeze(0), 212 | ) 213 | e_hats.append(e_hat) 214 | logits_.append(logit) 215 | residual = compute_symmetric_epipolar_residual( 216 | e_hat.reshape(3, 3).detach().cpu(), coord[:, :2].detach().cpu(), coord[:, 2:].detach().cpu()) 217 | residuals.append(torch.from_numpy(residual)) 218 | start_idx = end_idx 219 | return torch.stack(e_hats, dim=0), torch.cat(residuals, dim=0) 220 | 221 | 222 | def weighted_8points(coords, logits): 223 | # logits shape = (batch, num_point, 1) 224 | # coords shape = (batch, 4, num_point) 225 | w = torch.nn.functional.relu(torch.tanh(logits)) 226 | X = torch.stack([ 227 | coords[:, 2] * coords[:, 0], coords[:, 2] * coords[:, 1], coords[:, 2], 228 | coords[:, 3] * coords[:, 0], coords[:, 3] * coords[:, 1], coords[:, 3], 229 | coords[:, 0], coords[:, 1], 230 | torch.ones_like(coords[:, 0]) 231 | ], 232 | dim=1).transpose(2, 1) 233 | # wX shape = (batch, num_point, 9) 234 | wX = w.unsqueeze(-1) * X 235 | # XwX shape = (batch, 9, 9) 236 | XwX = torch.bmm(X.transpose(2, 1), wX) 237 | 238 | v = batch_symeig(XwX) 239 | # _, v = torch.symeig(XwX, eigenvectors=True) 240 | e = torch.reshape(v[:, :, 0], (logits.shape[0], 9)) 241 | e = e / torch.norm(e, dim=1, keepdim=True) 242 | return e 243 | 244 | 245 | def quaternion_from_rotation(R): 246 | return Rotation.from_matrix(R).as_quat() 247 | 248 | 249 | def compute_angular_error(R_gt, t_gt, E_hat, coords, scores): 250 | num_top = len(scores) // 10 251 | num_top = max(1, num_top) 252 | th = np.sort(scores)[::-1][num_top] 253 | mask = scores >= th 254 | 255 | coords = coords.astype(np.float64) 256 | p1_good = coords[mask, :2] 257 | p2_good = coords[mask, 2:] 258 | E_hat = E_hat.astype(p1_good.dtype) 259 | 260 | # decompose estimated essential matrix 261 | num_inlier, R, t, _ = cv2.recoverPose(E_hat, p1_good, p2_good) 262 | 263 | eps = np.finfo(float).eps 264 | 265 | # calculate rotation error 266 | q = quaternion_from_rotation(R) 267 | q = q / (np.linalg.norm(q) + eps) 268 | q_gt = quaternion_from_rotation(R_gt) 269 | q_gt = q_gt / (np.linalg.norm(q_gt) + eps) 270 | loss_q = np.maximum(eps, (1 - np.sum(q * q_gt)**2)) 271 | err_q = np.arccos(1 - 2 * loss_q) 272 | 273 | # calculate translation error 274 | t = t.flatten() 275 | t = t / (np.linalg.norm(t) + eps) 276 | t_gt = t_gt / (np.linalg.norm(t_gt) + eps) 277 | loss_t = np.maximum(eps, (1 - np.sum(t * t_gt)**2)) 278 | err_t = np.arccos(np.sqrt(1 - loss_t)) 279 | err_q = err_q * 180 / np.pi 280 | err_t = err_t * 180 / np.pi 281 | return err_q, err_t 282 | 283 | 284 | def batch_symeig(X): 285 | # it is much faster to run symeig on CPU 286 | X = X.cpu() 287 | b, d, _ = X.size() 288 | bv = X.new(b, d, d) 289 | for batch_idx in range(X.shape[0]): 290 | e, v = torch.symeig(X[batch_idx, :, :].squeeze(), True) 291 | bv[batch_idx, :, :] = v 292 | bv = bv.cuda() 293 | return bv 294 | 295 | 296 | def batch_episym(x1, x2, F): 297 | batch_size, num_pts = x1.shape[0], x1.shape[1] 298 | x1 = torch.cat([x1, x1.new_ones(batch_size, num_pts, 1)], 299 | dim=-1).reshape(batch_size, num_pts, 3, 1) 300 | x2 = torch.cat([x2, x2.new_ones(batch_size, num_pts, 1)], 301 | dim=-1).reshape(batch_size, num_pts, 3, 1) 302 | F = F.reshape(-1, 1, 3, 3).repeat(1, num_pts, 1, 1) 303 | x2Fx1 = torch.matmul(x2.transpose(2, 3), 304 | torch.matmul(F, x1)).reshape(batch_size, num_pts) 305 | Fx1 = torch.matmul(F, x1).reshape(batch_size, num_pts, 3) 306 | Ftx2 = torch.matmul(F.transpose(2, 3), x2).reshape(batch_size, num_pts, 3) 307 | ys = x2Fx1**2 * (1.0 / (Fx1[:, :, 0]**2 + Fx1[:, :, 1]**2 + 1e-15) + 1.0 / 308 | (Ftx2[:, :, 0]**2 + Ftx2[:, :, 1]**2 + 1e-15)) 309 | return ys 310 | 311 | 312 | def torch_skew_symmetric(v): 313 | zero = torch.zeros_like(v[:, 0]) 314 | M = torch.stack( 315 | [zero, -v[:, 2], v[:, 1], v[:, 2], zero, -v[:, 0], -v[:, 1], v[:, 0], zero], 316 | dim=1) 317 | 318 | return M 319 | -------------------------------------------------------------------------------- /lib/util_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.sampler import Sampler 3 | 4 | 5 | class InfSampler(Sampler): 6 | """Samples elements randomly, without replacement. 7 | 8 | Arguments: 9 | data_source (Dataset): dataset to sample from 10 | """ 11 | 12 | def __init__(self, data_source, shuffle=False): 13 | self.data_source = data_source 14 | self.shuffle = shuffle 15 | self.reset_permutation() 16 | 17 | def reset_permutation(self): 18 | perm = len(self.data_source) 19 | if self.shuffle: 20 | perm = torch.randperm(perm) 21 | self._perm = perm.tolist() 22 | 23 | def __iter__(self): 24 | return self 25 | 26 | def __next__(self): 27 | if len(self._perm) == 0: 28 | self.reset_permutation() 29 | return self._perm.pop() 30 | 31 | def __len__(self): 32 | return len(self.data_source) 33 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import model.simpleunet as simpleunets 3 | import model.resunet as resunets 4 | import model.resnetsc as resnetsc 5 | import model.pyramidnet as pyramids 6 | 7 | MODELS = [] 8 | 9 | 10 | def add_models(module): 11 | MODELS.extend([getattr(module, a) for a in dir(module) if 'Net' in a or 'MLP' in a]) 12 | 13 | 14 | add_models(simpleunets) 15 | add_models(resunets) 16 | add_models(resnetsc) 17 | add_models(pyramids) 18 | 19 | 20 | def load_model(name): 21 | '''Creates and returns an instance of the model given its class name. 22 | ''' 23 | # Find the model class from its name 24 | all_models = MODELS 25 | mdict = {model.__name__: model for model in all_models} 26 | if name not in mdict: 27 | logging.info(f'Invalid model index. You put {name}. Options are:') 28 | # Display a list of valid model names 29 | for model in all_models: 30 | logging.info('\t* {}'.format(model.__name__)) 31 | return None 32 | NetClass = mdict[name] 33 | 34 | return NetClass 35 | -------------------------------------------------------------------------------- /model/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import MinkowskiEngine as ME 4 | 5 | 6 | class AIN(torch.nn.Module): 7 | # Attentive Instance Normalization 8 | 9 | def __init__(self, num_feats): 10 | super(AIN, self).__init__() 11 | self.num_feats = num_feats 12 | self.local_linear = torch.nn.Linear(num_feats, 1) 13 | self.global_linear = torch.nn.Linear(num_feats, 1) 14 | 15 | def forward(self, x): 16 | feats = x.feats 17 | local_w = self.local_linear(feats) 18 | global_w = self.global_linear(feats) 19 | weight = torch.zeros_like(local_w) 20 | for row_idx in x.coords_man.get_row_indices_per_batch(x.coords_key): 21 | _local_w = local_w[row_idx] 22 | _local_w = torch.sigmoid(_local_w) 23 | _global_w = global_w[row_idx] 24 | _global_w = torch.softmax(_global_w, dim=0) 25 | weight[row_idx] = _local_w * _global_w 26 | 27 | # normalize weight 28 | weight = weight / torch.sum(torch.abs(weight)) 29 | mean = torch.sum(feats * weight, dim=0, keepdim=True) / torch.sum(weight) 30 | std = torch.sqrt(torch.sum(weight*(feats - mean).pow(2), dim=0, keepdim=True)) 31 | return ME.SparseTensor( 32 | feats=(feats - mean) / std, 33 | coords_key=x.coords_key, 34 | coords_manager=x.coords_man, 35 | ) 36 | 37 | 38 | def get_norm(norm_type, num_feats, bn_momentum=0.05, dimension=-1): 39 | if norm_type == 'BN': 40 | return ME.MinkowskiBatchNorm(num_feats, momentum=bn_momentum) 41 | elif norm_type == 'IN': 42 | # return ME.MinkowskiInstanceNorm(num_feats, dimension=dimension) 43 | return ME.MinkowskiInstanceNorm(num_feats) 44 | elif norm_type == 'INBN': 45 | return torch.nn.Sequential( 46 | ME.MinkowskiInstanceNorm(num_feats), 47 | ME.MinkowskiBatchNorm(num_feats, momentum=bn_momentum)) 48 | elif norm_type == 'AIN': 49 | return AIN(num_feats) 50 | elif norm_type == 'AINBN': 51 | return torch.nn.Sequential( 52 | AIN(num_feats), ME.MinkowskiBatchNorm(num_feats, momentum=bn_momentum)) 53 | else: 54 | raise ValueError(f'Type {norm_type}, not defined') 55 | 56 | 57 | def get_nonlinearity(non_type): 58 | if non_type == 'ReLU': 59 | return ME.MinkowskiReLU() 60 | elif non_type == 'ELU': 61 | # return ME.MinkowskiInstanceNorm(num_feats, dimension=dimension) 62 | return ME.MinkowskiELU() 63 | else: 64 | raise ValueError(f'Type {non_type}, not defined') 65 | 66 | 67 | def random_offsets(kernel_size, n_kernel, dimension): 68 | n = kernel_size**dimension 69 | offsets = np.random.choice(n, n_kernel, replace=False) 70 | offsets = np.unravel_index(offsets, [ 71 | kernel_size, 72 | ] * dimension) 73 | offsets = np.stack(offsets).T 74 | offsets = offsets - kernel_size // 2 75 | return offsets 76 | 77 | 78 | def conv(in_channels, 79 | out_channels, 80 | kernel_size, 81 | stride=1, 82 | dilation=1, 83 | has_bias=False, 84 | region_type=ME.RegionType.HYPERCUBE, 85 | num_kernels=-1, 86 | dimension=-1): 87 | assert dimension > 0, 'Dimension must be a positive integer' 88 | if num_kernels > 0: 89 | offsets = random_offsets(kernel_size, num_kernels, dimension) 90 | kernel_generator = ME.KernelGenerator( 91 | kernel_size, 92 | stride, 93 | dilation, 94 | region_type=ME.RegionType.CUSTOM, 95 | region_offsets=torch.IntTensor(offsets), 96 | dimension=dimension) 97 | else: 98 | kernel_generator = ME.KernelGenerator( 99 | kernel_size, stride, dilation, region_type=region_type, dimension=dimension) 100 | 101 | return ME.MinkowskiConvolution( 102 | in_channels=in_channels, 103 | out_channels=out_channels, 104 | kernel_size=kernel_size, 105 | stride=stride, 106 | dilation=dilation, 107 | has_bias=has_bias, 108 | kernel_generator=kernel_generator, 109 | dimension=dimension) 110 | 111 | 112 | def conv_tr(in_channels, 113 | out_channels, 114 | kernel_size, 115 | stride=1, 116 | dilation=1, 117 | has_bias=False, 118 | region_type=ME.RegionType.HYPERCUBE, 119 | num_kernels=-1, 120 | dimension=-1): 121 | assert dimension > 0, 'Dimension must be a positive integer' 122 | if num_kernels > 0: 123 | offsets = random_offsets(kernel_size, num_kernels, dimension) 124 | kernel_generator = ME.KernelGenerator( 125 | kernel_size, 126 | stride, 127 | dilation, 128 | is_transpose=True, 129 | region_type=ME.RegionType.CUSTOM, 130 | region_offsets=torch.IntTensor(offsets), 131 | dimension=dimension) 132 | else: 133 | kernel_generator = ME.KernelGenerator( 134 | kernel_size, 135 | stride, 136 | dilation, 137 | is_transpose=True, 138 | region_type=region_type, 139 | dimension=dimension) 140 | 141 | kernel_generator = ME.KernelGenerator( 142 | kernel_size, 143 | stride, 144 | dilation, 145 | is_transpose=True, 146 | region_type=region_type, 147 | dimension=dimension) 148 | 149 | return ME.MinkowskiConvolutionTranspose( 150 | in_channels=in_channels, 151 | out_channels=out_channels, 152 | kernel_size=kernel_size, 153 | stride=stride, 154 | dilation=dilation, 155 | has_bias=has_bias, 156 | kernel_generator=kernel_generator, 157 | dimension=dimension) 158 | 159 | 160 | def conv_norm_non(inc, 161 | outc, 162 | kernel_size, 163 | stride, 164 | dimension, 165 | bn_momentum=0.05, 166 | region_type=ME.RegionType.HYPERCUBE, 167 | norm_type='BN', 168 | nonlinearity='ELU'): 169 | return torch.nn.Sequential( 170 | conv( 171 | in_channels=inc, 172 | out_channels=outc, 173 | kernel_size=kernel_size, 174 | stride=stride, 175 | dilation=1, 176 | has_bias=False, 177 | region_type=region_type, 178 | dimension=dimension), 179 | get_norm(norm_type, outc, bn_momentum=bn_momentum, dimension=dimension), 180 | get_nonlinearity(nonlinearity)) 181 | -------------------------------------------------------------------------------- /model/pyramidnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import MinkowskiEngine as ME 5 | from model.common import get_norm, get_nonlinearity, conv, conv_tr, conv_norm_non 6 | 7 | from model.residual_block import get_block 8 | from model.resnetsc import SCBlock 9 | 10 | from lib.util_2d import compute_e_hat 11 | 12 | 13 | class PyramidModule(ME.MinkowskiNetwork): 14 | NONLINEARITY = 'ELU' 15 | NORM_TYPE = 'BN' 16 | REGION_TYPE = ME.RegionType.HYPERCUBE 17 | 18 | def __init__(self, 19 | inc, 20 | outc, 21 | inner_inc, 22 | inner_outc, 23 | inner_module=None, 24 | depth=1, 25 | bn_momentum=0.05, 26 | dimension=-1): 27 | ME.MinkowskiNetwork.__init__(self, dimension) 28 | self.depth = depth 29 | 30 | self.conv = nn.Sequential( 31 | conv_norm_non( 32 | inc, 33 | inner_inc, 34 | 3, 35 | 2, 36 | dimension, 37 | region_type=self.REGION_TYPE, 38 | norm_type=self.NORM_TYPE, 39 | nonlinearity=self.NONLINEARITY), *[ 40 | get_block( 41 | self.NORM_TYPE, 42 | inner_inc, 43 | inner_inc, 44 | bn_momentum=bn_momentum, 45 | region_type=self.REGION_TYPE, 46 | dimension=dimension) for d in range(depth) 47 | ]) 48 | self.inner_module = inner_module 49 | self.convtr = nn.Sequential( 50 | conv_tr( 51 | in_channels=inner_outc, 52 | out_channels=inner_outc, 53 | kernel_size=3, 54 | stride=2, 55 | dilation=1, 56 | has_bias=False, 57 | region_type=self.REGION_TYPE, 58 | dimension=dimension), 59 | get_norm( 60 | self.NORM_TYPE, inner_outc, bn_momentum=bn_momentum, dimension=dimension), 61 | get_nonlinearity(self.NONLINEARITY)) 62 | 63 | self.cat_conv = conv_norm_non( 64 | inner_outc + inc, 65 | outc, 66 | 1, 67 | 1, 68 | dimension, 69 | norm_type=self.NORM_TYPE, 70 | nonlinearity=self.NONLINEARITY) 71 | 72 | def forward(self, x): 73 | y = self.conv(x) 74 | if self.inner_module: 75 | y = self.inner_module(y) 76 | y = self.convtr(y) 77 | y = ME.cat(x, y) 78 | return self.cat_conv(y) 79 | 80 | 81 | class PyramidModuleINBN(PyramidModule): 82 | NORM_TYPE = 'INBN' 83 | 84 | 85 | class PyramidModuleAINBN(PyramidModule): 86 | NORM_TYPE = 'AINBN' 87 | 88 | 89 | class PyramidNet(ME.MinkowskiNetwork): 90 | NORM_TYPE = 'BN' 91 | NONLINEARITY = 'ELU' 92 | PYRAMID_MODULE = PyramidModule 93 | CHANNELS = [32, 64, 128, 128] 94 | TR_CHANNELS = [64, 128, 128, 128] 95 | DEPTHS = [1, 1, 1, 1] 96 | # None b1, b2, b3, btr3, btr2 97 | # 1 2 3 -3 -2 -1 98 | REGION_TYPE = ME.RegionType.HYPERCUBE 99 | 100 | # To use the model, must call initialize_coords before forward pass. 101 | # Once data is processed, call clear to reset the model before calling initialize_coords 102 | def __init__(self, 103 | in_channels=3, 104 | out_channels=32, 105 | bn_momentum=0.1, 106 | conv1_kernel_size=3, 107 | normalize_feature=False, 108 | D=3): 109 | ME.MinkowskiNetwork.__init__(self, D) 110 | self.conv1_kernel_size = conv1_kernel_size 111 | self.normalize_feature = normalize_feature 112 | 113 | self.initialize_network(in_channels, out_channels, bn_momentum, D) 114 | 115 | def initialize_network(self, in_channels, out_channels, bn_momentum, dimension): 116 | NORM_TYPE = self.NORM_TYPE 117 | NONLINEARITY = self.NONLINEARITY 118 | CHANNELS = self.CHANNELS 119 | TR_CHANNELS = self.TR_CHANNELS 120 | DEPTHS = self.DEPTHS 121 | REGION_TYPE = self.REGION_TYPE 122 | 123 | self.conv = conv_norm_non( 124 | in_channels, 125 | CHANNELS[0], 126 | kernel_size=self.conv1_kernel_size, 127 | stride=1, 128 | dimension=dimension, 129 | bn_momentum=bn_momentum, 130 | region_type=REGION_TYPE, 131 | norm_type=NORM_TYPE, 132 | nonlinearity=NONLINEARITY) 133 | 134 | pyramid = None 135 | for d in range(len(DEPTHS) - 1, 0, -1): 136 | pyramid = self.PYRAMID_MODULE( 137 | CHANNELS[d - 1], 138 | TR_CHANNELS[d - 1], 139 | CHANNELS[d], 140 | TR_CHANNELS[d], 141 | pyramid, 142 | DEPTHS[d], 143 | dimension=dimension) 144 | self.pyramid = pyramid 145 | self.final = nn.Sequential( 146 | conv_norm_non( 147 | TR_CHANNELS[0], 148 | TR_CHANNELS[0], 149 | kernel_size=3, 150 | stride=1, 151 | dimension=dimension), 152 | conv(TR_CHANNELS[0], out_channels, 1, 1, dimension=dimension)) 153 | 154 | def forward(self, x): 155 | out = self.conv(x) 156 | out = self.pyramid(out) 157 | out = self.final(out) 158 | 159 | if self.normalize_feature: 160 | return ME.SparseTensor( 161 | out.F / (torch.norm(out.F, p=2, dim=1, keepdim=True) + 1e-8), 162 | coords_key=out.coords_key, 163 | coords_manager=out.coords_man) 164 | else: 165 | return out 166 | 167 | 168 | class PyramidNet8(PyramidNet): 169 | CHANNELS = [32, 64, 128, 128, 192, 192, 256, 256] 170 | TR_CHANNELS = [64, 128, 128, 192, 192, 192, 256, 256] 171 | DEPTHS = [1, 1, 1, 1, 1, 1, 1, 1] 172 | 173 | 174 | class PyramidNet8INBN(PyramidNet8): 175 | NORM_TYPE = 'INBN' 176 | PYRAMID_MODULE = PyramidModuleINBN 177 | 178 | 179 | class PyramidNet8AINBN(PyramidNet8): 180 | NORM_TYPE = 'AINBN' 181 | PYRAMID_MODULE = PyramidModuleAINBN 182 | 183 | class PyramidNet8AINBNNoBlock(PyramidNet8): 184 | NORM_TYPE = 'AINBN' 185 | PYRAMID_MODULE = PyramidModuleAINBN 186 | DEPTHS = [0, 0, 0, 0, 0, 0, 0, 0] 187 | 188 | class PyramidNet8NoBlock(PyramidNet8): 189 | DEPTHS = [0, 0, 0, 0, 0, 0, 0, 0] 190 | 191 | 192 | class PyramidNet12(PyramidNet): 193 | CHANNELS = [32, 64, 128, 128, 128, 192, 192, 192, 256, 256, 256, 512] 194 | TR_CHANNELS = [64, 64, 128, 128, 128, 192, 192, 192, 256, 256, 256, 512] 195 | DEPTHS = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 196 | 197 | 198 | class PyramidNet12INBN(PyramidNet12): 199 | NORM_TYPE = 'INBN' 200 | PYRAMID_MODULE = PyramidModuleINBN 201 | 202 | 203 | class PyramidNet12AINBN(PyramidNet12): 204 | NORM_TYPE = 'AINBN' 205 | PYRAMID_MODULE = PyramidModuleAINBN 206 | 207 | class PyramidNet12AINBNNoBlock(PyramidNet12AINBN): 208 | DEPTHS = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 209 | 210 | # --sample_minimum_coords False 211 | class PyramidIterationNet(PyramidNet): 212 | PYRAMID_NET = PyramidNet8INBN 213 | 214 | # To use the model, must call initialize_coords before forward pass. 215 | # Once data is processed, call clear to reset the model before calling initialize_coords 216 | def __init__(self, 217 | in_channels=3, 218 | out_channels=32, 219 | bn_momentum=0.05, 220 | conv1_kernel_size=3, 221 | normalize_feature=False, 222 | D=4): 223 | ME.MinkowskiNetwork.__init__(self, D) 224 | self.conv1_kernel_size = conv1_kernel_size 225 | 226 | self.pyramid1 = self.PYRAMID_NET( 227 | in_channels, 228 | out_channels, 229 | bn_momentum, 230 | conv1_kernel_size, 231 | normalize_feature=False, 232 | D=D) 233 | self.pyramid2 = self.PYRAMID_NET( 234 | in_channels + 2, 235 | out_channels, 236 | bn_momentum, 237 | conv1_kernel_size, 238 | normalize_feature=False, 239 | D=D) 240 | 241 | def forward(self, x, data): 242 | xyz, len_batch = data['xyz'], data['len_batch'] 243 | 244 | out = self.pyramid1(x) 245 | logits = out.F.squeeze() 246 | e_hats, residuals = compute_e_hat(xyz, logits, len_batch) 247 | 248 | new_feat = torch.cat([ 249 | x.feats, 250 | residuals.detach().float().to(logits.device).unsqueeze(1), 251 | F.relu(torch.tanh(logits)).detach().unsqueeze(1) 252 | ], 253 | dim=1) 254 | x_iter = ME.SparseTensor( 255 | feats=new_feat, coords_key=x.coords_key, 256 | coords_manager=x.coords_man).to(logits.device) 257 | 258 | out_iter = self.pyramid2(x_iter) 259 | logits_iter = out_iter.F.squeeze() 260 | e_hats_iter, residuals_iter = compute_e_hat(xyz, logits_iter, len_batch) 261 | res_logits, res_e_hats = [logits, logits_iter], [e_hats, e_hats_iter] 262 | return res_logits, res_e_hats 263 | 264 | 265 | class PyramidIterationNetNoBlock(PyramidIterationNet): 266 | PYRAMID_NET = PyramidNet8NoBlock 267 | 268 | 269 | class PyramidNetSC(PyramidNet): 270 | CHANNELS = [32, 64, 128, 128, 192, 192, 256, 256] 271 | TR_CHANNELS = [128, 128, 128, 192, 192, 192, 256, 256] 272 | DEPTHS = [1, 1, 1, 1, 1, 1, 1, 1] 273 | NORM_TYPE = 'INBN' 274 | PYRAMID_MODULE = PyramidModuleINBN 275 | 276 | # To use the model, must call initialize_coords before forward pass. 277 | # Once data is processed, call clear to reset the model before calling initialize_coords 278 | def __init__(self, 279 | in_channels=3, 280 | out_channels=32, 281 | bn_momentum=0.05, 282 | conv1_kernel_size=3, 283 | depth=6, 284 | clusters=500, 285 | D=4): 286 | ME.MinkowskiNetwork.__init__(self, D) 287 | self.conv1_kernel_size = conv1_kernel_size 288 | 289 | self.initialize_network(in_channels, out_channels, bn_momentum, D) 290 | self.initialize_scblocks( 291 | in_channels + 2, out_channels, depth=depth, clusters=clusters, D=D) 292 | 293 | def initialize_scblocks(self, in_channels, out_channels, depth, clusters, D): 294 | self.scblock = SCBlock( 295 | in_channels, out_channels, depth=depth, clusters=clusters, D=D) 296 | 297 | def forward(self, x, data): 298 | xyz, len_batch = data['xyz'], data['len_batch'] 299 | 300 | out = self.conv(x) 301 | out = self.pyramid(out) 302 | out = self.final(out) 303 | 304 | logits = out.F.squeeze() 305 | e_hats, residuals = compute_e_hat(xyz, logits, len_batch) 306 | 307 | new_feat = torch.cat([ 308 | x.feats, 309 | residuals.detach().float().to(logits.device).unsqueeze(1), 310 | F.relu(torch.tanh(logits)).detach().unsqueeze(1) 311 | ], 312 | dim=1) 313 | x_iter = ME.SparseTensor( 314 | feats=new_feat, coords_key=x.coords_key, 315 | coords_manager=x.coords_man).to(logits.device) 316 | 317 | logits_iter, e_hats_iter, residuals_iter = self.scblock(x_iter, data) 318 | 319 | res_logits, res_e_hats = [logits, logits_iter], [e_hats, e_hats_iter] 320 | return res_logits, res_e_hats 321 | 322 | 323 | class PyramidNetSCNoBlock(PyramidNetSC): 324 | DEPTHS = [0, 0, 0, 0, 0, 0, 0, 0] 325 | -------------------------------------------------------------------------------- /model/residual_block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from model.common import get_norm, conv 4 | 5 | import MinkowskiEngine as ME 6 | import MinkowskiEngine.MinkowskiFunctional as MEF 7 | 8 | 9 | class BasicBlockBase(nn.Module): 10 | expansion = 1 11 | NORM_TYPE = 'BN' 12 | 13 | def __init__(self, 14 | inplanes, 15 | planes, 16 | stride=1, 17 | dilation=1, 18 | downsample=None, 19 | bn_momentum=0.1, 20 | region_type=ME.RegionType.HYPERCUBE, 21 | dimension=3): 22 | super(BasicBlockBase, self).__init__() 23 | 24 | self.conv1 = conv( 25 | inplanes, 26 | planes, 27 | kernel_size=3, 28 | stride=stride, 29 | region_type=region_type, 30 | dimension=dimension) 31 | self.norm1 = get_norm( 32 | self.NORM_TYPE, planes, bn_momentum=bn_momentum, dimension=dimension) 33 | self.conv2 = conv( 34 | planes, 35 | planes, 36 | kernel_size=3, 37 | stride=1, 38 | dilation=dilation, 39 | has_bias=False, 40 | region_type=region_type, 41 | dimension=dimension) 42 | self.norm2 = get_norm( 43 | self.NORM_TYPE, planes, bn_momentum=bn_momentum, dimension=dimension) 44 | self.downsample = downsample 45 | 46 | def forward(self, x): 47 | residual = x 48 | 49 | out = self.conv1(x) 50 | out = self.norm1(out) 51 | out = MEF.relu(out) 52 | 53 | out = self.conv2(out) 54 | out = self.norm2(out) 55 | 56 | if self.downsample is not None: 57 | residual = self.downsample(x) 58 | 59 | out += residual 60 | out = MEF.relu(out) 61 | 62 | return out 63 | 64 | 65 | class BasicBlockBN(BasicBlockBase): 66 | NORM_TYPE = 'BN' 67 | 68 | 69 | class BasicBlockIN(BasicBlockBase): 70 | NORM_TYPE = 'IN' 71 | 72 | 73 | class BasicBlockINBN(BasicBlockBase): 74 | expansion = 1 75 | 76 | def __init__(self, 77 | inplanes, 78 | planes, 79 | stride=1, 80 | dilation=1, 81 | downsample=None, 82 | bn_momentum=0.1, 83 | region_type=ME.RegionType.HYPERCUBE, 84 | dimension=3): 85 | super(BasicBlockBase, self).__init__() 86 | 87 | self.conv1 = conv( 88 | inplanes, 89 | planes, 90 | kernel_size=3, 91 | stride=stride, 92 | region_type=region_type, 93 | dimension=dimension) 94 | self.norm1in = get_norm('IN', planes, bn_momentum=bn_momentum, dimension=dimension) 95 | self.norm1bn = get_norm('BN', planes, bn_momentum=bn_momentum, dimension=dimension) 96 | self.conv2 = conv( 97 | planes, 98 | planes, 99 | kernel_size=3, 100 | stride=1, 101 | dilation=dilation, 102 | has_bias=False, 103 | region_type=region_type, 104 | dimension=dimension) 105 | self.norm2in = get_norm('IN', planes, bn_momentum=bn_momentum, dimension=dimension) 106 | self.norm2bn = get_norm('BN', planes, bn_momentum=bn_momentum, dimension=dimension) 107 | self.downsample = downsample 108 | 109 | def forward(self, x): 110 | residual = x 111 | 112 | out = self.conv1(x) 113 | out = self.norm1in(out) 114 | out = self.norm1bn(out) 115 | out = MEF.elu(out) 116 | 117 | out = self.conv2(out) 118 | out = self.norm2in(out) 119 | out = self.norm2bn(out) 120 | 121 | if self.downsample is not None: 122 | residual = self.downsample(x) 123 | 124 | out += residual 125 | out = MEF.elu(out) 126 | 127 | return out 128 | 129 | 130 | class BasicBlockAINBN(BasicBlockBase): 131 | NORM_TYPE = 'AINBN' 132 | 133 | 134 | def get_block(norm_type, 135 | inplanes, 136 | planes, 137 | stride=1, 138 | dilation=1, 139 | downsample=None, 140 | bn_momentum=0.1, 141 | region_type=ME.RegionType.HYPERCUBE, 142 | dimension=3): 143 | if norm_type == 'BN': 144 | return BasicBlockBN(inplanes, planes, stride, dilation, downsample, bn_momentum, 145 | region_type, dimension) 146 | elif norm_type == 'IN': 147 | return BasicBlockIN(inplanes, planes, stride, dilation, downsample, bn_momentum, 148 | region_type, dimension) 149 | elif norm_type == 'INBN': 150 | return BasicBlockINBN(inplanes, planes, stride, dilation, downsample, bn_momentum, 151 | region_type, dimension) 152 | elif norm_type == 'AINBN': 153 | return BasicBlockAINBN(inplanes, planes, stride, dilation, downsample, bn_momentum, 154 | region_type, dimension) 155 | else: 156 | raise ValueError(f'Type {norm_type}, not defined') 157 | -------------------------------------------------------------------------------- /model/resnetsc.py: -------------------------------------------------------------------------------- 1 | import MinkowskiEngine as ME 2 | import MinkowskiEngine.MinkowskiFunctional as MEF 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from baseline.model.oanet import OAFilter, diff_pool, diff_unpool 8 | from lib.util_2d import compute_e_hat 9 | from model.common import conv, conv_tr, get_norm 10 | 11 | 12 | class BasicBlock(ME.MinkowskiNetwork): 13 | 14 | def __init__(self, in_channels, out_channels=None, D=4, transpose=False, stride=1): 15 | ME.MinkowskiNetwork.__init__(self, D) 16 | if not out_channels: 17 | out_channels = in_channels 18 | self.shot_cut = None 19 | if out_channels != in_channels: 20 | self.shot_cut = conv( 21 | in_channels=in_channels, 22 | out_channels=out_channels, 23 | kernel_size=1, 24 | dimension=D) 25 | if transpose: 26 | self.conv = nn.Sequential( 27 | get_norm('IN', in_channels, bn_momentum=0.1, dimension=D), 28 | get_norm('BN', in_channels, bn_momentum=0.1, dimension=D), 29 | conv_tr(in_channels, out_channels, kernel_size=3, stride=stride, dimension=D), 30 | get_norm('IN', out_channels, bn_momentum=0.1, dimension=D), 31 | get_norm('BN', out_channels, bn_momentum=0.1, dimension=D), 32 | ME.MinkowskiReLU(), 33 | conv_tr( 34 | out_channels, out_channels, kernel_size=3, stride=stride, dimension=D)) 35 | else: 36 | self.conv = nn.Sequential( 37 | get_norm('IN', in_channels, bn_momentum=0.1, dimension=D), 38 | get_norm('BN', in_channels, bn_momentum=0.1, dimension=D), 39 | conv(in_channels, out_channels, kernel_size=3, stride=stride, dimension=D), 40 | get_norm('IN', out_channels, bn_momentum=0.1, dimension=D), 41 | get_norm('BN', out_channels, bn_momentum=0.1, dimension=D), 42 | ME.MinkowskiReLU(), 43 | conv(out_channels, out_channels, kernel_size=3, stride=stride, dimension=D)) 44 | 45 | def forward(self, x): 46 | out = self.conv(x) 47 | if self.shot_cut: 48 | out = out + self.shot_cut(x) 49 | else: 50 | out = out + x 51 | return out 52 | 53 | 54 | class DiffPool(diff_pool): 55 | 56 | def forward(self, x, len_batch): 57 | num_points = len_batch[0] 58 | batch_size = len(len_batch) 59 | assert len_batch.count( 60 | num_points) == batch_size, f'batch contains different numbers of coordinates' 61 | # x: (n, c), n = m*b 62 | input = x.reshape(batch_size, num_points, -1) # input: (b,m,c) 63 | input = input.transpose(1, 2).unsqueeze(-1) # input: (b,c,m,1) 64 | embed = self.conv(input) # embed: (b,k,m,1) 65 | S = torch.softmax(embed, dim=2).squeeze(3) # (b,k,m) 66 | out = torch.matmul(input.squeeze(3), S.transpose(1, 2)).unsqueeze(3) 67 | return out 68 | 69 | 70 | class DiffUnpool(diff_unpool): 71 | 72 | def forward(self, x_up, x_down, len_batch): 73 | num_points = len_batch[0] 74 | batch_size = len(len_batch) 75 | assert len_batch.count( 76 | num_points) == batch_size, f'batch contains different numbers of coordinates' 77 | 78 | input = x_up.reshape(batch_size, num_points, -1) # input: (b,m,c) 79 | input = input.transpose(1, 2).unsqueeze(-1) # input: (b,c,m,1) 80 | embed = self.conv(input) # embed: (b,k,m,1) 81 | S = torch.softmax(embed, dim=1).squeeze(3) # (b,k,m) 82 | out = torch.matmul(x_down.squeeze(3), S) # (b,c,k) * (b,k,m) => (b,c,m) 83 | num_channel = out.shape[1] 84 | out = out.transpose(1, 2).reshape(-1, num_channel) 85 | return out 86 | 87 | 88 | class SCBlock(ME.MinkowskiNetwork): 89 | """Spatial Correlation Block""" 90 | NET_CHANNEL = 128 91 | 92 | def __init__(self, in_channels, out_channels, depth, clusters, D=4): 93 | ME.MinkowskiNetwork.__init__(self, D) 94 | self.depth = depth 95 | self.clusters = clusters 96 | net_channels = self.NET_CHANNEL 97 | 98 | self.conv1 = conv(in_channels, net_channels, kernel_size=1, dimension=D) 99 | 100 | self.l1_1 = [] 101 | for _ in range(depth // 2): 102 | self.l1_1.append(BasicBlock(in_channels=net_channels, D=D)) 103 | 104 | self.down1 = DiffPool(net_channels, clusters) 105 | 106 | self.l2 = [] 107 | for _ in range(depth // 2): 108 | self.l2.append(OAFilter(net_channels, clusters)) 109 | 110 | self.up1 = DiffUnpool(net_channels, clusters) 111 | 112 | self.l1_2 = [] 113 | self.l1_2.append(BasicBlock(2 * net_channels, net_channels, D=D, transpose=True)) 114 | for _ in range(depth // 2 - 1): 115 | self.l1_2.append(BasicBlock(net_channels, net_channels, D=D, transpose=True)) 116 | 117 | self.l1_1 = nn.Sequential(*self.l1_1) 118 | self.l1_2 = nn.Sequential(*self.l1_2) 119 | self.l2 = nn.Sequential(*self.l2) 120 | self.output = conv(net_channels, out_channels, kernel_size=1, dimension=D) 121 | 122 | def weight_initialization(self): 123 | for m in self.modules(): 124 | if isinstance(m, ME.MinkowskiConvolution): 125 | ME.utils.kaiming_normal_(m.kernel, mode='fan_out', nonlinearity='relu') 126 | 127 | if isinstance(m, ME.MinkowskiBatchNorm): 128 | nn.init.constant_(m.bn.weight, 1) 129 | nn.init.constant_(m.bn.bias, 0) 130 | 131 | def forward(self, x, data): 132 | xyz, len_batch = data['xyz'], data['len_batch'] 133 | x1_1 = self.conv1(x) 134 | x1_1 = self.l1_1(x1_1) 135 | x1_1_F = x1_1.F 136 | 137 | x_down = self.down1(x1_1_F, len_batch) 138 | x2 = self.l2(x_down) 139 | x_up = self.up1(x1_1_F, x2, len_batch) 140 | x_up = ME.SparseTensor( 141 | x_up, 142 | coords_key=x1_1.coords_key, 143 | coords_manager=x1_1.coords_man, 144 | ) 145 | x = ME.cat(x1_1, x_up) 146 | out = self.l1_2(x) 147 | out = self.output(out) 148 | 149 | logits = out.F.squeeze() 150 | e_hats, residuals = compute_e_hat(xyz, logits, len_batch) 151 | 152 | return logits, e_hats, residuals 153 | 154 | 155 | class ResNetSC(ME.MinkowskiNetwork): 156 | BLOCK = SCBlock 157 | 158 | def __init__(self, in_channels, out_channels=None, clusters=None, D=4): 159 | ME.MinkowskiNetwork.__init__(self, D) 160 | 161 | self.iter_num = 1 162 | self.depth = 6 163 | self.clusters = clusters 164 | self.weight_init = self.BLOCK( 165 | in_channels, 166 | out_channels, 167 | self.depth, 168 | self.clusters, 169 | D, 170 | ) 171 | self.weight_iter = [ 172 | self.BLOCK( 173 | in_channels + 2, 174 | out_channels, 175 | self.depth, 176 | self.clusters, 177 | D, 178 | ) for _ in range(self.iter_num) 179 | ] 180 | self.weight_iter = nn.Sequential(*self.weight_iter) 181 | 182 | def forward(self, x, data): 183 | res_logits, res_e_hats = [], [] 184 | 185 | logits, e_hat, residual = self.weight_init(x, data) 186 | res_logits.append(logits) 187 | res_e_hats.append(e_hat) 188 | 189 | for i in range(self.iter_num): 190 | new_feat = torch.cat([ 191 | x.feats, 192 | residual.detach().float().to(logits.device).unsqueeze(1), 193 | F.relu(torch.tanh(logits)).detach().unsqueeze(1) 194 | ], 195 | dim=1) 196 | new_tensor = ME.SparseTensor( 197 | new_feat, coords_key=x.coords_key, 198 | coords_manager=x.coords_man).to(logits.device) 199 | logits, e_hat, residual = self.weight_iter[i](new_tensor, data) 200 | res_logits.append(logits) 201 | res_e_hats.append(e_hat) 202 | 203 | return res_logits, res_e_hats 204 | -------------------------------------------------------------------------------- /model/simpleunet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import MinkowskiEngine as ME 3 | import MinkowskiEngine.MinkowskiFunctional as MEF 4 | from model.common import get_norm 5 | 6 | 7 | class SimpleNet(ME.MinkowskiNetwork): 8 | NORM_TYPE = None 9 | CHANNELS = [None, 32, 64, 128] 10 | TR_CHANNELS = [None, 32, 32, 64] 11 | 12 | # To use the model, must call initialize_coords before forward pass. 13 | # Once data is processed, call clear to reset the model before calling initialize_coords 14 | def __init__(self, 15 | in_channels=3, 16 | out_channels=32, 17 | bn_momentum=0.1, 18 | conv1_kernel_size=3, 19 | normalize_feature=False, 20 | D=3): 21 | super(SimpleNet, self).__init__(D) 22 | NORM_TYPE = self.NORM_TYPE 23 | CHANNELS = self.CHANNELS 24 | TR_CHANNELS = self.TR_CHANNELS 25 | self.normalize_feature = normalize_feature 26 | self.conv1 = ME.MinkowskiConvolution( 27 | in_channels=in_channels, 28 | out_channels=CHANNELS[1], 29 | kernel_size=conv1_kernel_size, 30 | stride=1, 31 | dilation=1, 32 | has_bias=False, 33 | dimension=D) 34 | self.norm1 = get_norm(NORM_TYPE, CHANNELS[1], bn_momentum=bn_momentum, dimension=D) 35 | 36 | self.conv2 = ME.MinkowskiConvolution( 37 | in_channels=CHANNELS[1], 38 | out_channels=CHANNELS[2], 39 | kernel_size=3, 40 | stride=2, 41 | dilation=1, 42 | has_bias=False, 43 | dimension=D) 44 | self.norm2 = get_norm(NORM_TYPE, CHANNELS[2], bn_momentum=bn_momentum, dimension=D) 45 | 46 | self.conv3 = ME.MinkowskiConvolution( 47 | in_channels=CHANNELS[2], 48 | out_channels=CHANNELS[3], 49 | kernel_size=3, 50 | stride=2, 51 | dilation=1, 52 | has_bias=False, 53 | dimension=D) 54 | self.norm3 = get_norm(NORM_TYPE, CHANNELS[3], bn_momentum=bn_momentum, dimension=D) 55 | 56 | self.conv3_tr = ME.MinkowskiConvolutionTranspose( 57 | in_channels=CHANNELS[3], 58 | out_channels=TR_CHANNELS[3], 59 | kernel_size=3, 60 | stride=2, 61 | dilation=1, 62 | has_bias=False, 63 | dimension=D) 64 | self.norm3_tr = get_norm( 65 | NORM_TYPE, TR_CHANNELS[3], bn_momentum=bn_momentum, dimension=D) 66 | 67 | self.conv2_tr = ME.MinkowskiConvolutionTranspose( 68 | in_channels=CHANNELS[2] + TR_CHANNELS[3], 69 | out_channels=TR_CHANNELS[2], 70 | kernel_size=3, 71 | stride=2, 72 | dilation=1, 73 | has_bias=False, 74 | dimension=D) 75 | self.norm2_tr = get_norm( 76 | NORM_TYPE, TR_CHANNELS[2], bn_momentum=bn_momentum, dimension=D) 77 | 78 | self.conv1_tr = ME.MinkowskiConvolution( 79 | in_channels=CHANNELS[1] + TR_CHANNELS[2], 80 | out_channels=TR_CHANNELS[1], 81 | kernel_size=3, 82 | stride=1, 83 | dilation=1, 84 | has_bias=False, 85 | dimension=D) 86 | self.norm1_tr = get_norm( 87 | NORM_TYPE, TR_CHANNELS[1], bn_momentum=bn_momentum, dimension=D) 88 | 89 | self.final = ME.MinkowskiConvolution( 90 | in_channels=TR_CHANNELS[1], 91 | out_channels=out_channels, 92 | kernel_size=1, 93 | stride=1, 94 | dilation=1, 95 | has_bias=True, 96 | dimension=D) 97 | 98 | def forward(self, x): 99 | out_s1 = self.conv1(x) 100 | out_s1 = self.norm1(out_s1) 101 | out = MEF.relu(out_s1) 102 | 103 | out_s2 = self.conv2(out) 104 | out_s2 = self.norm2(out_s2) 105 | out = MEF.relu(out_s2) 106 | 107 | out_s4 = self.conv3(out) 108 | out_s4 = self.norm3(out_s4) 109 | out = MEF.relu(out_s4) 110 | 111 | out = self.conv3_tr(out) 112 | out = self.norm3_tr(out) 113 | out_s2_tr = MEF.relu(out) 114 | 115 | out = ME.cat((out_s2_tr, out_s2)) 116 | 117 | out = self.conv2_tr(out) 118 | out = self.norm2_tr(out) 119 | out_s1_tr = MEF.relu(out) 120 | 121 | out = ME.cat((out_s1_tr, out_s1)) 122 | out = self.conv1_tr(out) 123 | out = self.norm1_tr(out) 124 | out = MEF.relu(out) 125 | 126 | out = self.final(out) 127 | 128 | if self.normalize_feature: 129 | return ME.SparseTensor( 130 | out.F / torch.norm(out.F, p=2, dim=1, keepdim=True), 131 | coords_key=out.coords_key, 132 | coords_manager=out.coords_man) 133 | else: 134 | return out 135 | 136 | 137 | class SimpleNetIN(SimpleNet): 138 | NORM_TYPE = 'IN' 139 | 140 | 141 | class SimpleNetBN(SimpleNet): 142 | NORM_TYPE = 'BN' 143 | 144 | 145 | class SimpleNetBNE(SimpleNetBN): 146 | CHANNELS = [None, 16, 32, 32] 147 | TR_CHANNELS = [None, 16, 16, 32] 148 | 149 | 150 | class SimpleNetINE(SimpleNetBNE): 151 | NORM_TYPE = 'IN' 152 | 153 | 154 | class SimpleNet2(ME.MinkowskiNetwork): 155 | NORM_TYPE = None 156 | CHANNELS = [None, 32, 64, 128, 256] 157 | TR_CHANNELS = [None, 32, 32, 64, 64] 158 | 159 | # To use the model, must call initialize_coords before forward pass. 160 | # Once data is processed, call clear to reset the model before calling initialize_coords 161 | def __init__(self, 162 | in_channels=3, 163 | out_channels=32, 164 | bn_momentum=0.1, 165 | conv1_kernel_size=3, 166 | normalize_feature=False, 167 | D=3): 168 | ME.MinkowskiNetwork.__init__(self, D) 169 | NORM_TYPE = self.NORM_TYPE 170 | CHANNELS = self.CHANNELS 171 | TR_CHANNELS = self.TR_CHANNELS 172 | self.normalize_feature = normalize_feature 173 | self.conv1 = ME.MinkowskiConvolution( 174 | in_channels=in_channels, 175 | out_channels=CHANNELS[1], 176 | kernel_size=conv1_kernel_size, 177 | stride=1, 178 | dilation=1, 179 | has_bias=False, 180 | dimension=D) 181 | self.norm1 = get_norm(NORM_TYPE, CHANNELS[1], bn_momentum=bn_momentum, dimension=D) 182 | 183 | self.conv2 = ME.MinkowskiConvolution( 184 | in_channels=CHANNELS[1], 185 | out_channels=CHANNELS[2], 186 | kernel_size=3, 187 | stride=2, 188 | dilation=1, 189 | has_bias=False, 190 | dimension=D) 191 | self.norm2 = get_norm(NORM_TYPE, CHANNELS[2], bn_momentum=bn_momentum, dimension=D) 192 | 193 | self.conv3 = ME.MinkowskiConvolution( 194 | in_channels=CHANNELS[2], 195 | out_channels=CHANNELS[3], 196 | kernel_size=3, 197 | stride=2, 198 | dilation=1, 199 | has_bias=False, 200 | dimension=D) 201 | self.norm3 = get_norm(NORM_TYPE, CHANNELS[3], bn_momentum=bn_momentum, dimension=D) 202 | 203 | self.conv4 = ME.MinkowskiConvolution( 204 | in_channels=CHANNELS[3], 205 | out_channels=CHANNELS[4], 206 | kernel_size=3, 207 | stride=2, 208 | dilation=1, 209 | has_bias=False, 210 | dimension=D) 211 | self.norm4 = get_norm(NORM_TYPE, CHANNELS[4], bn_momentum=bn_momentum, dimension=D) 212 | 213 | self.conv4_tr = ME.MinkowskiConvolutionTranspose( 214 | in_channels=CHANNELS[4], 215 | out_channels=TR_CHANNELS[4], 216 | kernel_size=3, 217 | stride=2, 218 | dilation=1, 219 | has_bias=False, 220 | dimension=D) 221 | self.norm4_tr = get_norm(NORM_TYPE, TR_CHANNELS[4], bn_momentum=bn_momentum, dimension=D) 222 | 223 | self.conv3_tr = ME.MinkowskiConvolutionTranspose( 224 | in_channels=CHANNELS[3] + TR_CHANNELS[4], 225 | out_channels=TR_CHANNELS[3], 226 | kernel_size=3, 227 | stride=2, 228 | dilation=1, 229 | has_bias=False, 230 | dimension=D) 231 | self.norm3_tr = get_norm(NORM_TYPE, TR_CHANNELS[3], bn_momentum=bn_momentum, dimension=D) 232 | 233 | self.conv2_tr = ME.MinkowskiConvolutionTranspose( 234 | in_channels=CHANNELS[2] + TR_CHANNELS[3], 235 | out_channels=TR_CHANNELS[2], 236 | kernel_size=3, 237 | stride=2, 238 | dilation=1, 239 | has_bias=False, 240 | dimension=D) 241 | self.norm2_tr = get_norm(NORM_TYPE, TR_CHANNELS[2], bn_momentum=bn_momentum, dimension=D) 242 | 243 | self.conv1_tr = ME.MinkowskiConvolution( 244 | in_channels=CHANNELS[1] + TR_CHANNELS[2], 245 | out_channels=TR_CHANNELS[1], 246 | kernel_size=3, 247 | stride=1, 248 | dilation=1, 249 | has_bias=False, 250 | dimension=D) 251 | self.norm1_tr = get_norm(NORM_TYPE, TR_CHANNELS[1], bn_momentum=bn_momentum, dimension=D) 252 | 253 | self.final = ME.MinkowskiConvolution( 254 | in_channels=TR_CHANNELS[1], 255 | out_channels=out_channels, 256 | kernel_size=1, 257 | stride=1, 258 | dilation=1, 259 | has_bias=True, 260 | dimension=D) 261 | 262 | def forward(self, x): 263 | out_s1 = self.conv1(x) 264 | out_s1 = self.norm1(out_s1) 265 | out = MEF.relu(out_s1) 266 | 267 | out_s2 = self.conv2(out) 268 | out_s2 = self.norm2(out_s2) 269 | out = MEF.relu(out_s2) 270 | 271 | out_s4 = self.conv3(out) 272 | out_s4 = self.norm3(out_s4) 273 | out = MEF.relu(out_s4) 274 | 275 | out_s8 = self.conv4(out) 276 | out_s8 = self.norm4(out_s8) 277 | out = MEF.relu(out_s8) 278 | 279 | out = self.conv4_tr(out) 280 | out = self.norm4_tr(out) 281 | out_s4_tr = MEF.relu(out) 282 | 283 | out = ME.cat((out_s4_tr, out_s4)) 284 | 285 | out = self.conv3_tr(out) 286 | out = self.norm3_tr(out) 287 | out_s2_tr = MEF.relu(out) 288 | 289 | out = ME.cat((out_s2_tr, out_s2)) 290 | 291 | out = self.conv2_tr(out) 292 | out = self.norm2_tr(out) 293 | out_s1_tr = MEF.relu(out) 294 | 295 | out = ME.cat((out_s1_tr, out_s1)) 296 | out = self.conv1_tr(out) 297 | out = self.norm1_tr(out) 298 | out = MEF.relu(out) 299 | 300 | out = self.final(out) 301 | 302 | if self.normalize_feature: 303 | return ME.SparseTensor( 304 | out.F / torch.norm(out.F, p=2, dim=1, keepdim=True), 305 | coords_key=out.coords_key, 306 | coords_manager=out.coords_man) 307 | else: 308 | return out 309 | 310 | 311 | class SimpleNetIN2(SimpleNet2): 312 | NORM_TYPE = 'IN' 313 | 314 | 315 | class SimpleNetBN2(SimpleNet2): 316 | NORM_TYPE = 'BN' 317 | 318 | 319 | class SimpleNetBN2B(SimpleNet2): 320 | NORM_TYPE = 'BN' 321 | CHANNELS = [None, 32, 64, 128, 256] 322 | TR_CHANNELS = [None, 64, 64, 64, 64] 323 | 324 | 325 | class SimpleNetBN2C(SimpleNet2): 326 | NORM_TYPE = 'BN' 327 | CHANNELS = [None, 32, 64, 128, 256] 328 | TR_CHANNELS = [None, 32, 64, 64, 128] 329 | 330 | 331 | class SimpleNetBN2D(SimpleNet2): 332 | NORM_TYPE = 'BN' 333 | CHANNELS = [None, 32, 64, 128, 256] 334 | TR_CHANNELS = [None, 32, 64, 64, 128] 335 | 336 | 337 | class SimpleNetBN2E(SimpleNet2): 338 | NORM_TYPE = 'BN' 339 | CHANNELS = [None, 16, 32, 64, 128] 340 | TR_CHANNELS = [None, 16, 32, 32, 64] 341 | 342 | 343 | class SimpleNetIN2E(SimpleNetBN2E): 344 | NORM_TYPE = 'IN' 345 | 346 | 347 | class SimpleNet3(ME.MinkowskiNetwork): 348 | NORM_TYPE = None 349 | CHANNELS = [None, 32, 64, 128, 256, 512] 350 | TR_CHANNELS = [None, 32, 32, 64, 64, 128] 351 | 352 | # To use the model, must call initialize_coords before forward pass. 353 | # Once data is processed, call clear to reset the model before calling initialize_coords 354 | def __init__(self, 355 | in_channels=3, 356 | out_channels=32, 357 | bn_momentum=0.1, 358 | conv1_kernel_size=3, 359 | normalize_feature=False, 360 | D=3): 361 | ME.MinkowskiNetwork.__init__(self, D) 362 | NORM_TYPE = self.NORM_TYPE 363 | CHANNELS = self.CHANNELS 364 | TR_CHANNELS = self.TR_CHANNELS 365 | self.normalize_feature = normalize_feature 366 | self.conv1 = ME.MinkowskiConvolution( 367 | in_channels=in_channels, 368 | out_channels=CHANNELS[1], 369 | kernel_size=conv1_kernel_size, 370 | stride=1, 371 | dilation=1, 372 | has_bias=False, 373 | dimension=D) 374 | self.norm1 = get_norm(NORM_TYPE, CHANNELS[1], bn_momentum=bn_momentum, dimension=D) 375 | 376 | self.conv2 = ME.MinkowskiConvolution( 377 | in_channels=CHANNELS[1], 378 | out_channels=CHANNELS[2], 379 | kernel_size=3, 380 | stride=2, 381 | dilation=1, 382 | has_bias=False, 383 | dimension=D) 384 | self.norm2 = get_norm(NORM_TYPE, CHANNELS[2], bn_momentum=bn_momentum, dimension=D) 385 | 386 | self.conv3 = ME.MinkowskiConvolution( 387 | in_channels=CHANNELS[2], 388 | out_channels=CHANNELS[3], 389 | kernel_size=3, 390 | stride=2, 391 | dilation=1, 392 | has_bias=False, 393 | dimension=D) 394 | self.norm3 = get_norm(NORM_TYPE, CHANNELS[3], bn_momentum=bn_momentum, dimension=D) 395 | 396 | self.conv4 = ME.MinkowskiConvolution( 397 | in_channels=CHANNELS[3], 398 | out_channels=CHANNELS[4], 399 | kernel_size=3, 400 | stride=2, 401 | dilation=1, 402 | has_bias=False, 403 | dimension=D) 404 | self.norm4 = get_norm(NORM_TYPE, CHANNELS[4], bn_momentum=bn_momentum, dimension=D) 405 | 406 | self.conv5 = ME.MinkowskiConvolution( 407 | in_channels=CHANNELS[4], 408 | out_channels=CHANNELS[5], 409 | kernel_size=3, 410 | stride=2, 411 | dilation=1, 412 | has_bias=False, 413 | dimension=D) 414 | self.norm5 = get_norm(NORM_TYPE, CHANNELS[5], bn_momentum=bn_momentum, dimension=D) 415 | 416 | self.conv5_tr = ME.MinkowskiConvolutionTranspose( 417 | in_channels=CHANNELS[5], 418 | out_channels=TR_CHANNELS[5], 419 | kernel_size=3, 420 | stride=2, 421 | dilation=1, 422 | has_bias=False, 423 | dimension=D) 424 | self.norm5_tr = get_norm(NORM_TYPE, TR_CHANNELS[5], bn_momentum=bn_momentum, dimension=D) 425 | 426 | self.conv4_tr = ME.MinkowskiConvolutionTranspose( 427 | in_channels=CHANNELS[4] + TR_CHANNELS[5], 428 | out_channels=TR_CHANNELS[4], 429 | kernel_size=3, 430 | stride=2, 431 | dilation=1, 432 | has_bias=False, 433 | dimension=D) 434 | self.norm4_tr = get_norm(NORM_TYPE, TR_CHANNELS[4], bn_momentum=bn_momentum, dimension=D) 435 | 436 | self.conv3_tr = ME.MinkowskiConvolutionTranspose( 437 | in_channels=CHANNELS[3] + TR_CHANNELS[4], 438 | out_channels=TR_CHANNELS[3], 439 | kernel_size=3, 440 | stride=2, 441 | dilation=1, 442 | has_bias=False, 443 | dimension=D) 444 | self.norm3_tr = get_norm(NORM_TYPE, TR_CHANNELS[3], bn_momentum=bn_momentum, dimension=D) 445 | 446 | self.conv2_tr = ME.MinkowskiConvolutionTranspose( 447 | in_channels=CHANNELS[2] + TR_CHANNELS[3], 448 | out_channels=TR_CHANNELS[2], 449 | kernel_size=3, 450 | stride=2, 451 | dilation=1, 452 | has_bias=False, 453 | dimension=D) 454 | self.norm2_tr = get_norm(NORM_TYPE, TR_CHANNELS[2], bn_momentum=bn_momentum, dimension=D) 455 | 456 | self.conv1_tr = ME.MinkowskiConvolution( 457 | in_channels=CHANNELS[1] + TR_CHANNELS[2], 458 | out_channels=TR_CHANNELS[1], 459 | kernel_size=1, 460 | stride=1, 461 | dilation=1, 462 | has_bias=True, 463 | dimension=D) 464 | 465 | def forward(self, x): 466 | out_s1 = self.conv1(x) 467 | out_s1 = self.norm1(out_s1) 468 | out = MEF.relu(out_s1) 469 | 470 | out_s2 = self.conv2(out) 471 | out_s2 = self.norm2(out_s2) 472 | out = MEF.relu(out_s2) 473 | 474 | out_s4 = self.conv3(out) 475 | out_s4 = self.norm3(out_s4) 476 | out = MEF.relu(out_s4) 477 | 478 | out_s8 = self.conv4(out) 479 | out_s8 = self.norm4(out_s8) 480 | out = MEF.relu(out_s8) 481 | 482 | out_s16 = self.conv5(out) 483 | out_s16 = self.norm5(out_s16) 484 | out = MEF.relu(out_s16) 485 | 486 | out = self.conv5_tr(out) 487 | out = self.norm5_tr(out) 488 | out_s8_tr = MEF.relu(out) 489 | 490 | out = ME.cat((out_s8_tr, out_s8)) 491 | 492 | out = self.conv4_tr(out) 493 | out = self.norm4_tr(out) 494 | out_s4_tr = MEF.relu(out) 495 | 496 | out = ME.cat((out_s4_tr, out_s4)) 497 | 498 | out = self.conv3_tr(out) 499 | out = self.norm3_tr(out) 500 | out_s2_tr = MEF.relu(out) 501 | 502 | out = ME.cat((out_s2_tr, out_s2)) 503 | 504 | out = self.conv2_tr(out) 505 | out = self.norm2_tr(out) 506 | out_s1_tr = MEF.relu(out) 507 | 508 | out = ME.cat((out_s1_tr, out_s1)) 509 | out = self.conv1_tr(out) 510 | 511 | if self.normalize_feature: 512 | return ME.SparseTensor( 513 | out.F / torch.norm(out.F, p=2, dim=1, keepdim=True), 514 | coords_key=out.coords_key, 515 | coords_manager=out.coords_man) 516 | else: 517 | return out 518 | 519 | 520 | class SimpleNetIN3(SimpleNet3): 521 | NORM_TYPE = 'IN' 522 | 523 | 524 | class SimpleNetBN3(SimpleNet3): 525 | NORM_TYPE = 'BN' 526 | 527 | 528 | class SimpleNetBN3B(SimpleNet3): 529 | NORM_TYPE = 'BN' 530 | CHANNELS = [None, 32, 64, 128, 256, 512] 531 | TR_CHANNELS = [None, 32, 64, 64, 64, 128] 532 | 533 | 534 | class SimpleNetBN3C(SimpleNet3): 535 | NORM_TYPE = 'BN' 536 | CHANNELS = [None, 32, 64, 128, 256, 512] 537 | TR_CHANNELS = [None, 32, 32, 64, 128, 128] 538 | 539 | 540 | class SimpleNetBN3D(SimpleNet3): 541 | NORM_TYPE = 'BN' 542 | CHANNELS = [None, 32, 64, 128, 256, 512] 543 | TR_CHANNELS = [None, 32, 64, 64, 128, 128] 544 | 545 | 546 | class SimpleNetBN3E(SimpleNet3): 547 | NORM_TYPE = 'BN' 548 | CHANNELS = [None, 16, 32, 64, 128, 256] 549 | TR_CHANNELS = [None, 16, 32, 32, 64, 128] 550 | 551 | 552 | class SimpleNetIN3E(SimpleNetBN3E): 553 | NORM_TYPE = 'IN' 554 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py 2 | numpy 3 | # pytorch # for anaconda, please refer to pytorch.org for installation 4 | scipy>=1.4.1 5 | matplotlib 6 | open3d-python 7 | # to visualize it, you need tensorflow, but it doesn't have to be in the same virual environment :) 8 | tensorboardX 9 | MinkowskiEngine 10 | future-fstrings 11 | easydict 12 | joblib 13 | 14 | scikit-learn 15 | 16 | # For scannet segmentation 17 | pandas 18 | plyfile 19 | 20 | # For 2d data generation 21 | opencv-contrib-python==3.4.2.17 22 | tqdm 23 | -------------------------------------------------------------------------------- /scripts/benchmark_yfcc.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os.path as osp 3 | import sys 4 | 5 | import numpy as np 6 | import open3d 7 | import pandas as pd 8 | import torch 9 | from easydict import EasyDict as edict 10 | from sklearn.metrics import (average_precision_score, precision_recall_curve, 11 | precision_recall_fscore_support) 12 | 13 | from config import get_parser 14 | from lib.twodim_data_loaders import make_data_loader 15 | from lib.util import ensure_dir, read_txt 16 | from train import get_trainer 17 | 18 | 19 | def print_table(scenes, keys, values, out_dir, filename): 20 | data = dict() 21 | metrics = list(zip(*values)) 22 | for k, metric in zip(keys, metrics): 23 | data[k] = metric 24 | 25 | df = pd.DataFrame(data, index=scenes) 26 | df.loc['mean'] = df.mean() 27 | print(df.to_string()) 28 | df.to_csv(osp.join(out_dir, filename)) 29 | 30 | 31 | def load_scenes(config): 32 | dataset = config.dataset 33 | 34 | if 'YFCC100M' in dataset: 35 | scene_path = 'config/test_yfcc.txt' 36 | elif dataset == 'ThreeDMatchPairDataset': 37 | scene_path = 'config/test_3dmatch.txt' 38 | elif dataset == 'SUN3DDatasetExtracted': 39 | scene_path = 'config/test_sun3d.txt' 40 | else: 41 | raise ValueError(f"{dataset} is not supported") 42 | 43 | scene_list = read_txt(scene_path) 44 | return scene_list 45 | 46 | 47 | def exp_prec_recall(target_list, pred_list, residual_list, scene_list, out_dir): 48 | logging.info("Exp 1. Evaluating classification scores") 49 | 50 | target_list = [np.hstack(targets) for targets in target_list] 51 | pred_list = [np.hstack(preds) for preds in pred_list] 52 | residual_list = [np.hstack(residuals) for residuals in residual_list] 53 | 54 | keys = ['prec', 'recall', 'f1', 'ap', 'mean', 'median'] 55 | metrics = [] 56 | for targets, preds, residuals in zip(target_list, pred_list, residual_list): 57 | prec, recall, f1, _ = precision_recall_fscore_support( 58 | targets, (preds > 0.5).astype(np.int), average='binary') 59 | ap = average_precision_score(targets, preds) 60 | mean = np.mean(residuals) 61 | median = np.median(residuals) 62 | metrics.append([prec, recall, f1, ap, mean, median]) 63 | 64 | logging.info("Classification Scores") 65 | print_table(scene_list, keys, metrics, out_dir, 'prec_recall.csv') 66 | 67 | 68 | def exp_ap_curve(target_list, pred_list, out_dir): 69 | logging.info(f"Exp 2. Drawing Prec-Recall curve") 70 | 71 | targets = np.hstack([np.hstack(targets) for targets in target_list]) 72 | preds = np.hstack([np.hstack(preds) for preds in pred_list]) 73 | 74 | prec, recall, _ = precision_recall_curve(targets, preds) 75 | idx = np.linspace(0, len(recall) - 1, 100).astype(np.int) 76 | prec = prec[idx] 77 | recall = recall[idx] 78 | np.savez(osp.join(out_dir, 'ap_curve.npz'), prec=prec, recall=recall) 79 | 80 | 81 | def exp_distance_ap(residual_list, scene_list, out_dir): 82 | logging.info(f"Exp 3. Evaulating distance AP") 83 | 84 | ths = np.arange(20) * 0.01 85 | 86 | metrics = [] 87 | for residuals in residual_list: 88 | res_acc_hist, _ = np.histogram(residuals, ths) 89 | res_acc_hist = res_acc_hist.astype(np.float) / float(residuals.shape[0]) 90 | res_acc = np.cumsum(res_acc_hist) 91 | ap_list = [np.mean(res_acc[:i]) for i in range(1, len(ths))] 92 | metrics.append(ap_list) 93 | 94 | logging.info("mAP - Epipolar Distance") 95 | print_table(scene_list, ths[1:], metrics, out_dir, 'dist_ap.csv') 96 | 97 | 98 | def exp_angular_ap(err_q_list, err_t_list, scene_list, out_dir): 99 | logging.info("Exp 4. Evaluating augular AP") 100 | 101 | num_ths = 7 102 | ths = np.arange(num_ths) * 5 103 | 104 | metric_q, metric_t, metric_qt = [], [], [] 105 | for err_q, err_t in zip(err_q_list, err_t_list): 106 | q_acc_hist, _ = np.histogram(err_q, ths) 107 | q_acc_hist = q_acc_hist.astype(np.float) / float(err_q.shape[0]) 108 | q_acc = np.cumsum(q_acc_hist) 109 | q_ap = [np.mean(q_acc[:i]) for i in range(1, len(ths))] 110 | metric_q.append(q_ap) 111 | 112 | t_acc_hist, _ = np.histogram(err_t, ths) 113 | t_acc_hist = t_acc_hist.astype(np.float) / float(err_t.shape[0]) 114 | t_acc = np.cumsum(t_acc_hist) 115 | t_ap = [np.mean(t_acc[:i]) for i in range(1, len(ths))] 116 | metric_t.append(t_ap) 117 | 118 | qt_acc_hist, _ = np.histogram(np.maximum(err_q, err_t), ths) 119 | qt_acc_hist = qt_acc_hist.astype(np.float) / float(err_q.shape[0]) 120 | qt_acc = np.cumsum(qt_acc_hist) 121 | qt_ap = [np.mean(qt_acc[:i]) for i in range(1, len(ths))] 122 | metric_qt.append(qt_ap) 123 | 124 | logging.info("mAP - Rotation") 125 | print_table(scene_list, ths[1:], metric_q, out_dir, 'q_ap.csv') 126 | 127 | logging.info("mAP - Translation") 128 | print_table(scene_list, ths[1:], metric_t, out_dir, 't_ap.csv') 129 | 130 | logging.info("mAP - Rotation & Translation") 131 | print_table(scene_list, ths[1:], metric_qt, out_dir, 'qt_ap.csv') 132 | 133 | 134 | if __name__ == "__main__": 135 | # args 136 | parser = get_parser() 137 | parser.add_argument( 138 | '--do_extract', 139 | action='store_true', 140 | help='extract network output by feed-forwarding data') 141 | 142 | args = parser.parse_args() 143 | ensure_dir(args.out_dir) 144 | 145 | # setup logger 146 | ch = logging.StreamHandler(sys.stdout) 147 | logging.getLogger().setLevel(logging.INFO) 148 | logging.basicConfig( 149 | format='%(asctime)s %(message)s', datefmt='%m/%d %H:%M:%S', handlers=[ch]) 150 | 151 | logging.info("Start Benchmark") 152 | 153 | # prepare model 154 | checkpoint = torch.load(args.weights) 155 | 156 | config = checkpoint['config'] 157 | config.data_dir_raw = args.data_dir_raw 158 | config.data_dir_processed = args.data_dir_processed 159 | config.weights = args.weights 160 | config.out_dir = args.out_dir 161 | config.resume = None 162 | 163 | vargs = vars(args) 164 | for k, v in config.items(): 165 | vargs[k] = v 166 | config = edict(vargs) 167 | 168 | scenes = load_scenes(config) 169 | 170 | if args.do_extract: 171 | Trainer = get_trainer(config.trainer) 172 | model = Trainer(config, []) 173 | 174 | target_list, pred_list, residual_list = [], [], [] 175 | for scene in scenes: 176 | test_loader = make_data_loader( 177 | config, 178 | 'test', 179 | batch_size=1, 180 | num_workers=1, 181 | shuffle=False, 182 | repeat=False, 183 | scene=scene) 184 | 185 | targets, preds, residuals, err_qs, err_ts = model.test(test_loader) 186 | mean_residuals = [np.mean(res) for res in residuals] 187 | err_qs = np.hstack(err_qs) 188 | err_ts = np.hstack(err_ts) 189 | 190 | logging.info(f"Save raw data - {scene}") 191 | np.savez( 192 | osp.join(args.out_dir, f"{scene}_raw"), 193 | targets=targets, 194 | preds=preds, 195 | residuals=residuals, 196 | mean_residuals=mean_residuals, 197 | err_qs=err_qs, 198 | err_ts=err_ts) 199 | 200 | target_list, pred_list, residual_list, mean_residual_list, err_q_list, err_t_list = [], [], [], [], [], [] 201 | for scene in scenes: 202 | logging.info(f"Load raw data - {scene}") 203 | data = np.load(osp.join(args.out_dir, f"{scene}_raw.npz"), allow_pickle=True) 204 | target_list.append(data['targets']) 205 | pred_list.append(data['preds']) 206 | residual_list.append(data['residuals']) 207 | mean_residual_list.append(data['mean_residuals']) 208 | err_q_list.append(data['err_qs']) 209 | err_t_list.append(data['err_ts']) 210 | 211 | exp_prec_recall(target_list, pred_list, residual_list, scenes, args.out_dir) 212 | exp_ap_curve(target_list, pred_list, args.out_dir) 213 | exp_distance_ap(mean_residual_list, scenes, args.out_dir) 214 | exp_angular_ap(err_q_list, err_t_list, scenes, args.out_dir) 215 | -------------------------------------------------------------------------------- /scripts/download_yfcc.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | # Usage : bash download_yfcc.sh [path/to/download/dataset] 3 | # Example : baseh download_yfcc.sh /root/data 4 | 5 | DATA_DIR=$1 6 | 7 | DATA_NAME=oanet_data 8 | FILE_NAME=raw_data 9 | OUTPUT_NAME=raw_data_yfcc.tar.gz 10 | 11 | cd $DATA_DIR 12 | 13 | if [ ! -d download_data_$DATA_NAME ]; then 14 | mkdir -p download_data_$DATA_NAME 15 | fi 16 | 17 | let CHUNK_START=0 18 | let CHUNK_END=8 19 | 20 | 21 | for ((i=CHUNK_START;i<=CHUNK_END;i++)); do 22 | IDX=$(printf "%03d" $i) 23 | URL=research.altizure.com/data/$DATA_NAME/$FILE_NAME.tar.$IDX 24 | wget -c $URL -P download_data_$DATA_NAME 25 | echo $URL 26 | done 27 | 28 | 29 | cat download_data_oanet_data/*.tar.* > $OUTPUT_NAME 30 | rm -r download_data_oanet_data 31 | 32 | # Unzip 33 | tar -xvzf $OUTPUT_NAME 34 | mv raw_data/yfcc100m . 35 | rm -rf raw_data 36 | 37 | cd - -------------------------------------------------------------------------------- /scripts/plot_yfcc.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os.path as osp 3 | import sys 4 | 5 | import cv2 6 | import matplotlib.gridspec as grid 7 | import matplotlib.pyplot as plt 8 | from matplotlib.patches import ConnectionPatch 9 | import numpy as np 10 | import open3d 11 | import pandas as pd 12 | import torch 13 | from easydict import EasyDict as edict 14 | 15 | from config import get_parser 16 | from lib.twodim_data_loaders import make_data_loader 17 | from lib.util import ensure_dir, read_txt 18 | from train import get_trainer 19 | from scripts.benchmark_yfcc import load_scenes 20 | 21 | 22 | def draw_figure(img0, img1, coords, labels, preds): 23 | # prepare figure 24 | plt.clf() 25 | fig = plt.figure() 26 | ratios = [img0.shape[1] * img1.shape[0], img1.shape[1] * img0.shape[0]] 27 | gs = grid.GridSpec(nrows=2, ncols=1, height_ratios=ratios) 28 | ax1 = fig.add_subplot(gs[0]) 29 | ax2 = fig.add_subplot(gs[1]) 30 | ax1.axis('off') 31 | ax2.axis('off') 32 | preds = preds > 0.5 33 | coords = coords[preds] 34 | labels = labels[preds] 35 | 36 | img0 = cv2.cvtColor(img0, cv2.COLOR_BGR2RGB) 37 | img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB) 38 | for coord, is_inlier in zip(coords, labels): 39 | con = ConnectionPatch( 40 | xyA=coord[:2], 41 | xyB=coord[ 2:], 42 | coordsA="data", 43 | coordsB="data", 44 | axesA=ax2, 45 | axesB=ax1, 46 | color="green" if is_inlier else "red") 47 | ax2.add_artist(con) 48 | 49 | ax1.imshow(img1) 50 | ax2.imshow(img0) 51 | plt.subplots_adjust(left=0, bottom=0, right=1, top=1, hspace=0, wspace=0) 52 | return fig 53 | 54 | 55 | if __name__ == "__main__": 56 | # args 57 | parser = get_parser() 58 | parser.add_argument( 59 | '--do_extract', 60 | action='store_true', 61 | help='extract network output by feed-forwarding data') 62 | 63 | args = parser.parse_args() 64 | ensure_dir(args.out_dir) 65 | 66 | # setup logger 67 | ch = logging.StreamHandler(sys.stdout) 68 | logging.getLogger().setLevel(logging.INFO) 69 | logging.basicConfig( 70 | format='%(asctime)s %(message)s', datefmt='%m/%d %H:%M:%S', handlers=[ch]) 71 | 72 | logging.info("Start Benchmark") 73 | 74 | # prepare model 75 | checkpoint = torch.load(args.weights) 76 | 77 | config = checkpoint['config'] 78 | config.data_dir_raw = args.data_dir_raw 79 | config.data_dir_processed = args.data_dir_processed 80 | config.weights = args.weights 81 | config.out_dir = args.out_dir 82 | config.resume = None 83 | 84 | vargs = vars(args) 85 | for k, v in config.items(): 86 | vargs[k] = v 87 | config = edict(vargs) 88 | 89 | scenes = load_scenes(config) 90 | 91 | for scene in scenes: 92 | logging.info(f"Load raw data - {scene}") 93 | data = np.load(osp.join(args.out_dir, f"{scene}_raw.npz"), allow_pickle=True) 94 | preds = data['preds'] 95 | 96 | figure_dir = osp.join(args.out_dir, "figures") 97 | ensure_dir(figure_dir) 98 | 99 | test_loader = make_data_loader( 100 | config, 101 | 'test', 102 | batch_size=1, 103 | num_workers=1, 104 | shuffle=False, 105 | repeat=False, 106 | scene=scene) 107 | test_iter = test_loader.__iter__() 108 | 109 | for i in range(len(test_iter)): 110 | input_dict = test_iter.next() 111 | fig = draw_figure( 112 | img0=input_dict['img0'][0], 113 | img1=input_dict['img1'][0], 114 | coords=input_dict['coords'][0], 115 | labels=input_dict['labels'][0], 116 | preds=preds[i], 117 | ) 118 | filename = osp.join(figure_dir, f"{scene[0]}{i:03d}.png") 119 | fig.savefig(filename, dpi=100, bbox_inches='tight') 120 | logging.info(f"save {filename}") 121 | plt.close(fig) 122 | -------------------------------------------------------------------------------- /scripts/train_2d.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | export PATH_POSTFIX=$1 3 | export MISC_ARGS=$2 4 | 5 | export DATA_ROOT="./outputs/2d" 6 | export TRAINER=${TRAINER:-ImageCorrespondenceTrainer} 7 | export INLIER_MODEL=${INLIER_MODEL:-PyramidNetSCNoBlock} 8 | export DATASET=${DATASET:-YFCC100MDatasetUCN} 9 | export OPTIMIZER=${OPTIMIZER:-SGD} 10 | export LR=${LR:-1e-1} 11 | export MAX_EPOCH=${MAX_EPOCH:-250} 12 | export BATCH_SIZE=${BATCH_SIZE:-16} 13 | export QUANTIZATION_SIZE=${QUANTIZATION_SIZE:-0.01} 14 | export INLIER_THRESHOLD_PIXEL=${INLIER_THRESHOLD_PIXEL:-0.01} 15 | export INLIER_FEATURE_TYPE=${INLIER_FEATURE_TYPE:-coords} 16 | export COLLATION_2D=${COLLATION_2D:-collate_correspondence} 17 | export BEST_VAL_METRIC=${BEST_VAL_METRIC:-mAP20} 18 | export TIME=$(date +"%Y-%m-%d_%H-%M-%S") 19 | export VERSION=$(git rev-parse HEAD) 20 | 21 | export OUT_DIR=${DATA_ROOT}/${DATASET}-v${QUANTIZATION_SIZE}-i${INLIER_THRESHOLD_PIXEL}/${INLIER_MODEL}-${BEST_VAL_METRIC}-${INLIER_FEATURE_TYPE}/${OPTIMIZER}-lr${LR}-e${MAX_EPOCH}-b${BATCH_SIZE}/${PATH_POSTFIX}/${TIME} 22 | 23 | export PYTHONUNBUFFERED="True" 24 | 25 | echo $OUT_DIR 26 | 27 | mkdir -m 755 -p $OUT_DIR 28 | 29 | LOG=${OUT_DIR}/log_${TIME}.txt 30 | 31 | echo "Host: " $(hostname) | tee -a $LOG 32 | echo "Conda " $(which conda) | tee -a $LOG 33 | echo $(pwd) | tee -a $LOG 34 | echo "Version: " $VERSION | tee -a $LOG 35 | echo "Git diff" | tee -a $LOG 36 | echo "" | tee -a $LOG 37 | git diff | tee -a $LOG 38 | echo "" | tee -a $LOG 39 | nvidia-smi | tee -a $LOG 40 | 41 | set -x 42 | 43 | echo " 44 | python train.py \ 45 | --optimizer ${OPTIMIZER} \ 46 | --lr ${LR} \ 47 | --batch_size ${BATCH_SIZE} \ 48 | --val_batch_size ${BATCH_SIZE} \ 49 | --max_epoch ${MAX_EPOCH} \ 50 | --dataset ${DATASET} \ 51 | --trainer ${TRAINER} \ 52 | --inlier_model ${INLIER_MODEL} \ 53 | --inlier_feature_type ${INLIER_FEATURE_TYPE} \ 54 | --quantization_size ${QUANTIZATION_SIZE} \ 55 | --inlier_threshold_pixel ${INLIER_THRESHOLD_PIXEL} \ 56 | --collation_2d ${COLLATION_2D} \ 57 | --best_val_metric ${BEST_VAL_METRIC} \ 58 | --out_dir ${OUT_DIR} \ 59 | --sample_minimum_coords True \ 60 | ${MISC_ARGS} 2>&1" | tee -a $LOG 61 | 62 | 63 | # Training 64 | python train.py \ 65 | --optimizer ${OPTIMIZER} \ 66 | --lr ${LR} \ 67 | --batch_size ${BATCH_SIZE} \ 68 | --val_batch_size ${BATCH_SIZE} \ 69 | --max_epoch ${MAX_EPOCH} \ 70 | --dataset ${DATASET} \ 71 | --trainer ${TRAINER} \ 72 | --inlier_model ${INLIER_MODEL} \ 73 | --inlier_feature_type ${INLIER_FEATURE_TYPE} \ 74 | --quantization_size ${QUANTIZATION_SIZE} \ 75 | --inlier_threshold_pixel ${INLIER_THRESHOLD_PIXEL} \ 76 | --collation_2d ${COLLATION_2D} \ 77 | --best_val_metric ${BEST_VAL_METRIC} \ 78 | --out_dir ${OUT_DIR} \ 79 | --sample_minimum_coords True \ 80 | ${MISC_ARGS} 2>&1 | tee -a $LOG 81 | 82 | -------------------------------------------------------------------------------- /scripts/train_2d_onpaper.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | export PATH_POSTFIX=$1 3 | export MISC_ARGS=$2 4 | 5 | export DATA_ROOT="./outputs/2d" 6 | export TRAINER=${TRAINER:-ImageCorrespondenceTrainer} 7 | export INLIER_MODEL=${INLIER_MODEL:-ResNetSC} 8 | export DATASET=${DATASET:-YFCC100MDatasetExtracted} 9 | export OPTIMIZER=${OPTIMIZER:-SGD} 10 | export LR=${LR:-1e-1} 11 | export MAX_EPOCH=${MAX_EPOCH:-100} 12 | export BATCH_SIZE=${BATCH_SIZE:-32} 13 | export QUANTIZATION_SIZE=${QUANTIZATION_SIZE:-0.01} 14 | export INLIER_THRESHOLD_PIXEL=${INLIER_THRESHOLD_PIXEL:-0.01} 15 | export INLIER_FEATURE_TYPE=${INLIER_FEATURE_TYPE:-coords} 16 | export COLLATION_2D=${COLLATION_2D:-collate_correspondence} 17 | export BEST_VAL_METRIC=${BEST_VAL_METRIC:-ap} 18 | export TIME=$(date +"%Y-%m-%d_%H-%M-%S") 19 | export VERSION=$(git rev-parse HEAD) 20 | 21 | export OUT_DIR=${DATA_ROOT}/${DATASET}-v${QUANTIZATION_SIZE}-i${INLIER_THRESHOLD_PIXEL}/${INLIER_MODEL}-${BEST_VAL_METRIC}-${INLIER_FEATURE_TYPE}/${OPTIMIZER}-lr${LR}-e${MAX_EPOCH}-b${BATCH_SIZE}/${PATH_POSTFIX}/${TIME} 22 | 23 | export PYTHONUNBUFFERED="True" 24 | 25 | echo $OUT_DIR 26 | 27 | mkdir -m 755 -p $OUT_DIR 28 | 29 | LOG=${OUT_DIR}/log_${TIME}.txt 30 | 31 | echo "Host: " $(hostname) | tee -a $LOG 32 | echo "Conda " $(which conda) | tee -a $LOG 33 | echo $(pwd) | tee -a $LOG 34 | echo "Version: " $VERSION | tee -a $LOG 35 | echo "Git diff" | tee -a $LOG 36 | echo "" | tee -a $LOG 37 | git diff | tee -a $LOG 38 | echo "" | tee -a $LOG 39 | nvidia-smi | tee -a $LOG 40 | 41 | # Training 42 | python train.py \ 43 | --optimizer ${OPTIMIZER} \ 44 | --lr ${LR} \ 45 | --batch_size ${BATCH_SIZE} \ 46 | --val_batch_size ${BATCH_SIZE} \ 47 | --max_epoch ${MAX_EPOCH} \ 48 | --dataset ${DATASET} \ 49 | --trainer ${TRAINER} \ 50 | --inlier_model ${INLIER_MODEL} \ 51 | --inlier_feature_type ${INLIER_FEATURE_TYPE} \ 52 | --quantization_size ${QUANTIZATION_SIZE} \ 53 | --inlier_threshold_pixel ${INLIER_THRESHOLD_PIXEL} \ 54 | --collation_2d ${COLLATION_2D} \ 55 | --best_val_metric ${BEST_VAL_METRIC} \ 56 | --out_dir ${OUT_DIR} \ 57 | --sample_minimum_coords True \ 58 | ${MISC_ARGS} 2>&1 | tee -a $LOG 59 | 60 | # Test 61 | # TODO 62 | -------------------------------------------------------------------------------- /scripts/train_lfgc.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | export PATH_POSTFIX=$1 3 | export MISC_ARGS=$2 4 | 5 | export DATA_ROOT="./outputs/LFGC" 6 | export TRAINER=${TRAINER:-LFGCTrainer} 7 | export INLIER_MODEL=${INLIER_MODEL:-LFGCNet} 8 | export DATASET=${DATASET:-YFCC100MDatasetExtracted} 9 | export OPTIMIZER=${OPTIMIZER:-Adam} 10 | export LR=${LR:-1e-4} 11 | export MAX_EPOCH=${MAX_EPOCH:-100} 12 | export BATCH_SIZE=${BATCH_SIZE:-32} 13 | export INLIER_THRESHOLD_PIXEL=${INLIER_THRESHOLD_PIXEL:-0.01} 14 | export COLLATION_2D=${COLLATION_2D:-collate_lfgc} 15 | export BEST_VAL_METRIC=${BEST_VAL_METRIC:-f1} 16 | export TIME=$(date +"%Y-%m-%d_%H-%M-%S") 17 | export VERSION=$(git rev-parse HEAD) 18 | 19 | export OUT_DIR=${DATA_ROOT}/${DATASET}-i${INLIER_THRESHOLD_PIXEL}/${INLIER_MODEL}-${BEST_VAL_METRIC}/${OPTIMIZER}-lr${LR}-e${MAX_EPOCH}-b${BATCH_SIZE}/${PATH_POSTFIX}/${TIME} 20 | 21 | export PYTHONUNBUFFERED="True" 22 | 23 | echo $OUT_DIR 24 | 25 | mkdir -m 755 -p $OUT_DIR 26 | 27 | LOG=${OUT_DIR}/log_${TIME}.txt 28 | 29 | echo "Host: " $(hostname) | tee -a $LOG 30 | echo "Conda " $(which conda) | tee -a $LOG 31 | echo $(pwd) | tee -a $LOG 32 | echo "Version: " $VERSION | tee -a $LOG 33 | echo "Git diff" | tee -a $LOG 34 | echo "" | tee -a $LOG 35 | git diff | tee -a $LOG 36 | echo "" | tee -a $LOG 37 | nvidia-smi | tee -a $LOG 38 | 39 | # Training 40 | python train.py \ 41 | --optimizer ${OPTIMIZER} \ 42 | --lr ${LR} \ 43 | --batch_size ${BATCH_SIZE} \ 44 | --val_batch_size ${BATCH_SIZE} \ 45 | --max_epoch ${MAX_EPOCH} \ 46 | --dataset ${DATASET} \ 47 | --trainer ${TRAINER} \ 48 | --inlier_threshold_pixel ${INLIER_THRESHOLD_PIXEL} \ 49 | --collation_2d ${COLLATION_2D} \ 50 | --best_val_metric ${BEST_VAL_METRIC} \ 51 | --out_dir ${OUT_DIR} \ 52 | ${MISC_ARGS} 2>&1 | tee -a $LOG 53 | 54 | # Test 55 | # TODO 56 | -------------------------------------------------------------------------------- /scripts/train_oa.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | export PATH_POSTFIX=$1 3 | export MISC_ARGS=$2 4 | 5 | export DATA_ROOT="./outputs/OA" 6 | export TRAINER=${TRAINER:-OATrainer} 7 | export INLIER_MODEL=${INLIER_MODEL:-OANet} 8 | export DATASET=${DATASET:-YFCC100MDatasetExtracted} 9 | export OPTIMIZER=${OPTIMIZER:-Adam} 10 | export LR=${LR:-1e-4} 11 | export MAX_EPOCH=${MAX_EPOCH:-250} 12 | export BATCH_SIZE=${BATCH_SIZE:-32} 13 | export INLIER_THRESHOLD_PIXEL=${INLIER_THRESHOLD_PIXEL:-0.01} 14 | export COLLATION_2D=${COLLATION_2D:-collate_oa} 15 | export BEST_VAL_METRIC=${BEST_VAL_METRIC:-ap} 16 | export TIME=$(date +"%Y-%m-%d_%H-%M-%S") 17 | export VERSION=$(git rev-parse HEAD) 18 | 19 | export OUT_DIR=${DATA_ROOT}/${DATASET}-i${INLIER_THRESHOLD_PIXEL}/${INLIER_MODEL}-${BEST_VAL_METRIC}/${OPTIMIZER}-lr${LR}-e${MAX_EPOCH}-b${BATCH_SIZE}/${PATH_POSTFIX}/${TIME} 20 | 21 | export PYTHONUNBUFFERED="True" 22 | 23 | echo $OUT_DIR 24 | 25 | mkdir -m 755 -p $OUT_DIR 26 | 27 | LOG=${OUT_DIR}/log_${TIME}.txt 28 | 29 | echo "Host: " $(hostname) | tee -a $LOG 30 | echo "Conda " $(which conda) | tee -a $LOG 31 | echo $(pwd) | tee -a $LOG 32 | echo "Version: " $VERSION | tee -a $LOG 33 | echo "Git diff" | tee -a $LOG 34 | echo "" | tee -a $LOG 35 | git diff | tee -a $LOG 36 | echo "" | tee -a $LOG 37 | nvidia-smi | tee -a $LOG 38 | 39 | # Training 40 | python train.py \ 41 | --optimizer ${OPTIMIZER} \ 42 | --lr ${LR} \ 43 | --batch_size ${BATCH_SIZE} \ 44 | --val_batch_size ${BATCH_SIZE} \ 45 | --max_epoch ${MAX_EPOCH} \ 46 | --dataset ${DATASET} \ 47 | --trainer ${TRAINER} \ 48 | --inlier_threshold_pixel ${INLIER_THRESHOLD_PIXEL} \ 49 | --collation_2d ${COLLATION_2D} \ 50 | --best_val_metric ${BEST_VAL_METRIC} \ 51 | --out_dir ${OUT_DIR} \ 52 | ${MISC_ARGS} 2>&1 | tee -a $LOG 53 | 54 | # Test 55 | # TODO 56 | -------------------------------------------------------------------------------- /scripts/train_oa_onpaper.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | export PATH_POSTFIX=$1 3 | export MISC_ARGS=$2 4 | 5 | export DATA_ROOT="./outputs/OA" 6 | export TRAINER=${TRAINER:-OATrainer} 7 | export INLIER_MODEL=${INLIER_MODEL:-OANet} 8 | export DATASET=${DATASET:-YFCC100MDatasetExtracted} 9 | export OPTIMIZER=${OPTIMIZER:-Adam} 10 | export LR=${LR:-1e-3} 11 | export MAX_EPOCH=${MAX_EPOCH:-100} 12 | export BATCH_SIZE=${BATCH_SIZE:-32} 13 | export INLIER_THRESHOLD_PIXEL=${INLIER_THRESHOLD_PIXEL:-0.01} 14 | export COLLATION_2D=${COLLATION_2D:-collate_oa} 15 | export BEST_VAL_METRIC=${BEST_VAL_METRIC:-ap} 16 | export TIME=$(date +"%Y-%m-%d_%H-%M-%S") 17 | export VERSION=$(git rev-parse HEAD) 18 | 19 | export OUT_DIR=${DATA_ROOT}/${DATASET}-i${INLIER_THRESHOLD_PIXEL}/${INLIER_MODEL}-${BEST_VAL_METRIC}/${OPTIMIZER}-lr${LR}-e${MAX_EPOCH}-b${BATCH_SIZE}/${PATH_POSTFIX}/${TIME} 20 | 21 | export PYTHONUNBUFFERED="True" 22 | 23 | echo $OUT_DIR 24 | 25 | mkdir -m 755 -p $OUT_DIR 26 | 27 | LOG=${OUT_DIR}/log_${TIME}.txt 28 | 29 | echo "Host: " $(hostname) | tee -a $LOG 30 | echo "Conda " $(which conda) | tee -a $LOG 31 | echo $(pwd) | tee -a $LOG 32 | echo "Version: " $VERSION | tee -a $LOG 33 | echo "Git diff" | tee -a $LOG 34 | echo "" | tee -a $LOG 35 | git diff | tee -a $LOG 36 | echo "" | tee -a $LOG 37 | nvidia-smi | tee -a $LOG 38 | 39 | # Training 40 | python train.py \ 41 | --optimizer ${OPTIMIZER} \ 42 | --lr ${LR} \ 43 | --batch_size ${BATCH_SIZE} \ 44 | --val_batch_size ${BATCH_SIZE} \ 45 | --max_epoch ${MAX_EPOCH} \ 46 | --dataset ${DATASET} \ 47 | --trainer ${TRAINER} \ 48 | --inlier_threshold_pixel ${INLIER_THRESHOLD_PIXEL} \ 49 | --collation_2d ${COLLATION_2D} \ 50 | --best_val_metric ${BEST_VAL_METRIC} \ 51 | --out_dir ${OUT_DIR} \ 52 | ${MISC_ARGS} 2>&1 | tee -a $LOG 53 | 54 | # Test 55 | # TODO 56 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: future_fstrings -*- 2 | import os 3 | import sys 4 | import json 5 | import logging 6 | import torch 7 | from easydict import EasyDict as edict 8 | 9 | from lib.trainer import get_trainer 10 | from lib.data_loaders import make_data_loader 11 | from lib.loss import pts_loss2 12 | from config import get_config 13 | from model import load_model 14 | 15 | import MinkowskiEngine as ME 16 | 17 | ch = logging.StreamHandler(sys.stdout) 18 | logging.getLogger().setLevel(logging.INFO) 19 | logging.basicConfig( 20 | format='%(asctime)s %(message)s', datefmt='%m/%d %H:%M:%S', handlers=[ch]) 21 | 22 | torch.manual_seed(0) 23 | torch.cuda.manual_seed(0) 24 | 25 | logging.basicConfig(level=logging.INFO, format="") 26 | 27 | 28 | def main(config, resume=False): 29 | test_loader = make_data_loader( 30 | config, 31 | config.test_phase, 32 | 1, 33 | num_threads=config.test_num_thread) 34 | 35 | num_feats = 0 36 | if config.use_color: 37 | num_feats += 3 38 | if config.use_normal: 39 | num_feats += 3 40 | num_feats = max(1, num_feats) 41 | 42 | Model = load_model(config.model) 43 | model = Model(num_feats, config.model_n_out, config=config) 44 | 45 | if config.weights: 46 | logging.info(f"Loading the weights {config.weights}") 47 | checkpoint = torch.load(config.weights, map_location=lambda storage, loc: storage) 48 | model.load_state_dict(checkpoint['state_dict']) 49 | 50 | logging.info(model) 51 | 52 | metrics_fn = [pts_loss2] 53 | Trainer = get_trainer(config.trainer) 54 | trainer = Trainer( 55 | model, 56 | metrics_fn, 57 | config=config, 58 | data_loader=test_loader, 59 | val_data_loader=test_loader, 60 | ) 61 | 62 | test_dict = trainer._valid_epoch() 63 | 64 | 65 | if __name__ == "__main__": 66 | logger = logging.getLogger() 67 | config = get_config() 68 | if config.me_num_thread < 0: 69 | config.me_num_thread = os.cpu_count() 70 | 71 | dconfig = vars(config) 72 | if config.weights_dir: 73 | resume_config = json.load(open(config.weights_dir + '/config.json', 'r')) 74 | for k in dconfig: 75 | if k not in ['weights_dir', 'dataset'] and k in resume_config: 76 | dconfig[k] = resume_config[k] 77 | dconfig['weights'] = config.weights_dir + '/checkpoint.pth' 78 | 79 | logging.info('===> Configurations') 80 | for k in dconfig: 81 | logging.info(' {}: {}'.format(k, dconfig[k])) 82 | 83 | # Convert to dict 84 | config = edict(dconfig) 85 | ME.initialize_nthreads(config.me_num_thread, D=3) 86 | 87 | main(config) 88 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d # prevent loading error 2 | 3 | import sys 4 | import json 5 | import logging 6 | import torch 7 | from easydict import EasyDict as edict 8 | 9 | from lib.all_data_loaders import make_data_loader 10 | from config import get_config 11 | 12 | from lib.lfgc_trainer import LFGCTrainer 13 | from lib.oa_trainer import OATrainer 14 | from lib.twodim_trainer import ImageCorrespondenceTrainer 15 | 16 | ch = logging.StreamHandler(sys.stdout) 17 | logging.getLogger().setLevel(logging.INFO) 18 | logging.basicConfig( 19 | format='%(asctime)s %(message)s', datefmt='%m/%d %H:%M:%S', handlers=[ch]) 20 | 21 | torch.manual_seed(0) 22 | torch.cuda.manual_seed(0) 23 | 24 | logging.basicConfig(level=logging.INFO, format="") 25 | 26 | TRAINERS = [ 27 | # Register Trainer here 28 | LFGCTrainer, 29 | OATrainer, 30 | ImageCorrespondenceTrainer, 31 | ] 32 | 33 | trainer_map = {t.__name__: t for t in TRAINERS} 34 | 35 | 36 | def get_trainer(trainer): 37 | if trainer in trainer_map.keys(): 38 | return trainer_map[trainer] 39 | else: 40 | raise ValueError(f'Trainer {trainer} not found') 41 | 42 | 43 | def main(config, resume=False): 44 | train_loader = make_data_loader( 45 | config, 46 | config.train_phase, 47 | config.batch_size, 48 | shuffle=True, 49 | repeat=True, 50 | num_workers=config.train_num_workers) 51 | if config.test_valid: 52 | val_loader = make_data_loader( 53 | config, 54 | config.val_phase, 55 | config.val_batch_size, 56 | shuffle=True, 57 | repeat=True, 58 | num_workers=config.val_num_workers) 59 | else: 60 | val_loader = None 61 | 62 | Trainer = get_trainer(config.trainer) 63 | trainer = Trainer( 64 | config=config, 65 | data_loader=train_loader, 66 | val_data_loader=val_loader, 67 | ) 68 | 69 | trainer.train() 70 | 71 | if config.final_test: 72 | test_loader = make_data_loader( 73 | config, "test", config.val_batch_size, num_workers=config.val_num_workers) 74 | trainer.val_data_loader = test_loader 75 | test_dict = trainer._valid_epoch() 76 | test_loss = test_dict['loss'] 77 | trainer.writer.add_scalar('test/loss', test_loss, config.max_epoch) 78 | logging.info(f" Test loss: {test_loss}") 79 | 80 | 81 | if __name__ == "__main__": 82 | logger = logging.getLogger() 83 | config = get_config() 84 | dconfig = vars(config) 85 | if config.resume_dir: 86 | resume_config = json.load(open(config.resume_dir + '/config.json', 'r')) 87 | for k in dconfig: 88 | if k not in ['resume_dir'] and k in resume_config: 89 | dconfig[k] = resume_config[k] 90 | dconfig['resume'] = resume_config['out_dir'] + '/checkpoint.pth' 91 | 92 | logging.info('===> Configurations') 93 | for k in dconfig: 94 | logging.info(' {}: {}'.format(k, dconfig[k])) 95 | 96 | # Convert to dict 97 | config = edict(dconfig) 98 | 99 | main(config) 100 | -------------------------------------------------------------------------------- /ucn/blocks.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2019 Chris Choy (chrischoy@ai.stanford.edu) 4 | # Junha Lee (junhakiwi@postech.ac.kr) 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 7 | # this software and associated documentation files (the "Software"), to deal in 8 | # the Software without restriction, including without limitation the rights to 9 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 10 | # of the Software, and to permit persons to whom the Software is furnished to do 11 | # so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | import torch.nn as nn 24 | 25 | 26 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 27 | """3x3 convolution with padding""" 28 | return nn.Conv2d( 29 | in_planes, 30 | out_planes, 31 | kernel_size=3, 32 | stride=stride, 33 | padding=dilation, 34 | groups=groups, 35 | bias=False, 36 | dilation=dilation) 37 | 38 | 39 | def conv1x1(in_planes, out_planes, stride=1): 40 | """1x1 convolution""" 41 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 42 | 43 | 44 | class BasicBlock(nn.Module): 45 | expansion = 1 46 | __constants__ = ['downsample'] 47 | 48 | def __init__(self, 49 | inplanes, 50 | planes, 51 | stride=1, 52 | downsample=None, 53 | groups=1, 54 | base_width=64, 55 | dilation=1, 56 | norm_layer=None): 57 | super(BasicBlock, self).__init__() 58 | if norm_layer is None: 59 | norm_layer = nn.BatchNorm2d 60 | if groups != 1 or base_width != 64: 61 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 62 | if dilation > 1: 63 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 64 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 65 | self.conv1 = conv3x3(inplanes, planes, stride) 66 | self.bn1 = norm_layer(planes) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.conv2 = conv3x3(planes, planes) 69 | self.bn2 = norm_layer(planes) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x): 74 | identity = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | 83 | if self.downsample is not None: 84 | identity = self.downsample(x) 85 | 86 | out += identity 87 | out = self.relu(out) 88 | 89 | return out 90 | 91 | 92 | class Bottleneck(nn.Module): 93 | expansion = 4 94 | __constants__ = ['downsample'] 95 | 96 | def __init__(self, 97 | inplanes, 98 | planes, 99 | stride=1, 100 | downsample=None, 101 | groups=1, 102 | base_width=64, 103 | dilation=1, 104 | norm_layer=None): 105 | super(Bottleneck, self).__init__() 106 | if norm_layer is None: 107 | norm_layer = nn.BatchNorm2d 108 | width = int(planes * (base_width / 64.)) * groups 109 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 110 | self.conv1 = conv1x1(inplanes, width) 111 | self.bn1 = norm_layer(width) 112 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 113 | self.bn2 = norm_layer(width) 114 | self.conv3 = conv1x1(width, planes * self.expansion) 115 | self.bn3 = norm_layer(planes * self.expansion) 116 | self.relu = nn.ReLU(inplace=True) 117 | self.downsample = downsample 118 | self.stride = stride 119 | 120 | def forward(self, x): 121 | identity = x 122 | 123 | out = self.conv1(x) 124 | out = self.bn1(out) 125 | out = self.relu(out) 126 | 127 | out = self.conv2(out) 128 | out = self.bn2(out) 129 | out = self.relu(out) 130 | 131 | out = self.conv3(out) 132 | out = self.bn3(out) 133 | 134 | if self.downsample is not None: 135 | identity = self.downsample(x) 136 | 137 | out += identity 138 | out = self.relu(out) 139 | 140 | return out 141 | 142 | 143 | def get_block(norm_type, 144 | inplanes, 145 | planes, 146 | stride=1, 147 | dilation=1, 148 | downsample=None, 149 | bn_momentum=0.1): 150 | return BasicBlock( 151 | inplanes, planes, stride=stride, dilation=dilation, downsample=downsample) 152 | -------------------------------------------------------------------------------- /ucn/resunet.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2019 Chris Choy (chrischoy@ai.stanford.edu) 4 | # Junha Lee (junhakiwi@postech.ac.kr) 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 7 | # this software and associated documentation files (the "Software"), to deal in 8 | # the Software without restriction, including without limitation the rights to 9 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 10 | # of the Software, and to permit persons to whom the Software is furnished to do 11 | # so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | import torch 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | 27 | from .blocks import get_block 28 | 29 | 30 | def get_norm(norm_type, num_features, bn_momentum=0.1): 31 | return nn.BatchNorm2d(num_features, momentum=bn_momentum) 32 | 33 | 34 | class ResUNet2(nn.Module): 35 | CHANNELS = [None, 32, 64, 128, 256] 36 | TR_CHANNELS = [None, 32, 64, 64, 128] 37 | OUT_TENSOR_STRIDE = 1 38 | DEPTHS = [1, 1, 1, 1, 1, 1, 1] 39 | 40 | # To use the model, must call initialize_coords before forward pass. 41 | # Once data is processed, call clear to reset the model before calling initialize_coords 42 | def __init__(self, 43 | in_channels=1, 44 | out_channels=32, 45 | bn_momentum=0.1, 46 | normalize_feature=False): 47 | nn.Module.__init__(self) 48 | CHANNELS = self.CHANNELS 49 | TR_CHANNELS = self.TR_CHANNELS 50 | DEPTHS = self.DEPTHS 51 | NORM_TYPE = 'BN' 52 | BLOCK_NORM_TYPE = 'BN' 53 | self.normalize_feature = normalize_feature 54 | 55 | self.conv1 = nn.Conv2d( 56 | in_channels=in_channels, 57 | out_channels=CHANNELS[1], 58 | kernel_size=3, 59 | stride=1, 60 | padding=1, 61 | dilation=1, 62 | bias=False) 63 | self.norm1 = get_norm(NORM_TYPE, CHANNELS[1], bn_momentum=bn_momentum) 64 | 65 | self.blocks1 = nn.Sequential(*[ 66 | get_block(BLOCK_NORM_TYPE, CHANNELS[1], CHANNELS[1], bn_momentum=bn_momentum) 67 | for d in range(DEPTHS[0]) 68 | ]) 69 | 70 | self.conv2 = nn.Conv2d( 71 | in_channels=CHANNELS[1], 72 | out_channels=CHANNELS[2], 73 | kernel_size=3, 74 | stride=2, 75 | padding=1, 76 | dilation=1, 77 | bias=False) 78 | self.norm2 = get_norm(NORM_TYPE, CHANNELS[2], bn_momentum=bn_momentum) 79 | 80 | self.blocks2 = nn.Sequential(*[ 81 | get_block(BLOCK_NORM_TYPE, CHANNELS[2], CHANNELS[2], bn_momentum=bn_momentum) 82 | for d in range(DEPTHS[1]) 83 | ]) 84 | 85 | self.conv3 = nn.Conv2d( 86 | in_channels=CHANNELS[2], 87 | out_channels=CHANNELS[3], 88 | kernel_size=3, 89 | stride=2, 90 | padding=1, 91 | dilation=1, 92 | bias=False) 93 | self.norm3 = get_norm(NORM_TYPE, CHANNELS[3], bn_momentum=bn_momentum) 94 | 95 | self.blocks3 = nn.Sequential(*[ 96 | get_block(BLOCK_NORM_TYPE, CHANNELS[3], CHANNELS[3], bn_momentum=bn_momentum) 97 | for d in range(DEPTHS[2]) 98 | ]) 99 | 100 | self.conv4 = nn.Conv2d( 101 | in_channels=CHANNELS[3], 102 | out_channels=CHANNELS[4], 103 | kernel_size=3, 104 | stride=2, 105 | padding=1, 106 | dilation=1, 107 | bias=False) 108 | self.norm4 = get_norm(NORM_TYPE, CHANNELS[4], bn_momentum=bn_momentum) 109 | 110 | self.blocks4 = nn.Sequential(*[ 111 | get_block(BLOCK_NORM_TYPE, CHANNELS[4], CHANNELS[4], bn_momentum=bn_momentum) 112 | for d in range(DEPTHS[3]) 113 | ]) 114 | 115 | self.conv4_tr = nn.ConvTranspose2d( 116 | in_channels=CHANNELS[4], 117 | out_channels=TR_CHANNELS[4], 118 | kernel_size=3, 119 | stride=2, 120 | padding=0, 121 | output_padding=0, 122 | dilation=1, 123 | bias=False) 124 | self.norm4_tr = get_norm(NORM_TYPE, TR_CHANNELS[4], bn_momentum=bn_momentum) 125 | 126 | self.blocks4_tr = nn.Sequential(*[ 127 | get_block( 128 | BLOCK_NORM_TYPE, TR_CHANNELS[4], TR_CHANNELS[4], bn_momentum=bn_momentum) 129 | for d in range(DEPTHS[4]) 130 | ]) 131 | 132 | self.conv3_tr = nn.ConvTranspose2d( 133 | in_channels=CHANNELS[3] + TR_CHANNELS[4], 134 | out_channels=TR_CHANNELS[3], 135 | kernel_size=3, 136 | stride=2, 137 | padding=0, 138 | output_padding=0, 139 | dilation=1, 140 | bias=False) 141 | self.norm3_tr = get_norm(NORM_TYPE, TR_CHANNELS[3], bn_momentum=bn_momentum) 142 | 143 | self.blocks3_tr = nn.Sequential(*[ 144 | get_block( 145 | BLOCK_NORM_TYPE, TR_CHANNELS[3], TR_CHANNELS[3], bn_momentum=bn_momentum) 146 | for d in range(DEPTHS[5]) 147 | ]) 148 | 149 | self.conv2_tr = nn.ConvTranspose2d( 150 | in_channels=CHANNELS[2] + TR_CHANNELS[3], 151 | out_channels=TR_CHANNELS[2], 152 | kernel_size=3, 153 | stride=2, 154 | padding=0, 155 | output_padding=0, 156 | dilation=1, 157 | bias=False) 158 | self.norm2_tr = get_norm(NORM_TYPE, TR_CHANNELS[2], bn_momentum=bn_momentum) 159 | 160 | self.blocks2_tr = nn.Sequential(*[ 161 | get_block( 162 | BLOCK_NORM_TYPE, TR_CHANNELS[2], TR_CHANNELS[2], bn_momentum=bn_momentum) 163 | for d in range(DEPTHS[6]) 164 | ]) 165 | 166 | self.conv1_tr = nn.Conv2d( 167 | in_channels=CHANNELS[1] + TR_CHANNELS[2], 168 | out_channels=TR_CHANNELS[1], 169 | kernel_size=1, 170 | stride=1, 171 | padding=0, 172 | dilation=1, 173 | bias=False) 174 | # self.norm1_tr = get_norm(NORM_TYPE, TR_CHANNELS[1], bn_momentum=bn_momentum) 175 | 176 | self.final = nn.Conv2d( 177 | in_channels=TR_CHANNELS[1], 178 | out_channels=out_channels, 179 | kernel_size=1, 180 | stride=1, 181 | padding=0, 182 | dilation=1, 183 | bias=True) 184 | 185 | self.weight_initialization() 186 | 187 | def weight_initialization(self): 188 | for m in self.modules(): 189 | if isinstance(m, nn.Conv2d): 190 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 191 | elif isinstance(m, nn.ConvTranspose2d): 192 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 193 | elif isinstance(m, nn.BatchNorm2d): 194 | nn.init.constant_(m.weight, 1) 195 | nn.init.constant_(m.bias, 0) 196 | 197 | def forward(self, x): 198 | out_s1 = self.conv1(x) 199 | out_s1 = self.norm1(out_s1) 200 | out_s1 = F.relu(out_s1) 201 | out_s1 = self.blocks1(out_s1) 202 | 203 | out_s2 = self.conv2(out_s1) 204 | out_s2 = self.norm2(out_s2) 205 | out_s2 = F.relu(out_s2) 206 | out_s2 = self.blocks2(out_s2) 207 | 208 | out_s4 = self.conv3(out_s2) 209 | out_s4 = self.norm3(out_s4) 210 | out_s4 = F.relu(out_s4) 211 | out_s4 = self.blocks3(out_s4) 212 | 213 | out_s8 = self.conv4(out_s4) 214 | out_s8 = self.norm4(out_s8) 215 | out_s8 = F.relu(out_s8) 216 | out_s8 = self.blocks4(out_s8) 217 | 218 | out_s4_tr = self.conv4_tr(out_s8) 219 | out_s4_tr = self.norm4_tr(out_s4_tr) 220 | out_s4_tr = F.relu(out_s4_tr) 221 | out_s4_tr = self.blocks4_tr(out_s4_tr) 222 | 223 | out = torch.cat((out_s4_tr[:, :, :out_s4.shape[2], :out_s4.shape[3]], out_s4), 224 | dim=1) 225 | 226 | out_s2_tr = self.conv3_tr(out) 227 | out_s2_tr = self.norm3_tr(out_s2_tr) 228 | out_s2_tr = F.relu(out_s2_tr) 229 | out_s2_tr = self.blocks3_tr(out_s2_tr) 230 | 231 | out = torch.cat((out_s2_tr[:, :, :out_s2.shape[2], :out_s2.shape[3]], out_s2), 232 | dim=1) 233 | 234 | out_s1_tr = self.conv2_tr(out) 235 | out_s1_tr = self.norm2_tr(out_s1_tr) 236 | out_s1_tr = F.relu(out_s1_tr) 237 | out_s1_tr = self.blocks2_tr(out_s1_tr) 238 | 239 | out = torch.cat((out_s1_tr[:, :, :out_s1.shape[2], :out_s1.shape[3]], out_s1), 240 | dim=1) 241 | out = self.conv1_tr(out) 242 | out = F.relu(out) 243 | out = self.final(out) 244 | 245 | if self.normalize_feature: 246 | return out / (torch.norm(out, p=2, dim=1, keepdim=True) + 1e-8) 247 | else: 248 | return out 249 | 250 | 251 | class ResUNetBN2(ResUNet2): 252 | NORM_TYPE = 'BN' 253 | 254 | 255 | class ResUNetBN2B(ResUNet2): 256 | NORM_TYPE = 'BN' 257 | CHANNELS = [None, 32, 64, 128, 256] 258 | TR_CHANNELS = [None, 64, 64, 64, 64] 259 | 260 | 261 | class ResUNetBN2C(ResUNet2): 262 | NORM_TYPE = 'BN' 263 | CHANNELS = [None, 32, 64, 128, 256] 264 | TR_CHANNELS = [None, 64, 64, 64, 128] 265 | 266 | 267 | class ResUNetBN2D(ResUNet2): 268 | NORM_TYPE = 'BN' 269 | CHANNELS = [None, 32, 64, 128, 256] 270 | TR_CHANNELS = [None, 64, 64, 128, 128] 271 | 272 | 273 | class ResUNetBN2D2(ResUNet2): 274 | NORM_TYPE = 'BN' 275 | CHANNELS = [None, 32, 64, 128, 256] 276 | TR_CHANNELS = [None, 128, 128, 128, 128] 277 | 278 | 279 | class ResUNetBN2D3(ResUNet2): 280 | NORM_TYPE = 'BN' 281 | CHANNELS = [None, 32, 64, 128, 256] 282 | TR_CHANNELS = [None, 128, 128, 192, 192] 283 | 284 | 285 | class ResUNetBN2E(ResUNet2): 286 | NORM_TYPE = 'BN' 287 | CHANNELS = [None, 128, 128, 128, 256] 288 | TR_CHANNELS = [None, 64, 128, 128, 128] 289 | 290 | 291 | class ResUNetBN2F(ResUNet2): 292 | NORM_TYPE = 'BN' 293 | CHANNELS = [None, 16, 32, 64, 128] 294 | TR_CHANNELS = [None, 16, 32, 64, 128] 295 | 296 | 297 | class ResUNetBN2G(ResUNet2): 298 | NORM_TYPE = 'BN' 299 | CHANNELS = [None, 128, 128, 192, 256] 300 | TR_CHANNELS = [None, 128, 128, 192, 256] 301 | 302 | 303 | class ResUNetBN2G2(ResUNet2): 304 | NORM_TYPE = 'BN' 305 | CHANNELS = [None, 128, 128, 192, 256] 306 | TR_CHANNELS = [None, 192, 128, 192, 256] 307 | 308 | 309 | class ResUNetBN2G3(ResUNet2): 310 | NORM_TYPE = 'BN' 311 | CHANNELS = [None, 128, 128, 192, 256] 312 | TR_CHANNELS = [None, 192, 128, 192, 192] 313 | 314 | 315 | class ResUNetBN2H(ResUNet2): 316 | NORM_TYPE = 'BN' 317 | CHANNELS = [None, 128, 128, 192, 256] 318 | TR_CHANNELS = [None, 128, 128, 192, 256] 319 | DEPTHS = [2, 2, 2, 2, 2, 2, 2] 320 | 321 | 322 | MODELS = [ 323 | ResUNetBN2, ResUNetBN2B, ResUNetBN2C, ResUNetBN2D, ResUNetBN2D2, ResUNetBN2D3, 324 | ResUNetBN2E, ResUNetBN2F, ResUNetBN2G, ResUNetBN2H 325 | ] 326 | 327 | mdict = {model.__name__: model for model in MODELS} 328 | 329 | 330 | def load_model(name): 331 | if name in mdict.keys(): 332 | NetClass = mdict[name] 333 | return NetClass 334 | else: 335 | raise ValueError(f'{name} model does not exists in {mdict}') 336 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/HighDimConvNets/bd8f03150b4d639db61109a93c37f3be0dcaec38/util/__init__.py -------------------------------------------------------------------------------- /util/file.py: -------------------------------------------------------------------------------- 1 | import re 2 | from os import listdir 3 | from os.path import isfile, isdir, join, splitext 4 | 5 | import h5py 6 | 7 | 8 | def sorted_alphanum(file_list_ordered): 9 | convert = lambda text: int(text) if text.isdigit() else text 10 | alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] 11 | return sorted(file_list_ordered, key=alphanum_key) 12 | 13 | 14 | def get_file_list(path, extension=None): 15 | if extension is None: 16 | file_list = [join(path, f) for f in listdir(path) if isfile(join(path, f))] 17 | else: 18 | file_list = [ 19 | join(path, f) 20 | for f in listdir(path) 21 | if isfile(join(path, f)) and splitext(f)[1] == extension 22 | ] 23 | file_list = sorted_alphanum(file_list) 24 | return file_list 25 | 26 | 27 | def get_file_list_specific(path, string, extension=None): 28 | if extension is None: 29 | file_list = [join(path, f) for f in listdir(path) if isfile(join(path, f))] 30 | elif type(extension) == list: 31 | file_list = [ 32 | join(path, f) 33 | for f in listdir(path) 34 | if isfile(join(path, f)) and string in f and splitext(f)[1] in extension 35 | ] 36 | file_list = sorted_alphanum(file_list) 37 | else: 38 | file_list = [ 39 | join(path, f) 40 | for f in listdir(path) 41 | if isfile(join(path, f)) and string in f and splitext(f)[1] == extension 42 | ] 43 | file_list = sorted_alphanum(file_list) 44 | return file_list 45 | 46 | 47 | def get_folder_list(path): 48 | folder_list = [join(path, f) for f in listdir(path) if isdir(join(path, f))] 49 | folder_list = sorted_alphanum(folder_list) 50 | return folder_list 51 | 52 | 53 | def loadh5(path): 54 | """Load h5 file as dictionary 55 | 56 | Args: 57 | path (str): h5 file path 58 | 59 | Returns: 60 | dict_file (dict): loaded dictionary 61 | 62 | """ 63 | try: 64 | with h5py.File(path, "r") as h5file: 65 | return readh5(h5file) 66 | except Exception as e: 67 | print("Error while loading {}".format(path)) 68 | raise e 69 | 70 | 71 | def readh5(h5node): 72 | """Read h5 node recursively and loaded into a dict 73 | 74 | Args: 75 | h5node (h5py._hl.files.File): h5py File object 76 | 77 | Returns: 78 | dict_file (dict): loaded dictionary 79 | 80 | """ 81 | dict_file = {} 82 | for key in h5node.keys(): 83 | if type(h5node[key]) == h5py._hl.group.Group: 84 | dict_file[key] = readh5(h5node[key]) 85 | else: 86 | dict_file[key] = h5node[key][...] 87 | return dict_file 88 | 89 | 90 | def saveh5(dict_file, target_path): 91 | """Save dictionary as h5 file 92 | 93 | Args: 94 | dict_file (dict): dictionary to save 95 | target_path (str): target path string 96 | 97 | """ 98 | 99 | with h5py.File(target_path, "w") as h5file: 100 | if isinstance(dict_file, list): 101 | for i, d in enumerate(dict_file): 102 | newdict = {"dict" + str(i): d} 103 | writeh5(newdict, h5file) 104 | else: 105 | writeh5(dict_file, h5file) 106 | 107 | 108 | def writeh5(dict_file, h5node): 109 | """Write dictionaly recursively into h5py file 110 | 111 | Args: 112 | dict_file (dict): dictionary to write 113 | h5node (h5py._hl.file.File): target h5py file 114 | """ 115 | 116 | for key in dict_file.keys(): 117 | if isinstance(dict_file[key], dict): 118 | h5node.create_group(key) 119 | cur_grp = h5node[key] 120 | writeh5(dict_file[key], cur_grp) 121 | else: 122 | h5node[key] = dict_file[key] 123 | --------------------------------------------------------------------------------