├── cirtorch ├── __pycache__ │ └── functional.cpython-38.pyc └── functional.py ├── color.py ├── config ├── Cambridge_training_setup.yaml ├── DL3DV_indoor_training_setup.yaml ├── DL3DV_outdoor_training_setup.yaml ├── NEU_training_setup.yaml ├── SIASUN_training_setup.yaml ├── data │ ├── NEU.yaml │ ├── SIASUN.yaml │ ├── cambridge.yaml │ ├── dtu_indoor.yaml │ ├── dtu_outdoor.yaml │ └── shapenet.yaml ├── model │ └── model.yaml ├── shapnet_training_setup.yaml └── trainer │ └── trainer.yaml ├── copynerf.py ├── data ├── NEU.py ├── SIASUN.py ├── __init__.py ├── cambridge.py ├── dl3dv.py ├── poses_avg_stats │ ├── GreatCourt.txt │ ├── KingsCollege.txt │ ├── OldHospital.txt │ ├── ShopFacade.txt │ └── StMarysChurch.txt └── shapenet.py ├── datasets ├── Cambridge.py ├── __pycache__ │ └── pitts.cpython-38.pyc └── pitts.py ├── evaluation ├── __pycache__ │ └── pretrained_model.cpython-38.pyc ├── get_calibration_metrics.py ├── get_visual_output.py └── pretrained_model.py ├── getxy.py ├── load_LINEMOD.py ├── load_blender.py ├── load_deepvoxels.py ├── load_llff.py ├── main.py ├── model ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── code.cpython-38.pyc │ ├── encoder.cpython-38.pyc │ ├── loss_type.cpython-38.pyc │ ├── mlp.cpython-38.pyc │ ├── network.cpython-38.pyc │ └── renderer.cpython-38.pyc ├── code.py ├── encoder.py ├── loss_type.py ├── mlp.py ├── network.py ├── renderer.py └── testnetwork.py ├── nerf_init.py ├── netvlad.py ├── networks ├── CricaVPR.py ├── __init__.py ├── __pycache__ │ ├── LightViT_model.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ └── mobilenet.cpython-38.pyc ├── classification │ ├── .figures │ │ ├── efficientvit_main.gif │ │ ├── efficientvit_main_static.png │ │ └── modelACC_gpu.png │ ├── README.md │ ├── data │ │ ├── __init__.py │ │ ├── datasets.py │ │ ├── samplers.py │ │ └── threeaugment.py │ ├── engine.py │ ├── losses.py │ ├── main.py │ ├── model │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── build.cpython-38.pyc │ │ │ └── efficientvit.cpython-38.pyc │ │ ├── build.py │ │ └── efficientvit.py │ ├── requirements.txt │ ├── speed_test.py │ └── utils.py ├── efficientViT.py ├── eigenplaces.py ├── eigenplaces_model.zip ├── mixvpr.py ├── mobilenet.py ├── models.zip └── test.py ├── options.py ├── readme.md ├── run_nerf_helpers.py ├── script ├── config_dfnet.txt ├── config_dfnetdm.txt ├── config_nerfh.txt ├── dataset │ ├── UGNA_VPR_logo.png │ ├── calib.txt │ ├── pose.txt │ └── pose_00000.txt ├── dm │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── callbacks.cpython-38.pyc │ │ ├── direct_pose_model.cpython-38.pyc │ │ └── pose_model.cpython-38.pyc │ ├── callbacks.py │ ├── direct_pose_model.py │ ├── options.py │ ├── pose_model.py │ └── prepare_data.py ├── feature │ ├── __pycache__ │ │ ├── dfnet.cpython-38.pyc │ │ ├── misc.cpython-38.pyc │ │ └── options_nerf.cpython-38.pyc │ ├── dfnet.py │ ├── direct_feature_matching.py │ ├── efficientnet.py │ ├── misc.py │ ├── model.py │ └── options_nerf.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── losses.cpython-38.pyc │ │ ├── nerf.cpython-38.pyc │ │ ├── nerfw.cpython-38.pyc │ │ ├── options.cpython-38.pyc │ │ ├── ray_utils.cpython-38.pyc │ │ └── rendering.cpython-38.pyc │ ├── losses.py │ ├── metrics.py │ ├── nerf.py │ ├── nerfw.py │ ├── options.py │ ├── ray_utils.py │ └── rendering.py ├── nerf_random.py ├── nerf_random_wo.py ├── run_feature.py ├── run_nerf.py ├── train.py └── utils │ ├── __pycache__ │ ├── set_sys_path.cpython-38.pyc │ └── utils.cpython-38.pyc │ ├── set_sys_path.py │ └── utils.py ├── train_npn.py ├── trainer.py ├── training ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── trainer.cpython-38.pyc └── trainer.py ├── util.py ├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── data_augmentation.cpython-38.pyc │ ├── parser.cpython-38.pyc │ └── util.cpython-38.pyc ├── data_augmentation.py ├── json4txt.py ├── parser.py └── util.py └── vis.py /cirtorch/__pycache__/functional.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/UGNA-VPR/1ae3158d84b54affa011236af2cad95e364d8731/cirtorch/__pycache__/functional.cpython-38.pyc -------------------------------------------------------------------------------- /cirtorch/functional.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pdb 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | # -------------------------------------- 8 | # pooling 9 | # -------------------------------------- 10 | 11 | 12 | def mac(x): 13 | return F.max_pool2d(x, (x.size(-2), x.size(-1))) 14 | # return F.adaptive_max_pool2d(x, (1,1)) # alternative 15 | 16 | 17 | def spoc(x): 18 | return F.avg_pool2d(x, (x.size(-2), x.size(-1))) 19 | # return F.adaptive_avg_pool2d(x, (1,1)) # alternative 20 | 21 | 22 | def gem(x, p=3, eps=1e-6): 23 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1. / p) 24 | # return F.lp_pool2d(F.threshold(x, eps, eps), p, (x.size(-2), x.size(-1))) # alternative 25 | 26 | 27 | def rmac(x, L=3, eps=1e-6): 28 | ovr = 0.4 # desired overlap of neighboring regions 29 | steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension 30 | 31 | W = x.size(3) 32 | H = x.size(2) 33 | 34 | w = min(W, H) 35 | w2 = math.floor(w / 2.0 - 1) 36 | 37 | b = (max(H, W) - w) / (steps - 1) 38 | (tmp, idx) = torch.min(torch.abs(((w**2 - w * b) / w**2) - ovr), 0) # steps(idx) regions for long dimension 39 | 40 | # region overplus per dimension 41 | Wd = 0 42 | Hd = 0 43 | if H < W: 44 | Wd = idx.item() + 1 45 | elif H > W: 46 | Hd = idx.item() + 1 47 | 48 | v = F.max_pool2d(x, (x.size(-2), x.size(-1))) 49 | v = v / (torch.norm(v, p=2, dim=1, keepdim=True) + eps).expand_as(v) 50 | 51 | for l in range(1, L + 1): 52 | wl = math.floor(2 * w / (l + 1)) 53 | wl2 = math.floor(wl / 2 - 1) 54 | 55 | if l + Wd == 1: 56 | b = 0 57 | else: 58 | b = (W - wl) / (l + Wd - 1) 59 | cenW = torch.floor(wl2 + torch.Tensor(range(l - 1 + Wd + 1)) * b) - wl2 # center coordinates 60 | if l + Hd == 1: 61 | b = 0 62 | else: 63 | b = (H - wl) / (l + Hd - 1) 64 | cenH = torch.floor(wl2 + torch.Tensor(range(l - 1 + Hd + 1)) * b) - wl2 # center coordinates 65 | 66 | for i_ in cenH.tolist(): 67 | for j_ in cenW.tolist(): 68 | if wl == 0: 69 | continue 70 | R = x[:, :, (int(i_) + torch.Tensor(range(wl)).long()).tolist(), :] 71 | R = R[:, :, :, (int(j_) + torch.Tensor(range(wl)).long()).tolist()] 72 | vt = F.max_pool2d(R, (R.size(-2), R.size(-1))) 73 | vt = vt / (torch.norm(vt, p=2, dim=1, keepdim=True) + eps).expand_as(vt) 74 | v += vt 75 | 76 | return v 77 | 78 | 79 | def roipool(x, rpool, L=3, eps=1e-6): 80 | ovr = 0.4 # desired overlap of neighboring regions 81 | steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension 82 | 83 | W = x.size(3) 84 | H = x.size(2) 85 | 86 | w = min(W, H) 87 | w2 = math.floor(w / 2.0 - 1) 88 | 89 | b = (max(H, W) - w) / (steps - 1) 90 | _, idx = torch.min(torch.abs(((w**2 - w * b) / w**2) - ovr), 0) # steps(idx) regions for long dimension 91 | 92 | # region overplus per dimension 93 | Wd = 0 94 | Hd = 0 95 | if H < W: 96 | Wd = idx.item() + 1 97 | elif H > W: 98 | Hd = idx.item() + 1 99 | 100 | vecs = [] 101 | vecs.append(rpool(x).unsqueeze(1)) 102 | 103 | for l in range(1, L + 1): 104 | wl = math.floor(2 * w / (l + 1)) 105 | wl2 = math.floor(wl / 2 - 1) 106 | 107 | if l + Wd == 1: 108 | b = 0 109 | else: 110 | b = (W - wl) / (l + Wd - 1) 111 | cenW = torch.floor(wl2 + torch.Tensor(range(l - 1 + Wd + 1)) * b).int() - wl2 # center coordinates 112 | if l + Hd == 1: 113 | b = 0 114 | else: 115 | b = (H - wl) / (l + Hd - 1) 116 | cenH = torch.floor(wl2 + torch.Tensor(range(l - 1 + Hd + 1)) * b).int() - wl2 # center coordinates 117 | 118 | for i_ in cenH.tolist(): 119 | for j_ in cenW.tolist(): 120 | if wl == 0: 121 | continue 122 | vecs.append(rpool(x.narrow(2, i_, wl).narrow(3, j_, wl)).unsqueeze(1)) 123 | 124 | return torch.cat(vecs, dim=1) 125 | 126 | 127 | # -------------------------------------- 128 | # normalization 129 | # -------------------------------------- 130 | 131 | 132 | def l2n(x, eps=1e-6): 133 | return x / (torch.norm(x, p=2, dim=1, keepdim=True) + eps).expand_as(x) 134 | 135 | 136 | def powerlaw(x, eps=1e-6): 137 | x = x + self.eps 138 | return x.abs().sqrt().mul(x.sign()) 139 | 140 | 141 | # -------------------------------------- 142 | # loss 143 | # -------------------------------------- 144 | 145 | 146 | def contrastive_loss(x, label, margin=0.7, eps=1e-6): 147 | # x is D x N 148 | dim = x.size(0) # D 149 | nq = torch.sum(label.data == -1) # number of tuples 150 | S = x.size(1) // nq # number of images per tuple including query: 1+1+n 151 | 152 | x1 = x[:, ::S].permute(1, 0).repeat(1, S - 1).view((S - 1) * nq, dim).permute(1, 0) 153 | idx = [i for i in range(len(label)) if label.data[i] != -1] 154 | x2 = x[:, idx] 155 | lbl = label[label != -1] 156 | 157 | dif = x1 - x2 158 | D = torch.pow(dif + eps, 2).sum(dim=0).sqrt() 159 | 160 | y = 0.5 * lbl * torch.pow(D, 2) + 0.5 * (1 - lbl) * torch.pow(torch.clamp(margin - D, min=0), 2) 161 | y = torch.sum(y) 162 | return y 163 | 164 | 165 | def triplet_loss(x, label, margin=0.1): 166 | # x is D x N 167 | dim = x.size(0) # D 168 | nq = torch.sum(label.data == -1).item() # number of tuples 169 | S = x.size(1) // nq # number of images per tuple including query: 1+1+n 170 | 171 | xa = x[:, label.data == -1].permute(1, 0).repeat(1, S - 2).view((S - 2) * nq, dim).permute(1, 0) 172 | xp = x[:, label.data == 1].permute(1, 0).repeat(1, S - 2).view((S - 2) * nq, dim).permute(1, 0) 173 | xn = x[:, label.data == 0] 174 | 175 | dist_pos = torch.sum(torch.pow(xa - xp, 2), dim=0) 176 | dist_neg = torch.sum(torch.pow(xa - xn, 2), dim=0) 177 | 178 | return torch.sum(torch.clamp(dist_pos - dist_neg + margin, min=0)) -------------------------------------------------------------------------------- /color.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def rgb_to_yuv(image: torch.Tensor) -> torch.Tensor: 5 | r""" 6 | From Kornia. 7 | Convert an RGB image to YUV. 8 | 9 | .. image:: _static/img/rgb_to_yuv.png 10 | 11 | The image data is assumed to be in the range of (0, 1). 12 | 13 | Args: 14 | image: RGB Image to be converted to YUV with shape :math:`(*, 3, H, W)`. 15 | 16 | Returns: 17 | YUV version of the image with shape :math:`(*, 3, H, W)`. 18 | 19 | Example: 20 | >>> input = torch.rand(2, 3, 4, 5) 21 | >>> output = rgb_to_yuv(input) # 2x3x4x5 22 | """ 23 | if not isinstance(image, torch.Tensor): 24 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") 25 | 26 | if len(image.shape) < 3 or image.shape[-3] != 3: 27 | raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") 28 | 29 | r: torch.Tensor = image[..., 0, :, :] 30 | g: torch.Tensor = image[..., 1, :, :] 31 | b: torch.Tensor = image[..., 2, :, :] 32 | 33 | y: torch.Tensor = 0.299 * r + 0.587 * g + 0.114 * b 34 | u: torch.Tensor = -0.147 * r - 0.289 * g + 0.436 * b 35 | v: torch.Tensor = 0.615 * r - 0.515 * g - 0.100 * b 36 | 37 | out: torch.Tensor = torch.stack([y, u, v], -3) 38 | 39 | return out 40 | 41 | if __name__ == "__main__": 42 | img = torch.rand(58, 3, 224, 224) 43 | out = rgb_to_yuv(img) 44 | print(out.shape) -------------------------------------------------------------------------------- /config/Cambridge_training_setup.yaml: -------------------------------------------------------------------------------- 1 | data_cfg_path: config/data/cambridge.yaml 2 | model_cfg_path: config/model/model.yaml 3 | trainer_cfg_path: config/trainer/trainer.yaml 4 | -------------------------------------------------------------------------------- /config/DL3DV_indoor_training_setup.yaml: -------------------------------------------------------------------------------- 1 | data_cfg_path: config/data/dtu_indoor.yaml 2 | model_cfg_path: config/model/model.yaml 3 | trainer_cfg_path: config/trainer/trainer.yaml 4 | -------------------------------------------------------------------------------- /config/DL3DV_outdoor_training_setup.yaml: -------------------------------------------------------------------------------- 1 | data_cfg_path: config/data/dtu_outdoor.yaml 2 | model_cfg_path: config/model/model.yaml 3 | trainer_cfg_path: config/trainer/trainer.yaml 4 | -------------------------------------------------------------------------------- /config/NEU_training_setup.yaml: -------------------------------------------------------------------------------- 1 | data_cfg_path: config/data/NEU.yaml 2 | model_cfg_path: config/model/model.yaml 3 | trainer_cfg_path: config/trainer/trainer.yaml 4 | -------------------------------------------------------------------------------- /config/SIASUN_training_setup.yaml: -------------------------------------------------------------------------------- 1 | data_cfg_path: config/data/SIASUN.yaml 2 | model_cfg_path: config/model/model.yaml 3 | trainer_cfg_path: config/trainer/trainer.yaml 4 | -------------------------------------------------------------------------------- /config/data/NEU.yaml: -------------------------------------------------------------------------------- 1 | name: NEU 2 | batch_size: 4 3 | shuffle: true 4 | num_workers: 0 5 | 6 | dataset: 7 | data_rootdir: data/dataset/NEU 8 | max_imgs: 10 9 | image_size: [224, 224] # H W 10 | z_near: 0.1 11 | z_far: 3.5 12 | format: opencv 13 | 14 | data_augmentation: 15 | color_jitter: 16 | hue_range: 0.1 17 | saturation_range: 0.1 18 | brightness_range: 0.1 19 | contrast_range: 0.1 20 | -------------------------------------------------------------------------------- /config/data/SIASUN.yaml: -------------------------------------------------------------------------------- 1 | name: SIASUN 2 | batch_size: 4 3 | shuffle: true 4 | num_workers: 0 5 | 6 | dataset: 7 | data_rootdir: data/dataset/SIASUN 8 | max_imgs: 10 9 | image_size: [224, 224] # H W 10 | z_near: 0.1 11 | z_far: 3.5 12 | format: opencv 13 | 14 | data_augmentation: 15 | color_jitter: 16 | hue_range: 0.1 17 | saturation_range: 0.1 18 | brightness_range: 0.1 19 | contrast_range: 0.1 20 | -------------------------------------------------------------------------------- /config/data/cambridge.yaml: -------------------------------------------------------------------------------- 1 | name: Cambridge 2 | batch_size: 4 3 | shuffle: true 4 | num_workers: 0 5 | 6 | dataset: 7 | data_rootdir: data/dataset/Cambridge 8 | max_imgs: 10 9 | image_size: [224, 224] # H W 10 | z_near: 0.1 11 | z_far: 3.5 12 | format: opencv 13 | 14 | data_augmentation: 15 | color_jitter: 16 | hue_range: 0.1 17 | saturation_range: 0.1 18 | brightness_range: 0.1 19 | contrast_range: 0.1 20 | -------------------------------------------------------------------------------- /config/data/dtu_indoor.yaml: -------------------------------------------------------------------------------- 1 | name: DL3DV 2 | batch_size: 4 3 | shuffle: true 4 | num_workers: 0 5 | 6 | dataset: 7 | data_rootdir: data/dataset/DL3DV/DL3DV_indoor 8 | max_imgs: 320 9 | image_size: [224, 224] # H W 10 | z_near: 0.1 11 | z_far: 3.5 12 | format: opencv 13 | 14 | data_augmentation: 15 | color_jitter: 16 | hue_range: 0.1 17 | saturation_range: 0.1 18 | brightness_range: 0.1 19 | contrast_range: 0.1 20 | -------------------------------------------------------------------------------- /config/data/dtu_outdoor.yaml: -------------------------------------------------------------------------------- 1 | name: DL3DV 2 | batch_size: 4 3 | shuffle: true 4 | num_workers: 0 5 | 6 | dataset: 7 | data_rootdir: data/dataset/DL3DV/DL3DV_outdoor 8 | max_imgs: 300 9 | image_size: [224, 224] # H W 10 | z_near: 0.1 11 | z_far: 3.5 12 | format: opencv 13 | 14 | data_augmentation: 15 | color_jitter: 16 | hue_range: 0.1 17 | saturation_range: 0.1 18 | brightness_range: 0.1 19 | contrast_range: 0.1 20 | -------------------------------------------------------------------------------- /config/data/shapenet.yaml: -------------------------------------------------------------------------------- 1 | name: shapenet 2 | batch_size: 4 3 | shuffle: true 4 | num_workers: 0 5 | 6 | dataset: 7 | data_rootdir: data/dataset/shapenet 8 | image_size: [200, 200] # H W 9 | z_near: 0.1 10 | z_far: 4.0 11 | format: normal 12 | 13 | data_augmentation: 14 | color_jitter: 15 | hue_range: 0.1 16 | saturation_range: 0.1 17 | brightness_range: 0.1 18 | contrast_range: 0.1 -------------------------------------------------------------------------------- /config/model/model.yaml: -------------------------------------------------------------------------------- 1 | network: 2 | encoder: 3 | backbone: resnet34 4 | pretrained: true 5 | num_layers: 3 6 | index_interp: bilinear 7 | index_padding: border 8 | upsample_interp: bilinear 9 | use_first_pool: true 10 | norm_type: batch 11 | 12 | mlp: 13 | mlp_feature: 14 | d_latent: 256 15 | d_feature: 128 16 | use_encoding: true 17 | use_view: true 18 | block_num: 2 19 | positional_encoding: 20 | num_freqs: 6 21 | d_in: 3 22 | include_input: true 23 | freq_factor: 2 24 | 25 | mlp_output: 26 | d_feature: 256 27 | block_num: 2 28 | d_out: 1024 29 | 30 | renderer: 31 | d_in: 256 32 | d_hidden: 128 33 | raymarch_steps: 16 34 | trainable: false 35 | use_encoding: false 36 | positional_encoding: 37 | num_freqs: 6 38 | d_in: 3 39 | include_input: true 40 | freq_factor: 2 -------------------------------------------------------------------------------- /config/shapnet_training_setup.yaml: -------------------------------------------------------------------------------- 1 | data_cfg_path: config/data/shapenet.yaml 2 | model_cfg_path: config/model/model.yaml 3 | trainer_cfg_path: config/trainer/trainer.yaml 4 | -------------------------------------------------------------------------------- /config/trainer/trainer.yaml: -------------------------------------------------------------------------------- 1 | tracking_metric: min 2 | num_epoch_repeats: 16 3 | num_epochs: 1000 4 | ray_batch_size: 1200 5 | nviews: [3, 4, 5] 6 | use_data_augmentation: true 7 | freeze_encoder: false 8 | print_interval: 5 9 | save_interval: 1 10 | vis_interval: 1 11 | vis_repeat: 10 12 | gpu_id: [0] 13 | 14 | loss: 15 | loss_type: logit 16 | rgb_loss_type: mse 17 | 18 | optimizer: 19 | learning_rate: 1.0e-4 20 | gamma: 0.999 21 | 22 | -------------------------------------------------------------------------------- /copynerf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | 5 | # 源文件夹路径 6 | source_folder = "C:/Users/65309/Desktop/nerfCambridge/GreatCourt" 7 | # 目标文件夹路径1,用于存放3/10的文件 8 | destination_folder1 = "C:/Users/65309/Desktop/nerfCambridge/query/GreatCourt" 9 | # 目标文件夹路径2,用于存放7/10的文件 10 | destination_folder2 = "C:/Users/65309/Desktop/nerfCambridge/database/GreatCourt" 11 | 12 | 13 | image_files = [f for f in os.listdir(source_folder) if f.endswith('.jpg') or f.endswith('.png')] 14 | # 计算要选择的文件数量 15 | total_files = len(image_files) 16 | num_files_to_copy1 = total_files * 3 // 10 17 | num_files_to_copy2 = total_files - num_files_to_copy1 18 | 19 | # 创建目标文件夹 20 | os.makedirs(destination_folder1, exist_ok=True) 21 | os.makedirs(destination_folder2, exist_ok=True) 22 | 23 | # 获取所有文件列表 24 | files_list = list(range(total_files)) 25 | 26 | # 随机选择要复制的文件 27 | files_to_copy1 = random.sample(files_list, num_files_to_copy1) 28 | files_to_copy2 = [file_index for file_index in files_list if file_index not in files_to_copy1] 29 | 30 | # 复制3/10的文件 31 | for file_index in files_to_copy1: 32 | image_name = f"{file_index:05d}.png" 33 | txt_name = f"pose_{file_index:05d}.txt" 34 | source_path = os.path.join(source_folder, image_name) 35 | dest_path = os.path.join(destination_folder1, image_name) 36 | shutil.copyfile(source_path, dest_path) 37 | source_path = os.path.join(source_folder, txt_name) 38 | dest_path = os.path.join(destination_folder1, txt_name) 39 | shutil.copyfile(source_path, dest_path) 40 | 41 | print(f"{num_files_to_copy1} files copied to the first destination folder.") 42 | 43 | # 复制剩余的7/10的文件 44 | for file_index in files_to_copy2: 45 | image_name = f"{file_index:05d}.png" 46 | txt_name = f"pose_{file_index:05d}.txt" 47 | source_path = os.path.join(source_folder, image_name) 48 | dest_path = os.path.join(destination_folder2, image_name) 49 | shutil.copyfile(source_path, dest_path) 50 | source_path = os.path.join(source_folder, txt_name) 51 | dest_path = os.path.join(destination_folder2, txt_name) 52 | shutil.copyfile(source_path, dest_path) 53 | 54 | print(f"{num_files_to_copy2} files copied to the second destination folder.") 55 | -------------------------------------------------------------------------------- /data/NEU.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import json 4 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader, Dataset 8 | import glob 9 | import imageio 10 | import numpy as np 11 | import cv2 12 | from utils.util import get_image_to_tensor_balanced, coordinate_transformation 13 | from utils.data_augmentation import get_transformation 14 | 15 | 16 | class NEUDataModule: 17 | def __init__(self, cfg): 18 | self.batch_size = cfg["batch_size"] 19 | self.shuffle = cfg["shuffle"] 20 | self.num_workers = cfg["num_workers"] 21 | 22 | self.dataset_cfg = cfg["dataset"] 23 | self.data_augmentation = cfg["data_augmentation"] 24 | 25 | def load_dataset(self, mode, use_data_augmentation=False, scene_list=None): 26 | self.mode = mode 27 | self.dataset_cfg["mode"] = mode 28 | self.dataset_cfg["scene_list"] = scene_list 29 | 30 | if use_data_augmentation: 31 | self.dataset_cfg["transformation"] = self.data_augmentation 32 | else: 33 | self.dataset_cfg["transformation"] = None 34 | 35 | return NEUDataset.init_from_cfg(self.dataset_cfg) 36 | 37 | def get_dataloader(self, dataset): 38 | batch_size = self.batch_size 39 | shuffle = self.shuffle 40 | num_workers = self.num_workers 41 | 42 | if self.mode == "test": 43 | batch_size = 1 44 | shuffle = False 45 | num_workers = 0 46 | 47 | dataloader = DataLoader( 48 | dataset, 49 | batch_size=batch_size, 50 | shuffle=shuffle, 51 | num_workers=num_workers, 52 | ) 53 | return dataloader 54 | 55 | 56 | class NEUDataset(Dataset): 57 | def __init__( 58 | self, 59 | mode, 60 | data_rootdir, 61 | max_imgs, 62 | image_size, 63 | z_near, 64 | z_far, 65 | trans_cfg, 66 | dataset_format, 67 | scene_list, 68 | ): 69 | """ 70 | Inits DTU dataset instance 71 | 72 | Args: 73 | mode: either train, val or test 74 | data_rootdir: root directory of dataset 75 | max_imgs: maximal images for the object 76 | image_size: [H, W] pixels 77 | z_near: minimal distance of the object 78 | z_far: maximal distance of the object 79 | trans_cfg: configurations for data augmentations(transformation) 80 | dataset_formate: the coordinate system the original dataset uses 81 | """ 82 | 83 | super().__init__() 84 | self.max_imgs = max_imgs 85 | self.image_size = image_size 86 | self.z_near = z_near 87 | self.z_far = z_far 88 | self.dataset_format = dataset_format 89 | self.rootdir=data_rootdir 90 | 91 | self.transformations = [] 92 | if trans_cfg is not None: 93 | self.transformations = get_transformation(trans_cfg) 94 | 95 | assert os.path.exists(data_rootdir) 96 | file_list = os.path.join(data_rootdir, f"{mode}.lst") 97 | assert os.path.exists(file_list) 98 | base_dir = os.path.dirname(file_list) 99 | if scene_list is None: 100 | with open(file_list, "r") as f: 101 | self.scene_list = [x.strip() for x in f.readlines()] 102 | else: 103 | self.scene_list = [f"scan{x}" for x in scene_list] 104 | 105 | self.objs_path = [os.path.join(base_dir, scene) for scene in self.scene_list] 106 | 107 | self.image_to_tensor = get_image_to_tensor_balanced() 108 | 109 | def __len__(self): 110 | return len(self.objs_path) 111 | 112 | def __getitem__(self, index): 113 | scan_name = self.scene_list[index] 114 | root_dir = self.objs_path[index] 115 | rgb_paths = [ 116 | x 117 | for x in glob.glob(os.path.join(root_dir, "images", "*")) 118 | if (x.endswith(".jpg") or x.endswith(".png")) 119 | ] 120 | 121 | rgb_paths = sorted(rgb_paths) 122 | #print(len(rgb_paths)) 123 | #print(self.max_imgs) 124 | if len(rgb_paths) <= self.max_imgs: 125 | sel_indices = np.arange(len(rgb_paths)) 126 | else: 127 | sel_indices = np.random.choice(len(rgb_paths), self.max_imgs, replace=False) 128 | rgb_paths = [rgb_paths[i] for i in sel_indices] 129 | 130 | transforms=[] 131 | 132 | poses_folder=os.path.join(self.rootdir, scan_name, "poses") 133 | 134 | Fx=211 135 | Fy=211 136 | Cx=240 137 | Cy=135 138 | for file_name in os.listdir(poses_folder): 139 | if file_name.endswith('.txt'): 140 | file_path = os.path.join(poses_folder, file_name) 141 | # 读取txt文件中的位姿数据 142 | with open(file_path, 'r') as file: 143 | lines = file.readlines() 144 | pose_matrix = np.zeros((4, 4)) 145 | for i, line in enumerate(lines): 146 | values = [float(val) for val in line.strip().split()] 147 | pose_matrix[i, :] = values 148 | 149 | # 将位姿数据矩阵添加到列表中 150 | transforms.append(pose_matrix) 151 | 152 | #cam_path = os.path.join(root_dir, "cameras.npz") 153 | #all_cam = np.load(cam_path) 154 | all_imgs = [] 155 | all_poses = [] 156 | focal = None 157 | fx, fy, cx, cy = 0.0, 0.0, 0.0, 0.0 158 | 159 | for idx, rgb_path in enumerate(rgb_paths): 160 | i = sel_indices[idx] 161 | img = imageio.imread(rgb_path)[..., :3] 162 | 163 | # decompose projection matrix 164 | P = transforms[i] 165 | fx += Fx 166 | fy += Fy 167 | cx += Cx 168 | cy += Cy 169 | 170 | pose = np.eye(4, dtype=np.float32) 171 | pose = P 172 | 173 | 174 | #scale_mtx = all_cam.get("scale_mat_" + str(i)) 175 | #if scale_mtx is not None: 176 | #norm_trans = scale_mtx[:3, 3:] 177 | #norm_scale = np.diagonal(scale_mtx[:3, :3])[..., None] 178 | #pose[:3, 3:] -= norm_trans 179 | #pose[:3, 3:] /= norm_scale 180 | 181 | # camera poses in world coordinate 182 | pose = coordinate_transformation(pose, format=self.dataset_format) 183 | img_tensor = self.image_to_tensor(img) 184 | all_imgs.append(img_tensor) 185 | all_poses.append(pose)#pose is c2w 186 | 187 | # get average intrinsics for one object 188 | fx /= len(rgb_paths) 189 | fy /= len(rgb_paths) 190 | cx /= len(rgb_paths) 191 | cy /= len(rgb_paths) 192 | focal = torch.tensor((fx, fy), dtype=torch.float32) 193 | c = torch.tensor((cx, cy), dtype=torch.float32) 194 | 195 | all_imgs = torch.stack(all_imgs) 196 | all_poses = torch.stack(all_poses) 197 | print("all_img_shape") 198 | print(all_imgs.shape) 199 | # resize images if given image size is not euqal to original size 200 | if np.any(np.array(all_imgs.shape[-2:]) != self.image_size): 201 | scale_h = self.image_size[0] / all_imgs.shape[-2] 202 | scale_w = self.image_size[1] / all_imgs.shape[-1] 203 | print(self.image_size[0]) 204 | print(all_imgs.shape[-2]) 205 | print(self.image_size[1]) 206 | print(all_imgs.shape[-1]) 207 | focal[0] *= scale_w 208 | focal[1] *= scale_h 209 | c[0] *= scale_w 210 | c[1] *= scale_h 211 | print("all_img_shape") 212 | print(all_imgs.shape) 213 | print(type(self.image_size)) 214 | all_imgs = F.interpolate(all_imgs, size=self.image_size, mode="area") 215 | print(all_imgs.shape) 216 | # aplly data augmentations 217 | for transformer in self.transformations: 218 | all_imgs = transformer(all_imgs) 219 | data_instance = { 220 | "scan_name": scan_name, 221 | "path": root_dir, 222 | "img_id": index, 223 | "focal": focal, 224 | "c": c, 225 | "images": all_imgs, 226 | "poses": all_poses, 227 | } 228 | return data_instance 229 | 230 | @classmethod 231 | def init_from_cfg(cls, cfg): 232 | return cls( 233 | mode=cfg["mode"], 234 | data_rootdir=cfg["data_rootdir"], 235 | max_imgs=cfg["max_imgs"], 236 | image_size=cfg["image_size"], 237 | z_near=cfg["z_near"], 238 | z_far=cfg["z_far"], 239 | trans_cfg=cfg["transformation"], 240 | dataset_format=cfg["format"], 241 | scene_list=cfg["scene_list"], 242 | ) 243 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dl3dv import DL3DVDataModule 2 | from .shapenet import ShapenetDataModule 3 | from .cambridge import CamDataModule 4 | from .NEU import NEUDataModule 5 | from .SIASUN import SIASUNDataModule 6 | def get_data(cfg): 7 | dataset_name = cfg["name"] 8 | 9 | if dataset_name == "NEU": 10 | print(f"loading NEU dataset \n") 11 | return NEUDataModule(cfg) 12 | elif dataset_name == "Cambridge": 13 | print(f"loading Cambridge dataset \n") 14 | return CamDataModule(cfg) 15 | elif dataset_name == "SIASUN": 16 | print(f"loading Cambridge dataset \n") 17 | return SIASUNDataModule(cfg) 18 | elif dataset_name == "shapenet": 19 | print(f"loading shapenet dataset \n") 20 | return ShapenetDataModule(cfg) 21 | else: 22 | RuntimeError("dataset is not implemeneted!") 23 | -------------------------------------------------------------------------------- /data/dl3dv.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import json 4 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader, Dataset 8 | import glob 9 | import imageio 10 | import numpy as np 11 | import cv2 12 | from utils.util import get_image_to_tensor_balanced, coordinate_transformation 13 | from utils.data_augmentation import get_transformation 14 | 15 | 16 | class DL3DVDataModule: 17 | def __init__(self, cfg): 18 | self.batch_size = cfg["batch_size"] 19 | self.shuffle = cfg["shuffle"] 20 | self.num_workers = cfg["num_workers"] 21 | 22 | self.dataset_cfg = cfg["dataset"] 23 | self.data_augmentation = cfg["data_augmentation"] 24 | 25 | def load_dataset(self, mode, use_data_augmentation=False, scene_list=None): 26 | self.mode = mode 27 | self.dataset_cfg["mode"] = mode 28 | self.dataset_cfg["scene_list"] = scene_list 29 | 30 | if use_data_augmentation: 31 | self.dataset_cfg["transformation"] = self.data_augmentation 32 | else: 33 | self.dataset_cfg["transformation"] = None 34 | 35 | return DL3DVDataset.init_from_cfg(self.dataset_cfg) 36 | 37 | def get_dataloader(self, dataset): 38 | batch_size = self.batch_size 39 | shuffle = self.shuffle 40 | num_workers = self.num_workers 41 | 42 | if self.mode == "test": 43 | batch_size = 1 44 | shuffle = False 45 | num_workers = 0 46 | 47 | dataloader = DataLoader( 48 | dataset, 49 | batch_size=batch_size, 50 | shuffle=shuffle, 51 | num_workers=num_workers, 52 | ) 53 | return dataloader 54 | 55 | 56 | class DL3DVDataset(Dataset): 57 | def __init__( 58 | self, 59 | mode, 60 | data_rootdir, 61 | max_imgs, 62 | image_size, 63 | z_near, 64 | z_far, 65 | trans_cfg, 66 | dataset_format, 67 | scene_list, 68 | ): 69 | """ 70 | Inits DTU dataset instance 71 | 72 | Args: 73 | mode: either train, val or test 74 | data_rootdir: root directory of dataset 75 | max_imgs: maximal images for the object 76 | image_size: [H, W] pixels 77 | z_near: minimal distance of the object 78 | z_far: maximal distance of the object 79 | trans_cfg: configurations for data augmentations(transformation) 80 | dataset_formate: the coordinate system the original dataset uses 81 | """ 82 | 83 | super().__init__() 84 | self.max_imgs = max_imgs 85 | self.image_size = image_size 86 | self.z_near = z_near 87 | self.z_far = z_far 88 | self.dataset_format = dataset_format 89 | self.rootdir=data_rootdir 90 | 91 | self.transformations = [] 92 | if trans_cfg is not None: 93 | self.transformations = get_transformation(trans_cfg) 94 | 95 | assert os.path.exists(data_rootdir) 96 | file_list = os.path.join(data_rootdir, f"{mode}.lst") 97 | assert os.path.exists(file_list) 98 | base_dir = os.path.dirname(file_list) 99 | if scene_list is None: 100 | with open(file_list, "r") as f: 101 | self.scene_list = [x.strip() for x in f.readlines()] 102 | else: 103 | self.scene_list = [f"scan{x}" for x in scene_list] 104 | 105 | self.objs_path = [os.path.join(base_dir, scene) for scene in self.scene_list] 106 | 107 | self.image_to_tensor = get_image_to_tensor_balanced() 108 | 109 | def __len__(self): 110 | return len(self.objs_path) 111 | 112 | def __getitem__(self, index): 113 | scan_name = self.scene_list[index] 114 | root_dir = self.objs_path[index] 115 | rgb_paths = [ 116 | x 117 | for x in glob.glob(os.path.join(root_dir, "images", "*")) 118 | if (x.endswith(".jpg") or x.endswith(".png")) 119 | ] 120 | print(root_dir) 121 | rgb_paths = sorted(rgb_paths) 122 | 123 | if len(rgb_paths) <= self.max_imgs: 124 | sel_indices = np.arange(len(rgb_paths)) 125 | else: 126 | sel_indices = np.random.choice(len(rgb_paths), self.max_imgs, replace=False) 127 | rgb_paths = [rgb_paths[i] for i in sel_indices] 128 | 129 | transforms=[] 130 | json_file=os.path.join(self.rootdir, scan_name, "transforms.json") 131 | with open(json_file, 'r') as f: 132 | data = json.load(f) 133 | Fx=data["fl_x"]/8 134 | Fy=data["fl_y"]/8 135 | Cx=data["cx"]/8 136 | Cy=data["cy"]/8 137 | for frame in data['frames']: 138 | transform_matrix = np.array(frame['transform_matrix']) 139 | transforms.append(transform_matrix) 140 | 141 | #cam_path = os.path.join(root_dir, "cameras.npz") 142 | #all_cam = np.load(cam_path) 143 | all_imgs = [] 144 | all_poses = [] 145 | focal = None 146 | fx, fy, cx, cy = 0.0, 0.0, 0.0, 0.0 147 | 148 | for idx, rgb_path in enumerate(rgb_paths): 149 | i = sel_indices[idx] 150 | img = imageio.imread(rgb_path)[..., :3] 151 | 152 | # decompose projection matrix 153 | P = transforms[i] 154 | fx += Fx 155 | fy += Fy 156 | cx += Cx 157 | cy += Cy 158 | 159 | pose = np.eye(4, dtype=np.float32) 160 | pose = P 161 | 162 | 163 | #scale_mtx = all_cam.get("scale_mat_" + str(i)) 164 | #if scale_mtx is not None: 165 | #norm_trans = scale_mtx[:3, 3:] 166 | #norm_scale = np.diagonal(scale_mtx[:3, :3])[..., None] 167 | #pose[:3, 3:] -= norm_trans 168 | #pose[:3, 3:] /= norm_scale 169 | 170 | # camera poses in world coordinate 171 | pose = coordinate_transformation(pose, format=self.dataset_format) 172 | img_tensor = self.image_to_tensor(img) 173 | all_imgs.append(img_tensor) 174 | all_poses.append(pose)#pose is c2w 175 | 176 | # get average intrinsics for one object 177 | fx /= len(rgb_paths) 178 | fy /= len(rgb_paths) 179 | cx /= len(rgb_paths) 180 | cy /= len(rgb_paths) 181 | focal = torch.tensor((fx, fy), dtype=torch.float32) 182 | c = torch.tensor((cx, cy), dtype=torch.float32) 183 | 184 | all_imgs = torch.stack(all_imgs) 185 | all_poses = torch.stack(all_poses) 186 | 187 | # resize images if given image size is not euqal to original size 188 | if np.any(np.array(all_imgs.shape[-2:]) != self.image_size): 189 | scale_h = self.image_size[0] / all_imgs.shape[-2] 190 | scale_w = self.image_size[1] / all_imgs.shape[-1] 191 | focal[0] *= scale_w 192 | focal[1] *= scale_h 193 | c[0] *= scale_w 194 | c[1] *= scale_h 195 | all_imgs = F.interpolate(all_imgs, size=self.image_size, mode="area") 196 | 197 | 198 | # aplly data augmentations 199 | for transformer in self.transformations: 200 | all_imgs = transformer(all_imgs) 201 | data_instance = { 202 | "scan_name": scan_name, 203 | "path": root_dir, 204 | "img_id": index, 205 | "focal": focal, 206 | "c": c, 207 | "images": all_imgs, 208 | "poses": all_poses, 209 | } 210 | return data_instance 211 | 212 | @classmethod 213 | def init_from_cfg(cls, cfg): 214 | return cls( 215 | mode=cfg["mode"], 216 | data_rootdir=cfg["data_rootdir"], 217 | max_imgs=cfg["max_imgs"], 218 | image_size=cfg["image_size"], 219 | z_near=cfg["z_near"], 220 | z_far=cfg["z_far"], 221 | trans_cfg=cfg["transformation"], 222 | dataset_format=cfg["format"], 223 | scene_list=cfg["scene_list"], 224 | ) 225 | -------------------------------------------------------------------------------- /data/poses_avg_stats/GreatCourt.txt: -------------------------------------------------------------------------------- 1 | 3.405292754461299309e-01 4.953070871398960184e-01 7.991937825040464904e-01 4.704508373014760281e+01 2 | -9.402316354416121458e-01 1.812586354202151695e-01 2.882876667504056800e-01 3.467785281210451842e+01 3 | -2.069849976503225410e-03 -8.495976674371187309e-01 5.274272643753655787e-01 1.101080132352710184e+00 4 | -------------------------------------------------------------------------------- /data/poses_avg_stats/KingsCollege.txt: -------------------------------------------------------------------------------- 1 | 9.995083419588323137e-01 -1.453974655309233331e-02 2.777895111190991154e-02 2.004095163645802913e+01 2 | -2.395968310872182219e-02 2.172811532927548528e-01 9.758149588979971867e-01 -2.354010655332784197e+01 3 | -2.022394471995193205e-02 -9.760007664924973403e-01 2.168259575466510436e-01 1.650110331018928678e+00 4 | -------------------------------------------------------------------------------- /data/poses_avg_stats/OldHospital.txt: -------------------------------------------------------------------------------- 1 | 9.997941252129602940e-01 6.239930741698496326e-03 1.930726428032739084e-02 1.319547963328867723e+01 2 | -3.333807443587469103e-03 -8.880897259859261705e-01 4.596580515189216398e-01 -6.473184854291670343e-01 3 | 2.001481745059596404e-02 -4.596277862168271500e-01 -8.878860879751624413e-01 2.310333011616541654e+01 4 | -------------------------------------------------------------------------------- /data/poses_avg_stats/ShopFacade.txt: -------------------------------------------------------------------------------- 1 | 2.084004683986779016e-01 1.972095064159990266e-02 9.778447365901210553e-01 -4.512817941282106560e+00 2 | -9.780353328393808221e-01 8.307943784904847639e-03 2.082735359757174609e-01 1.914896116567694540e+00 3 | -4.016526979027209426e-03 -9.997710048685441997e-01 2.101916590087021808e-02 1.768500113487243564e+00 4 | -------------------------------------------------------------------------------- /data/poses_avg_stats/StMarysChurch.txt: -------------------------------------------------------------------------------- 1 | -6.692001528162709878e-01 7.430812642562667492e-01 1.179059789653581552e-03 1.114036505648812359e+01 2 | 3.891382817260490012e-02 3.662925707351961935e-02 -9.985709847092467673e-01 -5.441265972613005403e-02 3 | -7.420625778515127502e-01 -6.681979738352623599e-01 -5.342844106669619036e-02 1.708768320112491068e+01 4 | -------------------------------------------------------------------------------- /data/shapenet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader, Dataset 8 | import glob 9 | import imageio 10 | import numpy as np 11 | import yaml 12 | import cv2 13 | from utils.util import get_image_to_tensor_balanced, coordinate_transformation 14 | from utils.data_augmentation import get_transformation 15 | 16 | 17 | class ShapenetDataModule: 18 | def __init__(self, cfg): 19 | self.batch_size = cfg["batch_size"] 20 | self.shuffle = cfg["shuffle"] 21 | self.num_workers = cfg["num_workers"] 22 | 23 | self.dataset_cfg = cfg["dataset"] 24 | self.data_augmentation = cfg["data_augmentation"] 25 | 26 | def load_dataset(self, mode, use_data_augmentation=False, scene_list=None): 27 | self.mode = mode 28 | self.dataset_cfg["mode"] = mode 29 | self.dataset_cfg["scene_list"] = scene_list 30 | 31 | if use_data_augmentation: 32 | self.dataset_cfg["transformation"] = self.data_augmentation 33 | else: 34 | self.dataset_cfg["transformation"] = None 35 | 36 | return ShapenetDataset.init_from_cfg(self.dataset_cfg) 37 | 38 | def get_dataloader(self, dataset): 39 | batch_size = self.batch_size 40 | shuffle = self.shuffle 41 | num_workers = self.num_workers 42 | 43 | if self.mode == "test": 44 | batch_size = 1 45 | shuffle = False 46 | num_workers = 0 47 | 48 | dataloader = DataLoader( 49 | dataset, 50 | batch_size=batch_size, 51 | shuffle=shuffle, 52 | num_workers=num_workers, 53 | ) 54 | return dataloader 55 | 56 | 57 | class ShapenetDataset(Dataset): 58 | def __init__( 59 | self, 60 | mode, 61 | data_rootdir, 62 | image_size, 63 | z_near, 64 | z_far, 65 | trans_cfg, 66 | dataset_format, 67 | scene_list, 68 | ): 69 | """ 70 | Inits Shapenet dataset instance 71 | 72 | Args: 73 | mode: either train, val or test 74 | data_rootdir: root directory of dataset 75 | image_size: [H, W] pixels 76 | z_near: minimal distance of the object 77 | z_far: maximal distance of the object 78 | trans_cfg: configurations for data augmentations(transformation) 79 | dataset_formate: the coordinate system the original dataset uses 80 | """ 81 | 82 | super().__init__() 83 | self.image_size = image_size 84 | self.z_near = z_near 85 | self.z_far = z_far 86 | self.dataset_format = dataset_format 87 | 88 | self.transformations = [] 89 | if trans_cfg is not None: 90 | self.transformations = get_transformation(trans_cfg) 91 | 92 | assert os.path.exists(data_rootdir) 93 | file_list = os.path.join(data_rootdir, f"{mode}.lst") 94 | assert os.path.exists(file_list) 95 | base_dir = os.path.dirname(file_list) 96 | if scene_list is None: 97 | with open(file_list, "r") as f: 98 | self.scene_list = [x.strip() for x in f.readlines()] 99 | else: 100 | self.scene_list = [f"scan{x}" for x in scene_list] 101 | 102 | self.objs_path = [os.path.join(base_dir, scene) for scene in self.scene_list] 103 | 104 | self.image_to_tensor = get_image_to_tensor_balanced() 105 | 106 | def __len__(self): 107 | return len(self.objs_path) 108 | 109 | def __getitem__(self, index): 110 | scan_name = self.scene_list[index] 111 | root_dir = self.objs_path[index] 112 | rgb_paths = [ 113 | x 114 | for x in glob.glob(os.path.join(root_dir, "images", "*")) 115 | if (x.endswith(".jpg") or x.endswith(".png")) 116 | ] 117 | rgb_paths = sorted(rgb_paths) 118 | 119 | cam_path = os.path.join(root_dir, "trajectory.npy") 120 | intrinsic_path = os.path.join(root_dir, "camera_info.yaml") 121 | all_cam = np.load(cam_path) 122 | all_imgs = [] 123 | all_poses = [] 124 | 125 | for idx, rgb_path in enumerate(rgb_paths): 126 | img = imageio.imread(rgb_path)[..., :3] 127 | pose = all_cam[idx] 128 | 129 | # camera poses in world coordinate 130 | pose = coordinate_transformation(pose, format=self.dataset_format) 131 | img_tensor = self.image_to_tensor(img) 132 | all_imgs.append(img_tensor) 133 | all_poses.append(pose) 134 | 135 | with open(intrinsic_path, "r") as file: 136 | intrinsics = yaml.safe_load(file) 137 | 138 | focal = intrinsics["focal"] 139 | c = intrinsics["c"] 140 | focal = torch.tensor(focal, dtype=torch.float32) 141 | c = torch.tensor(c, dtype=torch.float32) 142 | 143 | all_imgs = torch.stack(all_imgs) 144 | all_poses = torch.stack(all_poses) 145 | 146 | # resize images if given image size is not euqal to original size 147 | if np.any(np.array(all_imgs.shape[-2:]) != self.image_size): 148 | scale = self.image_size[0] / all_imgs.shape[-2] 149 | focal *= scale 150 | c *= scale 151 | all_imgs = F.interpolate(all_imgs, size=self.image_size, mode="area") 152 | 153 | # aplly data augmentations 154 | for transformer in self.transformations: 155 | all_imgs = transformer(all_imgs) 156 | data_instance = { 157 | "scan_name": scan_name, 158 | "path": root_dir, 159 | "img_id": index, 160 | "focal": focal, 161 | "c": c, 162 | "images": all_imgs, 163 | "poses": all_poses, 164 | } 165 | return data_instance 166 | 167 | @classmethod 168 | def init_from_cfg(cls, cfg): 169 | return cls( 170 | mode=cfg["mode"], 171 | data_rootdir=cfg["data_rootdir"], 172 | image_size=cfg["image_size"], 173 | z_near=cfg["z_near"], 174 | z_far=cfg["z_far"], 175 | trans_cfg=cfg["transformation"], 176 | dataset_format=cfg["format"], 177 | scene_list=cfg["scene_list"], 178 | ) 179 | -------------------------------------------------------------------------------- /datasets/__pycache__/pitts.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/UGNA-VPR/1ae3158d84b54affa011236af2cad95e364d8731/datasets/__pycache__/pitts.cpython-38.pyc -------------------------------------------------------------------------------- /evaluation/__pycache__/pretrained_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/UGNA-VPR/1ae3158d84b54affa011236af2cad95e364d8731/evaluation/__pycache__/pretrained_model.cpython-38.pyc -------------------------------------------------------------------------------- /evaluation/get_visual_output.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | #python .\evaluation\get_visual_output.py -M first -si 0 -ri "0 1 2" -ti 76 4 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 5 | from evaluation.pretrained_model import PretrainedModel 6 | from data import get_data 7 | from utils import parser, util 8 | import yaml 9 | from dotmap import DotMap 10 | import torch 11 | import warnings 12 | import numpy as np 13 | import imageio 14 | from datetime import datetime 15 | from networks.mobilenet import TeacherNet 16 | 17 | warnings.filterwarnings("ignore") 18 | 19 | 20 | def main(): 21 | """ 22 | given scene index, reference index and novel view index, 23 | this script outputs ground truth, reference, uncertainty, depth and RGB images. 24 | used for sanity check. 25 | """ 26 | 27 | args = parser.parse_args(visual_args) 28 | log_path = os.path.join("logs", args.model_name) 29 | 30 | assert os.path.exists(log_path), "experiment does not exist" 31 | with open(f"{log_path}/training_setup.yaml", "r") as config_file: 32 | cfg = yaml.safe_load(config_file) 33 | 34 | checkpoint_path = os.path.join(log_path, "checkpoints", "best.ckpt") 35 | assert os.path.exists(checkpoint_path), "checkpoint does not exist" 36 | ckpt_file = torch.load(checkpoint_path) 37 | 38 | gpu_id = list(map(int, args.gpu_id.split())) 39 | device = util.get_cuda(gpu_id[0]) 40 | 41 | model = PretrainedModel(cfg["model"], ckpt_file, device, gpu_id) 42 | encoder = TeacherNet() 43 | pretrained_weights_path = 'logs/ckpt_best.pth.tar' 44 | pretrained_state_dict = torch.load(pretrained_weights_path) 45 | encoder.load_state_dict(pretrained_state_dict["state_dict"]) 46 | encoder = encoder.to(device) 47 | datamodule = get_data(cfg["data"]) 48 | dataset = datamodule.load_dataset("val") 49 | z_near = dataset.z_near 50 | z_far = dataset.z_far 51 | 52 | scene_idx = args.scene_idx 53 | ref_idx = list(map(int, args.ref_idx.split())) 54 | target_idx = args.target_idx 55 | 56 | data_instance = dataset.__getitem__(scene_idx) 57 | scene_title = data_instance["scan_name"] 58 | print(f"visual test on {scene_title}") 59 | 60 | images = data_instance["images"].to(device) 61 | images_0to1 = images * 0.5 + 0.5 62 | _, _, H, W = images.shape 63 | print(images.shape) 64 | focal = data_instance["focal"].to(device) 65 | c = data_instance["c"].to(device) 66 | poses = data_instance["poses"].to(device) 67 | print(poses.shape) 68 | with torch.no_grad(): 69 | model.network.encode( 70 | images[ref_idx].unsqueeze(0), 71 | poses[ref_idx].unsqueeze(0), 72 | focal.unsqueeze(0), 73 | c.unsqueeze(0), 74 | ) 75 | 76 | novel_pose = poses[target_idx] 77 | novel_pose = novel_pose.unsqueeze(0) 78 | novel_pose = novel_pose.unsqueeze(0) 79 | print(novel_pose.dtype) 80 | predict = DotMap(model.network(novel_pose)) 81 | print("uncertainty") 82 | print(predict.uncertainty) 83 | print(predict.uncertainty.shape) 84 | print("all_uncertainty") 85 | print(predict.all_uncertainty) 86 | 87 | des_np=predict.des[0].cpu().numpy() 88 | des_np1=des_np 89 | #print(des_np.shape)#(1.512) 90 | des_np = np.square(des_np) * 50 # 为了绘图提高对比度 91 | rgb_np=util.visualize_descriptors(des_np, (200, 512)) 92 | rgb_np1=util.visualize_descriptors(des_np1, (200, 512)) 93 | uncertainty = predict.uncertainty[0].cpu().numpy() 94 | uncertainty = np.square(uncertainty) # 为了绘图提高对比度 95 | gt_encoder=images_0to1[target_idx].unsqueeze(0) 96 | gt = images_0to1[target_idx].permute(1, 2, 0).cpu().numpy() 97 | gt_des, _ = encoder(gt_encoder) 98 | print(gt_des.shape) 99 | gt_des = gt_des.cpu().numpy() 100 | gt_np = util.visualize_descriptors(gt_des, (200,512)) 101 | error_np = np.abs(rgb_np1 - gt_np)*4 102 | print("error_np") 103 | print(np.mean(error_np)) 104 | uncertainty_np = util.visualize_descriptors(uncertainty, (200, 512)) 105 | 106 | ref_images = images_0to1[ref_idx].permute(0, 2, 3, 1).cpu().numpy() 107 | 108 | ref_images = np.hstack((*ref_images,)) 109 | 110 | error_map = util.error_cmap(error_np) 111 | rgb_map = util.des_cmap(rgb_np) 112 | gt_map = util.des_cmap(gt_np) 113 | uncertainty_map = util.unc_cmap(uncertainty_np) 114 | 115 | experiment_path = os.path.join( 116 | "experiments", 117 | args.model_name, 118 | "visual_experiment", 119 | datetime.now().strftime("%d-%m-%Y-%H-%M"), 120 | ) 121 | 122 | os.makedirs(experiment_path) 123 | 124 | imageio.imwrite( 125 | f"{experiment_path}/{scene_title}_reference_images_{ref_idx}.jpg", 126 | (ref_images * 255).astype(np.uint8), 127 | ) # ref img 128 | imageio.imwrite( 129 | f"{experiment_path}/{scene_title}_rgb_{target_idx}.jpg", 130 | rgb_map 131 | ) # gt des 132 | imageio.imwrite( 133 | f"{experiment_path}/{scene_title}_gt_{target_idx}.jpg", 134 | gt_map 135 | ) # gt des 136 | imageio.imwrite( 137 | f"{experiment_path}/{scene_title}_uncertainty_{target_idx}.jpg", uncertainty_map 138 | ) 139 | imageio.imwrite( 140 | f"{experiment_path}/{scene_title}_error_{target_idx}.jpg", error_map 141 | ) 142 | imageio.imwrite( 143 | f"{experiment_path}/{scene_title}_ground_truth.jpg", (gt * 255).astype(np.uint8) 144 | ) # gt img 145 | 146 | 147 | def visual_args(parser): 148 | """ 149 | Parse arguments for novel view synthesis setup. 150 | """ 151 | 152 | # mandatory arguments 153 | parser.add_argument( 154 | "--model_name", 155 | "-M", 156 | type=str, 157 | required=True, 158 | help="model name of pretrained model", 159 | ) 160 | 161 | parser.add_argument( 162 | "--scene_idx", 163 | "-si", 164 | type=int, 165 | required=True, 166 | help="scene index in DTU validation split", 167 | ) 168 | 169 | parser.add_argument( 170 | "--ref_idx", 171 | "-ri", 172 | type=str, 173 | required=True, 174 | help="reference view index, space delimited", 175 | ) 176 | 177 | parser.add_argument( 178 | "--target_idx", "-ti", type=int, required=True, help="target view index" 179 | ) 180 | 181 | # arguments with default values 182 | parser.add_argument( 183 | "--gpu_id", type=str, default="0", help="GPU(s) to use, space delimited" 184 | ) 185 | return parser 186 | 187 | 188 | if __name__ == "__main__": 189 | main() 190 | #python .\evaluation\get_visual_output.py -M first -si 0 -ri "0 1 2" -ti 3 -------------------------------------------------------------------------------- /evaluation/pretrained_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 5 | from model import get_model 6 | import warnings 7 | 8 | warnings.filterwarnings("ignore") 9 | 10 | 11 | class PretrainedModel: 12 | def __init__(self, model_config, checkpoint_file, device, gpu_id): 13 | self.device = device 14 | self.network, self.renderer = self.load_pretrained_model( 15 | model_config, checkpoint_file 16 | ) 17 | self.renderer_par = self.renderer.parallelize(self.network, gpu_id).eval() 18 | 19 | def load_pretrained_model(self, model_config, checkpoint_file): 20 | print("------ configure model ------") 21 | 22 | network, renderer = get_model(model_config) 23 | 24 | network = network.to(self.device).eval() 25 | renderer = renderer.to(self.device).eval() 26 | 27 | print("------ load model parameters ------") 28 | 29 | network.load_state_dict(checkpoint_file["network_state_dict"]) 30 | #renderer.load_state_dict(checkpoint_file["renderer_state_dict"]) 31 | 32 | return network, renderer 33 | -------------------------------------------------------------------------------- /getxy.py: -------------------------------------------------------------------------------- 1 | #从pose中提取xy保存 2 | import os 3 | import numpy as np 4 | 5 | # 原始文件夹和新文件夹路径 6 | original_folder = 'C:/Users/65309/Desktop/Cambridge_nerf/poses4' 7 | new_folder = 'C:/Users/65309/Desktop/Cambridge_nerf/poses' 8 | 9 | # 创建新文件夹(如果不存在) 10 | if not os.path.exists(new_folder): 11 | os.makedirs(new_folder) 12 | 13 | # 遍历原始文件夹中的所有txt文件 14 | for filename in os.listdir(original_folder): 15 | if filename.endswith('.txt'): 16 | original_filepath = os.path.join(original_folder, filename) 17 | new_filepath = os.path.join(new_folder, filename) 18 | parts = filename.split("_") 19 | # 获取第一个部分作为目标字符部分 20 | target_string = parts[0] 21 | print(target_string) 22 | # 读取原始txt文件的内容并提取第一行第四列和第二行第四列的数据 23 | with open(original_filepath, 'r') as original_file: 24 | lines = original_file.readlines() 25 | if(target_string == "GreatCourt"): 26 | element_1 = float(lines[0].split()[3]) 27 | element_2 = float(lines[1].split()[3]) 28 | if(target_string == "KingsCollege"): 29 | element_1 = float(lines[0].split()[3])+1000 30 | element_2 = float(lines[1].split()[3])+1000 31 | if (target_string == "OldHospital"): 32 | element_1 = float(lines[0].split()[3])+2000 33 | element_2 = float(lines[1].split()[3])+2000 34 | if(target_string == "ShopFacade"): 35 | element_1 = float(lines[0].split()[3])+3000 36 | element_2 = float(lines[1].split()[3])+3000 37 | if(target_string == "StMarysChurch"): 38 | element_1 = float(lines[0].split()[3])+4000 39 | element_2 = float(lines[1].split()[3])+4000 40 | pose_loc = np.array([[element_1, element_2]]) 41 | 42 | # 将提取的数据保存到新文件夹中与原文件同名的txt文件中 43 | np.savetxt(new_filepath, pose_loc) 44 | -------------------------------------------------------------------------------- /load_LINEMOD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import imageio 5 | import json 6 | import torch.nn.functional as F 7 | import cv2 8 | 9 | 10 | trans_t = lambda t : torch.Tensor([ 11 | [1,0,0,0], 12 | [0,1,0,0], 13 | [0,0,1,t], 14 | [0,0,0,1]]).float() 15 | 16 | rot_phi = lambda phi : torch.Tensor([ 17 | [1,0,0,0], 18 | [0,np.cos(phi),-np.sin(phi),0], 19 | [0,np.sin(phi), np.cos(phi),0], 20 | [0,0,0,1]]).float() 21 | 22 | rot_theta = lambda th : torch.Tensor([ 23 | [np.cos(th),0,-np.sin(th),0], 24 | [0,1,0,0], 25 | [np.sin(th),0, np.cos(th),0], 26 | [0,0,0,1]]).float() 27 | 28 | 29 | def pose_spherical(theta, phi, radius): 30 | c2w = trans_t(radius) 31 | c2w = rot_phi(phi/180.*np.pi) @ c2w 32 | c2w = rot_theta(theta/180.*np.pi) @ c2w 33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 34 | return c2w 35 | 36 | 37 | def load_LINEMOD_data(basedir, half_res=False, testskip=1): 38 | splits = ['train', 'val', 'test'] 39 | metas = {} 40 | for s in splits: 41 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: 42 | metas[s] = json.load(fp) 43 | 44 | all_imgs = [] 45 | all_poses = [] 46 | counts = [0] 47 | for s in splits: 48 | meta = metas[s] 49 | imgs = [] 50 | poses = [] 51 | if s=='train' or testskip==0: 52 | skip = 1 53 | else: 54 | skip = testskip 55 | 56 | for idx_test, frame in enumerate(meta['frames'][::skip]): 57 | fname = frame['file_path'] 58 | if s == 'test': 59 | print(f"{idx_test}th test frame: {fname}") 60 | imgs.append(imageio.imread(fname)) 61 | poses.append(np.array(frame['transform_matrix'])) 62 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA) 63 | poses = np.array(poses).astype(np.float32) 64 | counts.append(counts[-1] + imgs.shape[0]) 65 | all_imgs.append(imgs) 66 | all_poses.append(poses) 67 | 68 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] 69 | 70 | imgs = np.concatenate(all_imgs, 0) 71 | poses = np.concatenate(all_poses, 0) 72 | 73 | H, W = imgs[0].shape[:2] 74 | focal = float(meta['frames'][0]['intrinsic_matrix'][0][0]) 75 | K = meta['frames'][0]['intrinsic_matrix'] 76 | print(f"Focal: {focal}") 77 | 78 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) 79 | 80 | if half_res: 81 | H = H//2 82 | W = W//2 83 | focal = focal/2. 84 | 85 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 3)) 86 | for i, img in enumerate(imgs): 87 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) 88 | imgs = imgs_half_res 89 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy() 90 | 91 | near = np.floor(min(metas['train']['near'], metas['test']['near'])) 92 | far = np.ceil(max(metas['train']['far'], metas['test']['far'])) 93 | return imgs, poses, render_poses, [H, W, focal], K, i_split, near, far 94 | 95 | 96 | if __name__=='__main__': 97 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 98 | 99 | basedir = "./logs" 100 | imgs, poses, render_poses, [H, W, focal], K, i_split, near, far = load_LINEMOD_data(basedir, half_res=False, testskip=1) -------------------------------------------------------------------------------- /load_blender.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import imageio 5 | import json 6 | import torch.nn.functional as F 7 | import cv2 8 | import os.path as osp 9 | 10 | trans_t = lambda t : torch.Tensor([ 11 | [1,0,0,0], 12 | [0,1,0,0], 13 | [0,0,1,t], 14 | [0,0,0,1]]).float() 15 | 16 | rot_phi = lambda phi : torch.Tensor([ 17 | [1,0,0,0], 18 | [0,np.cos(phi),-np.sin(phi),0], 19 | [0,np.sin(phi), np.cos(phi),0], 20 | [0,0,0,1]]).float() 21 | 22 | rot_theta = lambda th : torch.Tensor([ 23 | [np.cos(th),0,-np.sin(th),0], 24 | [0,1,0,0], 25 | [np.sin(th),0, np.cos(th),0], 26 | [0,0,0,1]]).float() 27 | 28 | 29 | def pose_spherical(theta, phi, radius): 30 | c2w = trans_t(radius) 31 | c2w = rot_phi(phi/180.*np.pi) @ c2w 32 | c2w = rot_theta(theta/180.*np.pi) @ c2w 33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 34 | return c2w 35 | 36 | class CameraParams: 37 | def __init__(self, near, far, pose_scale, pose_scale2, move_all_cam_vec): 38 | self.near = near 39 | self.far = far 40 | self.pose_scale = pose_scale 41 | self.pose_scale2 = pose_scale2 42 | self.move_all_cam_vec = move_all_cam_vec 43 | def load_blender_data_Cam(datadir, half_res=False, testskip=1): 44 | splits = ['train', 'val', 'test'] 45 | base_dir, scene = osp.split(datadir) 46 | 47 | world_setup_fn = osp.join(base_dir, scene) + '/world_setup.json' 48 | 49 | # read json file 50 | with open(world_setup_fn, 'r') as myfile: 51 | data = myfile.read() 52 | 53 | # parse json file 54 | obj = json.loads(data) 55 | near = obj['near'] 56 | far = obj['far'] 57 | pose_scale = obj['pose_scale'] 58 | pose_scale2 = obj['pose_scale2'] 59 | move_all_cam_vec = obj['move_all_cam_vec'] 60 | 61 | camera_params = CameraParams(near, far, pose_scale, pose_scale2, move_all_cam_vec) 62 | 63 | 64 | all_imgs = [] 65 | all_poses = [] 66 | counts = [0] 67 | 68 | for s in splits: 69 | root_dir = os.path.join(datadir,s) 70 | rgb_dir = root_dir + '/rgb/' 71 | pose_dir = root_dir + '/poses/' 72 | if s=='train' or testskip==0: 73 | skip = 4 74 | else: 75 | skip = testskip 76 | 77 | 78 | rgb_files = os.listdir(rgb_dir) 79 | rgb_files = [rgb_dir + f for f in rgb_files] 80 | rgb_files.sort() 81 | 82 | pose_files = os.listdir(pose_dir) 83 | pose_files = [pose_dir + f for f in pose_files] 84 | pose_files.sort() 85 | 86 | if scene == 'ShopFacade' and s == 'train' : 87 | del rgb_files[42] 88 | del rgb_files[35] 89 | del pose_files[42] 90 | del pose_files[35] 91 | if len(rgb_files) != len(pose_files): 92 | raise Exception('RGB file count does not match pose file count!') 93 | 94 | # trainskip and testskip 95 | frame_idx = np.arange(len(rgb_files)) 96 | if s == 'train' and skip > 1 : 97 | frame_idx_tmp = frame_idx[::skip] 98 | frame_idx = frame_idx_tmp 99 | elif s != 'train' and testskip > 1: 100 | frame_idx_tmp = frame_idx[::testskip] 101 | frame_idx = frame_idx_tmp 102 | gt_idx = frame_idx 103 | 104 | rgb_files = [rgb_files[i] for i in frame_idx] 105 | pose_files = [pose_files[i] for i in frame_idx] 106 | 107 | if len(rgb_files) != len(pose_files): 108 | raise Exception('RGB file count does not match pose file count!') 109 | imgs = [] 110 | # read poses 111 | poses = [] 112 | for i in range(len(pose_files)): 113 | pose = np.loadtxt(pose_files[i]) 114 | poses.append(pose) 115 | image = imageio.imread(rgb_files[i]) 116 | if image.shape[-1] == 3: 117 | alpha_channel = np.ones((image.shape[0], image.shape[1], 1), dtype=image.dtype) * 255 118 | image = np.concatenate((image, alpha_channel), axis=-1) 119 | imgs.append(image) 120 | 121 | poses = np.array(poses).astype(np.float32) # [N, 4, 4] 122 | all_poses.append(poses) 123 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA) 124 | counts.append(counts[-1] + imgs.shape[0]) 125 | all_imgs.append(imgs) 126 | 127 | 128 | i_split = [np.arange(counts[i], counts[i + 1]) for i in range(3)] 129 | 130 | imgs = np.concatenate(all_imgs, 0) 131 | poses = np.concatenate(all_poses, 0) 132 | 133 | [H, W, focal] = [480, 854, 744.] 134 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180, 180, 40 + 1)[:-1]], 0) 135 | 136 | if half_res: 137 | H = H // 2 138 | W = W // 2 139 | focal = focal / 2. 140 | 141 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 4)) 142 | for i, img in enumerate(imgs): 143 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) 144 | imgs = imgs_half_res 145 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy() 146 | 147 | return imgs, poses, render_poses, [H, W, focal], i_split, camera_params 148 | 149 | 150 | def load_blender_data(basedir, half_res=False, testskip=1): 151 | splits = ['train', 'val', 'test'] 152 | metas = {} 153 | for s in splits: 154 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: 155 | metas[s] = json.load(fp) 156 | 157 | all_imgs = [] 158 | all_poses = [] 159 | counts = [0] 160 | for s in splits: 161 | meta = metas[s] 162 | imgs = [] 163 | poses = [] 164 | if s == 'train' or testskip == 0: 165 | skip = 1 166 | else: 167 | skip = testskip 168 | 169 | for frame in meta['frames'][::skip]: 170 | fname = os.path.join(basedir, frame['file_path'] + '.png') 171 | imgs.append(imageio.imread(fname)) 172 | poses.append(np.array(frame['transform_matrix'])) 173 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA) 174 | poses = np.array(poses).astype(np.float32) 175 | print(imgs.shape) 176 | counts.append(counts[-1] + imgs.shape[0]) 177 | all_imgs.append(imgs) 178 | all_poses.append(poses) 179 | 180 | i_split = [np.arange(counts[i], counts[i + 1]) for i in range(3)] 181 | 182 | imgs = np.concatenate(all_imgs, 0) 183 | poses = np.concatenate(all_poses, 0) 184 | 185 | H, W = imgs[0].shape[:2] 186 | camera_angle_x = float(meta['camera_angle_x']) 187 | focal = .5 * W / np.tan(.5 * camera_angle_x) 188 | 189 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180, 180, 40 + 1)[:-1]], 0) 190 | 191 | if half_res: 192 | H = H // 2 193 | W = W // 2 194 | focal = focal / 2. 195 | 196 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 4)) 197 | for i, img in enumerate(imgs): 198 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) 199 | imgs = imgs_half_res 200 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy() 201 | 202 | return imgs, poses, render_poses, [H, W, focal], i_split 203 | if __name__=='__main__': 204 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 205 | datadir="./data/nerf_synthetic/lego" 206 | datadir2 = "./data/Cambridge/GreatCourt" 207 | imgs, poses, render_poses, [H, W, focal], i_split = load_blender_data(datadir, half_res=False, testskip=1) 208 | #imgs, poses, render_poses, [H, W, focal], i_split = load_blender_data_Cam(datadir2, half_res=False, testskip=1) 209 | #print(i_split[0]) 210 | #print(poses.shape) 211 | #print(render_poses.shape) -------------------------------------------------------------------------------- /load_deepvoxels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import imageio 4 | 5 | 6 | def load_dv_data(scene='cube', basedir='/data/deepvoxels', testskip=8): 7 | 8 | 9 | def parse_intrinsics(filepath, trgt_sidelength, invert_y=False): 10 | # Get camera intrinsics 11 | with open(filepath, 'r') as file: 12 | f, cx, cy = list(map(float, file.readline().split()))[:3] 13 | grid_barycenter = np.array(list(map(float, file.readline().split()))) 14 | near_plane = float(file.readline()) 15 | scale = float(file.readline()) 16 | height, width = map(float, file.readline().split()) 17 | 18 | try: 19 | world2cam_poses = int(file.readline()) 20 | except ValueError: 21 | world2cam_poses = None 22 | 23 | if world2cam_poses is None: 24 | world2cam_poses = False 25 | 26 | world2cam_poses = bool(world2cam_poses) 27 | 28 | print(cx,cy,f,height,width) 29 | 30 | cx = cx / width * trgt_sidelength 31 | cy = cy / height * trgt_sidelength 32 | f = trgt_sidelength / height * f 33 | 34 | fx = f 35 | if invert_y: 36 | fy = -f 37 | else: 38 | fy = f 39 | 40 | # Build the intrinsic matrices 41 | full_intrinsic = np.array([[fx, 0., cx, 0.], 42 | [0., fy, cy, 0], 43 | [0., 0, 1, 0], 44 | [0, 0, 0, 1]]) 45 | 46 | return full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses 47 | 48 | 49 | def load_pose(filename): 50 | assert os.path.isfile(filename) 51 | nums = open(filename).read().split() 52 | return np.array([float(x) for x in nums]).reshape([4,4]).astype(np.float32) 53 | 54 | 55 | H = 512 56 | W = 512 57 | deepvoxels_base = '{}/train/{}/'.format(basedir, scene) 58 | 59 | full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses = parse_intrinsics(os.path.join(deepvoxels_base, 'intrinsics.txt'), H) 60 | print(full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses) 61 | focal = full_intrinsic[0,0] 62 | print(H, W, focal) 63 | 64 | 65 | def dir2poses(posedir): 66 | poses = np.stack([load_pose(os.path.join(posedir, f)) for f in sorted(os.listdir(posedir)) if f.endswith('txt')], 0) 67 | transf = np.array([ 68 | [1,0,0,0], 69 | [0,-1,0,0], 70 | [0,0,-1,0], 71 | [0,0,0,1.], 72 | ]) 73 | poses = poses @ transf 74 | poses = poses[:,:3,:4].astype(np.float32) 75 | return poses 76 | 77 | posedir = os.path.join(deepvoxels_base, 'pose') 78 | poses = dir2poses(posedir) 79 | testposes = dir2poses('{}/test/{}/pose'.format(basedir, scene)) 80 | testposes = testposes[::testskip] 81 | valposes = dir2poses('{}/validation/{}/pose'.format(basedir, scene)) 82 | valposes = valposes[::testskip] 83 | 84 | imgfiles = [f for f in sorted(os.listdir(os.path.join(deepvoxels_base, 'rgb'))) if f.endswith('png')] 85 | imgs = np.stack([imageio.imread(os.path.join(deepvoxels_base, 'rgb', f))/255. for f in imgfiles], 0).astype(np.float32) 86 | 87 | 88 | testimgd = '{}/test/{}/rgb'.format(basedir, scene) 89 | imgfiles = [f for f in sorted(os.listdir(testimgd)) if f.endswith('png')] 90 | testimgs = np.stack([imageio.imread(os.path.join(testimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32) 91 | 92 | valimgd = '{}/validation/{}/rgb'.format(basedir, scene) 93 | imgfiles = [f for f in sorted(os.listdir(valimgd)) if f.endswith('png')] 94 | valimgs = np.stack([imageio.imread(os.path.join(valimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32) 95 | 96 | all_imgs = [imgs, valimgs, testimgs] 97 | counts = [0] + [x.shape[0] for x in all_imgs] 98 | counts = np.cumsum(counts) 99 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] 100 | 101 | imgs = np.concatenate(all_imgs, 0) 102 | poses = np.concatenate([poses, valposes, testposes], 0) 103 | 104 | render_poses = testposes 105 | 106 | print(poses.shape, imgs.shape) 107 | 108 | return imgs, poses, render_poses, [H,W,focal], i_split 109 | 110 | 111 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, join 2 | 3 | import trainer 4 | from options import Options 5 | 6 | options_handler = Options() 7 | options = options_handler.parse() 8 | 9 | if __name__ == "__main__": 10 | 11 | if options.phase in ['test_tea', 'test_stu', 'train_stu']: 12 | print(f'resume from {options.resume}') 13 | options = options_handler.update_opt_from_json(join(dirname(options.resume), 'flags.json'), options) 14 | #options.nEpochs = 200 15 | options.batchSize=4 16 | options.cacheBatchSize=4 17 | options.threads=0 18 | tr = trainer.Trainer(options) 19 | print(tr.opt.phase, '-->', tr.opt.runsPath) 20 | elif options.phase in ['train_tea']: 21 | tr = trainer.Trainer(options) 22 | print(tr.opt.phase, '-->', tr.opt.runsPath) 23 | 24 | if options.phase in ['train_tea']: 25 | tr.train() 26 | elif options.phase in ['train_stu']: 27 | tr.train_student() 28 | elif options.phase in ['test_tea', 'test_stu']: 29 | tr.test() -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .network import Network 2 | from .renderer import Renderer 3 | 4 | 5 | def get_model(cfg): 6 | print(f"loading model \n") 7 | network = Network(cfg["network"]) 8 | renderer = Renderer.init_from_cfg(cfg["renderer"]) 9 | 10 | return network, renderer 11 | -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/UGNA-VPR/1ae3158d84b54affa011236af2cad95e364d8731/model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/code.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/UGNA-VPR/1ae3158d84b54affa011236af2cad95e364d8731/model/__pycache__/code.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/UGNA-VPR/1ae3158d84b54affa011236af2cad95e364d8731/model/__pycache__/encoder.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/loss_type.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/UGNA-VPR/1ae3158d84b54affa011236af2cad95e364d8731/model/__pycache__/loss_type.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/mlp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/UGNA-VPR/1ae3158d84b54affa011236af2cad95e364d8731/model/__pycache__/mlp.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/network.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/UGNA-VPR/1ae3158d84b54affa011236af2cad95e364d8731/model/__pycache__/network.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/renderer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/UGNA-VPR/1ae3158d84b54affa011236af2cad95e364d8731/model/__pycache__/renderer.cpython-38.pyc -------------------------------------------------------------------------------- /model/code.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import torch.autograd.profiler as profiler 5 | 6 | 7 | class PositionalEncoding(nn.Module): 8 | def __init__(self, num_freqs, d_in, include_input, freq_factor): 9 | """ 10 | Init poistional encoding instance. 11 | 12 | Args: 13 | num_freqs: frequency level for positional encoding. 14 | d_in: input dimension, by dedault should be 3 for x, y ,z or 6 for x y z vx, vy, vz. 15 | include_input: whether use pure input embedding vector. 16 | freq_factor: coefficient for positional encoding. 17 | """ 18 | 19 | super().__init__() 20 | self.num_freqs = num_freqs 21 | self.d_in = d_in 22 | self.freqs = freq_factor * 2.0 ** torch.arange(0, num_freqs) 23 | self.d_out = self.num_freqs * 2 * d_in 24 | self.include_input = include_input 25 | if include_input: 26 | self.d_out += d_in 27 | 28 | self.register_buffer( 29 | "_freqs", torch.repeat_interleave(self.freqs, 2).view(1, -1, 1) 30 | ) 31 | _phases = torch.zeros(2 * self.num_freqs) 32 | _phases[1::2] = np.pi * 0.5 33 | self.register_buffer("_phases", _phases.view(1, -1, 1)) 34 | 35 | def forward(self, x): 36 | """ 37 | Apply positional encoding. 38 | 39 | Args: 40 | x: (batch, self.d_in), pose information. 41 | 42 | Returns: 43 | embed: (batch, self.d_out), postional embedding. 44 | """ 45 | 46 | # with profiler.record_function("positional_encoding"): 47 | embed = x.unsqueeze(1).repeat(1, self.num_freqs * 2, 1) 48 | embed=embed.to(self._phases.dtype) 49 | embed = torch.sin(torch.addcmul(self._phases, embed, self._freqs)) 50 | embed = embed.view(x.shape[0], -1) 51 | if self.include_input: 52 | embed = torch.cat((x, embed), dim=-1) 53 | return embed 54 | 55 | @classmethod 56 | def init_from_cfg(cls, cfg): 57 | return cls( 58 | num_freqs=cfg["num_freqs"], 59 | d_in=cfg["d_in"], 60 | include_input=cfg["include_input"], 61 | freq_factor=cfg["freq_factor"], 62 | ) 63 | -------------------------------------------------------------------------------- /model/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | import torch.autograd.profiler as profiler 6 | from utils import util 7 | 8 | 9 | 10 | class Encoder(nn.Module): 11 | def __init__( 12 | self, 13 | backbone="resnet34", 14 | pretrained=True, 15 | num_layers=3, 16 | index_interp="bilinear", 17 | index_padding="border", 18 | upsample_interp="bilinear", 19 | use_first_pool=True, 20 | norm_type="batch", 21 | ): 22 | """ 23 | Inits Encoder instance. 24 | 25 | Args: 26 | backbone: encoder model resnest34. 27 | pretrained: whether to use model weights pretrained on ImageNet. 28 | num_layers: number of resnet layers to use (1-5). 29 | index_interp: interpolation to use for feature map indexing. 30 | index_padding: padding mode to use for indexing (border, zeros, reflection). 31 | upsample_interp: interpolation to use for upscaling latent code. 32 | use_first_pool: whether to use first maxpool layer. 33 | norm_type: norm type to use. usually "batch" 34 | """ 35 | 36 | super().__init__() 37 | self.num_layers = num_layers 38 | self.index_interp = index_interp 39 | self.index_padding = index_padding 40 | self.upsample_interp = upsample_interp 41 | self.use_first_pool = use_first_pool 42 | self.latent_size = [0, 64, 128, 256, 512, 1024][num_layers] 43 | 44 | norm_layer = util.get_norm_layer(norm_type) 45 | self.model = getattr(torchvision.models, backbone)( 46 | pretrained=pretrained, norm_layer=norm_layer 47 | ) 48 | 49 | def forward(self, x): 50 | """ 51 | Get feature maps of RGB image inputs. 52 | 53 | Args: 54 | x: image (RN, C, H, W). 55 | 56 | Returns: 57 | latent: features (RN, L, H, W), L is feature map channel length. 58 | """ 59 | 60 | # with profiler.record_function("encoder_inference"): 61 | x = self.model.conv1(x) 62 | x = self.model.bn1(x) 63 | x = self.model.relu(x) 64 | 65 | latents = [x] 66 | if self.num_layers > 1: 67 | if self.use_first_pool: 68 | x = self.model.maxpool(x) 69 | x = self.model.layer1(x) 70 | latents.append(x) 71 | if self.num_layers > 2: 72 | x = self.model.layer2(x) 73 | latents.append(x) 74 | if self.num_layers > 3: 75 | x = self.model.layer3(x) 76 | latents.append(x) 77 | if self.num_layers > 4: 78 | x = self.model.layer4(x) 79 | latents.append(x) 80 | 81 | align_corners = None if self.index_interp == "nearest " else True 82 | latent_sz = latents[0].shape[-2:] 83 | 84 | # unpsample feature map from different layers to the same dimension 85 | for i in range(len(latents)): 86 | latents[i] = F.interpolate( 87 | latents[i], 88 | latent_sz, 89 | mode=self.upsample_interp, 90 | align_corners=align_corners, 91 | ) 92 | latent = torch.cat(latents, dim=1) 93 | return latent 94 | 95 | @classmethod 96 | def init_from_cfg(cls, cfg): 97 | return cls( 98 | backbone=cfg["backbone"], 99 | pretrained=cfg["pretrained"], 100 | num_layers=cfg["num_layers"], 101 | index_interp=cfg["index_interp"], 102 | index_padding=cfg["index_padding"], 103 | upsample_interp=cfg["upsample_interp"], 104 | use_first_pool=cfg["use_first_pool"], 105 | norm_type=cfg["norm_type"], 106 | ) 107 | if __name__ == '__main__': 108 | tea = Encoder() 109 | inputs = torch.rand((1, 3, 224, 224)) 110 | outputs_tea = tea(inputs) 111 | print(outputs_tea.shape) 112 | -------------------------------------------------------------------------------- /model/loss_type.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from utils import util 5 | 6 | 7 | class LogitWithUncertaintyLoss(nn.Module): 8 | """ 9 | uncertainty estimation in logit space 10 | """ 11 | 12 | def __init__(self, loss_type, reduction="none"): 13 | super().__init__() 14 | self.rgb_loss = get_rgb_loss(loss_type, reduction) 15 | 16 | def forward(self, predict, ground_truth): 17 | logit_mean = predict.logit_mean # (ON, N-RN, 512) mean of RGB logit 18 | logit_log_var = predict.logit_log_var # (ON, N-RN, 512) log variance of RGB logit 19 | gt = torch.clamp(ground_truth, min=1.0e-3, max=1.0 - 1.0e-3) # (ON, N-RN, 512) 20 | logit_diff = self.rgb_loss(torch.logit(gt), logit_mean) # (ON, N-RN, 512) 21 | gt_term = torch.log(gt * (1.0 - gt)) # (ON, N-RN, 512) 22 | loss = ( 23 | 0.5 * logit_log_var + gt_term + 0.5 * logit_diff / torch.exp(logit_log_var) 24 | ) 25 | 26 | return torch.mean(loss) 27 | 28 | 29 | class RGBLoss(nn.Module): 30 | """ 31 | pure RGB photometric loss 32 | """ 33 | 34 | def __init__(self, loss_type, reduction="mean"): 35 | super().__init__() 36 | self.rgb_loss = get_rgb_loss(loss_type, reduction) 37 | 38 | def forward(self, predict, ground_truth): 39 | rgb = predict.rgb 40 | return self.rgb_loss(rgb, ground_truth) 41 | 42 | 43 | def get_rgb_loss(loss_type, reduction): 44 | if loss_type == "mse": 45 | return nn.MSELoss(reduction=reduction) 46 | elif loss_type == "l1": 47 | return nn.L1Loss(reduction=reduction) 48 | elif loss_type == "smooth_l1": 49 | return nn.SmoothL1Loss(reduction=reduction) 50 | -------------------------------------------------------------------------------- /model/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.autograd.profiler as profiler 4 | from .code import PositionalEncoding 5 | 6 | 7 | def km_init(l): 8 | if isinstance(l, nn.Linear): 9 | nn.init.kaiming_normal_(l.weight, a=0, mode="fan_in") 10 | nn.init.constant_(l.bias, 0.0) 11 | 12 | 13 | class MLPFeature(nn.Module): 14 | def __init__( 15 | self, d_latent, d_feature, block_num, use_encoding, pe_config, use_view 16 | ): 17 | """ 18 | Inits MLP_feature model. 19 | 20 | Args: 21 | d_latent: encoder latent size. 22 | d_feature: mlp feature size. 23 | use_encoding: whether use positional encoding 24 | pe_config: configuration for positional encoding 25 | """ 26 | 27 | super().__init__() 28 | 29 | self.d_latent = d_latent 30 | self.d_pose = 3 # (x, y, z, vx, vy, vz) 31 | self.d_feature = d_feature 32 | self.block_num = block_num 33 | self.use_view = use_view 34 | 35 | self.activation = nn.ReLU() 36 | self.sigmoid = nn.Sigmoid() 37 | 38 | self.use_encoding = use_encoding 39 | if self.use_encoding: 40 | self.positional_encoding = PositionalEncoding.init_from_cfg(pe_config) 41 | self.d_pose = self.positional_encoding.d_out 42 | 43 | if self.use_view: 44 | self.d_pose += 3 45 | 46 | self.lin_in_p = nn.Sequential( 47 | nn.Linear(self.d_pose, self.d_feature), self.activation 48 | ) 49 | self.out_feat = nn.Sequential( 50 | nn.Linear(self.d_feature, self.d_feature), self.activation 51 | ) 52 | self.out_weight = nn.Sequential( 53 | nn.Linear(self.d_feature, self.d_feature), self.sigmoid 54 | ) 55 | 56 | self.lin_in_p.apply(km_init) 57 | self.out_feat.apply(km_init) 58 | 59 | self.blocks = nn.ModuleList() 60 | self.lin_in_z = nn.ModuleList() 61 | for _ in range(self.block_num): 62 | lin_z = nn.Sequential( 63 | nn.Linear(self.d_latent, self.d_feature), self.activation 64 | ) 65 | lin_z.apply(km_init) 66 | self.lin_in_z.append(lin_z) 67 | 68 | self.blocks.append(ResnetBlock(self.d_feature)) 69 | 70 | def forward(self, z, x): 71 | if self.use_encoding: 72 | p = self.positional_encoding(x[..., :3]) 73 | if self.use_view: 74 | p = torch.cat((p, x[..., 3:]), dim=-1) 75 | #print(p.dtype) 76 | p=p.to(torch.float32) 77 | p = self.lin_in_p(p) 78 | for i in range(self.block_num): 79 | tz = self.lin_in_z[i](z) 80 | p = p + tz 81 | p = self.blocks[i](p) 82 | 83 | out = self.out_feat(p) # (ON*RN*RB, d_feature) 84 | weight = self.out_weight(p) # (ON*RN*RB, 1) 85 | 86 | return out, weight 87 | 88 | @classmethod 89 | def init_from_cfg(cls, cfg): 90 | return cls( 91 | d_latent=cfg["d_latent"], 92 | d_feature=cfg["d_feature"], 93 | use_encoding=cfg["use_encoding"], 94 | use_view=cfg["use_view"], 95 | block_num=cfg["block_num"], 96 | pe_config=cfg["positional_encoding"], 97 | ) 98 | 99 | 100 | class MLPOut(nn.Module): 101 | def __init__(self, d_feature, d_out, block_num): 102 | """ 103 | Inits MLP_out model. 104 | 105 | Args: 106 | d_feature: feature size. 107 | d_out: output size. 108 | block_num: number of Resnet blocks. 109 | """ 110 | 111 | super().__init__() 112 | self.d_feature = d_feature 113 | self.d_out = d_out 114 | self.block_num = block_num 115 | 116 | self.lin_out = nn.Linear(self.d_feature, self.d_out) 117 | 118 | self.blocks = nn.ModuleList() 119 | for _ in range(self.block_num): 120 | self.blocks.append(ResnetBlock(self.d_feature)) 121 | 122 | def forward(self, x): 123 | for blkid in range(self.block_num): 124 | x = self.blocks[blkid](x) 125 | 126 | out = self.lin_out(x) 127 | return out 128 | 129 | @classmethod 130 | def init_from_cfg(cls, cfg): 131 | return cls( 132 | d_feature=cfg["d_feature"], 133 | block_num=cfg["block_num"], 134 | d_out=cfg["d_out"], 135 | ) 136 | 137 | 138 | class ResnetBlock(nn.Module): 139 | """ 140 | Fully connected ResNet Block class. 141 | """ 142 | 143 | def __init__(self, size_in, size_out=None, size_h=None, beta=0.0): 144 | """ 145 | Inits Resnet block. 146 | 147 | Args: 148 | size_in: input dimension. 149 | size_out: output dimension. 150 | size_h: hidden dimension. 151 | """ 152 | 153 | super().__init__() 154 | 155 | if size_out is None: 156 | size_out = size_in 157 | 158 | if size_h is None: 159 | size_h = min(size_in, size_out) 160 | self.size_in = size_in 161 | self.size_h = size_h 162 | self.size_out = size_out 163 | self.fc_0 = nn.Linear(size_in, size_h) 164 | self.fc_1 = nn.Linear(size_h, size_out) 165 | 166 | self.fc_0.apply(km_init) 167 | nn.init.constant_(self.fc_1.bias, 0.0) 168 | nn.init.zeros_(self.fc_1.weight) 169 | 170 | if beta > 0: 171 | self.activation = nn.Softplus(beta=beta) 172 | else: 173 | self.activation = nn.ReLU() 174 | 175 | if size_in == size_out: 176 | self.shortcut = None 177 | else: 178 | self.shortcut = nn.Linear(size_in, size_out, bias=False) 179 | self.shortcut.apply(km_init) 180 | 181 | def forward(self, x): 182 | with profiler.record_function("resblock"): 183 | res = self.fc_0(x) 184 | res = self.activation(res) 185 | res = self.fc_1(res) 186 | 187 | if self.shortcut is not None: 188 | x_s = self.shortcut(x) 189 | else: 190 | x_s = x 191 | out = self.activation(x_s + res) 192 | return out 193 | 194 | 195 | -------------------------------------------------------------------------------- /model/renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .code import PositionalEncoding 4 | import torch.autograd.profiler as profiler 5 | from dotmap import DotMap 6 | from utils import util 7 | 8 | 9 | def init_recurrent_weights(self): 10 | for m in self.modules(): 11 | if type(m) in [nn.GRU, nn.LSTM, nn.RNN]: 12 | for name, param in m.named_parameters(): 13 | if "weight_ih" in name: 14 | nn.init.kaiming_normal_(param.data) 15 | elif "weight_hh" in name: 16 | nn.init.orthogonal_(param.data) 17 | elif "bias" in name: 18 | param.data.fill_(0) 19 | 20 | 21 | def lstm_forget_gate_init(lstm_layer): 22 | for name, parameter in lstm_layer.named_parameters(): 23 | if not "bias" in name: 24 | continue 25 | n = parameter.size(0) 26 | start, end = n // 4, n // 2 27 | parameter.data[start:end].fill_(1.0) 28 | 29 | 30 | class RayMarcher(nn.Module): 31 | def __init__( 32 | self, 33 | d_in, 34 | d_hidden, 35 | raymarch_steps, 36 | use_encoding, 37 | pe_config, 38 | ): 39 | """ 40 | Inits LSTM ray marcher 41 | 42 | Args: 43 | d_in: input feature size. 44 | d_hidden: hidden feature size of LSTM cell. 45 | raymarch_steps: number of iteration 46 | use_encoding: whether use positional encoding for lstm input 47 | pe_config: positional encoding (pe) configuration 48 | """ 49 | 50 | super().__init__() 51 | 52 | self.lstm_d_in = d_in 53 | self.lstm_d_hidden = d_hidden 54 | self.lstm_d_out = 1 55 | 56 | self.raymarch_steps = raymarch_steps 57 | self.use_encoding = use_encoding 58 | 59 | if use_encoding: 60 | self.positional_encoding = PositionalEncoding.init_from_cfg(pe_config) 61 | self.lstm_d_in += self.positional_encoding.d_out 62 | 63 | self.lstm = nn.LSTMCell( 64 | input_size=self.lstm_d_in, hidden_size=self.lstm_d_hidden 65 | ) 66 | self.lstm.apply(init_recurrent_weights) 67 | lstm_forget_gate_init(self.lstm) 68 | self.out_layer = nn.Linear(self.lstm_d_hidden, self.lstm_d_out) 69 | nn.init.constant_(self.out_layer.bias, 0.0) 70 | nn.init.kaiming_normal_(self.out_layer.weight, a=0, mode="fan_in") 71 | 72 | self.sigmoid = nn.Sigmoid() 73 | 74 | def forward(self, network, rays): 75 | """ 76 | Predicting surface sampling points using LSTM 77 | 78 | Args: 79 | network: network model 80 | rays: ray (ON, RB, 8) 81 | 82 | Returns: 83 | sample_points: predicted surface points (x, y, z) in world coordinate, (ON, RB, 3) 84 | depth: distance to the camera center point, is the depth prediction, (ON, RB, 1) 85 | depth_confidence: accumulated confidence of current depth prediction (ON. RB, 1) 86 | 87 | """ 88 | with profiler.record_function("ray_marching"): 89 | ON, RB, _ = rays.shape 90 | RN = network.num_ref_views 91 | ray_dirs = rays[:, :, 3:6] # (ON, RB, 3) 92 | z_near = rays[:, :, 6:7] 93 | z_far = rays[:, :, 7:8] 94 | z_scale = z_far - z_near 95 | 96 | scaled_depth = [0 * z_scale] # scaled_depth should range from 0 -1 97 | sample_points = rays[:, :, :3] + z_near * ray_dirs # (ON, RB, 3) 98 | states = [None] 99 | 100 | for _ in range(self.raymarch_steps): 101 | with torch.no_grad(): 102 | # print("1", sample_points.shape) 103 | latent, p_feature = network.get_features( 104 | sample_points, ray_dirs 105 | ) # (ON*RN*RB, d_latent) 106 | 107 | feature, weight = network.mlp_feature( 108 | latent, 109 | p_feature, 110 | ) 111 | # (ON*RN*RB, d_feature) 112 | lstm_feature = util.weighted_pooling( 113 | feature, inner_dims=(RN, RB), weight=weight 114 | ).reshape( 115 | ON * RB, -1 116 | ) # (ON*RB, 2*d_feature) 117 | 118 | state = self.lstm(lstm_feature, states[-1]) # (2, ON*RB, d_hidden) 119 | 120 | if state[0].requires_grad: 121 | state[0].register_hook(lambda x: x.clamp(min=-5, max=5)) 122 | 123 | states.append(state) 124 | 125 | lstm_out = self.out_layer(state[0]).view( 126 | ON, RB, self.lstm_d_out 127 | ) # (ON, RB, 1) 128 | signed_distance = lstm_out 129 | depth_scaling = 1.0 / (1.0 * self.raymarch_steps) 130 | signed_distance = depth_scaling * signed_distance 131 | scaled_depth.append(self.sigmoid(scaled_depth[-1] + signed_distance)) 132 | depth = scaled_depth[-1] * z_scale + z_near # (ON, RB, 1) 133 | sample_points = rays[:, :, :3] + depth * ray_dirs # (ON, RB, 3) 134 | 135 | return sample_points, depth, scaled_depth 136 | 137 | 138 | class _RenderWrapper(nn.Module): 139 | def __init__(self, network, renderer): 140 | super().__init__() 141 | self.network = network 142 | self.renderer = renderer 143 | 144 | def forward(self, rays): 145 | outputs = self.renderer(self.network, rays) 146 | return outputs.toDict() 147 | 148 | 149 | class Renderer(nn.Module): 150 | def __init__( 151 | self, 152 | d_in, 153 | d_hidden, 154 | raymarch_steps, 155 | trainable, 156 | use_encoding, 157 | pe_config, 158 | ): 159 | super().__init__() 160 | self.ray_marcher = RayMarcher( 161 | d_in, d_hidden, raymarch_steps, use_encoding, pe_config 162 | ) 163 | self.is_trainable = trainable 164 | self.sigmoid = nn.Sigmoid() 165 | 166 | def forward(self, network, rays): 167 | """ 168 | Ray marching rendering. 169 | 170 | Args: 171 | network: network model 172 | rays: ray (ON, RB, 8) 173 | 174 | Returns: 175 | render dict 176 | """ 177 | 178 | with profiler.record_function("rendering"): 179 | assert len(rays.shape) == 3 180 | 181 | ( 182 | sample_points, 183 | depth_final, 184 | scaled_depth, 185 | ) = self.ray_marcher( 186 | network, rays 187 | ) # (ON, RB, 3), (ON, RB, 1), (ON, RB, 1), (step, ON, RB, 1) 188 | render_dict = { 189 | "depth": depth_final, 190 | "scaled_depth": torch.stack(scaled_depth), 191 | } 192 | 193 | out = network(sample_points, rays[:, :, 3:6]) # (ON, RB, 4) 194 | 195 | logit_mean = out[:, :, :3] # (ON, RB, 3) 196 | logit_log_var = out[:, :, 3:] # (ON, RB , 3) 197 | 198 | render_dict["logit_mean"] = logit_mean 199 | render_dict["logit_log_var"] = logit_log_var 200 | 201 | with torch.no_grad(): 202 | sampled_predictions = util.get_samples( 203 | logit_mean, torch.sqrt(torch.exp(logit_log_var)), 100 204 | ) 205 | rgb_mean = torch.mean(sampled_predictions, axis=0) 206 | rgb_std = torch.std(sampled_predictions, axis=0) 207 | render_dict["rgb"] = rgb_mean 208 | render_dict["uncertainty"] = torch.mean(rgb_std, dim=-1) 209 | 210 | return DotMap(render_dict) 211 | 212 | def parallelize(self, network, gpus=None): 213 | """ 214 | Returns a wrapper module compatible with DataParallel. 215 | Specify a list of GPU ids in 'gpus' to apply DataParallel automatically. 216 | 217 | Args: 218 | network: network model 219 | gpus: list of GPU ids to parallize to. No parallelization if gpus length is 1. 220 | 221 | Returns: 222 | wrapper module 223 | """ 224 | 225 | wrapped = _RenderWrapper(network, self) 226 | if gpus is not None and len(gpus) > 1: 227 | print("Using multi-GPU", gpus) 228 | wrapped = nn.DataParallel(wrapped, gpus, dim=1) 229 | return wrapped 230 | 231 | @classmethod 232 | def init_from_cfg(cls, cfg): 233 | return cls( 234 | d_in=cfg["d_in"], 235 | d_hidden=cfg["d_hidden"], 236 | raymarch_steps=cfg["raymarch_steps"], 237 | trainable=cfg["trainable"], 238 | use_encoding=cfg["use_encoding"], 239 | pe_config=cfg["positional_encoding"], 240 | ) 241 | -------------------------------------------------------------------------------- /model/testnetwork.py: -------------------------------------------------------------------------------- 1 | import torch 2 | def generate_pixel_grid(H, W, device): 3 | """ 4 | Generate a pixel grid. 5 | 6 | Args: 7 | H: Height of the image. 8 | W: Width of the image. 9 | device: Device to place the tensor on. 10 | 11 | Returns: 12 | pixel_grid: Tensor of shape (H, W, 2) representing the pixel grid. 13 | """ 14 | y, x = torch.meshgrid(torch.arange(H, device=device), torch.arange(W, device=device)) 15 | pixel_grid = torch.stack((x, y), dim=-1) 16 | return pixel_grid 17 | 18 | def pixel_to_camera(pixel_grid, focal, c): 19 | """ 20 | Convert pixel coordinates to camera coordinates. 21 | 22 | Args: 23 | pixel_grid: Tensor of shape (H, W, 2) representing the pixel grid. 24 | pose1: Tensor of shape (ON, X-RN, 4, 4) representing the c2w pose of the first camera. 25 | focal: Tensor of shape (ON, X-RN,2) representing the focal lengths of the first cameras. 26 | c: Tensor of shape (ON, X-RN, 2) representing the principal points of the first cameras. 27 | 28 | Returns: 29 | camera_coords: Tensor of shape (ON, X-RN, H, W, 3) representing the camera coordinates. 30 | """ 31 | camera_coords = ((pixel_grid[None, None] - c[:, :, None, None]) / focal[:, :, None, None]) 32 | camera_coords = torch.cat((camera_coords, torch.ones_like(camera_coords[..., :1])), dim=-1) 33 | 34 | return camera_coords 35 | 36 | 37 | import torch 38 | 39 | 40 | def world_to_pixel(world_coords, poses_ref, focal_ref, c_ref): 41 | ON, X_RN, _, _, _= world_coords.shape 42 | RN, _, _, _ = poses_ref.shape 43 | H = 7 44 | W = 7 45 | pixels_ref = torch.zeros(ON, X_RN, RN, H, W, 2, device=world_coords.device) 46 | 47 | for i in range(ON): 48 | for j in range(X_RN): 49 | for k in range(RN): 50 | # Construct the camera intrinsic matrix 51 | K = torch.tensor([[focal_ref[i, k, 0], 0, c_ref[i, k, 0]], 52 | [0, focal_ref[i, k, 1], c_ref[i, k, 1]], 53 | [0, 0, 1]], device=world_coords.device) 54 | 55 | print(world_coords[i, j].shape) 56 | # Transform world coordinates to camera coordinates 57 | camera_coords = (torch.matmul(world_coords[i, j], poses_ref[i, k])) 58 | print(camera_coords.shape) 59 | # Project camera coordinates to pixel coordinates 60 | pixels = torch.matmul(K, camera_coords) 61 | pixels_ref[i, j, k] = pixels.transpose(-1, -2) 62 | 63 | return pixels_ref 64 | 65 | 66 | def normalize_coordinates(pixels_ref, W, H): 67 | """ 68 | Normalize pixel coordinates to [-1, 1] range. 69 | 70 | Args: 71 | pixels_ref: Tensor of shape (ON, N-RN, RN, H, W, 2) representing the pixel coordinates for reference images. 72 | W: Width of the image. 73 | H: Height of the image. 74 | 75 | Returns: 76 | uv_feats: Tensor of shape (ON, N-RN, RN, H, W, 1, 2) representing the normalized pixel coordinates. 77 | """ 78 | uv_feats = 2 * pixels_ref / torch.tensor([W, H], device=pixels_ref.device).reshape(1, 1, 1, 1, 1, 2) - 1.0 79 | return uv_feats 80 | 81 | def camera_to_world(camera_coords, pose1): 82 | """ 83 | Convert camera coordinates to world coordinates. 84 | 85 | Args: 86 | camera_coords: Tensor of shape (ON, H, W, 3) representing the camera coordinates. 87 | pose1: Tensor of shape (ON, 4, 4) representing the c2w pose of the first camera. 88 | 89 | Returns: 90 | world_coords: Tensor of shape (ON, H, W, 3) representing the world coordinates. 91 | """ 92 | # Add singleton dimensions to camera_coords to match the shape of pose1_inv[..., :3, :3].transpose(-1, -2) 93 | camera_coords = camera_coords.unsqueeze(1) 94 | print(camera_coords.shape) 95 | # Transpose pose1 to match the shape of camera_coords for batched matrix multiplication 96 | pose1 = pose1.transpose(-1, -2) 97 | print(pose1.shape) 98 | # Perform matrix multiplication 99 | world_coords = torch.matmul(camera_coords, pose1[..., :3, :3]) + pose1[..., :3, 3:4] 100 | return world_coords 101 | 102 | def transform_images_to_reference_image_coordinates(pose1, poses_ref, focal, c, focal_ref, c_ref): 103 | """ 104 | Transform points from the coordinate system of image1 to the coordinate system of reference images. 105 | 106 | Args: 107 | pose1: Tensor of shape (ON, 4, 4) representing the c2w pose of the first camera. 108 | poses_ref: Tensor of shape (ON, RN, 4, 4) representing the w2c poses of reference cameras. 109 | focal: Tensor of shape (ON, 2) representing the focal lengths of the first cameras. 110 | c: Tensor of shape (ON, 2) representing the principal points of the first cameras. 111 | focal_ref: Tensor of shape (ON, 2) representing the focal lengths of reference cameras. 112 | c_ref: Tensor of shape (ON, 2) representing the principal points of reference cameras. 113 | 114 | Returns: 115 | uv_feats: Transformed points in the coordinate system of reference images, compatible with F.grid_sample. 116 | """ 117 | ON, X_RN, _, _ = pose1.shape 118 | _, RN, _, _ = poses_ref.shape 119 | H, W = 7, 7 # Assuming fixed size for simplicity, adjust according to your actual data 120 | # Generate pixel grid for image1 121 | focal=focal.unsqueeze(1).expand(ON, X_RN, 2) # (ON, X_RN, 2) 122 | c=c.unsqueeze(1).expand(ON,X_RN,2) 123 | focal_ref=focal_ref.unsqueeze(1).expand(ON, RN, 2) # (ON, RN, 2) 124 | c_ref=c_ref.unsqueeze(1).expand(ON,RN,2) 125 | pixel_grid1 = generate_pixel_grid(H, W, device=pose1.device) 126 | pixel_grid1 = pixel_grid1.to(focal.device) 127 | # Convert pixel grid to camera coordinates for image1 128 | camera_coords1 = pixel_to_camera(pixel_grid1, focal, c) 129 | # Convert camera coordinates to world coordinates for image1 130 | world_coords1 = camera_to_world(camera_coords1, pose1) 131 | # Convert world coordinates to pixel coordinates for reference images 132 | pixels_ref = world_to_pixel(world_coords1, poses_ref, focal_ref, c_ref) 133 | # Normalize pixel coordinates to [-1, 1] range 134 | uv_feats = normalize_coordinates(pixels_ref, W, H) 135 | return uv_feats 136 | 137 | # Define other helper functions (generate_pixel_grid, pixel_to_camera, world_to_pixel, normalize_coordinates) as before 138 | 139 | # Example usage 140 | #pose1 = torch.randn(2, 2, 4, 4) # Example c2w pose of the first camera 141 | #poses_ref = torch.randn(2, 3, 4, 4) # Example w2c poses of reference cameras 142 | #focal = torch.randn(2, 2) # Example focal lengths of the first cameras 143 | #c = torch.randn(2, 2) # Example principal points of the first cameras 144 | #focal_ref = torch.randn(2, 2) # Example focal lengths of reference cameras 145 | #_ref = torch.randn(2, 2) # Example principal points of reference cameras 146 | 147 | # Transform points to reference image coordinates 148 | #uv_feats = transform_images_to_reference_image_coordinates(pose1, poses_ref, focal, c, focal_ref, c_ref) 149 | 150 | def camera_to_world1(camera_coords, c2w_pose): 151 | print(camera_coords.shape) 152 | # 将相机坐标扩展为齐次坐标形式 153 | homogeneous_coords = torch.cat((camera_coords, torch.ones_like(camera_coords[..., :1])), dim=-1) 154 | print(homogeneous_coords.shape) 155 | print(c2w_pose.shape) 156 | #b=homogeneous_coords.unsqueeze(-1) 157 | #print(b.shape) 158 | # 将相机坐标系上的点转换到世界坐标系上 159 | world_coords = torch.matmul(homogeneous_coords,c2w_pose) 160 | print(world_coords.shape) 161 | # 去除齐次坐标,保留前三个坐标 162 | world_coords = world_coords[..., :3] 163 | print(world_coords.shape) 164 | return world_coords 165 | """ 166 | 167 | # Example usage 168 | ON = 2 169 | X_RN = 175 170 | RN = 3 171 | H = 7 172 | W = 7 173 | world_coords = torch.randn(ON, X_RN, H, W, 4) # Example world coordinates 174 | poses_ref = torch.randn(ON, RN, 4, 4) # Example reference camera poses 175 | focal_ref = torch.randn(ON, RN, 2) # Example focal lengths of reference cameras 176 | c_ref = torch.randn(ON, RN, 2) # Example principal points of reference cameras 177 | 178 | # Convert world coordinates to pixel coordinates 179 | pixels_ref = world_to_pixel(world_coords, poses_ref, focal_ref, c_ref) 180 | 181 | print("Pixel coordinates shape:", pixels_ref.shape) 182 | """ 183 | #a = torch.randn(3,3) 184 | #b = torch.randn(49,3) 185 | #c= torch.matmul(a,b.transpose(-1, -2)) 186 | #print(c.shape) 187 | 188 | torch.sin(torch.addcmul(self._phases, embed, self._freqs)) -------------------------------------------------------------------------------- /netvlad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | class NetVLADLoupe(nn.Module): 7 | def __init__(self, feature_size, max_samples, cluster_size, output_dim, 8 | gating=True, add_batch_norm=True, is_training=True): 9 | super(NetVLADLoupe, self).__init__() 10 | self.feature_size = feature_size 11 | self.max_samples = max_samples 12 | self.output_dim = output_dim 13 | self.is_training = is_training 14 | self.gating = gating 15 | self.add_batch_norm = add_batch_norm 16 | self.cluster_size = cluster_size 17 | self.softmax = nn.Softmax(dim=-1) 18 | 19 | self.cluster_weights = nn.Parameter(torch.randn( 20 | feature_size, cluster_size) * 1 / math.sqrt(feature_size)) 21 | self.cluster_weights2 = nn.Parameter(torch.randn( 22 | 1, feature_size, cluster_size) * 1 / math.sqrt(feature_size)) 23 | self.hidden1_weights = nn.Parameter(torch.randn( 24 | cluster_size * feature_size, output_dim) * 1 / math.sqrt(feature_size)) 25 | 26 | if add_batch_norm: 27 | self.cluster_biases = None 28 | self.bn1 = nn.BatchNorm1d(cluster_size) 29 | else: 30 | self.cluster_biases = nn.Parameter(torch.randn( 31 | cluster_size) * 1 / math.sqrt(feature_size)) 32 | self.bn1 = None 33 | 34 | self.bn2 = nn.BatchNorm1d(output_dim) 35 | 36 | if gating: 37 | self.context_gating = GatingContext( 38 | output_dim, add_batch_norm=add_batch_norm) 39 | 40 | def forward(self, x): 41 | x = x.transpose(1, 3).contiguous() 42 | x = x.view((-1, self.max_samples, self.feature_size)) 43 | activation = torch.matmul(x, self.cluster_weights) 44 | if self.add_batch_norm: 45 | activation = activation.view(-1, self.cluster_size) 46 | activation = self.bn1(activation) 47 | activation = activation.view(-1, self.max_samples, self.cluster_size) 48 | else: 49 | activation = activation + self.cluster_biases 50 | activation = self.softmax(activation) 51 | activation = activation.view((-1, self.max_samples, self.cluster_size)) 52 | 53 | a_sum = activation.sum(-2, keepdim=True) 54 | a = a_sum * self.cluster_weights2 55 | 56 | activation = torch.transpose(activation, 2, 1) 57 | x = x.view((-1, self.max_samples, self.feature_size)) 58 | vlad = torch.matmul(activation, x) 59 | vlad = torch.transpose(vlad, 2, 1) 60 | vlad = vlad - a 61 | 62 | vlad = F.normalize(vlad, dim=1, p=2) 63 | vlad = vlad.reshape((-1, self.cluster_size * self.feature_size)) 64 | vlad = F.normalize(vlad, dim=1, p=2) 65 | vlad = torch.matmul(vlad, self.hidden1_weights) 66 | 67 | if self.gating: 68 | vlad = self.context_gating(vlad) 69 | 70 | return vlad 71 | 72 | 73 | class GatingContext(nn.Module): 74 | def __init__(self, dim, add_batch_norm=True): 75 | super(GatingContext, self).__init__() 76 | self.dim = dim 77 | self.add_batch_norm = add_batch_norm 78 | self.gating_weights = nn.Parameter( 79 | torch.randn(dim, dim) * 1 / math.sqrt(dim)) 80 | self.sigmoid = nn.Sigmoid() 81 | 82 | if add_batch_norm: 83 | self.gating_biases = None 84 | self.bn1 = nn.BatchNorm1d(dim) 85 | else: 86 | self.gating_biases = nn.Parameter( 87 | torch.randn(dim) * 1 / math.sqrt(dim)) 88 | self.bn1 = None 89 | 90 | def forward(self, x): 91 | gates = torch.matmul(x, self.gating_weights) 92 | 93 | if self.add_batch_norm: 94 | gates = self.bn1(gates) 95 | else: 96 | gates = gates + self.gating_biases 97 | 98 | gates = self.sigmoid(gates) 99 | activation = x * gates 100 | 101 | return activation 102 | 103 | if __name__ == '__main__': 104 | net_vlad = NetVLADLoupe(feature_size=512, max_samples=224, cluster_size=64, 105 | output_dim=256, gating=True, add_batch_norm=False, 106 | is_training=True) 107 | inputs = torch.rand((1, 512, 224, 1)) 108 | outputs_tea = net_vlad(inputs) 109 | print(outputs_tea.shape) 110 | -------------------------------------------------------------------------------- /networks/CricaVPR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | from backbone.vision_transformer import vit_small, vit_base, vit_large, vit_giant2 6 | import math 7 | from sklearn.preprocessing import StandardScaler 8 | from sklearn.decomposition import PCA 9 | import torchvision.models as models 10 | 11 | class GeM(nn.Module): 12 | def __init__(self, p=3, eps=1e-6, work_with_tokens=False): 13 | super().__init__() 14 | self.p = Parameter(torch.ones(1)*p) 15 | self.eps = eps 16 | self.work_with_tokens=work_with_tokens 17 | def forward(self, x): 18 | return gem(x, p=self.p, eps=self.eps, work_with_tokens=self.work_with_tokens) 19 | def __repr__(self): 20 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' 21 | 22 | def gem(x, p=3, eps=1e-6, work_with_tokens=False): 23 | if work_with_tokens: 24 | x = x.permute(0, 2, 1) 25 | # unseqeeze to maintain compatibility with Flatten 26 | return F.avg_pool1d(x.clamp(min=eps).pow(p), (x.size(-1))).pow(1./p).unsqueeze(3) 27 | else: 28 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p) 29 | 30 | class Flatten(nn.Module): 31 | def __init__(self): super().__init__() 32 | def forward(self, x): assert x.shape[2] == x.shape[3] == 1; return x[:,:,0,0] 33 | 34 | class L2Norm(nn.Module): 35 | def __init__(self, dim=1): 36 | super().__init__() 37 | self.dim = dim 38 | def forward(self, x): 39 | return F.normalize(x, p=2, dim=self.dim) 40 | 41 | 42 | class CricaVPRNet(nn.Module): 43 | """The used networks are composed of a backbone and an aggregation layer. 44 | """ 45 | def __init__(self, pretrained_foundation = False, foundation_model_path = None): 46 | super().__init__() 47 | self.backbone = get_backbone(pretrained_foundation, foundation_model_path) 48 | self.aggregation = nn.Sequential(L2Norm(), GeM(work_with_tokens=None), Flatten()) 49 | 50 | # In TransformerEncoderLayer, "batch_first=False" means the input tensors should be provided as (seq, batch, feature) to encode on the "seq" dimension. 51 | # Our input tensor is provided as (batch, seq, feature), which performs encoding on the "batch" dimension. 52 | encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=16, dim_feedforward=2048, activation="gelu", dropout=0.1, batch_first=False) 53 | self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=2) # Cross-image encoder 54 | 55 | self.linear = nn.Linear(10752, 512) 56 | def forward(self, x): 57 | x = self.backbone(x) 58 | B,P,D = x["x_prenorm"].shape 59 | W = H = int(math.sqrt(P-1)) 60 | x0 = x["x_norm_clstoken"] 61 | x_p = x["x_norm_patchtokens"].view(B,W,H,D).permute(0, 3, 1, 2) 62 | feature=x_p 63 | x10,x11,x12,x13 = self.aggregation(x_p[:,:,0:8,0:8]),self.aggregation(x_p[:,:,0:8,8:]),self.aggregation(x_p[:,:,8:,0:8]),self.aggregation(x_p[:,:,8:,8:]) 64 | x20,x21,x22,x23,x24,x25,x26,x27,x28 = self.aggregation(x_p[:,:,0:5,0:5]),self.aggregation(x_p[:,:,0:5,5:11]),self.aggregation(x_p[:,:,0:5,11:]),\ 65 | self.aggregation(x_p[:,:,5:11,0:5]),self.aggregation(x_p[:,:,5:11,5:11]),self.aggregation(x_p[:,:,5:11,11:]),\ 66 | self.aggregation(x_p[:,:,11:,0:5]),self.aggregation(x_p[:,:,11:,5:11]),self.aggregation(x_p[:,:,11:,11:]) 67 | x = [i.unsqueeze(1) for i in [x0,x10,x11,x12,x13,x20,x21,x22,x23,x24,x25,x26,x27,x28]] 68 | x = torch.cat(x,dim=1) 69 | #print(x.shape) 70 | x = self.encoder(x).view(B,14*D) 71 | #print(x.shape) 72 | x=self.linear(x) 73 | x = torch.nn.functional.normalize(x, p=2, dim=-1) 74 | return x,feature 75 | 76 | def get_backbone(pretrained_foundation, foundation_model_path): 77 | backbone = vit_base(patch_size=14,img_size=224,init_values=1,block_chunks=0) 78 | if pretrained_foundation: 79 | assert foundation_model_path is not None, "Please specify foundation model path." 80 | model_dict = backbone.state_dict() 81 | state_dict = torch.load(foundation_model_path) 82 | model_dict.update(state_dict.items()) 83 | backbone.load_state_dict(model_dict) 84 | return backbone 85 | 86 | 87 | class Shen(nn.Module): #整合Vit和resnet 88 | def __init__(self, opt=None): 89 | super().__init__() 90 | self.backbone = CricaVPRNet() 91 | 92 | def forward(self, inputs): 93 | out, feature=self.backbone(inputs) #(B,S,C) 94 | 95 | return out, feature 96 | 97 | 98 | class Backbone(nn.Module): 99 | def __init__(self, opt=None): 100 | super().__init__() 101 | 102 | self.sigma_dim = 2048 103 | self.mu_dim = 2048 104 | 105 | self.backbone = Shen() 106 | 107 | 108 | class Stu_Backbone(nn.Module): 109 | def __init__(self): 110 | super(Stu_Backbone, self).__init__() 111 | self.resnet50 = models.resnet50(pretrained=True) 112 | 113 | 114 | def forward(self, inputs): 115 | #Res branch(1*1024) 116 | outRR = self.resnet50(inputs) 117 | 118 | 119 | return outRR 120 | 121 | 122 | class TeacherNet(Backbone): 123 | def __init__(self, opt=None): 124 | super().__init__() 125 | self.id = 'teacher' 126 | self.mean_head = nn.Sequential(L2Norm(dim=1)) 127 | 128 | def forward(self, inputs): 129 | B, C, H, W = inputs.shape # (B, 1, 3, 224, 224) 130 | # inputs = inputs.view(B * L, C, H, W) # ([B, 3, 224, 224]) 131 | 132 | backbone_output,shen = self.backbone(inputs) # ([B, 2048, 1, 1]) 133 | #print(backbone_output.shape) 134 | #mu = self.mean_head(backbone_output).view(B, -1) # ([B, 2048]) <= ([B, 2048, 1, 1]) 135 | 136 | return backbone_output,shen 137 | 138 | 139 | class StudentNet(TeacherNet): 140 | def __init__(self, opt=None): 141 | super().__init__() 142 | self.id = 'student' 143 | self.var_head = nn.Sequential(nn.Linear(2048, self.sigma_dim), nn.Sigmoid()) 144 | self.backboneS = Stu_Backbone() 145 | def forward(self, inputs): 146 | B, C, H, W = inputs.shape # (B, 1, 3, 224, 224) 147 | inputs = inputs.view(B, C, H, W) # ([B, 3, 224, 224]) 148 | backbone_output = self.backboneS(inputs) 149 | 150 | mu = self.mean_head(backbone_output).view(B, -1) # ([B, 2048]) <= ([B, 2048, 1, 1]) 151 | log_sigma_sq = self.var_head(backbone_output).view(B, -1) # ([B, 2048]) <= ([B, 2048, 1, 1]) 152 | 153 | return mu, log_sigma_sq 154 | 155 | 156 | def deliver_model(opt, id): 157 | if id == 'tea': 158 | return TeacherNet(opt) 159 | elif id == 'stu': 160 | return StudentNet(opt) 161 | 162 | if __name__ == '__main__': 163 | tea = TeacherNet() 164 | #stu = StudentNet() 165 | inputs = torch.rand((1, 3, 224, 224)) 166 | #pretrained_weights_path = '../logs/ckpt_best.pth.tar' 167 | #pretrained_state_dict = torch.load(pretrained_weights_path) 168 | #tea.load_state_dict(pretrained_state_dict["state_dict"]) 169 | outputs_tea,shen= tea(inputs) 170 | print(outputs_tea.shape) 171 | print(shen.shape) -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/UGNA-VPR/1ae3158d84b54affa011236af2cad95e364d8731/networks/__init__.py -------------------------------------------------------------------------------- /networks/__pycache__/LightViT_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/UGNA-VPR/1ae3158d84b54affa011236af2cad95e364d8731/networks/__pycache__/LightViT_model.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/UGNA-VPR/1ae3158d84b54affa011236af2cad95e364d8731/networks/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/mobilenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/UGNA-VPR/1ae3158d84b54affa011236af2cad95e364d8731/networks/__pycache__/mobilenet.cpython-38.pyc -------------------------------------------------------------------------------- /networks/classification/.figures/efficientvit_main.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/UGNA-VPR/1ae3158d84b54affa011236af2cad95e364d8731/networks/classification/.figures/efficientvit_main.gif -------------------------------------------------------------------------------- /networks/classification/.figures/efficientvit_main_static.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/UGNA-VPR/1ae3158d84b54affa011236af2cad95e364d8731/networks/classification/.figures/efficientvit_main_static.png -------------------------------------------------------------------------------- /networks/classification/.figures/modelACC_gpu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/UGNA-VPR/1ae3158d84b54affa011236af2cad95e364d8731/networks/classification/.figures/modelACC_gpu.png -------------------------------------------------------------------------------- /networks/classification/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/UGNA-VPR/1ae3158d84b54affa011236af2cad95e364d8731/networks/classification/data/__init__.py -------------------------------------------------------------------------------- /networks/classification/data/datasets.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Build trainining/testing datasets 3 | ''' 4 | import os 5 | import json 6 | 7 | from torchvision import datasets, transforms 8 | from torchvision.datasets.folder import ImageFolder, default_loader 9 | import torch 10 | 11 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 12 | from timm.data import create_transform 13 | 14 | try: 15 | from timm.data import TimmDatasetTar 16 | except ImportError: 17 | # for higher version of timm 18 | from timm.data import ImageDataset as TimmDatasetTar 19 | 20 | class INatDataset(ImageFolder): 21 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 22 | category='name', loader=default_loader): 23 | self.transform = transform 24 | self.loader = loader 25 | self.target_transform = target_transform 26 | self.year = year 27 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 28 | path_json = os.path.join( 29 | root, f'{"train" if train else "val"}{year}.json') 30 | with open(path_json) as json_file: 31 | data = json.load(json_file) 32 | 33 | with open(os.path.join(root, 'categories.json')) as json_file: 34 | data_catg = json.load(json_file) 35 | 36 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 37 | 38 | with open(path_json_for_targeter) as json_file: 39 | data_for_targeter = json.load(json_file) 40 | 41 | targeter = {} 42 | indexer = 0 43 | for elem in data_for_targeter['annotations']: 44 | king = [] 45 | king.append(data_catg[int(elem['category_id'])][category]) 46 | if king[0] not in targeter.keys(): 47 | targeter[king[0]] = indexer 48 | indexer += 1 49 | self.nb_classes = len(targeter) 50 | 51 | self.samples = [] 52 | for elem in data['images']: 53 | cut = elem['file_name'].split('/') 54 | target_current = int(cut[2]) 55 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 56 | 57 | categors = data_catg[target_current] 58 | target_current_true = targeter[categors[category]] 59 | self.samples.append((path_current, target_current_true)) 60 | 61 | # __getitem__ and __len__ inherited from ImageFolder 62 | 63 | 64 | def build_dataset(is_train, args): 65 | transform = build_transform(is_train, args) 66 | 67 | if args.data_set == 'CIFAR': 68 | dataset = datasets.CIFAR100( 69 | args.data_path, train=is_train, transform=transform) 70 | nb_classes = 100 71 | elif args.data_set == 'IMNET': 72 | prefix = 'train' if is_train else 'val' 73 | data_dir = os.path.join(args.data_path, f'{prefix}.tar') 74 | if os.path.exists(data_dir): 75 | dataset = TimmDatasetTar(data_dir, transform=transform) 76 | else: 77 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 78 | dataset = datasets.ImageFolder(root, transform=transform) 79 | nb_classes = 1000 80 | elif args.data_set == 'IMNETEE': 81 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 82 | dataset = datasets.ImageFolder(root, transform=transform) 83 | nb_classes = 10 84 | elif args.data_set == 'FLOWERS': 85 | root = os.path.join(args.data_path, 'train' if is_train else 'test') 86 | dataset = datasets.ImageFolder(root, transform=transform) 87 | if is_train: 88 | dataset = torch.utils.data.ConcatDataset( 89 | [dataset for _ in range(100)]) 90 | nb_classes = 102 91 | elif args.data_set == 'INAT': 92 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 93 | category=args.inat_category, transform=transform) 94 | nb_classes = dataset.nb_classes 95 | elif args.data_set == 'INAT19': 96 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 97 | category=args.inat_category, transform=transform) 98 | nb_classes = dataset.nb_classes 99 | return dataset, nb_classes 100 | 101 | 102 | def build_transform(is_train, args): 103 | resize_im = args.input_size > 32 104 | if is_train: 105 | # this should always dispatch to transforms_imagenet_train 106 | transform = create_transform( 107 | input_size=args.input_size, 108 | is_training=True, 109 | color_jitter=args.color_jitter, 110 | auto_augment=args.aa, 111 | interpolation=args.train_interpolation, 112 | re_prob=args.reprob, 113 | re_mode=args.remode, 114 | re_count=args.recount, 115 | ) 116 | if not resize_im: 117 | # replace RandomResizedCropAndInterpolation with 118 | # RandomCrop 119 | transform.transforms[0] = transforms.RandomCrop( 120 | args.input_size, padding=4) 121 | return transform 122 | 123 | t = [] 124 | if args.finetune: 125 | t.append( 126 | transforms.Resize((args.input_size, args.input_size), 127 | interpolation=3) 128 | ) 129 | else: 130 | if resize_im: 131 | size = int((256 / 224) * args.input_size) 132 | t.append( 133 | # to maintain same ratio w.r.t. 224 images 134 | transforms.Resize(size, interpolation=3), 135 | ) 136 | t.append(transforms.CenterCrop(args.input_size)) 137 | 138 | t.append(transforms.ToTensor()) 139 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 140 | return transforms.Compose(t) 141 | -------------------------------------------------------------------------------- /networks/classification/data/samplers.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Build samplers for data loading 3 | ''' 4 | import torch 5 | import torch.distributed as dist 6 | import math 7 | 8 | 9 | class RASampler(torch.utils.data.Sampler): 10 | """Sampler that restricts data loading to a subset of the dataset for distributed, 11 | with repeated augmentation. 12 | It ensures that different each augmented version of a sample will be visible to a 13 | different process (GPU) 14 | Heavily based on torch.utils.data.DistributedSampler 15 | """ 16 | 17 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 18 | if num_replicas is None: 19 | if not dist.is_available(): 20 | raise RuntimeError( 21 | "Requires distributed package to be available") 22 | num_replicas = dist.get_world_size() 23 | if rank is None: 24 | if not dist.is_available(): 25 | raise RuntimeError( 26 | "Requires distributed package to be available") 27 | rank = dist.get_rank() 28 | self.dataset = dataset 29 | self.num_replicas = num_replicas 30 | self.rank = rank 31 | self.epoch = 0 32 | self.num_samples = int( 33 | math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 34 | self.total_size = self.num_samples * self.num_replicas 35 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 36 | self.num_selected_samples = int(math.floor( 37 | len(self.dataset) // 256 * 256 / self.num_replicas)) 38 | self.shuffle = shuffle 39 | 40 | def __iter__(self): 41 | # deterministically shuffle based on epoch 42 | g = torch.Generator() 43 | g.manual_seed(self.epoch) 44 | if self.shuffle: 45 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 46 | else: 47 | indices = list(range(len(self.dataset))) 48 | 49 | # add extra samples to make it evenly divisible 50 | indices = [ele for ele in indices for i in range(3)] 51 | indices += indices[:(self.total_size - len(indices))] 52 | assert len(indices) == self.total_size 53 | 54 | # subsample 55 | indices = indices[self.rank:self.total_size:self.num_replicas] 56 | assert len(indices) == self.num_samples 57 | 58 | return iter(indices[:self.num_selected_samples]) 59 | 60 | def __len__(self): 61 | return self.num_selected_samples 62 | 63 | def set_epoch(self, epoch): 64 | self.epoch = epoch 65 | -------------------------------------------------------------------------------- /networks/classification/data/threeaugment.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3Augment implementation from (https://github.com/facebookresearch/deit/blob/main/augment.py) 3 | Data-augmentation (DA) based on dino DA (https://github.com/facebookresearch/dino) 4 | and timm DA(https://github.com/rwightman/pytorch-image-models) 5 | Can be called by adding "--ThreeAugment" to the command line 6 | """ 7 | import torch 8 | from torchvision import transforms 9 | 10 | from timm.data.transforms import str_to_pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor 11 | 12 | import numpy as np 13 | from torchvision import datasets, transforms 14 | import random 15 | 16 | 17 | 18 | from PIL import ImageFilter, ImageOps 19 | import torchvision.transforms.functional as TF 20 | 21 | 22 | class GaussianBlur(object): 23 | """ 24 | Apply Gaussian Blur to the PIL image. 25 | """ 26 | def __init__(self, p=0.1, radius_min=0.1, radius_max=2.): 27 | self.prob = p 28 | self.radius_min = radius_min 29 | self.radius_max = radius_max 30 | 31 | def __call__(self, img): 32 | do_it = random.random() <= self.prob 33 | if not do_it: 34 | return img 35 | 36 | img = img.filter( 37 | ImageFilter.GaussianBlur( 38 | radius=random.uniform(self.radius_min, self.radius_max) 39 | ) 40 | ) 41 | return img 42 | 43 | class Solarization(object): 44 | """ 45 | Apply Solarization to the PIL image. 46 | """ 47 | def __init__(self, p=0.2): 48 | self.p = p 49 | 50 | def __call__(self, img): 51 | if random.random() < self.p: 52 | return ImageOps.solarize(img) 53 | else: 54 | return img 55 | 56 | class gray_scale(object): 57 | """ 58 | Apply Solarization to the PIL image. 59 | """ 60 | def __init__(self, p=0.2): 61 | self.p = p 62 | self.transf = transforms.Grayscale(3) 63 | 64 | def __call__(self, img): 65 | if random.random() < self.p: 66 | return self.transf(img) 67 | else: 68 | return img 69 | 70 | 71 | 72 | class horizontal_flip(object): 73 | """ 74 | Apply Solarization to the PIL image. 75 | """ 76 | def __init__(self, p=0.2,activate_pred=False): 77 | self.p = p 78 | self.transf = transforms.RandomHorizontalFlip(p=1.0) 79 | 80 | def __call__(self, img): 81 | if random.random() < self.p: 82 | return self.transf(img) 83 | else: 84 | return img 85 | 86 | 87 | 88 | def new_data_aug_generator(args = None): 89 | img_size = args.input_size 90 | remove_random_resized_crop = False 91 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 92 | primary_tfl = [] 93 | scale=(0.08, 1.0) 94 | interpolation='bicubic' 95 | if remove_random_resized_crop: 96 | primary_tfl = [ 97 | transforms.Resize(img_size, interpolation=3), 98 | transforms.RandomCrop(img_size, padding=4,padding_mode='reflect'), 99 | transforms.RandomHorizontalFlip() 100 | ] 101 | else: 102 | primary_tfl = [ 103 | RandomResizedCropAndInterpolation( 104 | img_size, scale=scale, interpolation=interpolation), 105 | transforms.RandomHorizontalFlip() 106 | ] 107 | 108 | 109 | secondary_tfl = [transforms.RandomChoice([gray_scale(p=1.0), 110 | Solarization(p=1.0), 111 | GaussianBlur(p=1.0)])] 112 | 113 | if args.color_jitter is not None and not args.color_jitter==0: 114 | secondary_tfl.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter)) 115 | final_tfl = [ 116 | transforms.ToTensor(), 117 | transforms.Normalize( 118 | mean=torch.tensor(mean), 119 | std=torch.tensor(std)) 120 | ] 121 | return transforms.Compose(primary_tfl+secondary_tfl+final_tfl) 122 | -------------------------------------------------------------------------------- /networks/classification/engine.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train and eval functions used in main.py 3 | """ 4 | import math 5 | import sys 6 | from typing import Iterable, Optional 7 | 8 | import torch 9 | 10 | from timm.data import Mixup 11 | from timm.utils import accuracy, ModelEma 12 | 13 | from losses import DistillationLoss 14 | import utils 15 | 16 | def set_bn_state(model): 17 | for m in model.modules(): 18 | if isinstance(m, torch.nn.modules.batchnorm._BatchNorm): 19 | m.eval() 20 | 21 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 22 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 23 | device: torch.device, epoch: int, loss_scaler, 24 | clip_grad: float = 0, 25 | clip_mode: str = 'norm', 26 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 27 | set_training_mode=True, 28 | set_bn_eval=False,): 29 | model.train(set_training_mode) 30 | if set_bn_eval: 31 | set_bn_state(model) 32 | metric_logger = utils.MetricLogger(delimiter=" ") 33 | metric_logger.add_meter('lr', utils.SmoothedValue( 34 | window_size=1, fmt='{value:.6f}')) 35 | header = 'Epoch: [{}]'.format(epoch) 36 | print_freq = 100 37 | 38 | for samples, targets in metric_logger.log_every( 39 | data_loader, print_freq, header): 40 | samples = samples.to(device, non_blocking=True) 41 | targets = targets.to(device, non_blocking=True) 42 | 43 | if mixup_fn is not None: 44 | samples, targets = mixup_fn(samples, targets) 45 | 46 | if True: # with torch.cuda.amp.autocast(): 47 | outputs = model(samples) 48 | loss = criterion(samples, outputs, targets) 49 | 50 | loss_value = loss.item() 51 | 52 | if not math.isfinite(loss_value): 53 | print("Loss is {}, stopping training".format(loss_value)) 54 | sys.exit(1) 55 | 56 | optimizer.zero_grad() 57 | 58 | # this attribute is added by timm on one optimizer (adahessian) 59 | is_second_order = hasattr( 60 | optimizer, 'is_second_order') and optimizer.is_second_order 61 | loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode, 62 | parameters=model.parameters(), create_graph=is_second_order) 63 | 64 | torch.cuda.synchronize() 65 | if model_ema is not None: 66 | model_ema.update(model) 67 | 68 | metric_logger.update(loss=loss_value) 69 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 70 | # gather the stats from all processes 71 | metric_logger.synchronize_between_processes() 72 | print("Averaged stats:", metric_logger) 73 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 74 | 75 | 76 | @torch.no_grad() 77 | def evaluate(data_loader, model, device): 78 | criterion = torch.nn.CrossEntropyLoss() 79 | 80 | metric_logger = utils.MetricLogger(delimiter=" ") 81 | header = 'Test:' 82 | 83 | # switch to evaluation mode 84 | model.eval() 85 | 86 | for images, target in metric_logger.log_every(data_loader, 10, header): 87 | images = images.to(device, non_blocking=True) 88 | target = target.to(device, non_blocking=True) 89 | 90 | # compute output 91 | with torch.cuda.amp.autocast(): 92 | output = model(images) 93 | loss = criterion(output, target) 94 | 95 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 96 | 97 | batch_size = images.shape[0] 98 | metric_logger.update(loss=loss.item()) 99 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 100 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 101 | # gather the stats from all processes 102 | metric_logger.synchronize_between_processes() 103 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 104 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 105 | 106 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 107 | -------------------------------------------------------------------------------- /networks/classification/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements the knowledge distillation loss, proposed in deit 3 | """ 4 | import torch 5 | from torch.nn import functional as F 6 | 7 | 8 | class DistillationLoss(torch.nn.Module): 9 | """ 10 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 11 | taking a teacher model prediction and using it as additional supervision. 12 | """ 13 | 14 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 15 | distillation_type: str, alpha: float, tau: float): 16 | super().__init__() 17 | self.base_criterion = base_criterion 18 | self.teacher_model = teacher_model 19 | assert distillation_type in ['none', 'soft', 'hard'] 20 | self.distillation_type = distillation_type 21 | self.alpha = alpha 22 | self.tau = tau 23 | 24 | def forward(self, inputs, outputs, labels): 25 | """ 26 | Args: 27 | inputs: The original inputs that are feed to the teacher model 28 | outputs: the outputs of the model to be trained. It is expected to be 29 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 30 | in the first position and the distillation predictions as the second output 31 | labels: the labels for the base criterion 32 | """ 33 | outputs_kd = None 34 | if not isinstance(outputs, torch.Tensor): 35 | # assume that the model outputs a tuple of [outputs, outputs_kd] 36 | outputs, outputs_kd = outputs 37 | base_loss = self.base_criterion(outputs, labels) 38 | if self.distillation_type == 'none': 39 | return base_loss 40 | 41 | if outputs_kd is None: 42 | raise ValueError("When knowledge distillation is enabled, the model is " 43 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 44 | "class_token and the dist_token") 45 | # don't backprop throught the teacher 46 | with torch.no_grad(): 47 | teacher_outputs = self.teacher_model(inputs) 48 | 49 | if self.distillation_type == 'soft': 50 | T = self.tau 51 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 52 | # with slight modifications 53 | distillation_loss = F.kl_div( 54 | F.log_softmax(outputs_kd / T, dim=1), 55 | F.log_softmax(teacher_outputs / T, dim=1), 56 | reduction='sum', 57 | log_target=True 58 | ) * (T * T) / outputs_kd.numel() 59 | elif self.distillation_type == 'hard': 60 | distillation_loss = F.cross_entropy( 61 | outputs_kd, teacher_outputs.argmax(dim=1)) 62 | 63 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 64 | return loss 65 | -------------------------------------------------------------------------------- /networks/classification/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/UGNA-VPR/1ae3158d84b54affa011236af2cad95e364d8731/networks/classification/model/__init__.py -------------------------------------------------------------------------------- /networks/classification/model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/UGNA-VPR/1ae3158d84b54affa011236af2cad95e364d8731/networks/classification/model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /networks/classification/model/__pycache__/build.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/UGNA-VPR/1ae3158d84b54affa011236af2cad95e364d8731/networks/classification/model/__pycache__/build.cpython-38.pyc -------------------------------------------------------------------------------- /networks/classification/model/__pycache__/efficientvit.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/UGNA-VPR/1ae3158d84b54affa011236af2cad95e364d8731/networks/classification/model/__pycache__/efficientvit.cpython-38.pyc -------------------------------------------------------------------------------- /networks/classification/model/build.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Build the EfficientViT model family 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from .efficientvit import EfficientViT 8 | from timm.models.registry import register_model 9 | 10 | EfficientViT_m0 = { 11 | 'img_size': 224, 12 | 'patch_size': 16, 13 | 'embed_dim': [64, 128, 192], 14 | 'depth': [1, 2, 3], 15 | 'num_heads': [4, 4, 4], 16 | 'window_size': [7, 7, 7], 17 | 'kernels': [5, 5, 5, 5], 18 | } 19 | 20 | EfficientViT_m1 = { 21 | 'img_size': 224, 22 | 'patch_size': 16, 23 | 'embed_dim': [128, 144, 192], 24 | 'depth': [1, 2, 3], 25 | 'num_heads': [2, 3, 3], 26 | 'window_size': [7, 7, 7], 27 | 'kernels': [7, 5, 3, 3], 28 | } 29 | 30 | EfficientViT_m2 = { 31 | 'img_size': 224, 32 | 'patch_size': 16, 33 | 'embed_dim': [128, 192, 224], 34 | 'depth': [1, 2, 3], 35 | 'num_heads': [4, 3, 2], 36 | 'window_size': [7, 7, 7], 37 | 'kernels': [7, 5, 3, 3], 38 | } 39 | 40 | EfficientViT_m3 = { 41 | 'img_size': 224, 42 | 'patch_size': 16, 43 | 'embed_dim': [128, 240, 320], 44 | 'depth': [1, 2, 3], 45 | 'num_heads': [4, 3, 4], 46 | 'window_size': [7, 7, 7], 47 | 'kernels': [5, 5, 5, 5], 48 | } 49 | 50 | EfficientViT_m4 = { 51 | 'img_size': 224, 52 | 'patch_size': 16, 53 | 'embed_dim': [128, 256, 384], 54 | 'depth': [1, 2, 3], 55 | 'num_heads': [4, 4, 4], 56 | 'window_size': [7, 7, 7], 57 | 'kernels': [7, 5, 3, 3], 58 | } 59 | 60 | EfficientViT_m5 = { 61 | 'img_size': 224, 62 | 'patch_size': 16, 63 | 'embed_dim': [192, 288, 384], 64 | 'depth': [1, 3, 4], 65 | 'num_heads': [3, 3, 4], 66 | 'window_size': [7, 7, 7], 67 | 'kernels': [7, 5, 3, 3], 68 | } 69 | 70 | 71 | @register_model 72 | def EfficientViT_M0(num_classes=1000, pretrained=False, distillation=False, fuse=False, pretrained_cfg=None, model_cfg=EfficientViT_m0): 73 | model = EfficientViT(num_classes=num_classes, distillation=distillation, **model_cfg) 74 | if pretrained: 75 | pretrained = _checkpoint_url_format.format(pretrained) 76 | checkpoint = torch.hub.load_state_dict_from_url( 77 | pretrained, map_location='cpu') 78 | d = checkpoint['model'] 79 | D = model.state_dict() 80 | for k in d.keys(): 81 | if D[k].shape != d[k].shape: 82 | d[k] = d[k][:, :, None, None] 83 | model.load_state_dict(d) 84 | if fuse: 85 | replace_batchnorm(model) 86 | return model 87 | 88 | @register_model 89 | def EfficientViT_M1(num_classes=1000, pretrained=False, distillation=False, fuse=False, pretrained_cfg=None, model_cfg=EfficientViT_m1): 90 | model = EfficientViT(num_classes=num_classes, distillation=distillation, **model_cfg) 91 | if pretrained: 92 | pretrained = _checkpoint_url_format.format(pretrained) 93 | checkpoint = torch.hub.load_state_dict_from_url( 94 | pretrained, map_location='cpu') 95 | d = checkpoint['model'] 96 | D = model.state_dict() 97 | for k in d.keys(): 98 | if D[k].shape != d[k].shape: 99 | d[k] = d[k][:, :, None, None] 100 | model.load_state_dict(d) 101 | if fuse: 102 | replace_batchnorm(model) 103 | return model 104 | 105 | @register_model 106 | def EfficientViT_M2(num_classes=1000, pretrained=False, distillation=False, fuse=False, pretrained_cfg=None, model_cfg=EfficientViT_m2): 107 | model = EfficientViT(num_classes=num_classes, distillation=distillation, **model_cfg) 108 | if pretrained: 109 | pretrained = _checkpoint_url_format.format(pretrained) 110 | checkpoint = torch.hub.load_state_dict_from_url( 111 | pretrained, map_location='cpu') 112 | d = checkpoint['model'] 113 | D = model.state_dict() 114 | for k in d.keys(): 115 | if D[k].shape != d[k].shape: 116 | d[k] = d[k][:, :, None, None] 117 | model.load_state_dict(d) 118 | if fuse: 119 | replace_batchnorm(model) 120 | return model 121 | 122 | @register_model 123 | def EfficientViT_M3(num_classes=1000, pretrained=False, distillation=False, fuse=False, pretrained_cfg=None, model_cfg=EfficientViT_m3): 124 | model = EfficientViT(num_classes=num_classes, distillation=distillation, **model_cfg) 125 | if pretrained: 126 | pretrained = _checkpoint_url_format.format(pretrained) 127 | checkpoint = torch.hub.load_state_dict_from_url( 128 | pretrained, map_location='cpu') 129 | d = checkpoint['model'] 130 | D = model.state_dict() 131 | for k in d.keys(): 132 | if D[k].shape != d[k].shape: 133 | d[k] = d[k][:, :, None, None] 134 | model.load_state_dict(d) 135 | if fuse: 136 | replace_batchnorm(model) 137 | return model 138 | 139 | @register_model 140 | def EfficientViT_M4(num_classes=1000, pretrained=False, distillation=False, fuse=False, pretrained_cfg=None, model_cfg=EfficientViT_m4): 141 | model = EfficientViT(num_classes=num_classes, distillation=distillation, **model_cfg) 142 | if pretrained: 143 | pretrained = _checkpoint_url_format.format(pretrained) 144 | checkpoint = torch.hub.load_state_dict_from_url( 145 | pretrained, map_location='cpu') 146 | d = checkpoint['model'] 147 | D = model.state_dict() 148 | for k in d.keys(): 149 | if D[k].shape != d[k].shape: 150 | d[k] = d[k][:, :, None, None] 151 | model.load_state_dict(d) 152 | if fuse: 153 | replace_batchnorm(model) 154 | return model 155 | 156 | @register_model 157 | def EfficientViT_M5(num_classes=1000, pretrained=False, distillation=False, fuse=False, pretrained_cfg=None, model_cfg=EfficientViT_m5): 158 | model = EfficientViT(num_classes=num_classes, distillation=distillation, **model_cfg) 159 | if pretrained: 160 | pretrained = _checkpoint_url_format.format(pretrained) 161 | checkpoint = torch.hub.load_state_dict_from_url( 162 | pretrained, map_location='cpu') 163 | d = checkpoint['model'] 164 | D = model.state_dict() 165 | for k in d.keys(): 166 | if D[k].shape != d[k].shape: 167 | d[k] = d[k][:, :, None, None] 168 | model.load_state_dict(d) 169 | if fuse: 170 | replace_batchnorm(model) 171 | return model 172 | 173 | def replace_batchnorm(net): 174 | for child_name, child in net.named_children(): 175 | if hasattr(child, 'fuse'): 176 | setattr(net, child_name, child.fuse()) 177 | elif isinstance(child, torch.nn.BatchNorm2d): 178 | setattr(net, child_name, torch.nn.Identity()) 179 | else: 180 | replace_batchnorm(child) 181 | 182 | _checkpoint_url_format = \ 183 | 'https://github.com/xinyuliu-jeffrey/EfficientViT_Model_Zoo/releases/download/v1.0/{}.pth' 184 | -------------------------------------------------------------------------------- /networks/classification/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.11.0 2 | torchvision 3 | timm==0.5.4 4 | einops==0.4.1 5 | fvcore 6 | easydict 7 | matplotlib 8 | numpy==1.21.0 9 | yacs 10 | scikit-image==0.19.3 11 | pillow 12 | -------------------------------------------------------------------------------- /networks/classification/speed_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Testing the speed of different models 3 | """ 4 | import os 5 | import torch 6 | import torchvision 7 | import time 8 | import timm 9 | from model.build import EfficientViT_M0, EfficientViT_M1, EfficientViT_M2, EfficientViT_M3, EfficientViT_M4, EfficientViT_M5 10 | import torchvision 11 | import utils 12 | torch.autograd.set_grad_enabled(False) 13 | 14 | 15 | T0 = 10 16 | T1 = 60 17 | 18 | 19 | def compute_throughput_cpu(name, model, device, batch_size, resolution=224): 20 | inputs = torch.randn(batch_size, 3, resolution, resolution, device=device) 21 | # warmup 22 | start = time.time() 23 | while time.time() - start < T0: 24 | model(inputs) 25 | 26 | timing = [] 27 | while sum(timing) < T1: 28 | start = time.time() 29 | model(inputs) 30 | timing.append(time.time() - start) 31 | timing = torch.as_tensor(timing, dtype=torch.float32) 32 | print(name, device, batch_size / timing.mean().item(), 33 | 'images/s @ batch size', batch_size) 34 | 35 | def compute_throughput_cuda(name, model, device, batch_size, resolution=224): 36 | inputs = torch.randn(batch_size, 3, resolution, resolution, device=device) 37 | torch.cuda.empty_cache() 38 | torch.cuda.synchronize() 39 | start = time.time() 40 | with torch.cuda.amp.autocast(): 41 | while time.time() - start < T0: 42 | model(inputs) 43 | timing = [] 44 | if device == 'cuda:0': 45 | torch.cuda.synchronize() 46 | with torch.cuda.amp.autocast(): 47 | while sum(timing) < T1: 48 | start = time.time() 49 | model(inputs) 50 | torch.cuda.synchronize() 51 | timing.append(time.time() - start) 52 | timing = torch.as_tensor(timing, dtype=torch.float32) 53 | print(name, device, batch_size / timing.mean().item(), 54 | 'images/s @ batch size', batch_size) 55 | 56 | for device in ['cuda:0', 'cpu']: 57 | 58 | if 'cuda' in device and not torch.cuda.is_available(): 59 | print("no cuda") 60 | continue 61 | 62 | if device == 'cpu': 63 | os.system('echo -n "nb processors "; ' 64 | 'cat /proc/cpuinfo | grep ^processor | wc -l; ' 65 | 'cat /proc/cpuinfo | grep ^"model name" | tail -1') 66 | print('Using 1 cpu thread') 67 | torch.set_num_threads(1) 68 | compute_throughput = compute_throughput_cpu 69 | else: 70 | print(torch.cuda.get_device_name(torch.cuda.current_device())) 71 | compute_throughput = compute_throughput_cuda 72 | 73 | for n, batch_size0, resolution in [ 74 | ('EfficientViT_M0', 2048, 224), 75 | ('EfficientViT_M1', 2048, 224), 76 | ('EfficientViT_M2', 2048, 224), 77 | ('EfficientViT_M3', 2048, 224), 78 | ('EfficientViT_M4', 2048, 224), 79 | ('EfficientViT_M5', 2048, 224), 80 | ]: 81 | 82 | if device == 'cpu': 83 | batch_size = 16 84 | else: 85 | batch_size = batch_size0 86 | torch.cuda.empty_cache() 87 | inputs = torch.randn(batch_size, 3, resolution, 88 | resolution, device=device) 89 | model = eval(n)(num_classes=1000) 90 | utils.replace_batchnorm(model) 91 | model.to(device) 92 | model.eval() 93 | model = torch.jit.trace(model, inputs) 94 | compute_throughput(n, model, device, 95 | batch_size, resolution=resolution) 96 | -------------------------------------------------------------------------------- /networks/efficientViT.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import sys 3 | 4 | sys.path.append('..') 5 | from re import L 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.parameter import Parameter 9 | import cirtorch.functional as LF 10 | import math 11 | import torch.nn.functional as F 12 | import torch 13 | import timm 14 | from einops import rearrange, reduce, repeat 15 | from einops.layers.torch import Rearrange, Reduce 16 | import torchvision.models as models 17 | from netvlad import NetVLADLoupe 18 | from torch import Tensor 19 | from classification.model.build import EfficientViT_M4 20 | from torchvision.models import mobilenet_v2 21 | class L2Norm(nn.Module): 22 | def __init__(self, dim=1): 23 | super().__init__() 24 | self.dim = dim 25 | 26 | def forward(self, input): 27 | return F.normalize(input, p=2, dim=self.dim) 28 | 29 | 30 | class GeM(nn.Module): 31 | def __init__(self, p=3, eps=1e-6): 32 | super(GeM, self).__init__() 33 | self.p = Parameter(torch.ones(1) * p) 34 | self.eps = eps 35 | 36 | def forward(self, x): 37 | return LF.gem(x, p=self.p, eps=self.eps) 38 | 39 | def __repr__(self): 40 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' 41 | 42 | 43 | class FeedForward(nn.Module): 44 | def __init__(self, d_model, d_ff=1024, dropout=0.1): 45 | super().__init__() 46 | 47 | self.linear_1 = nn.Linear(d_model, d_ff) 48 | self.dropout = nn.Dropout(dropout) 49 | self.linear_2 = nn.Linear(d_ff, d_model) 50 | 51 | def forward(self, x): 52 | x = self.dropout(F.relu(self.linear_1(x))) 53 | x = self.linear_2(x) 54 | return x 55 | 56 | 57 | class Norm(nn.Module): 58 | def __init__(self, d_model, eps=1e-6): 59 | super().__init__() 60 | 61 | self.size = d_model 62 | self.alpha = nn.Parameter(torch.ones(self.size)) 63 | self.bias = nn.Parameter(torch.zeros(self.size)) 64 | self.eps = eps 65 | 66 | def forward(self, x): 67 | norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) \ 68 | / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias 69 | return norm 70 | 71 | 72 | 73 | 74 | 75 | class Shen(nn.Module): #整合Vit和resnet 76 | def __init__(self, opt=None): 77 | super().__init__() 78 | heads = 4 79 | d_model = 512 80 | dropout = 0.1 81 | efficientViT = EfficientViT_M4(pretrained='efficientvit_m4') 82 | featuresefficientViT = list(efficientViT.children())[:-1] 83 | self.backbone = nn.Sequential(*featuresefficientViT) 84 | #self.backbone = efficientViT 85 | #self.backbone = mobilenet_v2 86 | self.linear = nn.Sequential( 87 | nn.Flatten(), # 展平操作 88 | nn.Dropout(p=0.2), # Dropout 层 89 | nn.Linear(in_features=384 * 4 * 4, out_features=512) # 全连接层 90 | ) 91 | 92 | def forward(self, inputs): 93 | #ViT branch 94 | out=self.backbone(inputs) #(B,S,C) 95 | feature = out 96 | out = self.linear(out) 97 | #print(out.shape) 98 | 99 | return out,feature 100 | 101 | 102 | class ClassificationHead(nn.Sequential): 103 | def __init__(self, emb_size: int = 768, n_classes: int = 1000): 104 | super().__init__( 105 | Reduce('b n e -> b e', reduction='mean')) 106 | 107 | class Backbone(nn.Module): 108 | def __init__(self, opt=None): 109 | super().__init__() 110 | 111 | self.sigma_dim = 2048 112 | self.mu_dim = 2048 113 | 114 | self.backbone = Shen() 115 | 116 | 117 | class Stu_Backbone(nn.Module): 118 | def __init__(self): 119 | super(Stu_Backbone, self).__init__() 120 | self.resnet50 = models.resnet50(pretrained=True) 121 | 122 | 123 | def forward(self, inputs): 124 | #Res branch(1*1024) 125 | outRR = self.resnet50(inputs) 126 | 127 | return outRR 128 | 129 | 130 | class TeacherNet(Backbone): 131 | def __init__(self, opt=None): 132 | super().__init__() 133 | self.id = 'teacher' 134 | self.mean_head = nn.Sequential(L2Norm(dim=1)) 135 | 136 | def forward(self, inputs): 137 | B, C, H, W = inputs.shape # (B, 1, 3, 224, 224) 138 | # inputs = inputs.view(B * L, C, H, W) # ([B, 3, 224, 224]) 139 | 140 | backbone_output,shen = self.backbone(inputs) # ([B, 2048, 1, 1]) 141 | #print(backbone_output.shape) 142 | mu = self.mean_head(backbone_output).view(B, -1) # ([B, 2048]) <= ([B, 2048, 1, 1]) # ([B, 2048]) <= ([B, 2048, 1, 1]) 143 | 144 | return mu, shen 145 | 146 | 147 | class StudentNet(TeacherNet): 148 | def __init__(self, opt=None): 149 | super().__init__() 150 | self.id = 'student' 151 | self.var_head = nn.Sequential(nn.Linear(2048, self.sigma_dim), nn.Sigmoid()) 152 | self.backboneS = Stu_Backbone() 153 | def forward(self, inputs): 154 | B, C, H, W = inputs.shape # (B, 1, 3, 224, 224) 155 | inputs = inputs.view(B, C, H, W) # ([B, 3, 224, 224]) 156 | backbone_output = self.backboneS(inputs) 157 | 158 | mu = self.mean_head(backbone_output).view(B, -1) # ([B, 2048]) <= ([B, 2048, 1, 1]) 159 | log_sigma_sq = self.var_head(backbone_output).view(B, -1) # ([B, 2048]) <= ([B, 2048, 1, 1]) 160 | 161 | return mu, log_sigma_sq 162 | 163 | 164 | def deliver_model(opt, id): 165 | if id == 'tea': 166 | return TeacherNet(opt) 167 | elif id == 'stu': 168 | return StudentNet(opt) 169 | 170 | 171 | if __name__ == '__main__': 172 | tea = TeacherNet() 173 | stu = StudentNet() 174 | inputs = torch.rand((1, 3, 224, 224)) 175 | outputs_tea = tea(inputs) 176 | #outputs_stu = stu(inputs) 177 | # print(outputs_stu.shape) 178 | # print(tea.state_dict()) 179 | print(outputs_tea[0].shape, outputs_tea[1].shape) 180 | #print(outputs_stu[0].shape, outputs_stu[1].shape) 181 | num_params = sum(p.numel() for p in tea.parameters()) 182 | print(f"Number of parameters: {num_params}") -------------------------------------------------------------------------------- /networks/eigenplaces.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.parameter import Parameter 5 | import cirtorch.functional as LF 6 | import math 7 | import torch.nn.functional as F 8 | from einops.layers.torch import Rearrange, Reduce 9 | import torchvision.models as models 10 | from pytorch_lightning.callbacks import Callback, ModelCheckpoint 11 | from torch.optim import lr_scheduler, optimizer 12 | import utils 13 | from torch import Tensor 14 | # from dataloaders.GSVCitiesDataloader import GSVCitiesDataModule 15 | from models import helper 16 | from eigenplaces_model import eigenplaces_network 17 | 18 | class L2Norm(nn.Module): 19 | def __init__(self, dim=1): 20 | super().__init__() 21 | self.dim = dim 22 | 23 | def forward(self, input): 24 | return F.normalize(input, p=2, dim=self.dim) 25 | 26 | 27 | class PatchEmbedding(nn.Module): 28 | def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768): 29 | self.patch_size = patch_size 30 | super().__init__() 31 | self.projection = nn.Sequential( 32 | # 使用一个卷积层而不是一个线性层 -> 性能增加 33 | nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size), 34 | # 将卷积操作后的patch铺平 35 | Rearrange('b e h w -> b (h w) e'), 36 | ) 37 | 38 | def forward(self, x: Tensor) -> Tensor: 39 | x = self.projection(x) 40 | return x 41 | 42 | 43 | class GeM(nn.Module): 44 | def __init__(self, p=3, eps=1e-6): 45 | super(GeM, self).__init__() 46 | self.p = Parameter(torch.ones(1) * p) 47 | self.eps = eps 48 | 49 | def forward(self, x): 50 | return LF.gem(x, p=self.p, eps=self.eps) 51 | 52 | def __repr__(self): 53 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' 54 | 55 | 56 | class FeedForward(nn.Module): 57 | def __init__(self, d_model, d_ff=1024, dropout=0.1): 58 | super().__init__() 59 | 60 | self.linear_1 = nn.Linear(d_model, d_ff) 61 | self.dropout = nn.Dropout(dropout) 62 | self.linear_2 = nn.Linear(d_ff, d_model) 63 | 64 | def forward(self, x): 65 | x = self.dropout(F.relu(self.linear_1(x))) 66 | x = self.linear_2(x) 67 | return x 68 | 69 | 70 | class Norm(nn.Module): 71 | def __init__(self, d_model, eps=1e-6): 72 | super().__init__() 73 | 74 | self.size = d_model 75 | self.alpha = nn.Parameter(torch.ones(self.size)) 76 | self.bias = nn.Parameter(torch.zeros(self.size)) 77 | self.eps = eps 78 | 79 | def forward(self, x): 80 | norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) \ 81 | / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias 82 | return norm 83 | 84 | 85 | class Shen(nn.Module): #整合Vit和resnet 86 | def __init__(self, opt=None): 87 | super().__init__() 88 | self.backbone = eigenplaces_network.GeoLocalizationNet_("ResNet50", 512) 89 | 90 | def forward(self, inputs): 91 | out, feature=self.backbone(inputs) #(B,S,C) 92 | 93 | return out, feature 94 | 95 | 96 | class Backbone(nn.Module): 97 | def __init__(self, opt=None): 98 | super().__init__() 99 | 100 | self.sigma_dim = 2048 101 | self.mu_dim = 2048 102 | 103 | self.backbone = Shen() 104 | 105 | 106 | class Stu_Backbone(nn.Module): 107 | def __init__(self): 108 | super(Stu_Backbone, self).__init__() 109 | self.resnet50 = models.resnet50(pretrained=True) 110 | 111 | 112 | def forward(self, inputs): 113 | #Res branch(1*1024) 114 | outRR = self.resnet50(inputs) 115 | 116 | 117 | return outRR 118 | 119 | 120 | class TeacherNet(Backbone): 121 | def __init__(self, opt=None): 122 | super().__init__() 123 | self.id = 'teacher' 124 | self.mean_head = nn.Sequential(L2Norm(dim=1)) 125 | 126 | def forward(self, inputs): 127 | B, C, H, W = inputs.shape # (B, 1, 3, 224, 224) 128 | # inputs = inputs.view(B * L, C, H, W) # ([B, 3, 224, 224]) 129 | 130 | backbone_output,shen = self.backbone(inputs) # ([B, 2048, 1, 1]) 131 | #print(backbone_output.shape) 132 | #mu = self.mean_head(backbone_output).view(B, -1) # ([B, 2048]) <= ([B, 2048, 1, 1]) 133 | 134 | return backbone_output,shen 135 | 136 | 137 | class StudentNet(TeacherNet): 138 | def __init__(self, opt=None): 139 | super().__init__() 140 | self.id = 'student' 141 | self.var_head = nn.Sequential(nn.Linear(2048, self.sigma_dim), nn.Sigmoid()) 142 | self.backboneS = Stu_Backbone() 143 | def forward(self, inputs): 144 | B, C, H, W = inputs.shape # (B, 1, 3, 224, 224) 145 | inputs = inputs.view(B, C, H, W) # ([B, 3, 224, 224]) 146 | backbone_output = self.backboneS(inputs) 147 | 148 | mu = self.mean_head(backbone_output).view(B, -1) # ([B, 2048]) <= ([B, 2048, 1, 1]) 149 | log_sigma_sq = self.var_head(backbone_output).view(B, -1) # ([B, 2048]) <= ([B, 2048, 1, 1]) 150 | 151 | return mu, log_sigma_sq 152 | 153 | 154 | def deliver_model(opt, id): 155 | if id == 'tea': 156 | return TeacherNet(opt) 157 | elif id == 'stu': 158 | return StudentNet(opt) 159 | 160 | if __name__ == '__main__': 161 | tea = TeacherNet() 162 | #stu = StudentNet() 163 | inputs = torch.rand((1, 3, 224, 224)) 164 | #pretrained_weights_path = '../logs/ckpt_best.pth.tar' 165 | #pretrained_state_dict = torch.load(pretrained_weights_path) 166 | #tea.load_state_dict(pretrained_state_dict["state_dict"]) 167 | outputs_tea,shen= tea(inputs) 168 | print(outputs_tea.shape) 169 | print(shen.shape) -------------------------------------------------------------------------------- /networks/eigenplaces_model.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/UGNA-VPR/1ae3158d84b54affa011236af2cad95e364d8731/networks/eigenplaces_model.zip -------------------------------------------------------------------------------- /networks/mobilenet.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import sys 3 | 4 | sys.path.append('..') 5 | from re import L 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.parameter import Parameter 9 | import cirtorch.functional as LF 10 | import math 11 | import torch.nn.functional as F 12 | import torch 13 | import timm 14 | from einops import rearrange, reduce, repeat 15 | from einops.layers.torch import Rearrange, Reduce 16 | import torchvision.models as models 17 | from netvlad import NetVLADLoupe 18 | from torch import Tensor 19 | from torchvision.models import mobilenet_v2 20 | class L2Norm(nn.Module): 21 | def __init__(self, dim=1): 22 | super().__init__() 23 | self.dim = dim 24 | 25 | def forward(self, input): 26 | return F.normalize(input, p=2, dim=self.dim) 27 | 28 | 29 | class PatchEmbedding(nn.Module): 30 | def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768): 31 | self.patch_size = patch_size 32 | super().__init__() 33 | self.projection = nn.Sequential( 34 | # 使用一个卷积层而不是一个线性层 -> 性能增加 35 | nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size), 36 | # 将卷积操作后的patch铺平 37 | Rearrange('b e h w -> b (h w) e'), 38 | ) 39 | 40 | def forward(self, x: Tensor) -> Tensor: 41 | x = self.projection(x) 42 | return x 43 | 44 | 45 | class GeM(nn.Module): 46 | def __init__(self, p=3, eps=1e-6): 47 | super(GeM, self).__init__() 48 | self.p = Parameter(torch.ones(1) * p) 49 | self.eps = eps 50 | 51 | def forward(self, x): 52 | return LF.gem(x, p=self.p, eps=self.eps) 53 | 54 | def __repr__(self): 55 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' 56 | 57 | 58 | class FeedForward(nn.Module): 59 | def __init__(self, d_model, d_ff=1024, dropout=0.1): 60 | super().__init__() 61 | 62 | self.linear_1 = nn.Linear(d_model, d_ff) 63 | self.dropout = nn.Dropout(dropout) 64 | self.linear_2 = nn.Linear(d_ff, d_model) 65 | 66 | def forward(self, x): 67 | x = self.dropout(F.relu(self.linear_1(x))) 68 | x = self.linear_2(x) 69 | return x 70 | 71 | 72 | class Norm(nn.Module): 73 | def __init__(self, d_model, eps=1e-6): 74 | super().__init__() 75 | 76 | self.size = d_model 77 | self.alpha = nn.Parameter(torch.ones(self.size)) 78 | self.bias = nn.Parameter(torch.zeros(self.size)) 79 | self.eps = eps 80 | 81 | def forward(self, x): 82 | norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) \ 83 | / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias 84 | return norm 85 | 86 | 87 | class Shen(nn.Module): #整合Vit和resnet 88 | def __init__(self, opt=None): 89 | super().__init__() 90 | heads = 4 91 | d_model = 512 92 | dropout = 0.1 93 | mobilenet_v2 = models.mobilenet_v2(pretrained=True) 94 | featuresmobile = list(mobilenet_v2.children())[:-1] 95 | #print(featuresmobile) 96 | self.backbone = nn.Sequential(*featuresmobile) 97 | #self.backbone = mobilenet_v2 98 | self.linear = nn.Sequential( 99 | nn.Flatten(), # 展平操作 100 | nn.Dropout(p=0.2), # Dropout 层 101 | nn.Linear(in_features=1280 * 7 * 7, out_features=512) # 全连接层 102 | ) 103 | #self.net_vlad = NetVLADLoupe(feature_size=49, max_samples=1280, cluster_size=64, 104 | #output_dim=512, gating=True, add_batch_norm=False, 105 | #is_training=True) 106 | def forward(self, inputs): 107 | #ViT branch 108 | out=self.backbone(inputs) #(B,S,C) 109 | #print(out.shape) 110 | feature=out 111 | out=self.linear(out) 112 | #print(out.shape) 113 | #out= self.net_vlad(out) 114 | 115 | #feature_V_enhanced = self.net_vlad_V(feature_V) 116 | 117 | return out,feature 118 | 119 | 120 | class Backbone(nn.Module): 121 | def __init__(self, opt=None): 122 | super().__init__() 123 | 124 | self.sigma_dim = 2048 125 | self.mu_dim = 2048 126 | 127 | self.backbone = Shen() 128 | 129 | 130 | class Stu_Backbone(nn.Module): 131 | def __init__(self): 132 | super(Stu_Backbone, self).__init__() 133 | self.resnet50 = models.resnet50(pretrained=True) 134 | 135 | 136 | def forward(self, inputs): 137 | #Res branch(1*1024) 138 | outRR = self.resnet50(inputs) 139 | 140 | 141 | return outRR 142 | 143 | 144 | class TeacherNet(Backbone): 145 | def __init__(self, opt=None): 146 | super().__init__() 147 | self.id = 'teacher' 148 | self.mean_head = nn.Sequential(L2Norm(dim=1)) 149 | 150 | def forward(self, inputs): 151 | B, C, H, W = inputs.shape # (B, 1, 3, 224, 224) 152 | # inputs = inputs.view(B * L, C, H, W) # ([B, 3, 224, 224]) 153 | 154 | backbone_output,shen = self.backbone(inputs) # ([B, 2048, 1, 1]) 155 | #print(backbone_output.shape) 156 | mu = self.mean_head(backbone_output).view(B, -1) # ([B, 2048]) <= ([B, 2048, 1, 1]) 157 | 158 | return mu, shen 159 | 160 | 161 | class StudentNet(TeacherNet): 162 | def __init__(self, opt=None): 163 | super().__init__() 164 | self.id = 'student' 165 | self.var_head = nn.Sequential(nn.Linear(2048, self.sigma_dim), nn.Sigmoid()) 166 | self.backboneS = Stu_Backbone() 167 | def forward(self, inputs): 168 | B, C, H, W = inputs.shape # (B, 1, 3, 224, 224) 169 | inputs = inputs.view(B, C, H, W) # ([B, 3, 224, 224]) 170 | backbone_output = self.backboneS(inputs) 171 | 172 | mu = self.mean_head(backbone_output).view(B, -1) # ([B, 2048]) <= ([B, 2048, 1, 1]) 173 | log_sigma_sq = self.var_head(backbone_output).view(B, -1) # ([B, 2048]) <= ([B, 2048, 1, 1]) 174 | 175 | return mu, log_sigma_sq 176 | 177 | 178 | def deliver_model(opt, id): 179 | if id == 'tea': 180 | return TeacherNet(opt) 181 | elif id == 'stu': 182 | return StudentNet(opt) 183 | 184 | 185 | if __name__ == '__main__': 186 | tea = TeacherNet() 187 | #stu = StudentNet() 188 | inputs = torch.rand((1, 3, 224, 224)) 189 | #pretrained_weights_path = '../logs/ckpt_best.pth.tar' 190 | #pretrained_state_dict = torch.load(pretrained_weights_path) 191 | #tea.load_state_dict(pretrained_state_dict["state_dict"]) 192 | outputs_tea,shen = tea(inputs) 193 | print(outputs_tea.shape) 194 | print(shen.shape) 195 | #x = model(shen) 196 | #print(x.shape) 197 | #print(shen.shape) 198 | #outputs_stu = stu(inputs) 199 | # print(outputs_stu.shape) 200 | # print(tea.state_dict()) 201 | #print(outputs_tea[0].shape, outputs_tea[1].shape) 202 | #print(outputs_stu[0].shape, outputs_stu[1].shape) 203 | -------------------------------------------------------------------------------- /networks/models.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/UGNA-VPR/1ae3158d84b54affa011236af2cad95e364d8731/networks/models.zip -------------------------------------------------------------------------------- /networks/test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def pose_tensor_to_pose_representations(pose_tensor): 5 | # 获取张量的形状 6 | num_poses, _, _ = pose_tensor.shape 7 | 8 | # 初始化结果数组 9 | pose_representations_euler = np.zeros((num_poses, 6)) 10 | 11 | for i in range(num_poses): 12 | # 提取位置信息 13 | position = pose_tensor[i, :3, 3] 14 | 15 | # 提取旋转信息 16 | rotation_matrix = pose_tensor[i, :3, :3] 17 | 18 | # 欧拉角表示 19 | euler_angles = np.array([0, 0, 0]) 20 | euler_angles = np.degrees(np.around(np.array([ 21 | np.arctan2(rotation_matrix[2, 1], rotation_matrix[2, 2]), 22 | np.arctan2(-rotation_matrix[2, 0], np.sqrt(rotation_matrix[2, 1]**2 + rotation_matrix[2, 2]**2)), 23 | np.arctan2(rotation_matrix[1, 0], rotation_matrix[0, 0]) 24 | ]), decimals=6)) 25 | pose_representations_euler[i] = np.concatenate((position, euler_angles)) 26 | 27 | return pose_representations_euler 28 | 29 | # 示例用法 30 | pose_matrices = np.random.rand(10, 4, 4) # 生成随机的位姿矩阵数组,假设有10个位姿 31 | euler_poses= pose_tensor_to_pose_representations(pose_matrices) 32 | 33 | print("Euler Pose Representations:") 34 | print(euler_poses.shape) 35 | 36 | 37 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |
6 |
7 |
8 |
9 |