├── LICENSE ├── README.md ├── config ├── __init__.py ├── config.yaml └── config_read.py ├── data ├── __init__.py └── data_read.py ├── main.py ├── model ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── loss.cpython-39.pyc │ ├── mc_nerf.cpython-39.pyc │ ├── net_block.cpython-39.pyc │ └── net_utils.cpython-39.pyc ├── external │ └── pohsun_ssim │ │ ├── LICENSE.txt │ │ ├── README.md │ │ ├── einstein.png │ │ ├── max_ssim.gif │ │ ├── max_ssim.py │ │ ├── pytorch_ssim │ │ ├── __init__.py │ │ └── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── __init__.cpython-39.pyc │ │ ├── setup.cfg │ │ └── setup.py ├── loss.py ├── mc_nerf.py ├── net_block.py └── net_utils.py ├── requirements.yaml ├── synthetic_dataset_code ├── Array.py ├── Ball.py ├── HalfBall.py └── Room.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-39.pyc ├── distributed_init.cpython-39.pyc ├── log_init.cpython-39.pyc └── tensorboard_init.cpython-39.pyc ├── distributed_init.py ├── log_init.py └── tensorboard_init.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 SkylerGao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Version_1.0: 2 | 1. This version is only intended for use with synthetic datasets. 3 | 4 | # MC_NeRF: Multi-Camera Neural Radiance Fields for Multi-Camera Image Acquisition System 5 | Project page: https://in2-viaun.github.io/MC-NeRF/ 6 | arXiv preprint: https://arxiv.org/abs/2309.07846 7 | 8 | ## Abstract 9 | Neural Radiance Fields (NeRF) employ multi-view images for 3D scene representation and have shown remarkable performance. As one of the primary sources of multi-view images, multi-camera systems encounter challenges such as varying intrinsic parameters and frequent pose changes. Most previous NeRF-based methods often assume a global unique camera and seldom consider scenarios with multiple cameras. Besides, some pose-robust methods still remain susceptible to suboptimal solutions when poses are poor initialized. In this paper, we propose MC-NeRF, a method can jointly optimize both intrinsic and extrinsic parameters for bundle-adjusting Neural Radiance Fields. Firstly, we conduct a theoretical analysis to tackle the degenerate case and coupling issue that arise from the joint optimization between intrinsic and extrinsic parameters. Secondly, based on the proposed solutions, we introduce an efficient calibration image acquisition scheme for multi-camera systems, including the design of calibration object. Lastly, we present a global end-to-end network with training sequence that enables the regression of intrinsic and extrinsic parameters, along with the rendering network. Moreover, most existing datasets are designed for unique camera, we create a new dataset that includes four different styles of multi-camera acquisition systems, allowing readers to generate custom datasets. Experiments confirm the effectiveness of our method when each image corresponds to different camera parameters. Specifically, we adopt up to 110 images with 110 different intrinsic and extrinsic parameters, to achieve 3D scene representation without providing initial poses. 10 | 11 | ## Overview 12 | ![image](https://github.com/IN2-ViAUn/MC-NeRF/blob/main/image/overview.png) 13 | 14 | 15 | ## Prerequisites 16 | This code is developed with `python3.9.13`. PyTorch 2.0.1 and cuda 11.7 are required. 17 | It is recommended use `Anaconda` to set up the environment. Install the dependencies and activate the environment `mc-env` with 18 | ``` 19 | conda env create --file requirements.yaml python=3.9.13 20 | conda activate mc-env 21 | ``` 22 | 23 | ## Dataset 24 | To play with other scenes presented in the paper, download the data [here](https://drive.google.com/drive/folders/1VKElczwt7TdWOyiWnHZIaxKYlycA-dPZ). Place the downloaded dataset according to the following directory structure(The following are created in the root directory): 25 | ``` 26 | ├── config 27 | │   ├── ... 28 | │   29 | ├── data 30 | │   ├── dataset_Array 31 | │   │   └── Array_Computer 32 | │   │  └── Array_Ficus 33 | │   │  └── Array_Gate 34 | | | └── Array_Lego 35 | | | └── Array_Materials 36 | | | └── Array_Snowtruck 37 | | | └── Array_Statue 38 | | | └── Array_Train 39 | | ├── data_Ball 40 | | | └── Ball_Computer 41 | | | └── Ball_Ficus 42 | | | └── ... 43 | │   ├── data_HalfBall 44 | │   │   └── HalfBall_Computer 45 | | | └── HalfBall_Ficus 46 | | | └── ... 47 | | ├── data_Room 48 | | | └── Room_Computer 49 | | | └── Room_Ficus 50 | | | └── ... 51 | | ├── ... 52 | ``` 53 | The folder `synthetic_dataset_code` contains a Blender script for customizing a synthetic multi-camera dataset. Readers can modify the types of objects, as well as the number and parameters of the cameras, according to their needs. 54 | ## Running the code 55 | To train MC_NeRF(recommended for two-GPU mode): 56 | ``` 57 | # and can be set to your likes 58 | # replace {Style} with Array | Ball | HalfBall | Room 59 | # replace {Dataset} with Computer | Ficus | Gate | Lego | Materials | Snowtruck | Statue | Train 60 | 61 | # single-GPU mode 62 | python main.py --train --config= --root_data=./data/dataset_{Style} --data_name={Style}_{Dataset} --start_device= 63 | eg: python main.py --train --root_data=dataset_Ball --data_name=Ball_Computer --start_device=1 64 | 65 | # multi-GPU mode 66 | python -m torch.distributed.launch --nproc_per_node= --use_env main.py --train --config= --root_data=./data/dataset_{Style} --data_name={Style}_{Dataset} --start_device= 67 | eg: python -m torch.distributed.launch --nproc_per_node=2 --use_env main.py --train --root_data=dataset_Ball --data_name=Ball_Computer --start_device=1 68 | ``` 69 | --- 70 | To test MC_NeRF(only single-GPU mode): 71 | ``` 72 | # and can be set to your likes 73 | # replace {Style} with Array | Ball | HalfBall | Room 74 | # replace {Dataset} with Computer | Ficus | Gate | Lego | Materials | Snowtruck | Statue | Train 75 | 76 | python main.py --demo --config= --root_data=./data/dataset_{Style} --data_name={Style}_{Dataset} --start_device= 77 | eg: python main.py --demo --root_data=dataset_Ball --data_name=Ball_Computer --start_device=1 78 | ``` 79 | --- 80 | All the results will be stored in the directory `results`, all the neural network weight parameters will be stored in the directory `weights`. 81 | 82 | If you want to save log information to log.txt file: add `--log`. 83 | If you want to use tensorboard tools to show training results: add `--tensorboard` 84 | 85 | ## Citation 86 | If you find this implementation or pre-trained models helpful, please consider to cite: 87 | ``` 88 | @misc{gao2023mcnerf, 89 | title={MC-NeRF: Multi-Camera Neural Radiance Fields for Multi-Camera Image Acquisition Systems}, 90 | author={Yu Gao and Lutong Su and Hao Liang and Yufeng Yue and Yi Yang and Mengyin Fu}, 91 | year={2023}, 92 | eprint={2309.07846}, 93 | archivePrefix={arXiv}, 94 | primaryClass={cs.CV} 95 | } 96 | ``` 97 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config_read import Load_config -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | # Parameters for MC-NeRF 2 | system: 3 | data: 4 | # root path for data 5 | # root_data: '/home/a_datasets0/MC_NeRF/Array' 6 | # data_name: 'lego' 7 | # data seed 8 | seed: 42 9 | # device 10 | device: 11 | dev: 'cuda' 12 | # epoch for different stage 13 | epoch: 14 | # Camera Parameter Initial Stage 15 | cam_param_stage: 20 16 | # Global Optimization Stage 17 | global_opt_stage: 16 18 | # Fine-tuning Stage 19 | fine_tune_stage: 16 20 | train_params: 21 | stage_1_lr: 0.1 22 | # suggest: 23 | # Ball/HalfBall 0.0005 24 | # Room 0.0005 25 | # Array 0.00025~0.0001 26 | stage_2_lr: 0.0005 27 | stage_3_lr: 0.00025 28 | weight_decay: 0.0004 29 | warmup_epoch: 100 30 | batch: 7000 31 | test_params: 32 | # test model file 33 | nerf_model_name: "weights/train/lego-EPOCH-51-2023-09-19-18-49-50.ckpt" 34 | # forward resolution 35 | resolution_h: 800 36 | resolution_w: 800 37 | # weights save path 38 | weights_params: 39 | root_weights: './weights' 40 | # train/test rendering save path 41 | out_params: 42 | root_out: './results' 43 | test_enerf_pth: './img_rendered' 44 | log_params: 45 | logpath: './log' 46 | tensorboard_params: 47 | tb_pth: './tensorboard' 48 | # delete old files 49 | del_mode: False 50 | # size of Apriltag, including white boundary 51 | apriltag: 52 | tag_size: 1.0 53 | 54 | # NeRF config 55 | model: 56 | barf: 57 | # when set True, import barf mask to encode 58 | barf_mask: False 59 | # this range only available in Global Optimization Stage 60 | barf_start: 0.0 61 | barf_end: 1.0 62 | nerf: 63 | near: 1 64 | far: 8 65 | samples: 128 66 | sample_scale: 5 67 | # S:128, M:256, L:512, X:1024 68 | grid_nerf: 384 69 | sigma_init: 30.0 70 | sigma_default: -20.0 71 | weight_thresh: 0.001 72 | global_boader_min: -3.5 73 | global_boader_max: 3.5 74 | white_back: True 75 | emb_freqs_xyz: 10 76 | coarse_MLP_depth: 4 77 | coarse_MLP_width: 128 78 | coarse_MLP_skip: [2] 79 | fine_MLP_depth: 8 80 | fine_MLP_width: 256 81 | fine_MLP_skip: [4] 82 | MLP_deg: 2 83 | 84 | -------------------------------------------------------------------------------- /config/config_read.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | 4 | from pathlib import Path 5 | from utils import Log_config 6 | from utils import Distributed_config 7 | 8 | class Load_config(): 9 | def __init__(self, args): 10 | # config file root 11 | self.root_config = args.config 12 | # terminal load 13 | self.system_info = self.load_terminal(args) 14 | # yaml load 15 | self.load_yaml() 16 | # log config 17 | Log_config(self.system_info) 18 | # muti GPU config 19 | Distributed_config(self.system_info) 20 | 21 | def load_yaml(self): 22 | path_yaml = os.path.join(Path(self.root_config), Path('config.yaml')) 23 | with open(path_yaml, 'r', encoding='utf-8') as f: 24 | cfg_info = yaml.load(f, Loader = yaml.FullLoader) 25 | 26 | ################# System ##################### 27 | self.system_info['log_pth'] = cfg_info['system']['log_params']['logpath'] 28 | self.system_info['root_weight'] = cfg_info['system']['weights_params']['root_weights'] 29 | self.system_info['root_out'] = cfg_info['system']['out_params']['root_out'] 30 | self.system_info['device_type'] = cfg_info['system']['device']['dev'] 31 | self.system_info['seed'] = cfg_info['system']['data']['seed'] 32 | self.system_info['tb_pth'] = cfg_info['system']['tensorboard_params']['tb_pth'] 33 | self.system_info['tb_del'] = cfg_info['system']['tensorboard_params']['del_mode'] 34 | 35 | ################# MC-NeRF ##################### 36 | self.system_info["stage1_epoch"] = cfg_info['system']['epoch']['cam_param_stage'] 37 | self.system_info["stage2_epoch"] = cfg_info['system']['epoch']['global_opt_stage'] 38 | self.system_info["stage3_epoch"] = cfg_info['system']['epoch']['fine_tune_stage'] 39 | self.system_info["stage1_lr"] = cfg_info['system']['train_params']['stage_1_lr'] 40 | self.system_info["stage2_lr"] = cfg_info['system']['train_params']['stage_2_lr'] 41 | self.system_info["stage3_lr"] = cfg_info['system']['train_params']['stage_3_lr'] 42 | self.system_info["batch"] = cfg_info['system']['train_params']['batch'] 43 | self.system_info["weight_d"] = cfg_info['system']['train_params']['weight_decay'] 44 | self.system_info["warmup_epoch"] = cfg_info['system']['train_params']['warmup_epoch'] 45 | self.system_info["barf_mask"] = cfg_info['model']['barf']['barf_mask'] 46 | self.system_info["barf_start"] = cfg_info['model']['barf']['barf_start'] 47 | self.system_info["barf_end"] = cfg_info['model']['barf']['barf_end'] 48 | self.system_info["apriltag_size"] = cfg_info['system']['apriltag']['tag_size'] 49 | self.system_info['res_h'] = cfg_info['system']['test_params']['resolution_h'] 50 | self.system_info['res_w'] = cfg_info['system']['test_params']['resolution_w'] 51 | self.system_info['demo_ckpt'] = cfg_info['system']['test_params']['nerf_model_name'] 52 | self.system_info['demo_render_pth'] = os.path.join(Path(self.system_info['root_out']), 53 | Path(cfg_info['system']['out_params']['test_enerf_pth'])) 54 | 55 | ################# NeRF ##################### 56 | self.system_info["near"] = cfg_info['model']['nerf']['near'] 57 | self.system_info["far"] = cfg_info['model']['nerf']['far'] 58 | self.system_info["samples"] = cfg_info['model']['nerf']['samples'] 59 | self.system_info["scale"] = cfg_info['model']['nerf']['sample_scale'] 60 | self.system_info["sample_weight_thresh"] = cfg_info['model']['nerf']['weight_thresh'] 61 | self.system_info["grid_nerf"] = cfg_info['model']['nerf']['grid_nerf'] 62 | self.system_info["boader_min"] = cfg_info['model']['nerf']['global_boader_min'] 63 | self.system_info["boader_max"] = cfg_info['model']['nerf']['global_boader_max'] 64 | self.system_info["sigma_init"] = cfg_info['model']['nerf']['sigma_init'] 65 | self.system_info["sigma_default"] = cfg_info['model']['nerf']['sigma_default'] 66 | self.system_info["white_back"] = cfg_info['model']['nerf']['white_back'] 67 | self.system_info["emb_freqs_xyz"] = cfg_info['model']['nerf']['emb_freqs_xyz'] 68 | self.system_info['coarse_MLP_depth'] = cfg_info['model']['nerf']['coarse_MLP_depth'] 69 | self.system_info['coarse_MLP_width'] = cfg_info['model']['nerf']['coarse_MLP_width'] 70 | self.system_info['coarse_MLP_skip'] = cfg_info['model']['nerf']['coarse_MLP_skip'] 71 | self.system_info['fine_MLP_depth'] = cfg_info['model']['nerf']['fine_MLP_depth'] 72 | self.system_info['fine_MLP_width'] = cfg_info['model']['nerf']['fine_MLP_width'] 73 | self.system_info['fine_MLP_skip'] = cfg_info['model']['nerf']['fine_MLP_skip'] 74 | self.system_info["MLP_deg"] = cfg_info['model']['nerf']['MLP_deg'] 75 | 76 | def load_terminal(self, args): 77 | system_info = {} 78 | for mode, flag in enumerate([args.train, args.demo]): 79 | if flag is True: 80 | system_info['mode'] = mode 81 | break 82 | system_info['log'] = args.log 83 | system_info['start_device'] = args.start_device 84 | system_info['tb_available'] = args.tensorboard 85 | 86 | self.data_root = args.root_data 87 | self.data_name = args.data_name 88 | # root path of data 89 | system_info['data_name'] = self.data_name 90 | # name of data 91 | system_info['data_root'] = os.path.join(Path(self.data_root), Path(self.data_name)) 92 | 93 | return system_info -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_read import Data_set 2 | from .data_read import Data_loader 3 | -------------------------------------------------------------------------------- /data/data_read.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import logging 4 | import os 5 | import json 6 | import math 7 | import apriltag 8 | import random 9 | 10 | import numpy as np 11 | 12 | from PIL import Image 13 | from pathlib import Path 14 | from torchvision import transforms as T 15 | from torch.utils.data import DataLoader, DistributedSampler 16 | 17 | class Data_set(torch.utils.data.Dataset): 18 | def __init__(self, system_param): 19 | self.system_param = system_param 20 | self.data_root = self.system_param["data_root"] 21 | self.data_name = self.system_param["data_name"] 22 | self.batch = self.system_param['batch'] 23 | self.test_img_h = self.system_param['res_h'] 24 | self.test_img_w = self.system_param['res_w'] 25 | self.tag_size = self.system_param["apriltag_size"] 26 | self.transform = T.ToTensor() 27 | self.each_epoch, self.total_epoch, self.barf_start, self.barf_end = self.get_squence_info() 28 | # world points of calibration cube 29 | self.tag_world_pts = self.apriltag_gt_pts() 30 | 31 | # load render datasets 32 | self.rgbs_train, self.pose_train, self.intr_train, self.train_numb,\ 33 | self.rgbs_test, self.pose_test, self.intr_test, self.test_numb,\ 34 | self.rgbs_val, self.pose_val, self.intr_val, self.val_numb = self.load_blender_data_info() 35 | 36 | self.intr_inv_train = self.inverse_intrinsic(self.intr_train) 37 | self.intr_inv_test = self.inverse_intrinsic(self.intr_test) 38 | self.intr_inv_val = self.inverse_intrinsic(self.intr_val) 39 | 40 | self.intr = [self.intr_train, self.intr_test, self.intr_val] 41 | self.intr_inv = [self.intr_inv_train, self.intr_inv_test, self.intr_inv_val] 42 | self.data_numb = [self.train_numb, self.test_numb, self.val_numb] 43 | 44 | self.train_idx = torch.arange(0, self.train_numb) 45 | self.test_idx = torch.arange(0, self.test_numb) 46 | self.val_idx = torch.arange(0, self.val_numb) 47 | 48 | # train mode 49 | if self.system_param["mode"] == 0: 50 | self.rgbs_expd, self.idx_expd = self.expand_data_length(self.rgbs_train, 51 | self.train_h, 52 | self.train_w, 53 | self.train_idx, 54 | times=50) 55 | # load calibration info 56 | self.intr_wpts, self.intr_pts,\ 57 | self.extr_wpts, self.extr_pts,\ 58 | self.muti_tag_imgs = self.load_apriltag_json(self.data_root, self.expd_times) 59 | # test mode 60 | else: 61 | self.img_h = self.test_img_h 62 | self.img_w = self.test_img_w 63 | 64 | self.update_system_param = self.update_param(system_param) 65 | 66 | def __len__(self): 67 | if self.system_param["mode"] == 0: 68 | return len(self.rgbs_expd) 69 | else: 70 | return self.test_numb 71 | 72 | def __getitem__(self, idx): 73 | if self.system_param["mode"] == 0: 74 | return self.rgbs_expd[idx], self.idx_expd[idx], self.intr_wpts[idx], self.intr_pts[idx],\ 75 | self.extr_wpts[idx], self.extr_pts[idx] 76 | else: 77 | return self.rgbs_test[idx], self.test_idx[idx] 78 | 79 | # load MC datasets 80 | def load_blender_data_info(self): 81 | logging.info("Loading blender datasets...") 82 | logging.info("Current object:{}".format(self.data_name)) 83 | # json file path 84 | self.train_json = os.path.join(Path(self.data_root), Path("transforms_train.json")) 85 | test_json = os.path.join(Path(self.data_root), Path("transforms_test.json")) 86 | val_json = os.path.join(Path(self.data_root), Path("transforms_val.json")) 87 | # json to camera and image info 88 | fov_train, img_pth_train, pose_train, train_numb = self.load_blender_json(self.train_json, self.data_root) 89 | fov_test, img_pth_test, pose_test, test_numb = self.load_blender_json(test_json, self.data_root) 90 | fov_val, img_pth_val, pose_val, val_numb = self.load_blender_json(val_json, self.data_root) 91 | # rgba images to rgb 92 | rgbs_train, self.train_h, self.train_w = self.preprocess_blender_images(img_pth_train) 93 | rgbs_test, test_h, test_w = self.preprocess_blender_images(img_pth_test) 94 | rgbs_val, val_h, val_w = self.preprocess_blender_images(img_pth_val) 95 | # camera fov to intrinsic mat 96 | intr_train = self.blender_fov_to_intrinsic(fov_train, self.train_h, self.train_w) 97 | intr_test = self.blender_fov_to_intrinsic(fov_test, test_h, test_w) 98 | intr_val = self.blender_fov_to_intrinsic(fov_val, val_h, val_w) 99 | 100 | return rgbs_train, pose_train, intr_train, train_numb,\ 101 | rgbs_test, pose_test, intr_test, test_numb,\ 102 | rgbs_val, pose_val, intr_val, val_numb 103 | 104 | def load_blender_json(self, json_path, root_path, mode="extr"): 105 | with open(json_path,'r') as f: 106 | json_file = json.load(f) 107 | pose_list = [] 108 | path_list = [] 109 | fov_list = [] 110 | 111 | for i,data in enumerate(json_file['frames']): 112 | img_path = os.path.join(Path(root_path), Path(data["file_path"] + ".png")) 113 | cam_angle_x = data["camera_angle_x"] 114 | if mode == "extr": 115 | pose = np.array(data['transform_matrix']) 116 | pose = torch.tensor(pose) 117 | pose = self.blender_pose_transform(pose) # [3, 4] 118 | pose_list += [pose] 119 | path_list += [img_path] 120 | fov_list += [cam_angle_x] 121 | if mode == "extr": 122 | pose_list = torch.stack(pose_list, 0) 123 | fov_tensor = torch.tensor(fov_list) 124 | data_numb = len(path_list) 125 | 126 | return fov_tensor, path_list, pose_list, data_numb 127 | 128 | # rgba to rgb 129 | def preprocess_blender_images(self, img_path): 130 | rgbs_list = [] 131 | for pth in img_path: 132 | img = Image.open(pth) 133 | img = self.transform(img) 134 | img_h, img_w = img.shape[1], img.shape[2] 135 | img = img.reshape(4, -1).permute(1, 0) # (h*w, 4) RGBA 136 | img = img[:, :3]*img[:, -1:] + (1-img[:, -1:]) 137 | rgbs_list += [img] 138 | rgbs_tensor = torch.stack(rgbs_list, 0) 139 | return rgbs_tensor, img_h, img_w 140 | 141 | def blender_fov_to_intrinsic(self, fov, img_h, img_w): 142 | intr_mat_list = [] 143 | for f in fov: 144 | fx = (img_w/2)/(math.tan(f/2)) 145 | fy = (img_h/2)/(math.tan(f/2)) 146 | intr_mat = torch.tensor([[fx, 0, img_w/2], 147 | [0, fy, img_h/2], 148 | [0, 0, 1]]) 149 | intr_mat_list += [intr_mat] 150 | intr_mat = torch.stack(intr_mat_list, 0) 151 | 152 | return intr_mat 153 | 154 | # load calibration datasets 155 | def load_apriltag_json(self, apriltag_root, times=1): 156 | logging.info("Loading calibration packages...") 157 | calib_json = os.path.join(Path(apriltag_root), Path("transforms_calib.json")) 158 | _, path_calib, _, _ = self.load_blender_json(calib_json, apriltag_root, mode="intr") 159 | coord_json = os.path.join(Path(apriltag_root), Path("transforms_coord.json")) 160 | _, path_coord, _, _ = self.load_blender_json(coord_json, apriltag_root, mode="extr") 161 | tag_info_calib, tag_info_id_calib, muti_tag_id_calib, self.img_h, self.img_w = self.apriltag_detection(path_calib, check=True) 162 | tag_info_coord, tag_info_id_coord, muti_tag_id_coord, self.img_h, self.img_w = self.apriltag_detection(path_coord) 163 | # calibration data for intrinsic parameters 164 | intr_wpts, intr_pts = self.get_cam_train_data(tag_info_calib, param="intr", times=times) 165 | # calibration data for extrinsic parameters 166 | extr_wpts, extr_pts, extr_id_sq = self.get_cam_train_data(tag_info_coord, param="extr", times=times) 167 | 168 | return intr_wpts, intr_pts, extr_wpts, extr_pts, muti_tag_id_calib 169 | 170 | # apriltag detection, when check=True, muti-Apriltag detection is activated 171 | def apriltag_detection(self, apriltag_pth, check=False): 172 | all_tag_info = {} 173 | muti_tag_id = {1:[], 2:[], 3:[], 4:[], 5:[], 6:[]} 174 | all_id_info = {0:{'id':[],'pts':[]},\ 175 | 1:{'id':[],'pts':[]},\ 176 | 2:{'id':[],'pts':[]},\ 177 | 3:{'id':[],'pts':[]},\ 178 | 4:{'id':[],'pts':[]},\ 179 | 5:{'id':[],'pts':[]}} 180 | detect_flag = 0 181 | detector = apriltag.Detector(apriltag.DetectorOptions(families='tag36h11')) 182 | for img_id, pth in enumerate(apriltag_pth): 183 | cur_img = cv2.imread(pth) 184 | img_h, img_w = cur_img.shape[0], cur_img.shape[1] 185 | gray_img = cv2.cvtColor(cur_img, cv2.COLOR_BGR2GRAY) 186 | gray_img = cv2.normalize(gray_img, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8UC1) 187 | tags = detector.detect(gray_img) 188 | if len(tags) != 0: 189 | detect_flag += 1 190 | muti_tag_id[len(tags)] += [img_id] 191 | else: 192 | logging.info("Apriltags in image{} are not detected !!".format(img_id)) 193 | tag_ids = [] 194 | tag_pts = [] 195 | for tag in tags: 196 | tag_id = tag.tag_id 197 | center_p = tag.center # format:[x, y], also (w, h) 198 | corner_p = tag.corners # order:[lt, rt, rb, lb] 199 | points_tag = np.concatenate([center_p.reshape([1, -1]), corner_p], 0) 200 | tag_pts += [points_tag] 201 | tag_ids += [tag_id] 202 | all_id_info[tag_id]['id'] += [img_id] 203 | all_id_info[tag_id]['pts'] += [points_tag] 204 | if check and (len(tag_ids) < 2): 205 | logging.info("Muti-Apriltags are not detected in image{} !!".format(img_id)) 206 | all_tag_info[img_id] = [tag_ids, tag_pts] 207 | 208 | if detect_flag == len(apriltag_pth): 209 | logging.info("All images include calibration points....") 210 | else: 211 | logging.info("Unvalid calibration images existing !!") 212 | exit() 213 | 214 | return all_tag_info, all_id_info, muti_tag_id, img_h, img_w 215 | 216 | # generate data for camera parameters training 217 | def get_cam_train_data(self, target_dict, param="intr", times=1): 218 | global_pts_list = [] 219 | global_wpts_list = [] 220 | global_id_list = [] 221 | for i in range(self.train_numb*times): 222 | pts_list = [] 223 | wpts_list = [] 224 | id_list = [] 225 | for img_id, info in target_dict.items(): 226 | idx = random.randint(0, len(info[0])-1) 227 | tag_id = info[0][idx] 228 | tag_pt = info[1][idx] 229 | tag_wpts = self.tag_world_pts[tag_id] 230 | id_list += [tag_id] 231 | wpts_list += [tag_wpts] 232 | pts_list += [torch.from_numpy(tag_pt).to(torch.float32)] 233 | global_id_list += [torch.tensor(id_list)] 234 | global_wpts_list += [torch.stack(wpts_list, 0)] 235 | global_pts_list += [torch.stack(pts_list, 0)] 236 | global_tag_wpts = torch.stack(global_wpts_list, 0) 237 | global_tag_pts = torch.stack(global_pts_list, 0) 238 | tag_id_squence = torch.stack(global_id_list, 0) 239 | 240 | if param == "intr": 241 | return global_tag_wpts, global_tag_pts 242 | else: 243 | return global_tag_wpts, global_tag_pts, tag_id_squence 244 | 245 | # blender格式的位姿转换函数 246 | def blender_pose_transform(self, pose): 247 | pose_R = pose[:3, :3].to(torch.float32) 248 | pose_T = pose[:3, 3:].to(torch.float32) 249 | pose_flip_R = torch.diag(torch.tensor([1.0,-1.0,-1.0])) 250 | pose_flip_T = torch.zeros([3, 1]) 251 | pose_R_new = pose_R @ pose_flip_R 252 | pose_T_new = pose_R @ pose_flip_T + pose_T 253 | pose_R_new_inv = pose_R_new.T 254 | pose_T_new_inv = -pose_R_new_inv @ pose_T_new 255 | new_pose = torch.cat([pose_R_new_inv, pose_T_new_inv], -1) 256 | 257 | return new_pose 258 | 259 | def inverse_intrinsic(self, intr_mats): 260 | intr_inv_list = [] 261 | for intr in intr_mats: 262 | intr_inv = intr.inverse() 263 | intr_inv_list += [intr_inv] 264 | intr_inv = torch.stack(intr_inv_list, 0) 265 | return intr_inv 266 | 267 | def update_param(self, system_param): 268 | system_param["intr_mat"] = self.intr 269 | system_param["intr_mat_inv"] = self.intr_inv 270 | system_param["data_numb"] = self.data_numb 271 | system_param["gt_pose"] = self.pose_train 272 | system_param["valid_pose"] = self.pose_val 273 | system_param["test_pose"] = self.pose_test 274 | system_param["valid_rgbs"] = self.rgbs_val 275 | system_param["data_img_h"] = self.img_h 276 | system_param["data_img_w"] = self.img_w 277 | system_param["train_json_file"] = self.train_json 278 | system_param["epoch_squence"] = self.each_epoch 279 | system_param["epoch_numb"] = self.total_epoch 280 | system_param["barf_start"] = self.barf_start 281 | system_param["barf_end"] = self.barf_end 282 | 283 | return system_param 284 | 285 | # expand datasets to have more training data in each epoch 286 | def expand_data_length(self, rgbs, img_h, img_w, idx, times=None, squence=True): 287 | pixel_numb = img_h*img_w 288 | if times is None: 289 | self.expd_times = (pixel_numb // self.batch) + 1 290 | else: 291 | self.expd_times = times 292 | 293 | logging.info("Expanding datasets...") 294 | expd_rgbs = rgbs.repeat(self.expd_times, 1, 1) 295 | expd_idx = idx.repeat(self.expd_times) 296 | 297 | return expd_rgbs, expd_idx 298 | 299 | # generate world points for calibration cube 300 | def apriltag_gt_pts(self): 301 | cube_half = self.tag_size/2 302 | tag_half = self.tag_size*0.8/2 303 | world_tag_pts = {0:[[0.0, -cube_half, 0.0], 304 | [-tag_half, -cube_half, tag_half], 305 | [ tag_half, -cube_half, tag_half], 306 | [ tag_half, -cube_half, -tag_half], 307 | [-tag_half, -cube_half, -tag_half]], 308 | 1:[[ cube_half, 0.0, 0.0], 309 | [ cube_half,-tag_half, tag_half], 310 | [ cube_half, tag_half, tag_half], 311 | [ cube_half, tag_half, -tag_half], 312 | [ cube_half,-tag_half, -tag_half]], 313 | 2:[[ 0.0, cube_half, 0.0], 314 | [ tag_half, cube_half, tag_half], 315 | [-tag_half, cube_half, tag_half], 316 | [-tag_half, cube_half, -tag_half], 317 | [ tag_half, cube_half, -tag_half]], 318 | 3:[[ -cube_half, 0.0, 0.0], 319 | [ -cube_half, tag_half, tag_half], 320 | [ -cube_half,-tag_half, tag_half], 321 | [ -cube_half,-tag_half, -tag_half], 322 | [ -cube_half, tag_half, -tag_half]], 323 | 4:[[ 0.0, 0.0, cube_half], 324 | [-tag_half, tag_half, cube_half], 325 | [ tag_half, tag_half, cube_half], 326 | [ tag_half, -tag_half, cube_half], 327 | [-tag_half, -tag_half, cube_half]], 328 | 5:[[ 0.0, 0.0, -cube_half], 329 | [-tag_half, -tag_half, -cube_half], 330 | [ tag_half, -tag_half, -cube_half], 331 | [ tag_half, tag_half, -cube_half], 332 | [-tag_half, tag_half, -cube_half]]} 333 | world_tag_pts_tensor = {} 334 | for key in world_tag_pts: 335 | world_tag_pts_tensor[key] = torch.tensor(world_tag_pts[key]) 336 | return world_tag_pts_tensor 337 | 338 | def get_squence_info(self): 339 | stage1_epoch = self.system_param['stage1_epoch'] 340 | stage2_epoch = self.system_param['stage2_epoch'] 341 | stage3_epoch = self.system_param['stage3_epoch'] 342 | each_epoch = torch.tensor([stage1_epoch, stage2_epoch, stage3_epoch], dtype=torch.long) 343 | total_epoch = int(each_epoch.sum()) 344 | barf_start = self.system_param["barf_start"] 345 | barf_end = self.system_param["barf_end"] 346 | global_barf_start = float(stage1_epoch)/float(total_epoch) + barf_start 347 | global_barf_end = float(stage1_epoch + stage2_epoch)/float(total_epoch) 348 | ratio = (global_barf_end - global_barf_start)*barf_end 349 | global_barf_end = global_barf_start + ratio 350 | 351 | return each_epoch, total_epoch, global_barf_start, global_barf_end 352 | 353 | 354 | class Data_loader(): 355 | def __init__(self, dataset, sys_param): 356 | self.dataset = dataset 357 | self.sys_param = sys_param 358 | if sys_param['distributed']: 359 | self.sampler = DistributedSampler(dataset, shuffle=True) 360 | self.sampler_no_shuffle = DistributedSampler(dataset, shuffle=False) 361 | else: 362 | self.sampler = torch.utils.data.RandomSampler(dataset) 363 | self.sampler_no_shuffle = torch.utils.data.SequentialSampler(dataset) 364 | 365 | self.dataloader = self.pkg_dataloader() 366 | 367 | def pkg_dataloader(self, batch=1): 368 | self.batch_sampler_train = torch.utils.data.BatchSampler(self.sampler, batch, drop_last=True) 369 | self.batch_sampler_val = torch.utils.data.BatchSampler(self.sampler_no_shuffle, batch, drop_last=False) 370 | 371 | loader_train = DataLoader(self.dataset, 372 | batch_sampler=self.batch_sampler_train, 373 | num_workers=12, 374 | pin_memory=True) 375 | 376 | loader_val = DataLoader(self.dataset, 377 | batch_sampler=self.batch_sampler_val, 378 | num_workers=12, 379 | pin_memory=True) 380 | 381 | return {"Shuffle_loader": loader_train, "Squence_loader": loader_val} -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import logging 4 | import torch 5 | import random 6 | import time 7 | import lpips 8 | 9 | import numpy as np 10 | 11 | from torchvision import transforms 12 | from tqdm import tqdm 13 | from pathlib import Path 14 | 15 | from config import Load_config 16 | from data import Data_set 17 | from data import Data_loader 18 | from model import MC_Model 19 | from model import MC_NeRF_Loss 20 | from model import RAdam 21 | from model import apply_depth_colormap 22 | from model.external.pohsun_ssim import pytorch_ssim 23 | from utils import get_rank 24 | from utils import Tensorboard_config 25 | 26 | 27 | class Model_Engine(): 28 | def __init__(self, sys_param): 29 | self.mode_numb = sys_param["mode"] 30 | # dataset initialize 31 | dataset = Data_set(sys_param) 32 | self.sys_param = dataset.update_system_param 33 | self.device = self.sys_param["device_type"] 34 | # epoch for different stage 35 | self.cam_epoch = self.sys_param["stage1_epoch"] 36 | self.optim_epoch = self.sys_param["stage2_epoch"] 37 | self.fine_tune_epoch = self.sys_param["stage3_epoch"] 38 | self.each_epoch = self.sys_param["epoch_squence"] 39 | self.total_epoch = self.sys_param["epoch_numb"] 40 | # loading model 41 | self.loader = Data_loader(dataset, self.sys_param) 42 | self.mc_nerf = MC_Model(self.sys_param) 43 | # tensorboard initialize 44 | self.tblogger = Tensorboard_config(self.sys_param).tblogger 45 | # loss function 46 | self.loss_func = MC_NeRF_Loss(sys_param, self.tblogger).to(self.device) 47 | 48 | def forward(self): 49 | if self.mode_numb == 0: 50 | self.train_model() 51 | else: 52 | self.test_model() 53 | 54 | def train_model(self): 55 | train_loader = self.loader.dataloader["Shuffle_loader"] 56 | sampler_train = self.loader.sampler 57 | each_epoch_step = len(train_loader) 58 | total_step = self.total_epoch*each_epoch_step 59 | cur_step = 0 60 | if self.sys_param["distributed"]: 61 | mc_nerf = torch.nn.parallel.DistributedDataParallel(self.mc_nerf.to(self.device), device_ids=[sys_param['gpu']], find_unused_parameters=True) 62 | self.enerf_model_without_ddp = mc_nerf.module 63 | else: 64 | mc_nerf = self.mc_nerf.to(self.device) 65 | self.enerf_model_without_ddp = mc_nerf 66 | opt_list, sched_list = self.generate_optimizer(each_epoch_step) 67 | # start training 68 | for epoch in range(self.total_epoch): 69 | # current epoch name 70 | epoch_type = self.which_stage(self.each_epoch, epoch) 71 | running_loss = 0 72 | if sys_param["distributed"]: 73 | sampler_train.set_epoch(epoch) 74 | with tqdm(total = len(train_loader), 75 | desc='{}:{}'.format(epoch_type, epoch), 76 | bar_format='{desc} |{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt} {postfix}]', 77 | ncols=150) as bar: 78 | for step, data in enumerate(train_loader): 79 | optimizer = opt_list[self.enerf_model_without_ddp.opt_idx] 80 | cur_ratio = cur_step/total_step 81 | optimizer.zero_grad() 82 | loss_dict, intr_show, pose_show, rays_valid = mc_nerf(data, epoch, epoch_type, cur_ratio) 83 | loss_value = self.loss_func(loss_dict, epoch_type) 84 | loss_value.backward() 85 | optimizer.step() 86 | running_loss += loss_value.item() 87 | ave_loss = running_loss/(step + 1) 88 | cur_step += 1 89 | sched_list[self.enerf_model_without_ddp.opt_idx].step() 90 | bar.set_postfix_str('AveLoss:{:^7.9f}, LR:{:^7.5f}'.format(ave_loss, optimizer.param_groups[0]['lr'])) 91 | bar.update() 92 | self.enerf_model_without_ddp.nerf.save_model(self.mc_nerf, epoch) 93 | self.enerf_model_without_ddp.show_estimate_param(intr_show, pose_show, epoch, epoch_type) 94 | self.enerf_model_without_ddp.show_RT_est_results(epoch, epoch_type, mode='epoch') 95 | self.enerf_model_without_ddp.nerf.valid_train(epoch, rays_valid, epoch_type) 96 | 97 | @torch.no_grad() 98 | def test_model(self): 99 | test_loader = self.loader.dataloader["Squence_loader"] 100 | self.mc_nerf.to(self.device) 101 | self.mc_nerf.eval() 102 | res_rgbs = [] 103 | res_invdepth = [] 104 | img_h = sys_param["res_h"] 105 | img_w = sys_param["res_w"] 106 | img_res = img_h*img_w 107 | img_name_idx = 0 108 | with tqdm(total = len(test_loader), 109 | desc='Rendering:', 110 | bar_format='{desc} |{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt} {postfix}]', 111 | ncols=150) as bar: 112 | # 训练 113 | for step, data in enumerate(test_loader): 114 | gt_rgbs, img_idx = data 115 | # 开始前项计算 116 | pd_rgbs, pd_depth, pd_opacity = self.mc_nerf(img_idx.to(self.device)) 117 | invdepth = 1/(pd_depth/pd_opacity + 1e-10) * 2 118 | invdepth = apply_depth_colormap(invdepth, cmap="inferno") 119 | res_rgbs += [pd_rgbs] 120 | res_invdepth += [invdepth] 121 | bar.update() 122 | 123 | rgbs_cat = torch.cat(res_rgbs, 0) 124 | dept_cat = torch.cat(res_invdepth, 0) 125 | nowtime = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) 126 | save_pth_pd = os.path.join(Path(sys_param["demo_render_pth"] + "_" + nowtime), Path("pred")) 127 | save_pth_depth = os.path.join(Path(sys_param["demo_render_pth"] + "_" + nowtime), Path("depth")) 128 | save_pth_gt = os.path.join(Path(sys_param["demo_render_pth"] + "_" + nowtime), Path("gt")) 129 | 130 | if not os.path.exists(save_pth_gt): 131 | os.makedirs(save_pth_gt) 132 | if not os.path.exists(save_pth_pd): 133 | os.makedirs(save_pth_pd) 134 | if not os.path.exists(save_pth_depth): 135 | os.makedirs(save_pth_depth) 136 | 137 | PSNR_score = 0 138 | SSIM_score = 0 139 | LPIP_score = 0 140 | 141 | for i in range(0, rgbs_cat.shape[0], img_res): 142 | rgb_img_rays = rgbs_cat[i:i+img_res] 143 | dep_img_rays = dept_cat[i:i+img_res] 144 | gt_img = gt_rgbs.reshape(img_h, img_w, 3).cpu() 145 | img = rgb_img_rays.reshape(img_h, img_w, 3).cpu() 146 | dep = dep_img_rays.reshape(img_h, img_w, 3).cpu() 147 | gt_img = gt_img.permute(2, 0, 1) # (3, H, W) 148 | img = img.permute(2, 0, 1) # (3, H, W) 149 | dep = dep.permute(2, 0, 1) # (3, H, W) 150 | psnr = self.psnr_score(img, gt_img) 151 | ssim = self.ssim_score(img, gt_img) 152 | lpip = self.lpips_score(img, gt_img) 153 | gt_name = str(img_name_idx).zfill(4) + "gt" + ".png" 154 | img_name = str(img_name_idx).zfill(4) + ".png" 155 | dep_name = str(img_name_idx).zfill(4) + "depth" + ".png" 156 | gt_pth = os.path.join(Path(save_pth_gt), Path(gt_name)) 157 | img_pth = os.path.join(Path(save_pth_pd), Path(img_name)) 158 | dep_pth = os.path.join(Path(save_pth_depth), Path(dep_name)) 159 | # 转换图片格式 160 | transforms.ToPILImage()(gt_img).convert("RGB").save(gt_pth) 161 | transforms.ToPILImage()(img).convert("RGB").save(img_pth) 162 | transforms.ToPILImage()(dep).convert("RGB").save(dep_pth) 163 | img_name_idx += 1 164 | PSNR_score += psnr 165 | SSIM_score += ssim 166 | LPIP_score += lpip 167 | 168 | print('Results ({})'.format(sys_param['data_name'])) 169 | print('PSNR: {}'.format(PSNR_score/200)) 170 | print('SSIM: {}'.format(SSIM_score/200)) 171 | print('LPIP: {}'.format(LPIP_score/200)) 172 | 173 | torch.cuda.empty_cache() 174 | 175 | # optimizer config 176 | def generate_optimizer(self, each_epoch_step): 177 | cam_stage_lr = self.sys_param["stage1_lr"] 178 | optim_stage_lr = self.sys_param["stage2_lr"] 179 | fine_tune_lr = self.sys_param["stage3_lr"] 180 | weight_d = self.sys_param["weight_d"] 181 | # camera parameters initial stage 182 | for key in self.enerf_model_without_ddp.named_parameters(): 183 | if key[0].split(".")[0] == "nerf": 184 | key[1].requires_grad_(False) 185 | else: 186 | key[1].requires_grad_(True) 187 | opt_cam = RAdam(filter(lambda p: p.requires_grad, self.enerf_model_without_ddp.parameters()), lr=cam_stage_lr, eps=1e-8, weight_decay=weight_d) 188 | gamma_cam = (0.005/cam_stage_lr)**(1./(each_epoch_step*self.cam_epoch)) 189 | opt_cam_sched = torch.optim.lr_scheduler.ExponentialLR(opt_cam, gamma_cam) 190 | # global optimization stage 191 | for key in self.enerf_model_without_ddp.named_parameters(): 192 | key[1].requires_grad_(True) 193 | opt_global = RAdam(filter(lambda p: p.requires_grad, self.enerf_model_without_ddp.parameters()), lr=optim_stage_lr, eps=1e-8, weight_decay=weight_d) 194 | gamma_global = (optim_stage_lr/optim_stage_lr)**(1./(each_epoch_step*self.optim_epoch)) 195 | opt_global_sched = torch.optim.lr_scheduler.ExponentialLR(opt_global, gamma_global) 196 | # fine tuning stage 197 | for key in self.enerf_model_without_ddp.named_parameters(): 198 | key[1].requires_grad_(True) 199 | self.enerf_model_without_ddp.weights_pose.requires_grad_(False) 200 | opt_fine_tune = RAdam(filter(lambda p: p.requires_grad, self.enerf_model_without_ddp.parameters()), lr=fine_tune_lr, eps=1e-8, weight_decay=weight_d) 201 | gamma_fine_tune = (fine_tune_lr/fine_tune_lr)**(1./(each_epoch_step*self.fine_tune_epoch)) 202 | opt_fine_sched = torch.optim.lr_scheduler.ExponentialLR(opt_fine_tune, gamma_fine_tune) 203 | # activate all weights before training 204 | for key in self.enerf_model_without_ddp.named_parameters(): 205 | key[1].requires_grad_(True) 206 | 207 | return [opt_cam, opt_global, opt_fine_tune], [opt_cam_sched, opt_global_sched, opt_fine_sched] 208 | 209 | # get current stage info 210 | def which_stage(self, epoch_list, cur_epoch): 211 | epoch_name = ['CAM_PARAM_EPOCH', 'GLOBAL_OPTIM_EPOCH', 'FINE_TUNE_EPOCH'] 212 | name_idx = torch.arange(0, len(epoch_name)) 213 | epoch_list = torch.cumsum(epoch_list, 0) 214 | for idx, epoch in enumerate(epoch_list): 215 | if cur_epoch in range(epoch): 216 | cur_idx = name_idx[idx] 217 | return epoch_name[cur_idx] 218 | 219 | # get PSNR score 220 | def psnr_score(self, image_pred, image_gt, valid_mask=None, reduction='mean'): 221 | def mse(image_pred, image_gt, valid_mask=None, reduction='mean'): 222 | value = (image_pred-image_gt)**2 223 | if valid_mask is not None: 224 | value = value[valid_mask] 225 | if reduction == 'mean': 226 | return torch.mean(value) 227 | return value 228 | return -10*torch.log10(mse(image_pred, image_gt, valid_mask, reduction)) 229 | 230 | # get SSIM score 231 | def ssim_score(self, image_pred, image_gt): 232 | image_pred = image_pred.unsqueeze(0) 233 | image_gt = image_gt.unsqueeze(0) 234 | ssim = pytorch_ssim.ssim(image_pred, image_gt).item() 235 | return ssim 236 | 237 | # get lpips score 238 | def lpips_score(self, image_pred, image_gt): 239 | lpips_loss = lpips.LPIPS(net="alex") 240 | lpips_value = lpips_loss(image_pred*2-1, image_gt*2-1).item() 241 | return lpips_value 242 | 243 | 244 | if __name__ == "__main__": 245 | parser = argparse.ArgumentParser() 246 | # config file path 247 | parser.add_argument('--config', type=str, default="./config", 248 | help='root path of config file') 249 | # root path for data 250 | parser.add_argument('--root_data', type=str, default='./data/dataset_Ball', 251 | help='root path of data') 252 | parser.add_argument('--data_name', type=str, default='Ball_Computer', 253 | help='name of data') 254 | # work mode: train for train and valid, demo for test 255 | parser.add_argument('--demo', action='store_true', 256 | help='nerf rendering forward with test mode') 257 | parser.add_argument('--train', action='store_true', 258 | help='train mode') 259 | # save log file or not 260 | parser.add_argument('--log', action='store_true', 261 | help='save log information to log.txt file') 262 | # GPU number, which start, available in muti-GPU training 263 | parser.add_argument('--start_device', type=int, default=0, 264 | help='start training device for distributed mode') 265 | # active Tensorboard or not 266 | parser.add_argument('--tensorboard', action='store_true', 267 | help='use tensorboard tools to show training results') 268 | args = parser.parse_args() 269 | 270 | config_info = Load_config(args) 271 | sys_param = config_info.system_info 272 | logging.info("System Parameters:\n {} \n".format(sys_param)) 273 | # fix seed for reproduction 274 | seed = sys_param["seed"] + get_rank() 275 | torch.manual_seed(seed) 276 | np.random.seed(seed) 277 | random.seed(seed) 278 | 279 | # model engine 280 | model = Model_Engine(sys_param) 281 | model.forward() -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .mc_nerf import MC_Model 2 | from .net_utils import RAdam 3 | from .loss import MC_NeRF_Loss 4 | from .external.pohsun_ssim import pytorch_ssim 5 | from .net_utils import apply_depth_colormap -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkylerGao/MC_NeRF/8806f31ef8cf0eba6398bedebde3b2c7dce92ec0/model/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkylerGao/MC_NeRF/8806f31ef8cf0eba6398bedebde3b2c7dce92ec0/model/__pycache__/loss.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/mc_nerf.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkylerGao/MC_NeRF/8806f31ef8cf0eba6398bedebde3b2c7dce92ec0/model/__pycache__/mc_nerf.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/net_block.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkylerGao/MC_NeRF/8806f31ef8cf0eba6398bedebde3b2c7dce92ec0/model/__pycache__/net_block.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/net_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkylerGao/MC_NeRF/8806f31ef8cf0eba6398bedebde3b2c7dce92ec0/model/__pycache__/net_utils.cpython-39.pyc -------------------------------------------------------------------------------- /model/external/pohsun_ssim/LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT 2 | -------------------------------------------------------------------------------- /model/external/pohsun_ssim/README.md: -------------------------------------------------------------------------------- 1 | # pytorch-ssim 2 | 3 | ### Differentiable structural similarity (SSIM) index. 4 | ![einstein](https://raw.githubusercontent.com/Po-Hsun-Su/pytorch-ssim/master/einstein.png) ![Max_ssim](https://raw.githubusercontent.com/Po-Hsun-Su/pytorch-ssim/master/max_ssim.gif) 5 | 6 | ## Installation 7 | 1. Clone this repo. 8 | 2. Copy "pytorch_ssim" folder in your project. 9 | 10 | ## Example 11 | ### basic usage 12 | ```python 13 | import pytorch_ssim 14 | import torch 15 | from torch.autograd import Variable 16 | 17 | img1 = Variable(torch.rand(1, 1, 256, 256)) 18 | img2 = Variable(torch.rand(1, 1, 256, 256)) 19 | 20 | if torch.cuda.is_available(): 21 | img1 = img1.cuda() 22 | img2 = img2.cuda() 23 | 24 | print(pytorch_ssim.ssim(img1, img2)) 25 | 26 | ssim_loss = pytorch_ssim.SSIM(window_size = 11) 27 | 28 | print(ssim_loss(img1, img2)) 29 | 30 | ``` 31 | ### maximize ssim 32 | ```python 33 | import pytorch_ssim 34 | import torch 35 | from torch.autograd import Variable 36 | from torch import optim 37 | import cv2 38 | import numpy as np 39 | 40 | npImg1 = cv2.imread("einstein.png") 41 | 42 | img1 = torch.from_numpy(np.rollaxis(npImg1, 2)).float().unsqueeze(0)/255.0 43 | img2 = torch.rand(img1.size()) 44 | 45 | if torch.cuda.is_available(): 46 | img1 = img1.cuda() 47 | img2 = img2.cuda() 48 | 49 | 50 | img1 = Variable( img1, requires_grad=False) 51 | img2 = Variable( img2, requires_grad = True) 52 | 53 | 54 | # Functional: pytorch_ssim.ssim(img1, img2, window_size = 11, size_average = True) 55 | ssim_value = pytorch_ssim.ssim(img1, img2).data[0] 56 | print("Initial ssim:", ssim_value) 57 | 58 | # Module: pytorch_ssim.SSIM(window_size = 11, size_average = True) 59 | ssim_loss = pytorch_ssim.SSIM() 60 | 61 | optimizer = optim.Adam([img2], lr=0.01) 62 | 63 | while ssim_value < 0.95: 64 | optimizer.zero_grad() 65 | ssim_out = -ssim_loss(img1, img2) 66 | ssim_value = - ssim_out.data[0] 67 | print(ssim_value) 68 | ssim_out.backward() 69 | optimizer.step() 70 | 71 | ``` 72 | 73 | ## Reference 74 | https://ece.uwaterloo.ca/~z70wang/research/ssim/ 75 | -------------------------------------------------------------------------------- /model/external/pohsun_ssim/einstein.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkylerGao/MC_NeRF/8806f31ef8cf0eba6398bedebde3b2c7dce92ec0/model/external/pohsun_ssim/einstein.png -------------------------------------------------------------------------------- /model/external/pohsun_ssim/max_ssim.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkylerGao/MC_NeRF/8806f31ef8cf0eba6398bedebde3b2c7dce92ec0/model/external/pohsun_ssim/max_ssim.gif -------------------------------------------------------------------------------- /model/external/pohsun_ssim/max_ssim.py: -------------------------------------------------------------------------------- 1 | import pytorch_ssim 2 | import torch 3 | from torch.autograd import Variable 4 | from torch import optim 5 | import cv2 6 | import numpy as np 7 | 8 | npImg1 = cv2.imread("einstein.png") 9 | 10 | img1 = torch.from_numpy(np.rollaxis(npImg1, 2)).float().unsqueeze(0)/255.0 11 | img2 = torch.rand(img1.size()) 12 | 13 | if torch.cuda.is_available(): 14 | img1 = img1.cuda() 15 | img2 = img2.cuda() 16 | 17 | 18 | img1 = Variable( img1, requires_grad=False) 19 | img2 = Variable( img2, requires_grad = True) 20 | 21 | 22 | # Functional: pytorch_ssim.ssim(img1, img2, window_size = 11, size_average = True) 23 | ssim_value = pytorch_ssim.ssim(img1, img2).data[0] 24 | print("Initial ssim:", ssim_value) 25 | 26 | # Module: pytorch_ssim.SSIM(window_size = 11, size_average = True) 27 | ssim_loss = pytorch_ssim.SSIM() 28 | 29 | optimizer = optim.Adam([img2], lr=0.01) 30 | 31 | while ssim_value < 0.95: 32 | optimizer.zero_grad() 33 | ssim_out = -ssim_loss(img1, img2) 34 | ssim_value = - ssim_out.data[0] 35 | print(ssim_value) 36 | ssim_out.backward() 37 | optimizer.step() 38 | -------------------------------------------------------------------------------- /model/external/pohsun_ssim/pytorch_ssim/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 1 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | (_, channel, _, _) = img1.size() 49 | 50 | if channel == self.channel and self.window.data.type() == img1.data.type(): 51 | window = self.window 52 | else: 53 | window = create_window(self.window_size, channel) 54 | 55 | if img1.is_cuda: 56 | window = window.cuda(img1.get_device()) 57 | window = window.type_as(img1) 58 | 59 | self.window = window 60 | self.channel = channel 61 | 62 | 63 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 64 | 65 | def ssim(img1, img2, window_size = 11, size_average = True): 66 | (_, channel, _, _) = img1.size() 67 | window = create_window(window_size, channel) 68 | 69 | if img1.is_cuda: 70 | window = window.cuda(img1.get_device()) 71 | window = window.type_as(img1) 72 | 73 | return _ssim(img1, img2, window, window_size, channel, size_average) 74 | -------------------------------------------------------------------------------- /model/external/pohsun_ssim/pytorch_ssim/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkylerGao/MC_NeRF/8806f31ef8cf0eba6398bedebde3b2c7dce92ec0/model/external/pohsun_ssim/pytorch_ssim/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/external/pohsun_ssim/pytorch_ssim/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkylerGao/MC_NeRF/8806f31ef8cf0eba6398bedebde3b2c7dce92ec0/model/external/pohsun_ssim/pytorch_ssim/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /model/external/pohsun_ssim/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | -------------------------------------------------------------------------------- /model/external/pohsun_ssim/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | setup( 3 | name = 'pytorch_ssim', 4 | packages = ['pytorch_ssim'], # this must be the same as the name above 5 | version = '0.1', 6 | description = 'Differentiable structural similarity (SSIM) index', 7 | author = 'Po-Hsun (Evan) Su', 8 | author_email = 'evan.pohsun.su@gmail.com', 9 | url = 'https://github.com/Po-Hsun-Su/pytorch-ssim', # use the URL to the github repo 10 | download_url = 'https://github.com/Po-Hsun-Su/pytorch-ssim/archive/0.1.tar.gz', # I'll explain this in a second 11 | keywords = ['pytorch', 'image-processing', 'deep-learning'], # arbitrary keywords 12 | classifiers = [], 13 | ) 14 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class MC_NeRF_Loss(nn.Module): 5 | def __init__(self, sys_param, tblogger=None): 6 | super(MC_NeRF_Loss, self).__init__() 7 | self.sys_param = sys_param 8 | self.loss_l2 = nn.MSELoss(reduction='mean') 9 | # tensorboard 10 | self.tblogger = tblogger 11 | self.global_step = 0 12 | self.img_h = self.sys_param["data_img_h"] 13 | self.img_w = self.sys_param["data_img_w"] 14 | 15 | def forward(self, loss_dict, epoch_type): 16 | final_loss = 0.0 17 | self.global_step += 1 18 | if "intr" in loss_dict: 19 | loss_intr = self.get_reproject_loss(loss_dict["intr"]) 20 | if epoch_type == "CAM_PARAM_EPOCH": 21 | final_loss += loss_intr 22 | else: 23 | final_loss += loss_intr/(loss_intr.detach() + 1e-8) 24 | if "extr" in loss_dict: 25 | loss_extr = self.get_reproject_loss(loss_dict["extr"]) 26 | final_loss += loss_extr 27 | if "rgb" in loss_dict: 28 | loss_rgb = self.get_rgb_loss(loss_dict["rgb"]) 29 | final_loss += loss_rgb 30 | 31 | return final_loss 32 | 33 | def get_rgb_loss(self, rgbs_list): 34 | rgbs_c = rgbs_list[0] 35 | rgbs_f = rgbs_list[1] 36 | rgbs_gt = rgbs_list[2] 37 | if rgbs_f is None: 38 | loss_rgb_f = 0 39 | else: 40 | loss_rgb_f = self.loss_l2(rgbs_f, rgbs_gt) 41 | loss_rgb_c = self.loss_l2(rgbs_c, rgbs_gt) 42 | loss = loss_rgb_c + loss_rgb_f 43 | return loss 44 | 45 | def get_reproject_loss(self, rpro_list): 46 | # pts format [x, y] 47 | pd_pts = rpro_list[0] 48 | gt_pts = rpro_list[1] 49 | 50 | pd_pts_nx = pd_pts[..., 0] 51 | pd_pts_ny = pd_pts[..., 1] 52 | gt_pts_nx = gt_pts[..., 0] 53 | gt_pts_ny = gt_pts[..., 1] 54 | 55 | proj_loss_x = self.loss_l2(pd_pts_nx/self.img_w, gt_pts_nx/self.img_w) 56 | proj_loss_y = self.loss_l2(pd_pts_ny/self.img_h, gt_pts_ny/self.img_h) 57 | 58 | return proj_loss_x + proj_loss_y 59 | 60 | -------------------------------------------------------------------------------- /model/mc_nerf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import os 4 | import time 5 | import cv2 6 | import json 7 | import lpips 8 | 9 | import numpy as np 10 | import torch.nn as nn 11 | import torch.distributed as dist 12 | import prettytable as pt 13 | import matplotlib.pyplot as plt 14 | 15 | from pathlib import Path 16 | from torchvision import transforms 17 | from mpl_toolkits.mplot3d import Axes3D 18 | 19 | from .net_block import SinCosEmbedding 20 | from .net_block import CorseFine_NeRF 21 | from .net_utils import get_rank 22 | from .external.pohsun_ssim import pytorch_ssim 23 | 24 | class MC_Model(nn.Module): 25 | def __init__(self, sys_param): 26 | logging.info('Creating MC-NeRF Model...') 27 | super(MC_Model, self).__init__() 28 | self.sys_param = sys_param 29 | self.mode = self.sys_param["mode"] 30 | self.device = self.sys_param["device_type"] 31 | self.batch = self.sys_param["batch"] 32 | self.bound_min = self.sys_param["boader_min"] 33 | self.bound_max = self.sys_param["boader_max"] 34 | self.intr = self.sys_param["intr_mat"] 35 | self.intr_inv = self.sys_param["intr_mat_inv"] 36 | self.intr_train, self.intr_test, self.intr_val = self.intr[0], self.intr[1], self.intr[2] 37 | self.intr_train_inv, self.intr_test_inv, self.intr_val_inv = self.intr_inv[0], self.intr_inv[1], self.intr_inv[2] 38 | self.gt_pose = self.sys_param["gt_pose"].to(self.device) 39 | self.test_pose = self.sys_param["test_pose"].to(self.device) 40 | self.valid_pose = self.sys_param["valid_pose"].to(self.device) 41 | self.valid_rgbs = self.sys_param["valid_rgbs"].to(self.device) 42 | self.img_h = self.sys_param["data_img_h"] 43 | self.img_w = self.sys_param["data_img_w"] 44 | self.train_img_pth = self.sys_param["demo_render_pth"] 45 | self.data_name = self.sys_param["data_name"] 46 | self.data_numb = self.sys_param["data_numb"] 47 | self.train_numb, self.test_numb, self.val_numb = self.data_numb[0], self.data_numb[1], self.data_numb[2] 48 | self.train_json_pth = sys_param["train_json_file"] 49 | self.register_parameters() 50 | self.nerf = NeRF_Model(self.sys_param).to(self.device) 51 | self.table = pt.PrettyTable(['EPOCH', 'LOSS_FX', 'LOSS_FY', 'LOSS_UX', 'LOSS_UY', 'LOSS_K', 'LOSS_R', 'LOSS_T']) 52 | self.count_rays = 0 53 | self.init_show_figure(show_info=False) 54 | self.opt_idx = 0 55 | self.last_epoch_type = 0 56 | self.wait_reset = 0 57 | 58 | def forward(self, *args): 59 | if self.sys_param["mode"] == 0: 60 | loss_dict = {} 61 | gt_rgbs, img_id, intr_wpts, intr_pts,\ 62 | extr_wpts, extr_pts, epoch, epoch_type, cur_ratio = self.data2device(*args) 63 | # camera parameters initial stage 64 | if epoch_type == "CAM_PARAM_EPOCH": 65 | self.nerf.emmbedding_xyz.barf_mode = False 66 | self.intr_adj, self.pose_adj, self.calib_pose_adj = self.add_weights2param(intr=True, extr=True, calib_extr=True) 67 | reproj_pts_intr = self.get_reproject_pixels(intr_wpts, self.intr_adj, self.calib_pose_adj) 68 | reproj_pts_extr = self.get_reproject_pixels(extr_wpts, self.intr_adj, self.pose_adj) 69 | loss_dict["intr"] = [reproj_pts_intr, intr_pts] 70 | loss_dict["extr"] = [reproj_pts_extr, extr_pts] 71 | self.opt_idx = 0 72 | # global optimization stage 73 | elif epoch_type == "GLOBAL_OPTIM_EPOCH": 74 | self.nerf.emmbedding_xyz.barf_mode = True 75 | self.intr_adj, self.pose_adj, self.calib_pose_adj = self.add_weights2param(intr=True, extr=True, calib_extr=True) 76 | reproj_pts_intr = self.get_reproject_pixels(intr_wpts, self.intr_adj, self.calib_pose_adj) 77 | rays_d, rays_o = self.get_rays(self.pose_adj, img_id, self.inverse_intrinsic(self.intr_adj)) 78 | sample_rays_d, sample_rays_o, rand_idx = self.generate_rand_rays(rays_d, rays_o, rand=True) 79 | rgbs_c, rgbs_f = self.nerf(sample_rays_d, sample_rays_o, epoch, cur_ratio) 80 | gt_rgbs = gt_rgbs.reshape(-1, 3)[rand_idx] 81 | loss_dict["intr"] = [reproj_pts_intr, intr_pts] 82 | loss_dict["rgb"] = [rgbs_c, rgbs_f, gt_rgbs] 83 | self.opt_idx = 1 84 | # fine-tuning stage 85 | else: 86 | self.nerf.emmbedding_xyz.barf_mode = False 87 | self.intr_adj, self.pose_adj, self.calib_pose_adj = self.add_weights2param(intr=True, extr=False, calib_extr=True) 88 | reproj_pts_intr = self.get_reproject_pixels(intr_wpts, self.intr_adj, self.calib_pose_adj) 89 | rays_d, rays_o = self.get_rays(self.pose_adj, img_id, self.inverse_intrinsic(self.intr_adj)) 90 | sample_rays_d, sample_rays_o, rand_idx = self.generate_rand_rays(rays_d, rays_o, rand=True) 91 | rgbs_c, rgbs_f = self.nerf(sample_rays_d, sample_rays_o, epoch, 1) 92 | gt_rgbs = gt_rgbs.reshape(-1, 3)[rand_idx] 93 | loss_dict["intr"] = [reproj_pts_intr, intr_pts] 94 | loss_dict["rgb"] = [rgbs_c, rgbs_f, gt_rgbs] 95 | self.opt_idx = 2 96 | # generate valid data 97 | rays_dv, rays_ov = self.get_rays(self.valid_pose, img_id, self.intr_val_inv.to(self.device)) 98 | rgbs_v = self.valid_rgbs[img_id] 99 | rays_valid = [rays_dv.detach(), rays_ov.detach(), rgbs_v.detach()] 100 | 101 | intr_show = [self.intr_train.to(self.device).detach(), self.intr_adj.detach()] 102 | pose_show = [self.gt_pose.to(self.device).detach(), self.pose_adj.detach()] 103 | self.last_epoch_type = epoch_type 104 | 105 | return loss_dict, intr_show, pose_show, rays_valid 106 | else: 107 | img_id = args 108 | rgbs = [] 109 | depth = [] 110 | opacity = [] 111 | rays_d, rays_o = self.get_rays(self.test_pose, img_id, self.intr_test_inv.to(self.device)) 112 | for ii in range(0, rays_d.shape[0], self.batch): 113 | cur_rays_d = rays_d[ii:ii+self.batch] 114 | cur_rays_o = rays_o[ii:ii+self.batch] 115 | pd_rgbs, pd_depth, pd_opacity = self.nerf(cur_rays_d, cur_rays_o) 116 | rgbs += [pd_rgbs.detach().cpu()] 117 | depth += [pd_depth.detach().cpu()] 118 | opacity += [pd_opacity.detach().cpu()] 119 | rgbs = torch.cat(rgbs, 0) 120 | depth = torch.cat(depth, 0) 121 | opacity = torch.cat(opacity, 0) 122 | return rgbs, depth, opacity 123 | 124 | def get_rays(self, pose, img_id, intr_inv): 125 | select_pose = pose[img_id] 126 | with torch.no_grad(): 127 | y_range = torch.arange(self.img_h, dtype=torch.float32, device=self.device).add_(0.5) 128 | x_range = torch.arange(self.img_w, dtype=torch.float32, device=self.device).add_(0.5) 129 | Y,X = torch.meshgrid(y_range,x_range, indexing='ij') # [H,W] 130 | xy_grid = torch.stack([X,Y],dim=-1).reshape(-1,2) # [HW,2] 131 | xy_grid = xy_grid.unsqueeze(0) # [1,HW,2] 132 | pix_cord = self.pix2hom(xy_grid) 133 | cam_cord = self.pix2cam(pix_cord, intr_inv[img_id]) 134 | cam_orig = torch.zeros_like(cam_cord) 135 | cam_cord_hom = self.cam2hom(cam_cord) 136 | cam_orig_hom = self.cam2hom(cam_orig) 137 | world_cord = self.cam2world(cam_cord_hom, select_pose) 138 | rays_o = self.cam2world(cam_orig_hom, select_pose) 139 | rays_d = world_cord - rays_o 140 | rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 141 | 142 | rays_d = rays_d.reshape(-1, 3) 143 | rays_o = rays_o.reshape(-1, 3) 144 | 145 | return rays_d, rays_o 146 | 147 | def get_reproject_pixels(self, tag_wpts, intr_adj, pose_adj): 148 | world_pts = self.world2hom(tag_wpts) 149 | proj_cam_pts = self.world2cam(world_pts, pose_adj.unsqueeze(0)) 150 | proj_pts = self.cam2pix(proj_cam_pts, intr_adj.unsqueeze(0)) 151 | 152 | return proj_pts 153 | 154 | # transform parameters to learnable parameters 155 | def add_weights2param(self, intr=True, extr=True, calib_extr=False): 156 | if intr: 157 | intr_adj = self.add_weights2intr(self.img_h, self.img_w) 158 | else: 159 | intr_adj = self.add_weights2intr(self.img_h, self.img_w, adj=False) 160 | if extr: 161 | pose_adj = self.add_weights2pose() 162 | else: 163 | pose_adj = self.add_weights2pose(adj=False) 164 | if calib_extr: 165 | calib_pose_adj = self.add_weights2calib_pose() 166 | else: 167 | calib_pose_adj = self.add_weights2calib_pose(adj=False) 168 | 169 | return intr_adj, pose_adj, calib_pose_adj 170 | 171 | def add_weights2intr(self, img_h, img_w, adj=True): 172 | intr_init = torch.tensor([[img_w, 0, img_w/2], 173 | [0, img_w, img_h/2], 174 | [0, 0, 1]], device=self.device).expand(self.train_numb, 3, 3) 175 | intr_adj = intr_init.clone() 176 | if adj: 177 | intr_adj[:, 0, 0] = torch.abs(intr_init[:, 0, 0]*self.weights_fx.requires_grad_(True)) 178 | intr_adj[:, 1, 1] = torch.abs(intr_init[:, 1, 1]*self.weights_fy.requires_grad_(True)) 179 | intr_adj[:, 0, 2] = torch.abs(intr_init[:, 0, 2]*self.weights_ux.requires_grad_(True)) 180 | intr_adj[:, 1, 2] = torch.abs(intr_init[:, 1, 2]*self.weights_uy.requires_grad_(True)) 181 | else: 182 | intr_adj[:, 0, 0] = torch.abs(intr_init[:, 0, 0]*self.weights_fx.requires_grad_(False)) 183 | intr_adj[:, 1, 1] = torch.abs(intr_init[:, 1, 1]*self.weights_fy.requires_grad_(False)) 184 | intr_adj[:, 0, 2] = torch.abs(intr_init[:, 0, 2]*self.weights_ux.requires_grad_(False)) 185 | intr_adj[:, 1, 2] = torch.abs(intr_init[:, 1, 2]*self.weights_uy.requires_grad_(False)) 186 | return intr_adj 187 | 188 | def add_weights2pose(self, adj=True): 189 | if adj: 190 | weights_RT = self.se3_to_SE3(self.weights_pose.requires_grad_(True)) 191 | else: 192 | weights_RT = self.se3_to_SE3(self.weights_pose.requires_grad_(False)) 193 | 194 | return weights_RT 195 | 196 | def add_weights2calib_pose(self, adj=True): 197 | if adj: 198 | weights_RT = self.se3_to_SE3(self.weights_pose_intr.requires_grad_(True)) 199 | else: 200 | weights_RT = self.se3_to_SE3(self.weights_pose_intr.requires_grad_(False)) 201 | 202 | return weights_RT 203 | 204 | def inverse_intrinsic(self, intr_mats): 205 | intr_inv_list = [] 206 | for intr in intr_mats: 207 | intr_inv = intr.inverse() 208 | intr_inv_list += [intr_inv] 209 | intr_inv = torch.stack(intr_inv_list, 0) 210 | return intr_inv 211 | 212 | # [Batch, HW, 2]->[Batch, HW, 3] 213 | def pix2hom(self, pixel_cord): 214 | X_hom = torch.cat([pixel_cord, torch.ones_like(pixel_cord[...,:1])], dim=-1) 215 | return X_hom 216 | 217 | # [Batch, HW, 3]->[Batch, HW, 4] 218 | def cam2hom(self, cam_cord): 219 | X_hom = torch.cat([cam_cord, torch.ones_like(cam_cord[...,:1])], dim=-1) 220 | return X_hom 221 | 222 | # [Batch, HW, 3]->[Batch, HW, 4] 223 | def world2hom(self, world_cord): 224 | X_hom = torch.cat([world_cord, torch.ones_like(world_cord[...,:1])], dim=-1) 225 | return X_hom 226 | 227 | # pix_cord: [batch, ..., 3] 228 | # intr_inv_mat: [batch, ..., 3, 3] 229 | def pix2cam(self, pix_cord, intr_inv_mat): 230 | intr_inv_mat = intr_inv_mat.transpose(-2, -1) 231 | cam_cord = pix_cord @ intr_inv_mat 232 | return cam_cord 233 | 234 | # cam_cord: [batch, ..., 4] 235 | # intr_mat: [batch, ..., 3, 3] 236 | def cam2pix(self, cam_cord, intr_mat): 237 | hom_intr_mat = torch.cat([intr_mat, torch.zeros_like(intr_mat[...,:1])], dim=-1) 238 | pix_cord = hom_intr_mat @ cam_cord 239 | pix_cord = pix_cord[...,:2,:]/pix_cord[...,2:,:] 240 | pix_cord = pix_cord.transpose(-2, -1) 241 | return pix_cord 242 | 243 | # cam_cord: [batch, ..., 4] 244 | # pose: [batch, ..., 4] 245 | def cam2world(self, cam_cord, pose): 246 | pose_R = pose[..., :3] 247 | pose_T = pose[..., 3:] 248 | # 正交矩阵,转置等于逆矩阵 249 | pose_R_inv = pose_R.transpose(-2, -1) 250 | pose_T_inv = (-pose_R_inv @ pose_T) 251 | # [batch, 3, 4] 252 | pose_inv = torch.cat([pose_R_inv, pose_T_inv], -1) 253 | # [batch, HW, 3] 254 | world_cord = cam_cord @ pose_inv.transpose(-2, -1) 255 | 256 | return world_cord 257 | 258 | # world_cord: [batch, ..., 4] 259 | # pose: [batch, ..., 3, 4] 260 | def world2cam(self, world_cord, pose): 261 | shape = pose.shape 262 | supply_pose = torch.tensor([0, 0, 0, 1], device=self.device) 263 | supply_pose = supply_pose.expand(shape)[...,:1,:] 264 | hom_pose = torch.cat([pose, supply_pose], dim=-2) 265 | cam_cord = hom_pose @ world_cord.transpose(-2, -1) 266 | 267 | return cam_cord 268 | 269 | def se3_to_SE3(self, wu): # [...,3] 270 | w,u = wu.split([3,3],dim=-1) 271 | wx = self.skew_symmetric(w) 272 | theta = w.norm(dim=-1)[...,None,None] 273 | I = torch.eye(3, device=self.device) 274 | A = self.taylor_A(theta) 275 | B = self.taylor_B(theta) 276 | C = self.taylor_C(theta) 277 | R = I+A*wx+B*wx@wx 278 | V = I+B*wx+C*wx@wx 279 | Rt = torch.cat([R,(V@u[...,None])],dim=-1) 280 | 281 | return Rt 282 | 283 | def skew_symmetric(self, w): 284 | w0,w1,w2 = w.unbind(dim=-1) 285 | O = torch.zeros_like(w0) 286 | wx = torch.stack([torch.stack([O,-w2,w1],dim=-1), 287 | torch.stack([w2,O,-w0],dim=-1), 288 | torch.stack([-w1,w0,O],dim=-1)],dim=-2) 289 | return wx 290 | 291 | def taylor_A(self,x,nth=10): 292 | # Taylor expansion of sin(x)/x 293 | ans = torch.zeros_like(x) 294 | denom = 1. 295 | for i in range(nth+1): 296 | if i>0: denom *= (2*i)*(2*i+1) 297 | ans = ans+(-1)**i*x**(2*i)/denom 298 | return ans 299 | 300 | def taylor_B(self,x,nth=10): 301 | # Taylor expansion of (1-cos(x))/x**2 302 | ans = torch.zeros_like(x) 303 | denom = 1. 304 | for i in range(nth+1): 305 | denom *= (2*i+1)*(2*i+2) 306 | ans = ans+(-1)**i*x**(2*i)/denom 307 | return ans 308 | 309 | def taylor_C(self,x,nth=10): 310 | # Taylor expansion of (x-sin(x))/x**3 311 | ans = torch.zeros_like(x) 312 | denom = 1. 313 | for i in range(nth+1): 314 | denom *= (2*i+2)*(2*i+3) 315 | ans = ans+(-1)**i*x**(2*i)/denom 316 | return ans 317 | 318 | def compose_param2pose(self, param, pose): 319 | R_a,t_a = param[...,:3], param[...,3:] 320 | R_b,t_b = pose[...,:3], pose[...,3:] 321 | R_new = R_b @ R_a 322 | t_new = (R_b @ t_a + t_b) 323 | pose_new = torch.cat([R_new, t_new], -1) 324 | 325 | return pose_new 326 | 327 | def generate_rand_rays(self, rays_d, rays_o, rand=True): 328 | if rand: 329 | rand_idx = torch.randperm(rays_d.shape[0], device = self.device)[:self.batch] 330 | else: 331 | slides = torch.arange(0, rays_d.shape[0], self.batch, device=self.device) 332 | if self.count_rays == len(slides): 333 | self.count_rays = 0 334 | if slides[self.count_rays] + self.batch > rays_d.shape[0]: 335 | end_idx = rays_d.shape[0] 336 | else: 337 | end_idx = slides[self.count_rays] + self.batch 338 | rand_idx = torch.arange(slides[self.count_rays], end_idx, device=self.device) 339 | 340 | sample_rays_d = rays_d[rand_idx] 341 | sample_rays_o = rays_o[rand_idx] 342 | 343 | self.count_rays += 1 344 | 345 | return sample_rays_d, sample_rays_o, rand_idx 346 | 347 | def register_parameters(self): 348 | self.register_parameter( 349 | name="weights_pose", 350 | param=nn.Parameter(torch.ones([self.train_numb, 6], 351 | device=self.device), requires_grad=True)) 352 | self.register_parameter( 353 | name="weights_pose_intr", 354 | param=nn.Parameter(torch.ones([self.train_numb, 6], 355 | device=self.device), requires_grad=True)) 356 | self.register_parameter( 357 | name="weights_ux", 358 | param=nn.Parameter(torch.ones([self.train_numb], 359 | device=self.device), requires_grad=True)) 360 | self.register_parameter( 361 | name="weights_uy", 362 | param=nn.Parameter(torch.ones([self.train_numb], 363 | device=self.device), requires_grad=True)) 364 | self.register_parameter( 365 | name="weights_fx", 366 | param=nn.Parameter(torch.ones([self.train_numb], 367 | device=self.device), requires_grad=True)) 368 | self.register_parameter( 369 | name="weights_fy", 370 | param=nn.Parameter(torch.ones([self.train_numb], 371 | device=self.device), requires_grad=True)) 372 | 373 | # data to GPU 374 | def data2device(self, *args): 375 | gt_rgbs, img_id, intr_wpts, intr_pts, extr_wpts, extr_pts = args[0] 376 | epoch = args[1] 377 | epoch_type = args[2] 378 | cur_ratio = args[3] 379 | gt_rgbs = gt_rgbs.to(self.device) 380 | img_id = img_id.to(self.device) 381 | intr_wpts = intr_wpts.to(self.device) 382 | intr_pts = intr_pts.to(self.device) 383 | extr_wpts = extr_wpts.to(self.device) 384 | extr_pts = extr_pts.to(self.device) 385 | 386 | return gt_rgbs, img_id, intr_wpts, intr_pts, extr_wpts, extr_pts, epoch, epoch_type, cur_ratio 387 | 388 | def show_estimate_param(self, intr_show, pose_show, epoch, epoch_type): 389 | intr_loss = torch.abs(intr_show[0] - intr_show[1]) 390 | pose_loss = torch.abs(pose_show[0] - pose_show[1]) 391 | ave_loss_fx = intr_loss[:,0,0].mean() 392 | ave_loss_fy = intr_loss[:,1,1].mean() 393 | ave_loss_ux = intr_loss[:,0,2].mean() 394 | ave_loss_uy = intr_loss[:,1,2].mean() 395 | ave_loss_K = intr_loss.mean() 396 | 397 | ave_loss_R = pose_loss[:,:3,:3].mean() 398 | ave_loss_T = pose_loss[:,:3,3:].mean() 399 | self.table.add_row([int(epoch), round(float(ave_loss_fx), 4),\ 400 | round(float(ave_loss_fy), 4),\ 401 | round(float(ave_loss_ux), 4),\ 402 | round(float(ave_loss_uy), 4),\ 403 | round(float(ave_loss_K), 4),\ 404 | round(float(ave_loss_R), 4),\ 405 | round(float(ave_loss_T), 4)]) 406 | 407 | print(self.table) 408 | 409 | def init_show_figure(self, show_info=True): 410 | self.all_fig = plt.figure(figsize=(4,4)) 411 | plt.rcParams['font.family'] = 'serif' 412 | plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif'] 413 | plt.rcParams['mathtext.default'] = 'regular' 414 | self.ax = Axes3D(self.all_fig, auto_add_to_figure=False) 415 | self.all_fig.add_axes(self.ax) 416 | self.ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0)) 417 | self.ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 1.0)) 418 | self.ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 1.0)) 419 | 420 | if show_info: 421 | self.ax.set_xlabel("X Axis") 422 | self.ax.set_ylabel("Y Axis") 423 | self.ax.set_zlabel("Z Axis") 424 | self.ax.set_xlim(-3.5, 3.5) 425 | self.ax.set_ylim(-3.5, 3.5) 426 | self.ax.set_zlim(-3.5, 3.5) 427 | else: 428 | self.ax.grid(False) 429 | self.ax.axis(False) 430 | 431 | plt.ion() 432 | plt.gca().set_box_aspect((1, 1, 1)) 433 | 434 | def origin_pose_transform(self, pose): 435 | pose_R_new_inv = pose[..., :3] 436 | pose_T_new_inv = pose[..., 3:] 437 | pose_R_new = pose_R_new_inv.transpose(-2, -1) 438 | pose_T_new = -pose_R_new @ pose_T_new_inv 439 | pose_flip_R_inv = torch.diag(torch.tensor([1.0,-1.0,-1.0], device=self.device)).T.expand(pose.shape[0], 3, 3) 440 | pose_flip_T = torch.zeros([3, 1], device=self.device).expand(pose.shape[0], 3, 1) 441 | 442 | pose_R = pose_R_new @ pose_flip_R_inv 443 | pose_T = pose_T_new - pose_R @ pose_flip_T 444 | pose_ori = torch.cat([pose_R, pose_T], -1) 445 | 446 | return pose_ori 447 | 448 | def show_RT_est_results(self, epoch, epoch_type, mode='epoch', show_info=True): 449 | if epoch_type in ['INTR_EPOCH']: 450 | return 0 451 | 452 | # 判断一下路径是否存在 453 | save_path = os.path.join(Path(self.train_img_pth), Path(self.data_name), Path("cam_pose")) 454 | if not os.path.exists(save_path): 455 | os.makedirs(save_path) 456 | 457 | if mode == 'epoch': 458 | plt.cla() 459 | color_gt = (0.7,0.2,0.7) 460 | color_pd = (0,0.6,0.7) 461 | # [84, 3, 4] 462 | gt_ori_pose = self.origin_pose_transform(self.gt_pose.detach()) 463 | pd_ori_pose = self.origin_pose_transform(self.pose_adj.detach()) 464 | self.draw_camera_shape(gt_ori_pose, self.intr_train, color_gt, cam_size=0.3) 465 | self.draw_camera_shape(pd_ori_pose, self.intr_adj.detach(), color_pd, cam_size=0.3) 466 | 467 | if show_info: 468 | self.ax.set_xlim(-3.5, 3.5) 469 | self.ax.set_ylim(-3.5, 3.5) 470 | self.ax.set_zlim(-3.5, 3.5) 471 | else: 472 | self.ax.grid(False) 473 | self.ax.axis(False) 474 | self.ax.txt 475 | 476 | file_path = os.path.join(Path(save_path), Path("epoch_"+ str(epoch) + ".png")) 477 | plt.savefig(file_path) 478 | else: 479 | if epoch % 100 == 0: 480 | plt.cla() 481 | color_gt = (0.7,0.2,0.7) 482 | color_pd = (0,0.6,0.7) 483 | gt_ori_pose = self.origin_pose_transform(self.gt_pose.detach()) 484 | pd_ori_pose = self.origin_pose_transform(self.pose_adj.detach()) 485 | self.draw_camera_shape(gt_ori_pose, self.intr_train, color_gt, cam_size=0.3) 486 | self.draw_camera_shape(pd_ori_pose, self.intr_adj.detach(), color_pd, cam_size=0.3) 487 | 488 | if show_info: 489 | self.ax.set_xlim(-3.5, 3.5) 490 | self.ax.set_ylim(-3.5, 3.5) 491 | self.ax.set_zlim(-3.5, 3.5) 492 | else: 493 | self.ax.grid(False) 494 | self.ax.axis(False) 495 | 496 | file_path = os.path.join(Path(save_path), Path("step_"+ str(epoch) + ".png")) 497 | plt.savefig(file_path, dpi=500) 498 | 499 | def draw_camera_shape(self, extr_mat, intr_mat, color, cam_size=0.25): 500 | # extr_mat: [84, 3, 4] 501 | # intr_mat: [84, 3, 3] 502 | cam_line = cam_size 503 | focal = intr_mat[:,0,0]*cam_line/self.img_w 504 | cam_pts_1 = torch.stack([-torch.ones_like(focal)*cam_line/2, 505 | -torch.ones_like(focal)*cam_line/2, 506 | -focal], -1)[:,None,:].to(extr_mat.device) 507 | cam_pts_2 = torch.stack([-torch.ones_like(focal)*cam_line/2, 508 | torch.ones_like(focal)*cam_line/2, 509 | -focal], -1)[:,None,:].to(extr_mat.device) 510 | cam_pts_3 = torch.stack([ torch.ones_like(focal)*cam_line/2, 511 | torch.ones_like(focal)*cam_line/2, 512 | -focal], -1)[:,None,:].to(extr_mat.device) 513 | cam_pts_4 = torch.stack([ torch.ones_like(focal)*cam_line/2, 514 | -torch.ones_like(focal)*cam_line/2, 515 | -focal], -1)[:,None,:].to(extr_mat.device) 516 | cam_pts_1 = cam_pts_1 @ extr_mat[:, :3, :3].transpose(-2,-1) + extr_mat[:, :3, 3][:,None,:] 517 | cam_pts_2 = cam_pts_2 @ extr_mat[:, :3, :3].transpose(-2,-1) + extr_mat[:, :3, 3][:,None,:] 518 | cam_pts_3 = cam_pts_3 @ extr_mat[:, :3, :3].transpose(-2,-1) + extr_mat[:, :3, 3][:,None,:] 519 | cam_pts_4 = cam_pts_4 @ extr_mat[:, :3, :3].transpose(-2,-1) + extr_mat[:, :3, 3][:,None,:] 520 | cam_pts = torch.cat([cam_pts_1, cam_pts_2, cam_pts_3, cam_pts_4, cam_pts_1], dim=-2) 521 | for i in range(4): 522 | # [84, 2, 3] 523 | cur_line_pts = torch.stack([cam_pts[:,i,:], cam_pts[:,i+1,:]], dim=-2).to('cpu') 524 | for each_cam in cur_line_pts: 525 | self.ax.plot(each_cam[:,0],each_cam[:,1],each_cam[:,2],color=color,linewidth=0.5) 526 | extr_T = extr_mat[:, :3, 3] 527 | for i in range(4): 528 | # [84, 2, 3] 529 | cur_line_pts = torch.stack([extr_T, cam_pts[:,i,:]], dim=-2).to('cpu') 530 | for each_cam in cur_line_pts: 531 | self.ax.plot(each_cam[:,0],each_cam[:,1],each_cam[:,2],color=color,linewidth=0.5) 532 | extr_T = extr_T.to('cpu') 533 | 534 | self.ax.scatter(extr_T[:,0],extr_T[:,1],extr_T[:,2],color=color,s=5) 535 | 536 | def listify_matrix(self, matrix): 537 | matrix_list = [] 538 | for row in matrix: 539 | matrix_list.append(list(row)) 540 | return matrix_list 541 | 542 | # NeRF渲染模型 543 | class NeRF_Model(nn.Module): 544 | def __init__(self, sys_param): 545 | logging.info('Creating NeRF Model...') 546 | super(NeRF_Model, self).__init__() 547 | self.sys_param = sys_param 548 | self.mode = self.sys_param["mode"] 549 | self.device = self.sys_param["device_type"] 550 | self.near = self.sys_param["near"] 551 | self.far = self.sys_param["far"] 552 | self.samples_c = self.sys_param["samples"] 553 | self.sample_scale = self.sys_param["scale"] 554 | self.samples_f = self.samples_c * self.sample_scale 555 | self.dim_sh = 3 * (self.sys_param["MLP_deg"] + 1)**2 556 | self.white_back = self.sys_param["white_back"] 557 | self.weights_pth = self.sys_param['root_weight'] 558 | self.train_img_pth = self.sys_param["demo_render_pth"] 559 | self.batch_test = self.sys_param["batch"] 560 | self.xyz_min = self.sys_param["boader_min"] 561 | self.xyz_max = self.sys_param["boader_max"] 562 | self.xyz_scope = self.xyz_max - self.xyz_min 563 | self.grid_nerf = self.sys_param["grid_nerf"] 564 | self.sigma_init = self.sys_param["sigma_init"] 565 | self.sigma_default = self.sys_param["sigma_default"] 566 | self.warmup_epoch = self.sys_param["warmup_epoch"] 567 | self.weight_thresh = self.sys_param['sample_weight_thresh'] 568 | self.render_h = self.sys_param["res_h"] 569 | self.render_w = self.sys_param["res_w"] 570 | self.z_vals_c = torch.linspace(self.near, self.far, self.samples_c, device=self.device) 571 | self.z_vals_f = torch.linspace(self.near, self.far, self.samples_f, device=self.device) 572 | self.global_step = 0 573 | self.emmbedding_xyz = SinCosEmbedding(self.sys_param) 574 | self.nerf_coarse = CorseFine_NeRF(self.sys_param, type="coarse") 575 | self.nerf_fine = CorseFine_NeRF(self.sys_param, type="fine") 576 | self.data_name = self.sys_param["data_name"] 577 | if self.mode != 0: 578 | self.nerf_ckpt_name = self.sys_param['demo_ckpt'] 579 | self.nerf_ckpt_file = torch.load(Path(self.nerf_ckpt_name), map_location = self.device) 580 | self.nerf_ckpt_coarse = self.rewrite_nerf_ckpt(self.nerf_ckpt_file, coarse=True) 581 | self.nerf_ckpt_fine = self.rewrite_nerf_ckpt(self.nerf_ckpt_file) 582 | self.nerf_coarse.load_state_dict(self.nerf_ckpt_coarse) 583 | self.nerf_fine.load_state_dict(self.nerf_ckpt_fine) 584 | logging.info("Loading weights:{}".format(self.nerf_ckpt_name)) 585 | 586 | def forward(self, *args): 587 | self.global_step += 1 588 | if self.mode == 0: 589 | rays_d, rays_o, cur_epoch, step_r = args 590 | rgb_c, rgb_f = self.render_rays_train(rays_d, rays_o, cur_epoch, step_r, only_coarse=False) 591 | return rgb_c, rgb_f 592 | else: 593 | rays_d, rays_o = args 594 | results = self.render_rays_test(rays_d, rays_o, self.nerf_coarse, self.nerf_fine) 595 | 596 | return results 597 | 598 | def render_rays_train(self, rays_d, rays_o, cur_epoch, step_r, only_coarse = False): 599 | z_vals_samples_c = self.z_vals_c.clone().expand(rays_d.shape[0], -1) 600 | delta_z_vals_init = torch.empty(rays_d.shape[0], 1, device=self.device).uniform_(0.0, (self.far - self.near)/self.samples_c) 601 | z_vals_samples_c = z_vals_samples_c + delta_z_vals_init 602 | xyz_samples_c = rays_o.unsqueeze(1) + rays_d.unsqueeze(1)*z_vals_samples_c.unsqueeze(2) 603 | # [chunk,3]/[chunk, 256] 604 | rgb_coarse, sigmas_coarse, xyz_idx, depth_coarse, _ = \ 605 | self.inference(self.nerf_coarse, 606 | self.emmbedding_xyz, 607 | step_r, 608 | xyz_samples_c, 609 | rays_d, 610 | z_vals_samples_c) 611 | if only_coarse: 612 | return rgb_coarse, None, depth_coarse 613 | with torch.no_grad(): 614 | # [self.batch, 127] 615 | deltas_coarse = z_vals_samples_c[:, 1:] - z_vals_samples_c[:, :-1] 616 | delta_inf = 1e10 * torch.ones_like(deltas_coarse[:, :1]) 617 | # [self.batch, 128] 618 | deltas_coarse = torch.cat([deltas_coarse, delta_inf], -1) 619 | weights_coarse = self.sigma2weights(deltas_coarse, sigmas_coarse) 620 | # [self.batch, 128] 621 | weights_coarse = weights_coarse.detach() 622 | # [X, 2] 623 | idx_render = torch.nonzero(weights_coarse >= min(self.weight_thresh, weights_coarse.max().item())) 624 | # [X, 5, 2] 625 | idx_render = idx_render.unsqueeze(1).expand(-1, self.sample_scale, -1) 626 | # [X, 5, 2] 627 | idx_render_fine = idx_render.clone() 628 | idx_render_fine[..., 1] = idx_render[..., 1] * self.sample_scale + (torch.arange(self.sample_scale, device=self.device)).reshape(1, self.sample_scale) 629 | idx_render_fine = idx_render_fine.reshape(-1, 2) 630 | if idx_render_fine.shape[0] > rays_d.shape[0] * 128: 631 | indices = torch.randperm(idx_render_fine.shape[0])[:rays_d.shape[0] * 128] 632 | idx_render_fine = idx_render_fine[indices] 633 | z_vals_samples_f = self.z_vals_f.clone().expand(rays_d.shape[0], -1) 634 | z_vals_samples_f = z_vals_samples_f + delta_z_vals_init 635 | xyz_samples_f = rays_o.unsqueeze(1) + rays_d.unsqueeze(1) * z_vals_samples_f.unsqueeze(2) 636 | # [self.batch,3]/[self.batch, 256] 637 | rgb_fine, sigmas_fine, xyz_idx, depth_fine, _ = \ 638 | self.inference(self.nerf_fine, 639 | self.emmbedding_xyz, 640 | step_r, 641 | xyz_samples_f, 642 | rays_d, 643 | z_vals_samples_f, 644 | idx_render=idx_render_fine, 645 | coarse=False) 646 | return rgb_coarse, rgb_fine 647 | 648 | def render_rays_test(self, rays_d, rays_o, model_coarse, model_fine): 649 | step_r = 1 650 | z_vals_samples_c = self.z_vals_c.clone().expand(rays_d.shape[0], -1) 651 | xyz_samples_c = rays_o.unsqueeze(1) + rays_d.unsqueeze(1)*z_vals_samples_c.unsqueeze(2) 652 | rgb_coarse, sigmas_coarse, _, depth_coarse, opacity_coarse = self.inference(model_coarse, 653 | self.emmbedding_xyz, 654 | step_r, 655 | xyz_samples_c, 656 | rays_d, 657 | z_vals_samples_c) 658 | deltas_samples = z_vals_samples_c[:, 1:] - z_vals_samples_c[:, :-1] 659 | delta_inf_samples = 1e10 * torch.ones_like(deltas_samples[:, :1]) 660 | # [self.batch, 128] 661 | deltas = torch.cat([deltas_samples, delta_inf_samples], -1) 662 | weights = self.sigma2weights(deltas, sigmas_coarse) 663 | idx_render = torch.nonzero(weights >= min(self.weight_thresh, weights.max().item())) 664 | idx_render = idx_render.unsqueeze(1).expand(-1, self.sample_scale, -1) 665 | idx_render_fine = idx_render.clone() 666 | idx_render_fine[..., 1] = idx_render[..., 1] * self.sample_scale + (torch.arange(self.sample_scale, device=self.device)).reshape(1, self.sample_scale) 667 | idx_render_fine = idx_render_fine.reshape(-1, 2) 668 | z_vals_samples_f = self.z_vals_f.clone().expand(rays_d.shape[0], -1) 669 | xyz_samples_f = rays_o.unsqueeze(1) + rays_d.unsqueeze(1) * z_vals_samples_f.unsqueeze(2) 670 | # [self.batch,3]/[self.batch, 256] 671 | rgb_fine, _, _, depth_fine, opacity_fine = self.inference(model_fine, 672 | self.emmbedding_xyz, 673 | step_r, 674 | xyz_samples_f, 675 | rays_d, 676 | z_vals_samples_f, 677 | idx_render=idx_render_fine, 678 | coarse=False) 679 | # rgb_final = rgb_coarse + rgb_fine 680 | return rgb_fine, depth_fine, opacity_fine 681 | 682 | def inference(self, model, embedding_xyz, step_r, xyz, rays_d, z_vals, idx_render=None, coarse=True): 683 | if coarse: 684 | sample_numb = self.samples_c 685 | else: 686 | sample_numb = self.samples_f 687 | batch_numb = rays_d.shape[0] 688 | view_dir = rays_d.unsqueeze(1).expand(-1, sample_numb, -1) 689 | if idx_render != None: 690 | view_dir = view_dir[idx_render[:, 0], idx_render[:, 1]] 691 | xyz = xyz[idx_render[:, 0], idx_render[:, 1]] 692 | out_rgb = torch.full((batch_numb, sample_numb, 3), 1.0, device=self.device) 693 | out_sigma = torch.full((batch_numb, sample_numb, 1), self.sigma_default, device=self.device) 694 | out_defaults = torch.cat([out_sigma, out_rgb], dim=2) 695 | else: 696 | xyz = xyz.reshape(-1, 3) 697 | view_dir = view_dir.reshape(-1, 3) 698 | input_encode_xyz = embedding_xyz(xyz, step_r) 699 | nerf_model_out = model(input_encode_xyz, view_dir) 700 | if idx_render != None: 701 | out_defaults[idx_render[:, 0], idx_render[:, 1]] = nerf_model_out 702 | else: 703 | # 需要修改 704 | out_defaults = nerf_model_out.reshape(batch_numb, sample_numb, 4) 705 | sigmas, rgbs = torch.split(out_defaults, (1, 3), dim=-1) 706 | sigmas = sigmas.squeeze(-1) 707 | rays_length = rays_d.norm(dim = -1, keepdim = True) 708 | deltas = z_vals[:, 1:] - z_vals[:, :-1] 709 | delta_inf = 1e10 * torch.ones_like(deltas[:, :1]) 710 | deltas = torch.cat([deltas, delta_inf], -1) 711 | dist_samples = deltas * rays_length 712 | sigma_delta = torch.nn.Softplus()(sigmas) * dist_samples 713 | alpha = 1 - torch.exp(-sigma_delta) 714 | T = torch.exp(-torch.cat([torch.zeros_like(sigma_delta[..., :1]), sigma_delta[..., :-1]], dim = 1).cumsum(dim = 1)) 715 | prob = (T * alpha)[..., None] 716 | opacity = prob.sum(dim = 1) 717 | depth_samples = z_vals.unsqueeze(-1) 718 | depth = (depth_samples * prob).sum(dim = 1) 719 | weights = self.sigma2weights(deltas, sigmas) 720 | weights_sum = weights.sum(1) 721 | rgbs_weights = weights.unsqueeze(-1)*rgbs 722 | rgb_final = torch.sum(rgbs_weights, 1) 723 | 724 | if self.white_back: 725 | rgb_final = rgb_final + 1 - weights_sum.unsqueeze(-1) 726 | 727 | return rgb_final, sigmas, xyz, depth, opacity 728 | 729 | def sigma2weights(self, deltas, sigmas): 730 | noise = torch.randn(sigmas.shape, device=self.device) 731 | sigmas = sigmas + noise 732 | # [self.batch, 128] 733 | alphas = 1-torch.exp(-deltas*torch.nn.Softplus()(sigmas)) 734 | alphas_shifted = torch.cat([torch.ones_like(alphas[:, :1]), 1-alphas+1e-10], -1) # [1, a1, a2, ...] 735 | weights = alphas * torch.cumprod(alphas_shifted, -1)[:, :-1] 736 | return weights 737 | 738 | def save_model(self, model, epoch): 739 | save_path = os.path.join(Path(self.weights_pth), Path("train")) 740 | net = "{}-EPOCH-{}-".format(self.data_name, epoch) 741 | save_dict = {'model_nerf': model.state_dict()} 742 | nowtime = time.strftime("%Y-%m-%d-%H-%M-%S.ckpt", time.localtime()) 743 | self.model_name = net + nowtime 744 | self.file_path = os.path.join(Path(save_path), Path(self.model_name)) 745 | if not os.path.exists(save_path): 746 | os.makedirs(save_path) 747 | if self.sys_param["distributed"]: 748 | if dist.get_rank() == 0: 749 | torch.save(save_dict, self.file_path) 750 | else: 751 | torch.save(save_dict, self.file_path) 752 | logging.info('\nSave model:{}'.format(self.model_name)) 753 | 754 | def valid_train(self, epoch, val_data, epoch_type): 755 | if epoch_type in ['CAM_PARAM_EPOCH']: 756 | return 0 757 | torch.cuda.empty_cache() 758 | rank = get_rank() 759 | if rank == 0: 760 | rays_d, rays_o, gt_rgbs = val_data 761 | render_ckpt = torch.load(self.file_path, map_location = self.device) 762 | logging.info("Loading model:{}".format(self.model_name)) 763 | val_nerf_coarse = CorseFine_NeRF(self.sys_param, type="coarse").to(self.device) 764 | val_nerf_fine = CorseFine_NeRF(self.sys_param, type="fine").to(self.device) 765 | nerf_ckpt_coarse = self.rewrite_nerf_ckpt(render_ckpt, coarse=True) 766 | nerf_ckpt_fine = self.rewrite_nerf_ckpt(render_ckpt, coarse=False) 767 | val_nerf_coarse.load_state_dict(nerf_ckpt_coarse) 768 | val_nerf_fine.load_state_dict(nerf_ckpt_fine) 769 | val_nerf_coarse.eval() 770 | val_nerf_fine.eval() 771 | rgb_cat = [] 772 | depth_cat = [] 773 | with torch.no_grad(): 774 | logging.info("Rendering...") 775 | for ii in range(0, rays_d.shape[0], self.batch_test): 776 | batch_rays_d = rays_d[ii:ii+self.batch_test] 777 | batch_rays_o = rays_o[ii:ii+self.batch_test] 778 | pred_rays_rgbs, pred_rays_depth, _ = self.render_rays_test(batch_rays_d, batch_rays_o, val_nerf_coarse, val_nerf_fine) 779 | rgb_cat += [pred_rays_rgbs] 780 | depth_cat += [pred_rays_depth] 781 | rgb_cat = torch.cat(rgb_cat, 0) 782 | depth_cat = torch.cat(depth_cat, 0) 783 | logging.info("Saving image...") 784 | 785 | img = rgb_cat.view(self.render_h, self.render_w, 3).cpu() 786 | dep = depth_cat.view(self.render_h, self.render_w, 1).cpu() 787 | gt = gt_rgbs.view(self.render_h, self.render_w, 3).cpu() 788 | 789 | img = img.permute(2, 0, 1) # (3, H, W) 790 | dep = dep.permute(2, 0, 1) 791 | gt = gt.permute(2, 0, 1) # (3, H, W) 792 | 793 | img_path = os.path.join(Path(self.train_img_pth), Path(self.data_name), Path("epoch_"+ str(epoch) + ".png")) 794 | gt_path = os.path.join(Path(self.train_img_pth), Path(self.data_name), Path("epoch_"+ str(epoch) + "_gt.png")) 795 | dep_path = os.path.join(Path(self.train_img_pth), Path(self.data_name), Path("epoch_"+ str(epoch) + "_depth.png")) 796 | os.makedirs(os.path.dirname(img_path), exist_ok=True) 797 | 798 | transforms.ToPILImage()(img).convert("RGB").save(img_path) 799 | transforms.ToPILImage()(gt).convert("RGB").save(gt_path) 800 | transforms.ToPILImage()(dep).convert("L").save(dep_path) 801 | 802 | psnr_score = self.psnr_score(img, gt) 803 | lpips_score = self.lpips_score(img, gt) 804 | ssim_score = self.ssim_score(img, gt) 805 | 806 | logging.info("PSNR:{}".format(psnr_score)) 807 | logging.info("LPIPS:{}".format(lpips_score)) 808 | logging.info("SSIM:{}".format(ssim_score)) 809 | 810 | if self.sys_param["distributed"]: 811 | dist.barrier() 812 | 813 | torch.cuda.empty_cache() 814 | 815 | def rewrite_nerf_ckpt(self, nerf_ckpt_dict, coarse=False): 816 | state_dict = nerf_ckpt_dict["model_nerf"] 817 | new_state_dict = state_dict.copy() 818 | if coarse: 819 | net_name = "nerf_coarse" 820 | else: 821 | net_name = "nerf_fine" 822 | for key in state_dict: 823 | new_key_name = "" 824 | if net_name in key.split("."): 825 | del(new_state_dict[key]) 826 | idx_hash_nerf = key.split(".").index(net_name) 827 | new_key_name_list = key.split(".")[(idx_hash_nerf+1):] 828 | for idx, ele in enumerate(new_key_name_list): 829 | new_key_name += ele 830 | if idx != (len(new_key_name_list) - 1): 831 | new_key_name += "." 832 | 833 | new_state_dict[new_key_name] = state_dict[key] 834 | else: 835 | del(new_state_dict[key]) 836 | 837 | return new_state_dict 838 | 839 | def psnr_score(self, image_pred, image_gt, valid_mask=None, reduction='mean'): 840 | def mse(image_pred, image_gt, valid_mask=None, reduction='mean'): 841 | value = (image_pred-image_gt)**2 842 | if valid_mask is not None: 843 | value = value[valid_mask] 844 | if reduction == 'mean': 845 | return torch.mean(value) 846 | return value 847 | return -10*torch.log10(mse(image_pred, image_gt, valid_mask, reduction)) 848 | 849 | def lpips_score(self, image_pred, image_gt): 850 | lpips_loss = lpips.LPIPS(net="alex") 851 | lpips_value = lpips_loss(image_pred*2-1, image_gt*2-1).item() 852 | return lpips_value 853 | def ssim_score(self, image_pred, image_gt): 854 | image_pred = image_pred.unsqueeze(0) 855 | image_gt = image_gt.unsqueeze(0) 856 | ssim = pytorch_ssim.ssim(image_pred, image_gt).item() 857 | return ssim 858 | 859 | def query_sigma(self, xyz): 860 | ijk_coarse = ((xyz - self.xyz_min) / self.xyz_scope * self.grid_nerf).long().clamp(min=0, max=self.grid_nerf-1) 861 | sigmas = self.sigma_voxels[ijk_coarse[:, 0], ijk_coarse[:, 1], ijk_coarse[:, 2]] 862 | 863 | return sigmas 864 | def update_sigma(self, xyz, sigma, beta): 865 | ijk_coarse = ((xyz - self.xyz_min) / self.xyz_scope * self.grid_nerf).long().clamp(min=0, max=self.grid_nerf-1) 866 | self.sigma_voxels[ijk_coarse[:, 0], ijk_coarse[:, 1], ijk_coarse[:, 2]] \ 867 | = (1 - beta)*self.sigma_voxels[ijk_coarse[:, 0], ijk_coarse[:, 1], ijk_coarse[:, 2]] + beta*sigma 868 | -------------------------------------------------------------------------------- /model/net_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from .net_utils import eval_sh 5 | 6 | class SinCosEmbedding(nn.Module): 7 | def __init__(self, sys_params): 8 | super(SinCosEmbedding, self).__init__() 9 | self.sys_param = sys_params 10 | self.device = self.sys_param["device_type"] 11 | self.n_freqs = self.sys_param["emb_freqs_xyz"] 12 | self.barf_mode = self.sys_param["barf_mask"] 13 | self.barf_start = self.sys_param["barf_start"] 14 | self.barf_end = self.sys_param["barf_end"] 15 | self.in_channels = 3 16 | self.funcs = [torch.sin, torch.cos] 17 | self.out_channels = self.in_channels*(len(self.funcs)*self.n_freqs+1) 18 | self.freq_bands = 2**torch.linspace(0, self.n_freqs-1, self.n_freqs, device=self.device) 19 | 20 | def forward(self, x, step_r): 21 | shape = x.shape 22 | spectrum = x[...,None]*self.freq_bands 23 | sin,cos = spectrum.sin(), spectrum.cos() 24 | input_enc = torch.stack([sin, cos],dim=-2) 25 | x_enc = input_enc.view(shape[0],-1) 26 | if self.barf_mode: 27 | alpha = (step_r - self.barf_start)/(self.barf_end - self.barf_start)*self.n_freqs 28 | k = torch.arange(self.n_freqs, dtype=torch.float32, device=self.device) 29 | weight = (1-(alpha-k).clamp_(min=0,max=1).mul_(math.pi).cos_())/2 30 | shape = x_enc.shape 31 | x_enc = x_enc.view(-1, self.n_freqs)*weight 32 | x_enc = x_enc.view(shape[0], -1) 33 | x_enc = torch.cat([x, x_enc],dim=-1) 34 | 35 | return x_enc 36 | 37 | class CorseFine_NeRF(nn.Module): 38 | def __init__(self, sys_params, type="coarse"): 39 | super(CorseFine_NeRF, self).__init__() 40 | self.in_channels_xyz = 3*(2*sys_params["emb_freqs_xyz"] + 1) #63 41 | self.deg = sys_params["MLP_deg"] 42 | if type == "coarse": 43 | self.depth = sys_params["coarse_MLP_depth"] 44 | self.width = sys_params["coarse_MLP_width"] 45 | self.skips = sys_params["coarse_MLP_skip"] 46 | elif type == "fine": 47 | self.depth = sys_params["fine_MLP_depth"] 48 | self.width = sys_params["fine_MLP_width"] 49 | self.skips = sys_params["fine_MLP_skip"] 50 | 51 | for i in range(self.depth): 52 | if i == 0: 53 | layer = nn.Linear(self.in_channels_xyz, self.width) 54 | elif i in self.skips: 55 | layer = nn.Linear(self.width + self.in_channels_xyz, self.width) 56 | else: 57 | layer = nn.Linear(self.width, self.width) 58 | layer = nn.Sequential(layer, nn.ReLU(True)) 59 | setattr(self, f"xyz_encoding_{i+1}", layer) 60 | self.sigma = nn.Sequential(nn.Linear(self.width, self.width), 61 | nn.ReLU(True), 62 | nn.Linear(self.width, 1)) 63 | self.sh = nn.Sequential(nn.Linear(self.width, self.width), 64 | nn.ReLU(True), 65 | nn.Linear(self.width, 3 * (self.deg + 1)**2)) 66 | 67 | def forward(self, x, dirs): 68 | xyz_ = x 69 | for i in range(self.depth): 70 | if i in self.skips: 71 | xyz_ = torch.cat([x, xyz_], -1) 72 | xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_) 73 | sigma = self.sigma(xyz_) 74 | sh = self.sh(xyz_) 75 | rgb = eval_sh(deg=self.deg, sh=sh.reshape(-1, 3, (self.deg + 1)**2), dirs=dirs) # sh: [..., C, (deg + 1) ** 2] 76 | rgb = torch.sigmoid(rgb) 77 | out = torch.cat([sigma, rgb], -1) 78 | 79 | return out -------------------------------------------------------------------------------- /model/net_utils.py: -------------------------------------------------------------------------------- 1 | from matplotlib import cm 2 | import torch 3 | import math 4 | 5 | import torch.distributed as dist 6 | 7 | from torch.optim.optimizer import Optimizer 8 | 9 | 10 | class RAdam(Optimizer): 11 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 12 | if not 0.0 <= lr: 13 | raise ValueError("Invalid learning rate: {}".format(lr)) 14 | if not 0.0 <= eps: 15 | raise ValueError("Invalid epsilon value: {}".format(eps)) 16 | if not 0.0 <= betas[0] < 1.0: 17 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 18 | if not 0.0 <= betas[1] < 1.0: 19 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 20 | 21 | self.degenerated_to_sgd = degenerated_to_sgd 22 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): 23 | for param in params: 24 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): 25 | param['buffer'] = [[None, None, None] for _ in range(10)] 26 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) 27 | super(RAdam, self).__init__(params, defaults) 28 | 29 | def __setstate__(self, state): 30 | super(RAdam, self).__setstate__(state) 31 | 32 | def step(self, closure=None): 33 | 34 | loss = None 35 | if closure is not None: 36 | loss = closure() 37 | 38 | for group in self.param_groups: 39 | 40 | for p in group['params']: 41 | if p.grad is None: 42 | continue 43 | grad = p.grad.data.float() 44 | if grad.is_sparse: 45 | raise RuntimeError('RAdam does not support sparse gradients') 46 | 47 | p_data_fp32 = p.data.float() 48 | 49 | state = self.state[p] 50 | 51 | if len(state) == 0: 52 | state['step'] = 0 53 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 54 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 55 | else: 56 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 57 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 58 | 59 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 60 | beta1, beta2 = group['betas'] 61 | 62 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 63 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 64 | 65 | # exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 66 | # exp_avg.mul_(beta1).add_(1 - beta1, grad) 67 | 68 | state['step'] += 1 69 | buffered = group['buffer'][int(state['step'] % 10)] 70 | if state['step'] == buffered[0]: 71 | N_sma, step_size = buffered[1], buffered[2] 72 | else: 73 | buffered[0] = state['step'] 74 | beta2_t = beta2 ** state['step'] 75 | N_sma_max = 2 / (1 - beta2) - 1 76 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 77 | buffered[1] = N_sma 78 | 79 | # more conservative since it's an approximated value 80 | if N_sma >= 5: 81 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 82 | elif self.degenerated_to_sgd: 83 | step_size = 1.0 / (1 - beta1 ** state['step']) 84 | else: 85 | step_size = -1 86 | buffered[2] = step_size 87 | 88 | # more conservative since it's an approximated value 89 | if N_sma >= 5: 90 | if group['weight_decay'] != 0: 91 | p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr']) 92 | denom = exp_avg_sq.sqrt().add_(group['eps']) 93 | p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr']) 94 | p.data.copy_(p_data_fp32) 95 | elif step_size > 0: 96 | if group['weight_decay'] != 0: 97 | p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr']) 98 | p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr']) 99 | p.data.copy_(p_data_fp32) 100 | 101 | return loss 102 | 103 | def eval_sh(deg, sh, dirs): 104 | 105 | C0 = 0.28209479177387814 106 | C1 = 0.4886025119029199 107 | C2 = [ 108 | 1.0925484305920792, 109 | -1.0925484305920792, 110 | 0.31539156525252005, 111 | -1.0925484305920792, 112 | 0.5462742152960396 113 | ] 114 | C3 = [ 115 | -0.5900435899266435, 116 | 2.890611442640554, 117 | -0.4570457994644658, 118 | 0.3731763325901154, 119 | -0.4570457994644658, 120 | 1.445305721320277, 121 | -0.5900435899266435 122 | ] 123 | C4 = [ 124 | 2.5033429417967046, 125 | -1.7701307697799304, 126 | 0.9461746957575601, 127 | -0.6690465435572892, 128 | 0.10578554691520431, 129 | -0.6690465435572892, 130 | 0.47308734787878004, 131 | -1.7701307697799304, 132 | 0.6258357354491761, 133 | ] 134 | 135 | """ 136 | Evaluate spherical harmonics at unit directions 137 | using hardcoded SH polynomials. 138 | Works with torch/np/jnp. 139 | ... Can be 0 or more batch dimensions. 140 | 141 | Args: 142 | deg: int SH deg. Currently, 0-3 supported 143 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 144 | dirs: jnp.ndarray unit directions [..., 3] 145 | 146 | Returns: 147 | [..., C] 148 | """ 149 | 150 | assert deg <= 4 and deg >= 0 151 | assert (deg + 1) ** 2 == sh.shape[-1] 152 | C = sh.shape[-2] 153 | 154 | result = C0 * sh[..., 0] 155 | if deg > 0: 156 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 157 | result = (result - 158 | C1 * y * sh[..., 1] + 159 | C1 * z * sh[..., 2] - 160 | C1 * x * sh[..., 3]) 161 | if deg > 1: 162 | xx, yy, zz = x * x, y * y, z * z 163 | xy, yz, xz = x * y, y * z, x * z 164 | result = (result + 165 | C2[0] * xy * sh[..., 4] + 166 | C2[1] * yz * sh[..., 5] + 167 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 168 | C2[3] * xz * sh[..., 7] + 169 | C2[4] * (xx - yy) * sh[..., 8]) 170 | 171 | if deg > 2: 172 | result = (result + 173 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 174 | C3[1] * xy * z * sh[..., 10] + 175 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 176 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 177 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 178 | C3[5] * z * (xx - yy) * sh[..., 14] + 179 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 180 | if deg > 3: 181 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 182 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 183 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 184 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 185 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 186 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 187 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 188 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 189 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 190 | 191 | return result 192 | 193 | def get_rank(): 194 | if not is_dist_avail_and_initialized(): 195 | return 0 196 | return dist.get_rank() 197 | 198 | def is_dist_avail_and_initialized(): 199 | if not dist.is_available(): 200 | return False 201 | if not dist.is_initialized(): 202 | return False 203 | return True 204 | 205 | def apply_colormap(image, cmap="viridis"): 206 | colormap = cm.get_cmap(cmap) 207 | colormap = torch.tensor(colormap.colors).to(image.device) # type: ignore 208 | image_long = (image * 255).long() 209 | 210 | image_long[image_long<63]=63 211 | image_long[image_long>255]=255 212 | 213 | image_long_min = torch.min(image_long) 214 | image_long_max = torch.max(image_long) 215 | assert image_long_min >= 0, f"the min value is {image_long_min}" 216 | assert image_long_max <= 255, f"the max value is {image_long_max}" 217 | return colormap[image_long[..., 0]] 218 | 219 | def apply_depth_colormap( 220 | depth, 221 | accumulation= None, 222 | near_plane= None, 223 | far_plane= None, 224 | cmap="turbo",): 225 | 226 | depth = torch.clip(depth, 0 , 1) 227 | colored_image = apply_colormap(depth, cmap=cmap) 228 | 229 | if accumulation is not None: 230 | colored_image = colored_image * accumulation + (1 - accumulation) 231 | 232 | return colored_image -------------------------------------------------------------------------------- /requirements.yaml: -------------------------------------------------------------------------------- 1 | name: mc-env 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2023.08.22=h06a4308_0 8 | - ld_impl_linux-64=2.38=h1181459_1 9 | - libffi=3.3=he6710b0_2 10 | - libgcc-ng=11.2.0=h1234567_1 11 | - libgomp=11.2.0=h1234567_1 12 | - libstdcxx-ng=11.2.0=h1234567_1 13 | - ncurses=6.4=h6a678d5_0 14 | - openssl=1.1.1w=h7f8727e_0 15 | - pip=23.3.1=py39h06a4308_0 16 | - python=3.9.13=haa1d7c7_2 17 | - readline=8.2=h5eee18b_0 18 | - setuptools=68.0.0=py39h06a4308_0 19 | - sqlite=3.41.2=h5eee18b_0 20 | - tk=8.6.12=h1ccaba5_0 21 | - tzdata=2023c=h04d1e81_0 22 | - wheel=0.41.2=py39h06a4308_0 23 | - xz=5.4.5=h5eee18b_0 24 | - zlib=1.2.13=h5eee18b_0 25 | - pip: 26 | - absl-py==2.0.0 27 | - apriltag==0.0.16 28 | - cachetools==5.3.2 29 | - certifi==2023.11.17 30 | - charset-normalizer==3.3.2 31 | - cmake==3.27.9 32 | - contourpy==1.2.0 33 | - cycler==0.12.1 34 | - filelock==3.13.1 35 | - fonttools==4.46.0 36 | - google-auth==2.25.1 37 | - google-auth-oauthlib==1.1.0 38 | - grpcio==1.59.3 39 | - idna==3.6 40 | - importlib-metadata==7.0.0 41 | - importlib-resources==6.1.1 42 | - jinja2==3.1.2 43 | - kiwisolver==1.4.5 44 | - lit==17.0.6 45 | - lpips==0.1.4 46 | - markdown==3.5.1 47 | - markupsafe==2.1.3 48 | - matplotlib==3.8.2 49 | - mpmath==1.3.0 50 | - networkx==3.2.1 51 | - numpy==1.26.2 52 | - nvidia-cublas-cu11==11.10.3.66 53 | - nvidia-cuda-cupti-cu11==11.7.101 54 | - nvidia-cuda-nvrtc-cu11==11.7.99 55 | - nvidia-cuda-runtime-cu11==11.7.99 56 | - nvidia-cudnn-cu11==8.5.0.96 57 | - nvidia-cufft-cu11==10.9.0.58 58 | - nvidia-curand-cu11==10.2.10.91 59 | - nvidia-cusolver-cu11==11.4.0.1 60 | - nvidia-cusparse-cu11==11.7.4.91 61 | - nvidia-nccl-cu11==2.14.3 62 | - nvidia-nvtx-cu11==11.7.91 63 | - oauthlib==3.2.2 64 | - opencv-python==4.8.1.78 65 | - packaging==23.2 66 | - pillow==10.1.0 67 | - prettytable==3.9.0 68 | - protobuf==4.23.4 69 | - pyasn1==0.5.1 70 | - pyasn1-modules==0.3.0 71 | - pyparsing==3.1.1 72 | - python-dateutil==2.8.2 73 | - pyyaml==6.0.1 74 | - requests==2.31.0 75 | - requests-oauthlib==1.3.1 76 | - rsa==4.9 77 | - scipy==1.11.4 78 | - six==1.16.0 79 | - sympy==1.12 80 | - tensorboard==2.15.1 81 | - tensorboard-data-server==0.7.2 82 | - tensorboardx==2.6.2.2 83 | - torch==2.0.1 84 | - torchaudio==2.0.2 85 | - torchvision==0.15.2 86 | - tqdm==4.66.1 87 | - triton==2.0.0 88 | - typing-extensions==4.8.0 89 | - urllib3==2.1.0 90 | - wcwidth==0.2.12 91 | - werkzeug==3.0.1 92 | - zipp==3.17.0 93 | -------------------------------------------------------------------------------- /synthetic_dataset_code/Array.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | import math 3 | import json 4 | import os 5 | import random 6 | import cv2 7 | 8 | from pathlib import Path 9 | import numpy as np 10 | 11 | from mathutils import * 12 | 13 | class Array_Dataset(): 14 | def __init__(self, obj_name, seed): 15 | self.object_name = obj_name 16 | self.fov_min = 40 17 | self.fov_max = 80 18 | self.res_x = 800 19 | self.res_y = 800 20 | self.root_path = "F:\\NERF_Dataset" 21 | self.array_x = 3 22 | self.array_y = 3 23 | self.radius_array = 4 24 | self.theta_array = 45 25 | self.numb_x = 10 26 | self.numb_y = 10 27 | self.numb_train_val = self.numb_x*self.numb_y 28 | self.numb_test = 200 29 | random.seed(seed) 30 | self.cam_list = self.init_camera() 31 | self.train_fov, self.val_fov, self.test_fov = self.get_cam_fov_array() 32 | self.loc_train, self.loc_val, self.loc_test,\ 33 | self.rot_train, self.rot_val, self.rot_test = self.get_cam_pose_array() 34 | self.render_set() 35 | 36 | def render_images(self): 37 | collect_objects = bpy.data.collections 38 | for collects in collect_objects: 39 | collects.hide_render = True 40 | collect_objects["Object"].hide_render = False 41 | collect_objects[self.object_name.split("_")[-1]].hide_render = False 42 | self.render_process() 43 | self.cam_clear() 44 | 45 | def render_process(self): 46 | train_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("train")) 47 | val_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("val")) 48 | test_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("test")) 49 | if not os.path.exists(train_path): 50 | os.makedirs(train_path) 51 | if not os.path.exists(val_path): 52 | os.makedirs(val_path) 53 | if not os.path.exists(test_path): 54 | os.makedirs(test_path) 55 | for idx, cam_info in enumerate(zip(self.train_fov, self.cam_list[:self.numb_train_val])): 56 | fov, cam = cam_info 57 | bpy.context.scene.camera = cam 58 | loc, rot = self.loc_train[idx], self.rot_train[idx] 59 | cam.rotation_mode = 'XYZ' 60 | cam.location = loc 61 | cam.rotation_euler = rot 62 | cam.data.angle = self.angle2rad(fov) 63 | cur_pose = cam.matrix_world 64 | save_path = os.path.join(Path(train_path), Path("r_{}".format(idx))) 65 | bpy.context.scene.render.filepath = save_path 66 | bpy.ops.render.render(write_still=True) 67 | frame_data = {'file_path': "./train/r_{}".format(idx), 68 | "camera_angle_x": self.angle2rad(fov), 69 | 'transform_matrix': self.listify_matrix(cur_pose)} 70 | self.out_data_train['frames'].append(frame_data) 71 | with open(self.json_train_path, 'w') as out_file: 72 | json.dump(self.out_data_train, out_file, indent=4) 73 | for idx, cam_info in enumerate(zip(self.val_fov, self.cam_list[:self.numb_train_val])): 74 | fov, cam = cam_info 75 | bpy.context.scene.camera = cam 76 | loc, rot = self.loc_val[idx], self.rot_val[idx] 77 | cam.rotation_mode = 'XYZ' 78 | cam.location = loc 79 | cam.rotation_euler = rot 80 | cam.data.angle = self.angle2rad(fov) 81 | cur_pose = cam.matrix_world 82 | save_path = os.path.join(Path(val_path), Path("r_{}".format(idx))) 83 | bpy.context.scene.render.filepath = save_path 84 | bpy.ops.render.render(write_still=True) 85 | frame_data = {'file_path': "./val/r_{}".format(idx), 86 | "camera_angle_x": self.angle2rad(fov), 87 | 'transform_matrix': self.listify_matrix(cur_pose)} 88 | self.out_data_val['frames'].append(frame_data) 89 | with open(self.json_val_path, 'w') as out_file: 90 | json.dump(self.out_data_val, out_file, indent=4) 91 | for idx, cam_info in enumerate(zip(self.test_fov, self.cam_list)): 92 | fov, cam = cam_info 93 | bpy.context.scene.camera = cam 94 | loc, rot = self.loc_test[idx], self.rot_test[idx] 95 | cam.rotation_mode = 'XYZ' 96 | cam.location = loc 97 | cam.rotation_euler = rot 98 | cam.data.angle = self.angle2rad(fov) 99 | cur_pose = cam.matrix_world 100 | save_path = os.path.join(Path(test_path), Path("r_{}".format(idx))) 101 | bpy.context.scene.render.filepath = save_path 102 | bpy.ops.render.render(write_still=True) 103 | frame_data = {'file_path': "./test/r_{}".format(idx), 104 | "camera_angle_x": self.angle2rad(fov), 105 | 'transform_matrix': self.listify_matrix(cur_pose)} 106 | self.out_data_test['frames'].append(frame_data) 107 | with open(self.json_test_path, 'w') as out_file: 108 | json.dump(self.out_data_test, out_file, indent=4) 109 | 110 | def angle2rad(self, angle): 111 | return angle*math.pi/180 112 | 113 | def rad2angle(self, rad): 114 | return rad*180/math.pi 115 | 116 | def listify_matrix(self, matrix): 117 | matrix_list = [] 118 | for row in matrix: 119 | matrix_list.append(list(row)) 120 | return matrix_list 121 | 122 | def Rot_Z(self, angle_theta): 123 | rad_theta = self.angle2rad(angle_theta) 124 | rot_M_z = np.array([[np.cos(rad_theta), np.sin(rad_theta), 0], 125 | [-np.sin(rad_theta), np.cos(rad_theta), 0], 126 | [0, 0, 1]]) 127 | return rot_M_z 128 | 129 | def Rot_X(self, angle_phi): 130 | rad_phi = -self.angle2rad(angle_phi) 131 | rot_M_x = np.array([[1, 0, 0], 132 | [0, np.cos(rad_phi), np.sin(rad_phi)], 133 | [0, -np.sin(rad_phi), np.cos(rad_phi)]]) 134 | return rot_M_x 135 | 136 | def init_camera(self): 137 | cam_list = [] 138 | for idx in range(max(self.numb_train_val, self.numb_test)): 139 | bpy.ops.object.camera_add(enter_editmode=False, 140 | align='VIEW', 141 | location=(0, 0, 0), 142 | rotation=(0, 0, 0), 143 | scale=(1, 1, 1)) 144 | cur_cam = bpy.context.selected_objects[0] 145 | cur_cam.name = "Array_{}".format(idx) 146 | cur_cam.data.type = 'PERSP' 147 | cur_cam.data.lens_unit = 'FOV' 148 | cam_list += [cur_cam] 149 | 150 | return cam_list 151 | 152 | def get_cam_fov_array(self): 153 | fov_angle_train = [] 154 | fov_angle_val = [] 155 | for i in range(self.numb_train_val): 156 | cur_fov_train = random.randint(self.fov_min, self.fov_max) 157 | while cur_fov_train in [fov_angle_train]: 158 | cur_fov_train = random.randint(self.fov_min, self.fov_max) 159 | fov_angle_train += [cur_fov_train] 160 | fov_angle_val = fov_angle_train 161 | fov_angle_test = list(np.linspace(self.fov_max, self.fov_min, self.numb_test//2)) 162 | fov_angle_test_inv = fov_angle_test.copy() 163 | fov_angle_test_inv.sort() 164 | fov_angle_test = fov_angle_test + fov_angle_test_inv 165 | assert len(fov_angle_test) == self.numb_test, "Length Error for test fov !!!" 166 | 167 | return fov_angle_train, fov_angle_val, fov_angle_test 168 | 169 | def get_cam_pose_array(self): 170 | loc_train, rot_train = self.get_array_pose_train() 171 | loc_val, rot_val = self.get_array_pose_val() 172 | loc_test, rot_test = self.get_array_pose_test() 173 | 174 | return loc_train, loc_val, loc_test, rot_train, rot_val, rot_test 175 | 176 | def get_array_pose_train(self): 177 | rot_M_x_90 = self.Rot_X(90) 178 | rot_M_z = self.Rot_Z(self.theta_array) 179 | x_range = np.linspace(-self.array_x/2 , self.array_x/2, self.numb_x) 180 | y_range = np.linspace(-self.array_y/2 , self.array_y/2, self.numb_y) 181 | loc_x, loc_y = np.meshgrid(x_range, y_range) 182 | loc_z = -np.ones_like(loc_x)*self.radius_array 183 | cord_array = np.stack([loc_x, loc_y, loc_z], -1) 184 | cord_array = cord_array.reshape(-1, 3) 185 | cord_array = np.matmul(cord_array, rot_M_x_90) 186 | cord_array = np.matmul(cord_array, rot_M_z) 187 | 188 | loc_train = [cord for cord in cord_array] 189 | rot_train = [self.get_rot_from_loc(loc) for loc in loc_train] 190 | 191 | return loc_train, rot_train 192 | 193 | def get_array_pose_val(self): 194 | cord_list = [] 195 | rot_M_x_90 = self.Rot_X(90) 196 | rot_M_z = self.Rot_Z(self.theta_array) 197 | for i in range(self.numb_train_val): 198 | cur_x = random.uniform(-self.array_x/2 , self.array_x/2) 199 | cur_y = random.uniform(-self.array_y/2 , self.array_y/2) 200 | cur_z = -self.radius_array 201 | cord_list += [np.array([cur_x, cur_y, cur_z])] 202 | cord_array = np.stack(cord_list, 0) 203 | cord_array = np.matmul(cord_array, rot_M_x_90) 204 | cord_array = np.matmul(cord_array, rot_M_z) 205 | loc_val = [cord for cord in cord_array] 206 | rot_val = [self.get_rot_from_loc(loc) for loc in loc_val] 207 | 208 | return loc_val, rot_val 209 | 210 | def get_array_pose_test(self): 211 | cord_list = [] 212 | rot_M_x_90 = self.Rot_X(90) 213 | rot_M_z = self.Rot_Z(self.theta_array) 214 | min_r = min(self.array_x, self.array_y) 215 | r_face = np.abs(np.linspace(min_r, -min_r, self.numb_test)) 216 | rot_face = np.linspace(-360, 360, self.numb_test) 217 | for i in range(self.numb_test): 218 | start_pos = np.array([0.0, r_face[i], -self.radius_array]) 219 | roll_mat = self.Rot_Z(rot_face[i]) 220 | next_pose = np.matmul(start_pos, roll_mat) 221 | cord_list += [next_pose] 222 | 223 | cord_array = np.stack(cord_list, 0) 224 | cord_array = np.matmul(cord_array, rot_M_x_90) 225 | cord_array = np.matmul(cord_array, rot_M_z) 226 | 227 | loc_test = [cord for cord in cord_array] 228 | rot_test = [self.get_rot_from_loc(loc) for loc in loc_test] 229 | 230 | return loc_test, rot_test 231 | 232 | def render_set(self): 233 | self.json_train_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("transforms_train.json")) 234 | self.json_val_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("transforms_val.json")) 235 | self.json_test_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("transforms_test.json")) 236 | self.json_coord_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("transforms_coord.json")) 237 | self.json_calib_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("transforms_calib.json")) 238 | bpy.context.scene.render.image_settings.file_format = 'PNG' 239 | bpy.context.scene.render.film_transparent = True 240 | bpy.context.scene.render.resolution_x = self.res_x 241 | bpy.context.scene.render.resolution_y = self.res_y 242 | self.out_data_train = {"frames":[]} 243 | self.out_data_val = {"frames":[]} 244 | self.out_data_test = {"frames":[]} 245 | self.out_data_coord = {"frames":[]} 246 | self.out_data_calib = {"frames":[]} 247 | 248 | def get_rot_from_loc(self, loc): 249 | loc_np = np.array(loc) 250 | loc_r = np.linalg.norm(loc_np[:2], ord=2) 251 | rot_phi = np.arctan(loc_np[2]/loc_r) 252 | loc_vect_xy = loc_np[:2]/loc_r 253 | std_vect = np.array([0, -1]) 254 | cos_theta = np.dot(loc_vect_xy, std_vect) / (np.linalg.norm(loc_vect_xy)*np.linalg.norm(std_vect)) 255 | sin_theta = np.cross(loc_vect_xy, std_vect) / (np.linalg.norm(loc_vect_xy)*np.linalg.norm(std_vect)) 256 | if sin_theta > 0: 257 | rot_theta = 2*np.pi - np.arccos(cos_theta) 258 | else: 259 | rot_theta = np.arccos(cos_theta) 260 | rot = np.array([self.angle2rad(90), 0, 0]) + np.array([-rot_phi, 0, 0]) 261 | rot = rot + np.array([0, 0, rot_theta]) 262 | 263 | return rot 264 | 265 | def cam_clear(self): 266 | for cam in self.cam_list: 267 | bpy.data.objects.remove(cam) 268 | 269 | def apriltag_more_than_two(self, detector, save_path): 270 | save_path = save_path + ".png" 271 | cur_img = cv2.imread(save_path) 272 | gray_img = cv2.cvtColor(cur_img, cv2.COLOR_BGR2GRAY) 273 | gray_img = cv2.normalize(gray_img, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8UC1) 274 | _, ids, _ = detector.detectMarkers(gray_img) 275 | if (ids is not None) and (len(ids) > 2): 276 | return True 277 | else: 278 | return False 279 | 280 | def render_calibration_images(self): 281 | arucoDict = cv2.aruco.getPredefinedDictionary(cv2.aruco.DICT_APRILTAG_36h11) 282 | arucoParams = cv2.aruco.DetectorParameters() 283 | detector = cv2.aruco.ArucoDetector(arucoDict, arucoParams) 284 | coord_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("coord")) 285 | calib_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("calib")) 286 | if not os.path.exists(coord_path): 287 | os.makedirs(coord_path) 288 | if not os.path.exists(calib_path): 289 | os.makedirs(calib_path) 290 | collect_objects = bpy.data.collections 291 | for collects in collect_objects: 292 | collects.hide_render = True 293 | collect_objects["Calibration Object"].hide_render = False 294 | object_cube = bpy.data.objects["Cube"] 295 | object_cube.rotation_euler = [0, 0, 0] 296 | for idx, cam_info in enumerate(zip(self.train_fov, self.cam_list[:self.numb_train_val])): 297 | fov, cam = cam_info 298 | bpy.context.scene.camera = cam 299 | loc, rot = self.loc_train[idx], self.rot_train[idx] 300 | cam.rotation_mode = 'XYZ' 301 | cam.location = loc 302 | cam.rotation_euler = rot 303 | cam.data.angle = self.angle2rad(fov) 304 | cur_pose = cam.matrix_world 305 | save_path = os.path.join(Path(coord_path), Path("r_{}".format(idx))) 306 | bpy.context.scene.render.filepath = save_path 307 | bpy.ops.render.render(write_still=True) 308 | frame_data = {'file_path': "./coord/r_{}".format(idx), 309 | "camera_angle_x": self.angle2rad(fov), 310 | 'transform_matrix': self.listify_matrix(cur_pose)} 311 | self.out_data_coord['frames'].append(frame_data) 312 | with open(self.json_coord_path, 'w') as out_file: 313 | json.dump(self.out_data_coord, out_file, indent=4) 314 | 315 | for idx, cam_info in enumerate(zip(self.train_fov, self.cam_list[:self.numb_train_val])): 316 | fov, cam = cam_info 317 | bpy.context.scene.camera = cam 318 | loc, rot = self.loc_train[idx], self.rot_train[idx] 319 | cam.rotation_mode = 'XYZ' 320 | cam.location = loc 321 | cam.rotation_euler = rot 322 | cam.data.angle = self.angle2rad(fov) 323 | save_path = os.path.join(Path(calib_path), Path("r_{}".format(idx))) 324 | bpy.context.scene.render.filepath = save_path 325 | bpy.ops.render.render(write_still=True) 326 | while not self.apriltag_more_than_two(detector, save_path): 327 | object_cube.rotation_euler[0] = random.uniform(0, 2*math.pi) 328 | object_cube.rotation_euler[1] = random.uniform(0, 2*math.pi) 329 | object_cube.rotation_euler[2] = random.uniform(0, 2*math.pi) 330 | bpy.ops.render.render(write_still=True) 331 | 332 | frame_data = {'file_path': "./calib/r_{}".format(idx), 333 | "camera_angle_x": self.angle2rad(fov)} 334 | self.out_data_calib['frames'].append(frame_data) 335 | with open(self.json_calib_path, 'w') as out_file: 336 | json.dump(self.out_data_calib, out_file, indent=4) 337 | collect_objects["Calibration Object"].hide_render = True 338 | 339 | if __name__ == "__main__": 340 | seed_dict = {"Lego":0, 341 | "Gate":1, 342 | "Materials":2, 343 | "Ficus":3, 344 | "Computer":4, 345 | "Snowtruck":5, 346 | "Statue":6, 347 | "Train":7} 348 | cur_data = "Materials" 349 | dataset = Array_Dataset(obj_name = "Array_{}".format(cur_data), seed=seed_dict[cur_data]) 350 | dataset.render_images() 351 | dataset.render_calibration_images() 352 | dataset.cam_clear() -------------------------------------------------------------------------------- /synthetic_dataset_code/Ball.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | import math 3 | import json 4 | import os 5 | import random 6 | import cv2 7 | 8 | from pathlib import Path 9 | import numpy as np 10 | 11 | from mathutils import * 12 | 13 | class Ball_Dataset(): 14 | def __init__(self, obj_name, seed): 15 | self.object_name = obj_name 16 | self.fov_min = 40 17 | self.fov_max = 80 18 | self.numb_train_val = 110 19 | self.numb_test = 200 20 | self.res_x = 800 21 | self.res_y = 800 22 | self.root_path = "F:\\NERF_Dataset" 23 | self.radius = 3 24 | self.start_pos = np.array([0, -self.radius, 0]) 25 | self.start_rot = np.array([self.angle2rad(90), 0, 0]) 26 | random.seed(seed) 27 | self.cam_list = self.init_camera() 28 | self.train_fov, self.val_fov, self.test_fov = self.get_cam_fov_ball() 29 | self.loc_train, self.loc_val, self.loc_test,\ 30 | self.rot_train, self.rot_val, self.rot_test = self.get_cam_pose_ball() 31 | self.render_set() 32 | 33 | 34 | def render_images(self): 35 | collect_objects = bpy.data.collections 36 | for collects in collect_objects: 37 | collects.hide_render = True 38 | collect_objects["Object"].hide_render = False 39 | collect_objects[self.object_name.split("_")[-1]].hide_render = False 40 | self.render_process() 41 | 42 | def render_process(self): 43 | train_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("train")) 44 | val_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("val")) 45 | test_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("test")) 46 | if not os.path.exists(train_path): 47 | os.makedirs(train_path) 48 | if not os.path.exists(val_path): 49 | os.makedirs(val_path) 50 | if not os.path.exists(test_path): 51 | os.makedirs(test_path) 52 | for idx, cam_info in enumerate(zip(self.train_fov, self.cam_list[:self.numb_train_val])): 53 | fov, cam = cam_info 54 | bpy.context.scene.camera = cam 55 | loc, rot = self.loc_train[idx], self.rot_train[idx] 56 | cam.rotation_mode = 'XYZ' 57 | cam.location = loc 58 | cam.rotation_euler = rot 59 | cam.data.angle = self.angle2rad(fov) 60 | cur_pose = cam.matrix_world 61 | save_path = os.path.join(Path(train_path), Path("r_{}".format(idx))) 62 | bpy.context.scene.render.filepath = save_path 63 | bpy.ops.render.render(write_still=True) 64 | frame_data = {'file_path': "./train/r_{}".format(idx), 65 | "camera_angle_x": self.angle2rad(fov), 66 | 'transform_matrix': self.listify_matrix(cur_pose)} 67 | self.out_data_train['frames'].append(frame_data) 68 | with open(self.json_train_path, 'w') as out_file: 69 | json.dump(self.out_data_train, out_file, indent=4) 70 | for idx, cam_info in enumerate(zip(self.val_fov, self.cam_list[:self.numb_train_val])): 71 | fov, cam = cam_info 72 | bpy.context.scene.camera = cam 73 | loc, rot = self.loc_val[idx], self.rot_val[idx] 74 | cam.rotation_mode = 'XYZ' 75 | cam.location = loc 76 | cam.rotation_euler = rot 77 | cam.data.angle = self.angle2rad(fov) 78 | cur_pose = cam.matrix_world 79 | save_path = os.path.join(Path(val_path), Path("r_{}".format(idx))) 80 | bpy.context.scene.render.filepath = save_path 81 | bpy.ops.render.render(write_still=True) 82 | frame_data = {'file_path': "./val/r_{}".format(idx), 83 | "camera_angle_x": self.angle2rad(fov), 84 | 'transform_matrix': self.listify_matrix(cur_pose)} 85 | self.out_data_val['frames'].append(frame_data) 86 | with open(self.json_val_path, 'w') as out_file: 87 | json.dump(self.out_data_val, out_file, indent=4) 88 | for idx, cam_info in enumerate(zip(self.test_fov, self.cam_list)): 89 | fov, cam = cam_info 90 | bpy.context.scene.camera = cam 91 | loc, rot = self.loc_test[idx], self.rot_test[idx] 92 | cam.rotation_mode = 'XYZ' 93 | cam.location = loc 94 | cam.rotation_euler = rot 95 | cam.data.angle = self.angle2rad(fov) 96 | cur_pose = cam.matrix_world 97 | save_path = os.path.join(Path(test_path), Path("r_{}".format(idx))) 98 | bpy.context.scene.render.filepath = save_path 99 | bpy.ops.render.render(write_still=True) 100 | frame_data = {'file_path': "./test/r_{}".format(idx), 101 | "camera_angle_x": self.angle2rad(fov), 102 | 'transform_matrix': self.listify_matrix(cur_pose)} 103 | self.out_data_test['frames'].append(frame_data) 104 | with open(self.json_test_path, 'w') as out_file: 105 | json.dump(self.out_data_test, out_file, indent=4) 106 | 107 | def angle2rad(self, angle): 108 | return angle*math.pi/180 109 | 110 | def listify_matrix(self, matrix): 111 | matrix_list = [] 112 | for row in matrix: 113 | matrix_list.append(list(row)) 114 | return matrix_list 115 | 116 | def Rot_Z(self, angle_theta): 117 | rad_theta = self.angle2rad(angle_theta) 118 | rot_M_z = np.array([[np.cos(rad_theta), np.sin(rad_theta), 0], 119 | [-np.sin(rad_theta), np.cos(rad_theta), 0], 120 | [0, 0, 1]]) 121 | return rot_M_z 122 | 123 | def Rot_X(self, angle_phi): 124 | rad_phi = -self.angle2rad(angle_phi) 125 | rot_M_x = np.array([[1, 0, 0], 126 | [0, np.cos(rad_phi), np.sin(rad_phi)], 127 | [0, -np.sin(rad_phi), np.cos(rad_phi)]]) 128 | return rot_M_x 129 | 130 | def init_camera(self): 131 | cam_list = [] 132 | for idx in range(max(self.numb_train_val, self.numb_test)): 133 | bpy.ops.object.camera_add(enter_editmode=False, 134 | align='VIEW', 135 | location=(0, 0, 0), 136 | rotation=(0, 0, 0), 137 | scale=(1, 1, 1)) 138 | cur_cam = bpy.context.selected_objects[0] 139 | cur_cam.name = "Ball_{}".format(idx) 140 | cur_cam.data.type = 'PERSP' 141 | cur_cam.data.lens_unit = 'FOV' 142 | cam_list += [cur_cam] 143 | 144 | return cam_list 145 | 146 | def get_cam_fov_ball(self): 147 | fov_angle_train = [] 148 | fov_angle_val = [] 149 | for i in range(self.numb_train_val): 150 | cur_fov_train = random.randint(self.fov_min, self.fov_max) 151 | while cur_fov_train in [fov_angle_train]: 152 | cur_fov_train = random.randint(self.fov_min, self.fov_max) 153 | fov_angle_train += [cur_fov_train] 154 | fov_angle_val = fov_angle_train 155 | fov_angle_test = list(np.linspace(self.fov_max, self.fov_min, self.numb_test//2)) 156 | fov_angle_test_inv = fov_angle_test.copy() 157 | fov_angle_test_inv.sort() 158 | fov_angle_test = fov_angle_test + fov_angle_test_inv 159 | assert len(fov_angle_test) == self.numb_test, "Length Error for test fov !!!" 160 | 161 | return fov_angle_train, fov_angle_val, fov_angle_test 162 | 163 | def get_cam_pose_ball(self): 164 | loc_train, loc_val, loc_test = [], [], [] 165 | rot_train, rot_val, rot_test = [], [], [] 166 | theta_train = [] 167 | phi_train = [] 168 | theta_range = list(np.linspace(0, 360, 12, endpoint=False)) 169 | phi_range = list(np.linspace(-80, 80, 9)) 170 | phi_end = [-90, 90] 171 | for phi in phi_range: 172 | roll_mat = self.Rot_X(phi) 173 | for theta in theta_range: 174 | theta_train += [theta] 175 | phi_train += [phi] 176 | pitch_mat = self.Rot_Z(theta) 177 | next_pose = np.matmul(self.start_pos, roll_mat) 178 | next_pose = np.matmul(next_pose, pitch_mat) 179 | next_rot = self.start_rot + np.array([-self.angle2rad(phi), 0, 0]) 180 | next_rot = next_rot + np.array([0, 0, self.angle2rad(theta)]) 181 | loc_train += [next_pose] 182 | rot_train += [next_rot] 183 | for phi in phi_end: 184 | roll_mat = self.Rot_X(phi) 185 | theta_train += [0] # theta没转 186 | phi_train += [phi] 187 | next_pose = np.matmul(self.start_pos, roll_mat) 188 | next_rot = self.start_rot + np.array([-self.angle2rad(phi), 0, 0]) 189 | loc_train += [next_pose] 190 | rot_train += [next_rot] 191 | 192 | for i in range(self.numb_train_val): 193 | theta = random.randint(0, 360) 194 | phi = random.randint(0, 90) 195 | while theta in theta_train: 196 | theta = random.randint(0, 360) 197 | while phi in phi_train: 198 | phi = random.randint(0, 90) 199 | pitch_mat = self.Rot_Z(theta) 200 | roll_mat = self.Rot_X(phi) 201 | next_pose = np.matmul(self.start_pos, roll_mat) 202 | next_pose = np.matmul(next_pose, pitch_mat) 203 | next_rot = self.start_rot + np.array([-self.angle2rad(phi), 0, 0]) 204 | next_rot = next_rot + np.array([0, 0, self.angle2rad(theta)]) 205 | loc_val += [next_pose] 206 | rot_val += [next_rot] 207 | theta = list(np.linspace(360, -360, self.numb_test)) 208 | phi = list(np.linspace(90, -90, self.numb_test//2)) 209 | phi_inv = phi.copy() 210 | phi_inv.sort() 211 | phi = phi + phi_inv 212 | 213 | assert len(theta) == self.numb_test, "Length Error for test pose !!!" 214 | for i in range(self.numb_test): 215 | pitch_mat = self.Rot_Z(theta[i]) 216 | roll_mat = self.Rot_X(phi[i]) 217 | next_pose = np.matmul(self.start_pos, roll_mat) 218 | next_pose = np.matmul(next_pose, pitch_mat) 219 | next_rot = self.start_rot + np.array([-self.angle2rad(phi[i]), 0, 0]) 220 | next_rot = next_rot + np.array([0, 0, self.angle2rad(theta[i])]) 221 | loc_test += [next_pose] 222 | rot_test += [next_rot] 223 | 224 | return loc_train, loc_val, loc_test, rot_train, rot_val, rot_test 225 | 226 | def render_set(self): 227 | self.json_train_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("transforms_train.json")) 228 | self.json_val_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("transforms_val.json")) 229 | self.json_test_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("transforms_test.json")) 230 | self.json_coord_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("transforms_coord.json")) 231 | self.json_calib_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("transforms_calib.json")) 232 | bpy.context.scene.render.image_settings.file_format = 'PNG' 233 | bpy.context.scene.render.film_transparent = True 234 | bpy.context.scene.render.resolution_x = self.res_x 235 | bpy.context.scene.render.resolution_y = self.res_y 236 | self.out_data_train = {"frames":[]} 237 | self.out_data_val = {"frames":[]} 238 | self.out_data_test = {"frames":[]} 239 | self.out_data_coord = {"frames":[]} 240 | self.out_data_calib = {"frames":[]} 241 | 242 | def cam_clear(self): 243 | for cam in self.cam_list: 244 | bpy.data.objects.remove(cam) 245 | 246 | def apriltag_more_than_two(self, detector, save_path): 247 | save_path = save_path + ".png" 248 | cur_img = cv2.imread(save_path) 249 | gray_img = cv2.cvtColor(cur_img, cv2.COLOR_BGR2GRAY) 250 | gray_img = cv2.normalize(gray_img, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8UC1) 251 | _, ids, _ = detector.detectMarkers(gray_img) 252 | if (ids is not None) and (len(ids) > 2): 253 | return True 254 | else: 255 | return False 256 | 257 | def render_calibration_images(self): 258 | arucoDict = cv2.aruco.getPredefinedDictionary(cv2.aruco.DICT_APRILTAG_36h11) 259 | arucoParams = cv2.aruco.DetectorParameters() 260 | detector = cv2.aruco.ArucoDetector(arucoDict, arucoParams) 261 | coord_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("coord")) 262 | calib_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("calib")) 263 | if not os.path.exists(coord_path): 264 | os.makedirs(coord_path) 265 | if not os.path.exists(calib_path): 266 | os.makedirs(calib_path) 267 | collect_objects = bpy.data.collections 268 | for collects in collect_objects: 269 | collects.hide_render = True 270 | collect_objects["Calibration Object"].hide_render = False 271 | object_cube = bpy.data.objects["Cube"] 272 | object_cube.rotation_euler = [0, 0, 0] 273 | for idx, cam_info in enumerate(zip(self.train_fov, self.cam_list[:self.numb_train_val])): 274 | fov, cam = cam_info 275 | bpy.context.scene.camera = cam 276 | loc, rot = self.loc_train[idx], self.rot_train[idx] 277 | cam.rotation_mode = 'XYZ' 278 | cam.location = loc 279 | cam.rotation_euler = rot 280 | cam.data.angle = self.angle2rad(fov) 281 | cur_pose = cam.matrix_world 282 | save_path = os.path.join(Path(coord_path), Path("r_{}".format(idx))) 283 | bpy.context.scene.render.filepath = save_path 284 | bpy.ops.render.render(write_still=True) 285 | frame_data = {'file_path': "./coord/r_{}".format(idx), 286 | "camera_angle_x": self.angle2rad(fov), 287 | 'transform_matrix': self.listify_matrix(cur_pose)} 288 | self.out_data_coord['frames'].append(frame_data) 289 | with open(self.json_coord_path, 'w') as out_file: 290 | json.dump(self.out_data_coord, out_file, indent=4) 291 | 292 | for idx, cam_info in enumerate(zip(self.train_fov, self.cam_list[:self.numb_train_val])): 293 | fov, cam = cam_info 294 | bpy.context.scene.camera = cam 295 | loc, rot = self.loc_train[idx], self.rot_train[idx] 296 | cam.rotation_mode = 'XYZ' 297 | cam.location = loc 298 | cam.rotation_euler = rot 299 | cam.data.angle = self.angle2rad(fov) 300 | save_path = os.path.join(Path(calib_path), Path("r_{}".format(idx))) 301 | bpy.context.scene.render.filepath = save_path 302 | bpy.ops.render.render(write_still=True) 303 | while not self.apriltag_more_than_two(detector, save_path): 304 | object_cube.rotation_euler[0] = random.uniform(0, 2*math.pi) 305 | object_cube.rotation_euler[1] = random.uniform(0, 2*math.pi) 306 | object_cube.rotation_euler[2] = random.uniform(0, 2*math.pi) 307 | bpy.ops.render.render(write_still=True) 308 | 309 | frame_data = {'file_path': "./calib/r_{}".format(idx), 310 | "camera_angle_x": self.angle2rad(fov)} 311 | self.out_data_calib['frames'].append(frame_data) 312 | with open(self.json_calib_path, 'w') as out_file: 313 | json.dump(self.out_data_calib, out_file, indent=4) 314 | collect_objects["Calibration Object"].hide_render = True 315 | 316 | if __name__ == "__main__": 317 | seed_dict = {"Lego":0, 318 | "Gate":1, 319 | "Materials":2, 320 | "Ficus":3, 321 | "Computer":4, 322 | "Snowtruck":5, 323 | "Statue":6, 324 | "Train":7} 325 | cur_data = "Gate" 326 | dataset = Ball_Dataset(obj_name = "Ball_{}".format(cur_data), seed=seed_dict[cur_data]) 327 | dataset.render_images() 328 | dataset.render_calibration_images() 329 | dataset.cam_clear() -------------------------------------------------------------------------------- /synthetic_dataset_code/HalfBall.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | import math 3 | import json 4 | import os 5 | import random 6 | import cv2 7 | 8 | from pathlib import Path 9 | import numpy as np 10 | 11 | from mathutils import * 12 | 13 | class Half_Ball_Dataset(): 14 | def __init__(self, obj_name, seed): 15 | self.object_name = obj_name 16 | self.fov_min = 40 17 | self.fov_max = 80 18 | self.numb_train_val = 100 19 | self.numb_test = 200 20 | self.res_x = 800 21 | self.res_y = 800 22 | self.root_path = "F:\\NERF_Dataset" 23 | self.radius = 3 24 | self.start_pos = np.array([0, -self.radius, 0]) 25 | self.start_rot = np.array([self.angle2rad(90), 0, 0]) 26 | random.seed(seed) 27 | self.cam_list = self.init_camera() 28 | self.train_fov, self.val_fov, self.test_fov = self.get_cam_fov_half_ball() 29 | self.loc_train, self.loc_val, self.loc_test,\ 30 | self.rot_train, self.rot_val, self.rot_test = self.get_cam_pose_half_ball() 31 | self.render_set() 32 | 33 | def render_images(self): 34 | collect_objects = bpy.data.collections 35 | for collects in collect_objects: 36 | collects.hide_render = True 37 | collect_objects["Object"].hide_render = False 38 | collect_objects[self.object_name.split("_")[-1]].hide_render = False 39 | self.render_process() 40 | 41 | def render_process(self): 42 | train_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("train")) 43 | val_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("val")) 44 | test_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("test")) 45 | if not os.path.exists(train_path): 46 | os.makedirs(train_path) 47 | if not os.path.exists(val_path): 48 | os.makedirs(val_path) 49 | if not os.path.exists(test_path): 50 | os.makedirs(test_path) 51 | for idx, cam_info in enumerate(zip(self.train_fov, self.cam_list[:self.numb_train_val])): 52 | fov, cam = cam_info 53 | bpy.context.scene.camera = cam 54 | loc, rot = self.loc_train[idx], self.rot_train[idx] 55 | cam.rotation_mode = 'XYZ' 56 | cam.location = loc 57 | cam.rotation_euler = rot 58 | cam.data.angle = self.angle2rad(fov) 59 | cur_pose = cam.matrix_world 60 | save_path = os.path.join(Path(train_path), Path("r_{}".format(idx))) 61 | bpy.context.scene.render.filepath = save_path 62 | bpy.ops.render.render(write_still=True) 63 | frame_data = {'file_path': "./train/r_{}".format(idx), 64 | "camera_angle_x": self.angle2rad(fov), 65 | 'transform_matrix': self.listify_matrix(cur_pose)} 66 | self.out_data_train['frames'].append(frame_data) 67 | with open(self.json_train_path, 'w') as out_file: 68 | json.dump(self.out_data_train, out_file, indent=4) 69 | for idx, cam_info in enumerate(zip(self.val_fov, self.cam_list[:self.numb_train_val])): 70 | fov, cam = cam_info 71 | bpy.context.scene.camera = cam 72 | loc, rot = self.loc_val[idx], self.rot_val[idx] 73 | cam.rotation_mode = 'XYZ' 74 | cam.location = loc 75 | cam.rotation_euler = rot 76 | cam.data.angle = self.angle2rad(fov) 77 | cur_pose = cam.matrix_world 78 | save_path = os.path.join(Path(val_path), Path("r_{}".format(idx))) 79 | bpy.context.scene.render.filepath = save_path 80 | bpy.ops.render.render(write_still=True) 81 | frame_data = {'file_path': "./val/r_{}".format(idx), 82 | "camera_angle_x": self.angle2rad(fov), 83 | 'transform_matrix': self.listify_matrix(cur_pose)} 84 | self.out_data_val['frames'].append(frame_data) 85 | with open(self.json_val_path, 'w') as out_file: 86 | json.dump(self.out_data_val, out_file, indent=4) 87 | for idx, cam_info in enumerate(zip(self.test_fov, self.cam_list)): 88 | fov, cam = cam_info 89 | bpy.context.scene.camera = cam 90 | loc, rot = self.loc_test[idx], self.rot_test[idx] 91 | cam.rotation_mode = 'XYZ' 92 | cam.location = loc 93 | cam.rotation_euler = rot 94 | cam.data.angle = self.angle2rad(fov) 95 | cur_pose = cam.matrix_world 96 | save_path = os.path.join(Path(test_path), Path("r_{}".format(idx))) 97 | bpy.context.scene.render.filepath = save_path 98 | bpy.ops.render.render(write_still=True) 99 | frame_data = {'file_path': "./test/r_{}".format(idx), 100 | "camera_angle_x": self.angle2rad(fov), 101 | 'transform_matrix': self.listify_matrix(cur_pose)} 102 | self.out_data_test['frames'].append(frame_data) 103 | with open(self.json_test_path, 'w') as out_file: 104 | json.dump(self.out_data_test, out_file, indent=4) 105 | 106 | def angle2rad(self, angle): 107 | return angle*math.pi/180 108 | 109 | def listify_matrix(self, matrix): 110 | matrix_list = [] 111 | for row in matrix: 112 | matrix_list.append(list(row)) 113 | return matrix_list 114 | 115 | def Rot_Z(self, angle_theta): 116 | rad_theta = self.angle2rad(angle_theta) 117 | rot_M_z = np.array([[np.cos(rad_theta), np.sin(rad_theta), 0], 118 | [-np.sin(rad_theta), np.cos(rad_theta), 0], 119 | [0, 0, 1]]) 120 | return rot_M_z 121 | 122 | def Rot_X(self, angle_phi): 123 | rad_phi = -self.angle2rad(angle_phi) 124 | rot_M_x = np.array([[1, 0, 0], 125 | [0, np.cos(rad_phi), np.sin(rad_phi)], 126 | [0, -np.sin(rad_phi), np.cos(rad_phi)]]) 127 | return rot_M_x 128 | 129 | def init_camera(self): 130 | cam_list = [] 131 | for idx in range(max(self.numb_train_val, self.numb_test)): 132 | bpy.ops.object.camera_add(enter_editmode=False, 133 | align='VIEW', 134 | location=(0, 0, 0), 135 | rotation=(0, 0, 0), 136 | scale=(1, 1, 1)) 137 | cur_cam = bpy.context.selected_objects[0] 138 | cur_cam.name = "Half_Ball_{}".format(idx) 139 | cur_cam.data.type = 'PERSP' 140 | cur_cam.data.lens_unit = 'FOV' 141 | cam_list += [cur_cam] 142 | 143 | return cam_list 144 | 145 | def get_cam_fov_half_ball(self): 146 | fov_angle_train = [] 147 | fov_angle_val = [] 148 | for i in range(self.numb_train_val): 149 | cur_fov_train = random.randint(self.fov_min, self.fov_max) 150 | while cur_fov_train in [fov_angle_train]: 151 | cur_fov_train = random.randint(self.fov_min, self.fov_max) 152 | fov_angle_train += [cur_fov_train] 153 | fov_angle_val = fov_angle_train 154 | fov_angle_test = list(np.linspace(self.fov_max, self.fov_min, self.numb_test//2)) 155 | fov_angle_test_inv = fov_angle_test.copy() 156 | fov_angle_test_inv.sort() 157 | fov_angle_test = fov_angle_test + fov_angle_test_inv 158 | assert len(fov_angle_test) == self.numb_test, "Length Error for test fov !!!" 159 | 160 | return fov_angle_train, fov_angle_val, fov_angle_test 161 | 162 | def get_cam_pose_half_ball(self): 163 | loc_train, loc_val, loc_test = [], [], [] 164 | rot_train, rot_val, rot_test = [], [], [] 165 | theta_train = [] 166 | phi_train = [] 167 | for i in range(self.numb_train_val): 168 | theta = random.randint(0, 360) 169 | phi = random.randint(0, 90) 170 | theta_train += [theta] 171 | phi_train += [phi] 172 | pitch_mat = self.Rot_Z(theta) 173 | roll_mat = self.Rot_X(phi) 174 | next_pose = np.matmul(self.start_pos, roll_mat) 175 | next_pose = np.matmul(next_pose, pitch_mat) 176 | next_rot = self.start_rot + np.array([-self.angle2rad(phi), 0, 0]) 177 | next_rot = next_rot + np.array([0, 0, self.angle2rad(theta)]) 178 | loc_train += [next_pose] 179 | rot_train += [next_rot] 180 | for i in range(self.numb_train_val): 181 | theta = random.randint(0, 360) 182 | phi = random.randint(0, 90) 183 | while theta in theta_train: 184 | theta = random.randint(0, 360) 185 | while phi in phi_train: 186 | phi = random.randint(0, 90) 187 | pitch_mat = self.Rot_Z(theta) 188 | roll_mat = self.Rot_X(phi) 189 | next_pose = np.matmul(self.start_pos, roll_mat) 190 | next_pose = np.matmul(next_pose, pitch_mat) 191 | next_rot = self.start_rot + np.array([-self.angle2rad(phi), 0, 0]) 192 | next_rot = next_rot + np.array([0, 0, self.angle2rad(theta)]) 193 | loc_val += [next_pose] 194 | rot_val += [next_rot] 195 | theta = list(np.linspace(360, -360, self.numb_test)) 196 | phi = list(np.linspace(90, 0, self.numb_test//2)) 197 | phi_inv = phi.copy() 198 | phi_inv.sort() 199 | phi = phi + phi_inv 200 | 201 | assert len(theta) == self.numb_test, "Length Error for test pose !!!" 202 | for i in range(self.numb_test): 203 | pitch_mat = self.Rot_Z(theta[i]) 204 | roll_mat = self.Rot_X(phi[i]) 205 | next_pose = np.matmul(self.start_pos, roll_mat) 206 | next_pose = np.matmul(next_pose, pitch_mat) 207 | next_rot = self.start_rot + np.array([-self.angle2rad(phi[i]), 0, 0]) 208 | next_rot = next_rot + np.array([0, 0, self.angle2rad(theta[i])]) 209 | loc_test += [next_pose] 210 | rot_test += [next_rot] 211 | 212 | return loc_train, loc_val, loc_test, rot_train, rot_val, rot_test 213 | 214 | def render_set(self): 215 | self.json_train_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("transforms_train.json")) 216 | self.json_val_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("transforms_val.json")) 217 | self.json_test_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("transforms_test.json")) 218 | self.json_coord_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("transforms_coord.json")) 219 | self.json_calib_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("transforms_calib.json")) 220 | bpy.context.scene.render.image_settings.file_format = 'PNG' 221 | bpy.context.scene.render.film_transparent = True 222 | bpy.context.scene.render.resolution_x = self.res_x 223 | bpy.context.scene.render.resolution_y = self.res_y 224 | self.out_data_train = {"frames":[]} 225 | self.out_data_val = {"frames":[]} 226 | self.out_data_test = {"frames":[]} 227 | self.out_data_coord = {"frames":[]} 228 | self.out_data_calib = {"frames":[]} 229 | 230 | def cam_clear(self): 231 | for cam in self.cam_list: 232 | bpy.data.objects.remove(cam) 233 | 234 | def apriltag_more_than_two(self, detector, save_path): 235 | save_path = save_path + ".png" 236 | cur_img = cv2.imread(save_path) 237 | gray_img = cv2.cvtColor(cur_img, cv2.COLOR_BGR2GRAY) 238 | gray_img = cv2.normalize(gray_img, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8UC1) 239 | _, ids, _ = detector.detectMarkers(gray_img) 240 | if (ids is not None) and (len(ids) > 2): 241 | return True 242 | else: 243 | return False 244 | 245 | def render_calibration_images(self, img_id=None): 246 | arucoDict = cv2.aruco.getPredefinedDictionary(cv2.aruco.DICT_APRILTAG_36h11) 247 | arucoParams = cv2.aruco.DetectorParameters() 248 | detector = cv2.aruco.ArucoDetector(arucoDict, arucoParams) 249 | coord_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("coord")) 250 | calib_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("calib")) 251 | if not os.path.exists(coord_path): 252 | os.makedirs(coord_path) 253 | if not os.path.exists(calib_path): 254 | os.makedirs(calib_path) 255 | collect_objects = bpy.data.collections 256 | for collects in collect_objects: 257 | collects.hide_render = True 258 | collect_objects["Calibration Object"].hide_render = False 259 | object_cube = bpy.data.objects["Cube"] 260 | object_cube.rotation_euler = [0, 0, 0] 261 | for idx, cam_info in enumerate(zip(self.train_fov, self.cam_list[:self.numb_train_val])): 262 | fov, cam = cam_info 263 | bpy.context.scene.camera = cam 264 | loc, rot = self.loc_train[idx], self.rot_train[idx] 265 | cam.rotation_mode = 'XYZ' 266 | cam.location = loc 267 | cam.rotation_euler = rot 268 | cam.data.angle = self.angle2rad(fov) 269 | cur_pose = cam.matrix_world 270 | save_path = os.path.join(Path(coord_path), Path("r_{}".format(idx))) 271 | bpy.context.scene.render.filepath = save_path 272 | bpy.ops.render.render(write_still=True) 273 | frame_data = {'file_path': "./coord/r_{}".format(idx), 274 | "camera_angle_x": self.angle2rad(fov), 275 | 'transform_matrix': self.listify_matrix(cur_pose)} 276 | self.out_data_coord['frames'].append(frame_data) 277 | 278 | if img_id == None: 279 | for idx, cam_info in enumerate(zip(self.train_fov, self.cam_list[:self.numb_train_val])): 280 | fov, cam = cam_info 281 | bpy.context.scene.camera = cam 282 | loc, rot = self.loc_train[idx], self.rot_train[idx] 283 | cam.rotation_mode = 'XYZ' 284 | cam.location = loc 285 | cam.rotation_euler = rot 286 | cam.data.angle = self.angle2rad(fov) 287 | save_path = os.path.join(Path(calib_path), Path("r_{}".format(idx))) 288 | bpy.context.scene.render.filepath = save_path 289 | bpy.ops.render.render(write_still=True) 290 | while not self.apriltag_more_than_two(detector, save_path): 291 | object_cube.rotation_euler[0] = random.uniform(0, 2*math.pi) 292 | object_cube.rotation_euler[1] = random.uniform(0, 2*math.pi) 293 | object_cube.rotation_euler[2] = random.uniform(0, 2*math.pi) 294 | bpy.ops.render.render(write_still=True) 295 | 296 | frame_data = {'file_path': "./calib/r_{}".format(idx), 297 | "camera_angle_x": self.angle2rad(fov)} 298 | self.out_data_calib['frames'].append(frame_data) 299 | else: 300 | fov = self.train_fov[img_id] 301 | cam = self.cam_list[:self.numb_train_val][img_id] 302 | bpy.context.scene.camera = cam 303 | loc, rot = self.loc_train[img_id], self.rot_train[img_id] 304 | cam.rotation_mode = 'XYZ' 305 | cam.location = loc 306 | cam.rotation_euler = rot 307 | cam.data.angle = self.angle2rad(fov) 308 | cur_pose = cam.matrix_world 309 | save_path = os.path.join(Path(calib_path), Path("r_{}".format(img_id))) 310 | bpy.context.scene.render.filepath = save_path 311 | bpy.ops.render.render(write_still=True) 312 | while not self.apriltag_more_than_two(detector, save_path): 313 | object_cube.rotation_euler[0] = random.uniform(0, 2*math.pi) 314 | object_cube.rotation_euler[1] = random.uniform(0, 2*math.pi) 315 | object_cube.rotation_euler[2] = random.uniform(0, 2*math.pi) 316 | bpy.ops.render.render(write_still=True) 317 | frame_data = {'file_path': "./calib/r_{}".format(img_id), 318 | "camera_angle_x": self.angle2rad(fov)} 319 | self.out_data_calib['frames'].append(frame_data) 320 | 321 | with open(self.json_calib_path, 'w') as out_file: 322 | json.dump(self.out_data_calib, out_file, indent=4) 323 | collect_objects["Calibration Object"].hide_render = True 324 | 325 | if __name__ == "__main__": 326 | seed_dict = {"Lego":0, 327 | "Gate":1, 328 | "Materials":2, 329 | "Ficus":3, 330 | "Computer":4, 331 | "Snowtruck":5, 332 | "Statue":6, 333 | "Train":7} 334 | cur_data = "Train" 335 | dataset = Half_Ball_Dataset(obj_name = "HalfBall_{}".format(cur_data), seed=seed_dict[cur_data]) 336 | dataset.render_images() 337 | dataset.render_calibration_images() 338 | dataset.cam_clear() -------------------------------------------------------------------------------- /synthetic_dataset_code/Room.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | import math 3 | import json 4 | import os 5 | import random 6 | import cv2 7 | 8 | from pathlib import Path 9 | import numpy as np 10 | 11 | from mathutils import * 12 | 13 | class Room_Dataset(): 14 | def __init__(self, obj_name, seed): 15 | self.object_name = obj_name 16 | self.fov_min = 40 17 | self.fov_max = 80 18 | self.numb_train_val = 88 19 | self.numb_test = 200 20 | self.res_x = 800 21 | self.res_y = 800 22 | self.root_path = "F:\\NERF_Dataset" 23 | self.room_x = 6 24 | self.room_y = 4 25 | self.room_z = 3 26 | self.theta_room = 15 27 | self.round_numb = 7 28 | self.radius = min(self.room_x, self.room_y, self.room_z) 29 | self.start_pos = np.array([0, -self.radius, 0]) 30 | self.start_rot = np.array([self.angle2rad(90), 0, 0]) 31 | random.seed(seed) 32 | self.cam_list = self.init_camera() 33 | self.train_fov, self.val_fov, self.test_fov = self.get_cam_fov_room() 34 | self.loc_train, self.loc_val, self.loc_test,\ 35 | self.rot_train, self.rot_val, self.rot_test = self.get_cam_pose_room() 36 | self.render_set() 37 | 38 | def render_images(self): 39 | collect_objects = bpy.data.collections 40 | for collects in collect_objects: 41 | collects.hide_render = True 42 | collect_objects["Object"].hide_render = False 43 | collect_objects[self.object_name.split("_")[-1]].hide_render = False 44 | self.render_process() 45 | 46 | 47 | def render_process(self): 48 | train_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("train")) 49 | val_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("val")) 50 | test_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("test")) 51 | if not os.path.exists(train_path): 52 | os.makedirs(train_path) 53 | if not os.path.exists(val_path): 54 | os.makedirs(val_path) 55 | if not os.path.exists(test_path): 56 | os.makedirs(test_path) 57 | for idx, cam_info in enumerate(zip(self.train_fov, self.cam_list[:self.numb_train_val])): 58 | fov, cam = cam_info 59 | bpy.context.scene.camera = cam 60 | loc, rot = self.loc_train[idx], self.rot_train[idx] 61 | cam.rotation_mode = 'XYZ' 62 | cam.location = loc 63 | cam.rotation_euler = rot 64 | cam.data.angle = self.angle2rad(fov) 65 | cur_pose = cam.matrix_world 66 | save_path = os.path.join(Path(train_path), Path("r_{}".format(idx))) 67 | bpy.context.scene.render.filepath = save_path 68 | bpy.ops.render.render(write_still=True) 69 | frame_data = {'file_path': "./train/r_{}".format(idx), 70 | "camera_angle_x": self.angle2rad(fov), 71 | 'transform_matrix': self.listify_matrix(cur_pose)} 72 | self.out_data_train['frames'].append(frame_data) 73 | with open(self.json_train_path, 'w') as out_file: 74 | json.dump(self.out_data_train, out_file, indent=4) 75 | for idx, cam_info in enumerate(zip(self.val_fov, self.cam_list[:self.numb_train_val])): 76 | fov, cam = cam_info 77 | bpy.context.scene.camera = cam 78 | loc, rot = self.loc_val[idx], self.rot_val[idx] 79 | cam.rotation_mode = 'XYZ' 80 | cam.location = loc 81 | cam.rotation_euler = rot 82 | cam.data.angle = self.angle2rad(fov) 83 | cur_pose = cam.matrix_world 84 | save_path = os.path.join(Path(val_path), Path("r_{}".format(idx))) 85 | bpy.context.scene.render.filepath = save_path 86 | bpy.ops.render.render(write_still=True) 87 | frame_data = {'file_path': "./val/r_{}".format(idx), 88 | "camera_angle_x": self.angle2rad(fov), 89 | 'transform_matrix': self.listify_matrix(cur_pose)} 90 | self.out_data_val['frames'].append(frame_data) 91 | with open(self.json_val_path, 'w') as out_file: 92 | json.dump(self.out_data_val, out_file, indent=4) 93 | for idx, cam_info in enumerate(zip(self.test_fov, self.cam_list)): 94 | fov, cam = cam_info 95 | bpy.context.scene.camera = cam 96 | loc, rot = self.loc_test[idx], self.rot_test[idx] 97 | cam.rotation_mode = 'XYZ' 98 | cam.location = loc 99 | cam.rotation_euler = rot 100 | cam.data.angle = self.angle2rad(fov) 101 | cur_pose = cam.matrix_world 102 | save_path = os.path.join(Path(test_path), Path("r_{}".format(idx))) 103 | bpy.context.scene.render.filepath = save_path 104 | bpy.ops.render.render(write_still=True) 105 | frame_data = {'file_path': "./test/r_{}".format(idx), 106 | "camera_angle_x": self.angle2rad(fov), 107 | 'transform_matrix': self.listify_matrix(cur_pose)} 108 | self.out_data_test['frames'].append(frame_data) 109 | with open(self.json_test_path, 'w') as out_file: 110 | json.dump(self.out_data_test, out_file, indent=4) 111 | 112 | def angle2rad(self, angle): 113 | return angle*math.pi/180 114 | 115 | def rad2angle(self, rad): 116 | return rad*180/math.pi 117 | 118 | def listify_matrix(self, matrix): 119 | matrix_list = [] 120 | for row in matrix: 121 | matrix_list.append(list(row)) 122 | return matrix_list 123 | 124 | def Rot_Z(self, angle_theta): 125 | rad_theta = self.angle2rad(angle_theta) 126 | rot_M_z = np.array([[np.cos(rad_theta), np.sin(rad_theta), 0], 127 | [-np.sin(rad_theta), np.cos(rad_theta), 0], 128 | [0, 0, 1]]) 129 | return rot_M_z 130 | 131 | def Rot_X(self, angle_phi): 132 | rad_phi = -self.angle2rad(angle_phi) 133 | rot_M_x = np.array([[1, 0, 0], 134 | [0, np.cos(rad_phi), np.sin(rad_phi)], 135 | [0, -np.sin(rad_phi), np.cos(rad_phi)]]) 136 | return rot_M_x 137 | 138 | def init_camera(self): 139 | cam_list = [] 140 | for idx in range(max(self.numb_train_val, self.numb_test)): 141 | bpy.ops.object.camera_add(enter_editmode=False, 142 | align='VIEW', 143 | location=(0, 0, 0), 144 | rotation=(0, 0, 0), 145 | scale=(1, 1, 1)) 146 | cur_cam = bpy.context.selected_objects[0] 147 | cur_cam.name = "Room_{}".format(idx) 148 | cur_cam.data.type = 'PERSP' 149 | cur_cam.data.lens_unit = 'FOV' 150 | cam_list += [cur_cam] 151 | 152 | return cam_list 153 | 154 | def get_cam_fov_room(self): 155 | fov_angle_train = [] 156 | fov_angle_val = [] 157 | for i in range(self.numb_train_val): 158 | cur_fov_train = random.randint(self.fov_min, self.fov_max) 159 | while cur_fov_train in [fov_angle_train]: 160 | cur_fov_train = random.randint(self.fov_min, self.fov_max) 161 | fov_angle_train += [cur_fov_train] 162 | fov_angle_val = fov_angle_train 163 | fov_angle_test = list(np.linspace(self.fov_max, self.fov_min, self.numb_test//2)) 164 | fov_angle_test_inv = fov_angle_test.copy() 165 | fov_angle_test_inv.sort() 166 | fov_angle_test = fov_angle_test + fov_angle_test_inv 167 | assert len(fov_angle_test) == self.numb_test, "Length Error for test fov !!!" 168 | 169 | return fov_angle_train, fov_angle_val, fov_angle_test 170 | 171 | def get_cam_pose_room(self): 172 | loc_test = [] 173 | rot_test = [] 174 | numb_rot_loc = 180 // self.theta_room 175 | numb_rot_rot = 360 // self.theta_room 176 | loc_train = self.rect_position(self.round_numb, numb_rot_loc) 177 | rot_train = self.rect_rotation(self.round_numb, numb_rot_rot, loc_train) 178 | loc_val, rot_val = self.surface_position() 179 | theta = list(np.linspace(360, -360, self.numb_test)) 180 | phi = list(np.linspace(90, 0, self.numb_test//2)) 181 | phi_inv = phi.copy() 182 | phi_inv.sort() 183 | phi = phi + phi_inv 184 | assert len(theta) == self.numb_test, "Length Error for test pose !!!" 185 | for i in range(self.numb_test): 186 | pitch_mat = self.Rot_Z(theta[i]) 187 | roll_mat = self.Rot_X(phi[i]) 188 | next_pose = np.matmul(self.start_pos, roll_mat) 189 | next_pose = np.matmul(next_pose, pitch_mat) 190 | next_rot = self.start_rot + np.array([-self.angle2rad(phi[i]), 0, 0]) 191 | next_rot = next_rot + np.array([0, 0, self.angle2rad(theta[i])]) 192 | loc_test += [next_pose] 193 | rot_test += [next_rot] 194 | 195 | return loc_train, loc_val, loc_test, rot_train, rot_val, rot_test 196 | 197 | def rect_position(self, numb_round, numb_rot): 198 | pos_m = [] 199 | pos_t = [] 200 | pos_b = [] 201 | step_z = self.room_z/numb_round 202 | for i in range(1, numb_round-1): 203 | cur_z = step_z*i 204 | p1 = np.array([self.room_x/2, 0, cur_z]) 205 | p2 = np.array([self.room_x/2, self.room_y/2, cur_z]) 206 | p3 = np.array([0, self.room_y/2, cur_z]) 207 | p4 = np.array([-self.room_x/2, self.room_y/2, cur_z]) 208 | p5 = np.array([-self.room_x/2, 0, cur_z]) 209 | p6 = np.array([-self.room_x/2, -self.room_y/2, cur_z]) 210 | p7 = np.array([0, -self.room_y/2, cur_z]) 211 | p8 = np.array([self.room_x/2, -self.room_y/2, cur_z]) 212 | pos_m += [p1, p2, p3, p4, p5, p6, p7, p8] 213 | pos_temp1b = [] 214 | pos_temp2b = [] 215 | pos_temp1t = [] 216 | pos_temp2t = [] 217 | for i in range(numb_rot): 218 | rad_theta = self.angle2rad(self.theta_room*i) 219 | if (self.theta_room*i) == 90: 220 | p1b = np.array([0, self.room_y/2, 0]) 221 | p2b = np.array([0, -self.room_y/2, 0]) 222 | p1t = np.array([0, self.room_y/2, self.room_z]) 223 | p2t = np.array([0, -self.room_y/2, self.room_z]) 224 | elif (self.theta_room*i) == 0: 225 | p1b = np.array([self.room_x/2, 0, 0]) 226 | p2b = np.array([-self.room_x/2, 0, 0]) 227 | p1t = np.array([self.room_x/2, 0, self.room_z]) 228 | p2t = np.array([-self.room_x/2, 0, self.room_z]) 229 | else: 230 | x_abs = self.room_y/(2*math.tan(rad_theta)) 231 | symbol = x_abs/np.abs(x_abs) 232 | if symbol > 0: 233 | y_abs = math.tan(rad_theta)*self.room_x/2 234 | else: 235 | y_abs = -math.tan(rad_theta)*self.room_x/2 236 | if np.abs(x_abs) >= self.room_x/2 : 237 | p1b = np.array([ self.room_x/2*symbol, y_abs, 0]) 238 | p2b = np.array([-self.room_x/2*symbol, -y_abs, 0]) 239 | p1t = np.array([ self.room_x/2*symbol, y_abs, self.room_z]) 240 | p2t = np.array([-self.room_x/2*symbol, -y_abs, self.room_z]) 241 | else: 242 | p1b = np.array([ x_abs, self.room_y/2, 0]) 243 | p2b = np.array([-x_abs, -self.room_y/2, 0]) 244 | p1t = np.array([ x_abs, self.room_y/2, self.room_z]) 245 | p2t = np.array([-x_abs, -self.room_y/2, self.room_z]) 246 | pos_temp1b += [p1b] 247 | pos_temp2b += [p2b] 248 | pos_temp1t += [p1t] 249 | pos_temp2t += [p2t] 250 | pos_b = pos_temp1b + pos_temp2b 251 | pos_t = pos_temp1t + pos_temp2t 252 | position = pos_b + pos_m + pos_t 253 | 254 | return position 255 | 256 | def rect_rotation(self, numb_round, numb_rot, location): 257 | rotation_list = [] 258 | rad_phi_list = [] 259 | phi_list = [] 260 | theta_list = [] 261 | rad_theta = self.angle2rad(self.theta_room) 262 | for loc in location: 263 | radius = math.sqrt(loc[0]**2 + loc[1]**2) 264 | rad_phi = math.atan(loc[2]/radius) 265 | phi_list += [round(self.rad2angle(rad_phi), 1)] 266 | rad_phi_list += [-rad_phi] 267 | 268 | for i in range(numb_round): 269 | if (i == 0): 270 | bound = True 271 | start_rot_theta = np.array([self.angle2rad(90), 0, self.angle2rad(90)]) 272 | start_rot_phi = start_rot_theta + np.array([0, 0, 0]) 273 | bound_phi = rad_phi_list[1:numb_rot] 274 | elif(i == numb_round-1): 275 | bound = True 276 | start_rot_theta = np.array([self.angle2rad(90), 0, self.angle2rad(90)]) 277 | start_rot_phi = start_rot_theta + np.array([-math.atan(2*self.room_z/self.room_x), 0, 0]) 278 | bound_phi = rad_phi_list[-numb_rot+1:] 279 | else: 280 | bound = False 281 | skip = False 282 | mid_phi = rad_phi_list[numb_rot+(i-1)*8 : numb_rot+(i-1)*8+8] 283 | for j in range(numb_rot): 284 | if bound: 285 | rotation_list += [start_rot_phi] 286 | if j == numb_rot-1: 287 | theta_list += [j*self.theta_room] 288 | continue 289 | start_rot_theta = start_rot_theta + np.array([0, 0, rad_theta]) 290 | start_rot_phi = start_rot_theta + np.array([bound_phi[j], 0, 0]) 291 | theta_list += [j*self.theta_room] 292 | else: 293 | if skip: 294 | pass 295 | else: 296 | theta_t = math.atan(self.room_y/self.room_x) 297 | rot1 = np.array([self.angle2rad(90), 0, self.angle2rad(90)]) + np.array([mid_phi[0], 0, 0]) 298 | rot2 = np.array([self.angle2rad(90), 0, self.angle2rad(90) + theta_t]) + np.array([mid_phi[1], 0, 0]) 299 | rot3 = np.array([self.angle2rad(90), 0, self.angle2rad(180)]) + np.array([mid_phi[2], 0, 0]) 300 | rot4 = np.array([self.angle2rad(90), 0, self.angle2rad(270) - theta_t]) + np.array([mid_phi[3], 0, 0]) 301 | rot5 = np.array([self.angle2rad(90), 0, self.angle2rad(270)]) + np.array([mid_phi[4], 0, 0]) 302 | rot6 = np.array([self.angle2rad(90), 0, self.angle2rad(270) + theta_t]) + np.array([mid_phi[5], 0, 0]) 303 | rot7 = np.array([self.angle2rad(90), 0, self.angle2rad(360)]) + np.array([mid_phi[6], 0, 0]) 304 | rot8 = np.array([self.angle2rad(90), 0, self.angle2rad(450) - theta_t]) + np.array([mid_phi[7], 0, 0]) 305 | rotation_list += [rot1, rot2, rot3, rot4, rot5, rot6, rot7, rot8] 306 | theta_t = theta_t * 180 / math.pi 307 | theta_t = int(theta_t) 308 | theta_list += [0, theta_t, 90, 180-theta_t, 180, 180+theta_t, 270, 360-theta_t] 309 | skip = True 310 | 311 | return rotation_list #, theta_list, phi_list 312 | 313 | def surface_position(self): 314 | loc_list = [] 315 | rot_list = [] 316 | cur_loc_list = [] 317 | room_rx = self.room_x/2 318 | room_ry = self.room_y/2 319 | for i in range(self.numb_train_val): 320 | loc_ax = random.choice([0, 1, 2, 3, 4]) 321 | if loc_ax == 0: 322 | cur_loc = [random.uniform(-room_rx, room_rx), random.uniform(-room_ry, room_ry), self.room_y] 323 | while cur_loc in cur_loc_list: 324 | cur_loc = [random.uniform(-room_rx, room_rx), random.uniform(-room_ry, room_ry), self.room_y] 325 | if loc_ax == 1: 326 | cur_loc = [random.uniform(-room_rx, room_rx), -room_ry, random.uniform(0, self.room_z)] 327 | while cur_loc in cur_loc_list: 328 | cur_loc = [random.uniform(-room_rx, room_rx), -room_ry, random.uniform(0, self.room_z)] 329 | if loc_ax == 2: 330 | cur_loc = [room_rx, random.uniform(-room_ry, room_ry), random.uniform(0, self.room_z)] 331 | while cur_loc in cur_loc_list: 332 | cur_loc = [room_rx, random.uniform(-room_ry, room_ry), random.uniform(0, self.room_z)] 333 | if loc_ax == 3: 334 | cur_loc = [random.uniform(-room_rx, room_rx), room_ry, random.uniform(0, self.room_z)] 335 | while cur_loc in cur_loc_list: 336 | cur_loc = [random.uniform(-room_rx, room_rx), room_ry, random.uniform(0, self.room_z)] 337 | if loc_ax == 4: 338 | cur_loc = [-room_rx, random.uniform(-room_ry, room_ry), random.uniform(0, self.room_z)] 339 | while cur_loc in cur_loc_list: 340 | cur_loc = [-room_rx, random.uniform(-room_ry, room_ry), random.uniform(0, self.room_z)] 341 | 342 | cur_loc_list += [cur_loc] 343 | cur_theta, cur_phi = self.get_rot_from_loc(cur_loc) 344 | cur_rot = self.start_rot + np.array([-cur_phi, 0, 0]) 345 | cur_rot = cur_rot + np.array([0, 0, cur_theta]) 346 | loc_list += [cur_loc] 347 | rot_list += [cur_rot] 348 | 349 | return loc_list, rot_list 350 | 351 | def get_rot_from_loc(self, loc): 352 | loc_np = np.array(loc) 353 | loc_r = np.linalg.norm(loc_np[:2], ord=2) 354 | rot_phi = np.arctan(loc_np[2]/loc_r) 355 | loc_vect_xy = loc_np[:2]/loc_r 356 | std_vect = np.array([0, -1]) 357 | cos_theta = np.dot(loc_vect_xy, std_vect) / (np.linalg.norm(loc_vect_xy)*np.linalg.norm(std_vect)) 358 | sin_theta = np.cross(loc_vect_xy, std_vect) / (np.linalg.norm(loc_vect_xy)*np.linalg.norm(std_vect)) 359 | if sin_theta > 0: 360 | rot_theta = 2*np.pi - np.arccos(cos_theta) 361 | else: 362 | rot_theta = np.arccos(cos_theta) 363 | return rot_theta, rot_phi 364 | 365 | def render_set(self): 366 | self.json_train_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("transforms_train.json")) 367 | self.json_val_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("transforms_val.json")) 368 | self.json_test_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("transforms_test.json")) 369 | self.json_coord_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("transforms_coord.json")) 370 | self.json_calib_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("transforms_calib.json")) 371 | bpy.context.scene.render.image_settings.file_format = 'PNG' 372 | bpy.context.scene.render.film_transparent = True 373 | bpy.context.scene.render.resolution_x = self.res_x 374 | bpy.context.scene.render.resolution_y = self.res_y 375 | self.out_data_train = {"frames":[]} 376 | self.out_data_val = {"frames":[]} 377 | self.out_data_test = {"frames":[]} 378 | self.out_data_coord = {"frames":[]} 379 | self.out_data_calib = {"frames":[]} 380 | 381 | def cam_clear(self): 382 | for cam in self.cam_list: 383 | bpy.data.objects.remove(cam) 384 | 385 | def apriltag_more_than_two(self, detector, save_path): 386 | save_path = save_path + ".png" 387 | cur_img = cv2.imread(save_path) 388 | gray_img = cv2.cvtColor(cur_img, cv2.COLOR_BGR2GRAY) 389 | gray_img = cv2.normalize(gray_img, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8UC1) 390 | _, ids, _ = detector.detectMarkers(gray_img) 391 | if (ids is not None) and (len(ids) > 2): 392 | return True 393 | else: 394 | return False 395 | 396 | def render_calibration_images(self, img_id=None): 397 | arucoDict = cv2.aruco.getPredefinedDictionary(cv2.aruco.DICT_APRILTAG_36h11) 398 | arucoParams = cv2.aruco.DetectorParameters() 399 | detector = cv2.aruco.ArucoDetector(arucoDict, arucoParams) 400 | coord_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("coord")) 401 | calib_path = os.path.join(Path(self.root_path), Path(self.object_name), Path("calib")) 402 | if not os.path.exists(coord_path): 403 | os.makedirs(coord_path) 404 | if not os.path.exists(calib_path): 405 | os.makedirs(calib_path) 406 | collect_objects = bpy.data.collections 407 | for collects in collect_objects: 408 | collects.hide_render = True 409 | collect_objects["Calibration Object"].hide_render = False 410 | object_cube = bpy.data.objects["Cube"] 411 | object_cube.rotation_euler = [0, 0, 0] 412 | for idx, cam_info in enumerate(zip(self.train_fov, self.cam_list[:self.numb_train_val])): 413 | fov, cam = cam_info 414 | bpy.context.scene.camera = cam 415 | loc, rot = self.loc_train[idx], self.rot_train[idx] 416 | cam.rotation_mode = 'XYZ' 417 | cam.location = loc 418 | cam.rotation_euler = rot 419 | cam.data.angle = self.angle2rad(fov) 420 | cur_pose = cam.matrix_world 421 | save_path = os.path.join(Path(coord_path), Path("r_{}".format(idx))) 422 | bpy.context.scene.render.filepath = save_path 423 | bpy.ops.render.render(write_still=True) 424 | frame_data = {'file_path': "./coord/r_{}".format(idx), 425 | "camera_angle_x": self.angle2rad(fov), 426 | 'transform_matrix': self.listify_matrix(cur_pose)} 427 | self.out_data_coord['frames'].append(frame_data) 428 | 429 | if img_id == None: 430 | for idx, cam_info in enumerate(zip(self.train_fov, self.cam_list[:self.numb_train_val])): 431 | fov, cam = cam_info 432 | bpy.context.scene.camera = cam 433 | loc, rot = self.loc_train[idx], self.rot_train[idx] 434 | cam.rotation_mode = 'XYZ' 435 | cam.location = loc 436 | cam.rotation_euler = rot 437 | cam.data.angle = self.angle2rad(fov) 438 | save_path = os.path.join(Path(calib_path), Path("r_{}".format(idx))) 439 | bpy.context.scene.render.filepath = save_path 440 | bpy.ops.render.render(write_still=True) 441 | while not self.apriltag_more_than_two(detector, save_path): 442 | object_cube.rotation_euler[0] = random.uniform(0, 2*math.pi) 443 | object_cube.rotation_euler[1] = random.uniform(0, 2*math.pi) 444 | object_cube.rotation_euler[2] = random.uniform(0, 2*math.pi) 445 | bpy.ops.render.render(write_still=True) 446 | 447 | frame_data = {'file_path': "./calib/r_{}".format(idx), 448 | "camera_angle_x": self.angle2rad(fov)} 449 | self.out_data_calib['frames'].append(frame_data) 450 | else: 451 | fov = self.train_fov[img_id] 452 | cam = self.cam_list[:self.numb_train_val][img_id] 453 | bpy.context.scene.camera = cam 454 | loc, rot = self.loc_train[img_id], self.rot_train[img_id] 455 | cam.rotation_mode = 'XYZ' 456 | cam.location = loc 457 | cam.rotation_euler = rot 458 | cam.data.angle = self.angle2rad(fov) 459 | cur_pose = cam.matrix_world 460 | save_path = os.path.join(Path(calib_path), Path("r_{}".format(img_id))) 461 | bpy.context.scene.render.filepath = save_path 462 | bpy.ops.render.render(write_still=True) 463 | while not self.apriltag_more_than_two(detector, save_path): 464 | object_cube.rotation_euler[0] = random.uniform(0, 2*math.pi) 465 | object_cube.rotation_euler[1] = random.uniform(0, 2*math.pi) 466 | object_cube.rotation_euler[2] = random.uniform(0, 2*math.pi) 467 | bpy.ops.render.render(write_still=True) 468 | frame_data = {'file_path': "./calib/r_{}".format(img_id), 469 | "camera_angle_x": self.angle2rad(fov)} 470 | self.out_data_calib['frames'].append(frame_data) 471 | 472 | with open(self.json_calib_path, 'w') as out_file: 473 | json.dump(self.out_data_calib, out_file, indent=4) 474 | collect_objects["Calibration Object"].hide_render = True 475 | 476 | if __name__ == "__main__": 477 | seed_dict = {"Lego":0, 478 | "Gate":1, 479 | "Materials":2, 480 | "Ficus":3, 481 | "Computer":4, 482 | "Snowtruck":5, 483 | "Statue":6, 484 | "Train":7} 485 | cur_data = "Snowtruck" 486 | dataset = Room_Dataset(obj_name = "Room_{}".format(cur_data), seed=seed_dict[cur_data]) 487 | dataset.render_images() 488 | dataset.render_calibration_images() 489 | dataset.cam_clear() 490 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .log_init import Log_config 2 | from .distributed_init import get_rank 3 | from .distributed_init import Distributed_config 4 | from .tensorboard_init import Tensorboard_config -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkylerGao/MC_NeRF/8806f31ef8cf0eba6398bedebde3b2c7dce92ec0/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/distributed_init.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkylerGao/MC_NeRF/8806f31ef8cf0eba6398bedebde3b2c7dce92ec0/utils/__pycache__/distributed_init.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/log_init.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkylerGao/MC_NeRF/8806f31ef8cf0eba6398bedebde3b2c7dce92ec0/utils/__pycache__/log_init.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/tensorboard_init.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkylerGao/MC_NeRF/8806f31ef8cf0eba6398bedebde3b2c7dce92ec0/utils/__pycache__/tensorboard_init.cpython-39.pyc -------------------------------------------------------------------------------- /utils/distributed_init.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | 5 | import torch.distributed as dist 6 | 7 | class Distributed_config(): 8 | def __init__(self, sys_param): 9 | self.init_distributed_mode(sys_param) 10 | 11 | def init_distributed_mode(self, sys_param): 12 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 13 | sys_param['rank'] = int(os.environ["RANK"]) 14 | sys_param['gpu'] = int(os.environ['LOCAL_RANK']) 15 | sys_param['world_size'] = int(os.environ['WORLD_SIZE']) 16 | 17 | else: 18 | logging.info('Muti-GPU Deactivate...') 19 | sys_param['distributed'] = False 20 | return 21 | 22 | start_device = sys_param["start_device"] 23 | sys_param['gpu'] = sys_param['gpu'] + start_device 24 | 25 | logging.info('Muti-GPU Activate: Current rank:{}, Current GPU:{}'.format(sys_param['rank'], sys_param['gpu'])) 26 | sys_param['distributed'] = True 27 | torch.cuda.set_device(sys_param['gpu']) 28 | sys_param['dist_backend'] = 'nccl' 29 | torch.distributed.init_process_group(backend=sys_param['dist_backend'], 30 | world_size=sys_param['world_size'], 31 | rank=sys_param['rank']) 32 | 33 | torch.distributed.barrier(device_ids=[sys_param['gpu']]) 34 | self.setup_for_distributed(sys_param['rank'] == 0) 35 | 36 | def setup_for_distributed(self, is_master): 37 | """ 38 | This function disables printing when not in master process 39 | """ 40 | import builtins as __builtin__ 41 | builtin_print = __builtin__.print 42 | logging_info = logging.info 43 | 44 | def print(*args, **kwargs): 45 | force = kwargs.pop('force', False) 46 | if is_master or force: 47 | builtin_print(*args, **kwargs) 48 | 49 | def info(*args, **kwargs): 50 | force = kwargs.pop('force', False) 51 | if is_master or force: 52 | logging_info(*args, **kwargs) 53 | 54 | __builtin__.print = print 55 | logging.info = info 56 | 57 | def get_rank(): 58 | if not is_dist_avail_and_initialized(): 59 | return 0 60 | return dist.get_rank() 61 | 62 | def is_dist_avail_and_initialized(): 63 | if not dist.is_available(): 64 | return False 65 | if not dist.is_initialized(): 66 | return False 67 | return True -------------------------------------------------------------------------------- /utils/log_init.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | import os 4 | 5 | from pathlib import Path 6 | 7 | class Log_config(): 8 | def __init__(self, sys_param): 9 | self.save_mode = sys_param['log'] 10 | self.save_pth = sys_param['log_pth'] 11 | self.log_function_start() 12 | 13 | def log_function_start(self): 14 | if self.save_mode: 15 | results_log_pth = os.path.join(Path("results"), Path(self.save_pth)) 16 | if not os.path.exists(results_log_pth): 17 | os.makedirs(results_log_pth) 18 | ticks = time.asctime(time.localtime(time.time()) ) 19 | ticks = str(ticks).replace(' ', '-').replace(':','-') 20 | log_name = '{}.log'.format(os.path.join(self.save_pth, ticks)) 21 | 22 | logging.basicConfig(format='[%(asctime)s %(levelname)s] %(message)s', 23 | datefmt='%m/%d/%Y %H:%M:%S', 24 | level=logging.INFO, 25 | filemode='a', 26 | filename=log_name) 27 | else: 28 | logging.basicConfig(format='[%(asctime)s %(levelname)s] %(message)s', 29 | datefmt='%m/%d/%Y %H:%M:%S', 30 | level=logging.INFO) -------------------------------------------------------------------------------- /utils/tensorboard_init.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | from pathlib import Path 5 | from torch.utils.tensorboard import SummaryWriter 6 | 7 | class Tensorboard_config(): 8 | def __init__(self, sys_param): 9 | if sys_param['tb_available']: 10 | self.save_pth = os.path.join(Path("results"), Path(sys_param["tb_pth"])) 11 | if (sys_param["distributed"]) and (sys_param['rank'] == 0): 12 | if (sys_param['tb_del']): 13 | if not os.path.exists(self.save_pth): 14 | os.makedirs(self.save_pth) 15 | else: 16 | shutil.rmtree(self.save_pth) 17 | os.makedirs(self.save_pth) 18 | else: 19 | if not os.path.exists(self.save_pth): 20 | os.makedirs(self.save_pth) 21 | self.tblogger = SummaryWriter(self.save_pth) 22 | else: 23 | self.tblogger = None 24 | 25 | def add_scalar(self, *args, **kwargs): 26 | self.tblogger.add_scalar(*args, **kwargs) --------------------------------------------------------------------------------