├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── assets └── gspose_overview.png ├── config └── inference_cfg.py ├── dataset ├── demo_dataset.py ├── extract_megapose_to_BOP.py ├── inference_datasets.py ├── megapose_dataset.py ├── misc.py └── parse_OnePoseCap_data.py ├── environment.yml ├── gaussian_object ├── __init__.py ├── arguments.py ├── build_3DGaussianObject.py ├── cameras.py ├── dataset_readers.py ├── gaussian_model.py ├── gaussian_render.py ├── gaussian_renderer │ ├── __init__.py │ └── network_gui.py ├── loss_utils.py ├── sh_utils.py └── utils │ ├── camera_utils.py │ ├── general_utils.py │ ├── graphics_utils.py │ ├── image_utils.py │ ├── loss_utils.py │ ├── sh_utils.py │ └── system_utils.py ├── inference.py ├── install_env.sh ├── misc_utils ├── gs_utils.py ├── loss_utils.py ├── metric_utils.py └── warmup_lr.py ├── model ├── blocks.py ├── curope │ ├── __init__.py │ ├── curope.cpp │ ├── curope2d.py │ ├── kernels.cu │ └── setup.py ├── dino_layers │ ├── __init__.py │ ├── attention.py │ ├── block.py │ ├── dino_head.py │ ├── drop_path.py │ ├── efficient_attention.py │ ├── layer_scale.py │ ├── mlp.py │ ├── patch_embed.py │ └── swiglu_ffn.py ├── generalized_mean_pooling.py ├── network.py └── position_encoding.py ├── notebook └── Demo_Example_with_GS-Pose.ipynb ├── three ├── __init__.py ├── batchview.py ├── core.py ├── imutils.py ├── meshutils.py ├── orientation.py ├── pytorch3d_rendering.py ├── quaternion.py ├── rigid.py ├── stats.py ├── torchutils.py └── utils.py └── training └── training.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | 163 | *database 164 | database* 165 | demo_data 166 | *.pth 167 | *.mp4 168 | *.MP4 169 | checkpoints 170 | dataspace 171 | notebook/*.png 172 | notebook/*.json 173 | dataspace 174 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/Connected_components_PyTorch"] 2 | path = submodules/Connected_components_PyTorch 3 | url = https://github.com/zsef123/Connected_components_PyTorch.git 4 | [submodule "submodules/simple-knn"] 5 | path = submodules/simple-knn 6 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git 7 | [submodule "submodules/diff-gaussian-rasterization"] 8 | path = submodules/diff-gaussian-rasterization 9 | url = https://github.com/graphdeco-inria/diff-gaussian-rasterization 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Dingding 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GS-Pose: Generalizable Segmentation-Based 6D Object Pose Estimation With 3D Gaussian Splatting 2 | - [[Project Page](https://dingdingcai.github.io/gs-pose)] 3 | - [[Paper](https://arxiv.org/abs/2403.10683)] 4 |

5 | 6 |

7 | 8 | ``` Bash 9 | @inproceedings{cai_2024_GSPose, 10 | author = {Cai, Dingding and Heikkil\"a, Janne and Rahtu, Esa}, 11 | title = {GS-Pose: Generalizable Segmentation-Based 6D Object Pose Estimation With 3D Gaussian Splatting}, 12 | journal = {arXiv preprint arXiv:2403.10683}, 13 | year = {2024}, 14 | } 15 | ``` 16 | 17 | ## Setup 18 | Please start by installing [Miniconda3](https://conda.io/projects/conda/en/latest/user-guide/install/linux.html). 19 | This repository contains submodules, and the default environment can be installed as below. 20 | 21 | ``` Bash 22 | git clone git@github.com:dingdingcai/GSPose.git --recursive 23 | cd GSPose 24 | conda env create -f environment.yml 25 | conda activate gspose 26 | 27 | bash install_env.sh 28 | ``` 29 | 30 | ## Pre-trained Model 31 | Download the [pretrained weights](https://drive.google.com/file/d/1VgOAemCrEeW_nT6qQ3R12oz_3UZmQILy/view?usp=sharing) and store it as ``checkpoints/model_wights.pth``. 32 | 33 | 34 | 35 | ## Demo Example 36 | An example of using GS-Pose for both pose estimation and tracking is provided in [``notebook``](./notebook/Demo_Example_with_GS-Pose.ipynb). 37 | 38 | 39 | ## Datasets 40 | Our evaluation is conducted on the LINEMOD and OnePose-LowTexture datasets. 41 | - For comparison with Gen6D, download [``LINEMOD_Gen6D``](https://connecthkuhk-my.sharepoint.com/:f:/g/personal/yuanly_connect_hku_hk/EkWESLayIVdEov4YlVrRShQBkOVTJwgK0bjF7chFg2GrBg?e=Y8UpXu). 42 | - For comparion with OnePose++, download [``lm``](https://bop.felk.cvut.cz/datasets) and the YOLOv5 detection results [``lm_yolo_detection``](https://zjueducn-my.sharepoint.com/:u:/g/personal/12121064_zju_edu_cn/EdodUdKGwHpCuvw3Cio5DYoBTntYLQuc7vNg9DkytWuJAQ?e=sAXp4B). 43 | - Download the [OnePose-LowTexture](https://github.com/zju3dv/OnePose_Plus_Plus/blob/main/doc/dataset_document.md 44 | ) dataset and store it under the directory ``onepose_dataset``. 45 | 46 | 47 | All datasets are organised under the ``dataspace`` directory, as below, 48 | ``` 49 | dataspace/ 50 | ├── LINEMOD_Gen6D 51 | │ 52 | ├── bop_dataset/ 53 | │ ├── lm 54 | │ └── lm_yolo_detection 55 | │ 56 | ├── onepose_dataset/ 57 | │ ├── scanned_model 58 | │ └── lowtexture_test_data 59 | │ 60 | └── README.md 61 | ``` 62 | 63 | ## Evaluation 64 | Evaluation on the subset of LINEMOD (comparison with Gen6D, Cas6D, etc.). 65 | - ``python inference.py --dataset_name LINEMOD_SUBSET --database_dir LMSubSet_database --outpose_dir LMSubSet_pose`` 66 | 67 | Evaluation on all objects of LINEMOD using the built-in detector. 68 | - ``python inference.py --dataset_name LINEMOD --database_dir LM_database --outpose_dir LM_pose`` 69 | 70 | Evaluation on all objects of LINEMOD using the YOLOv5 detection (comparison with OnePose/OnePose++). 71 | - ``python inference.py --dataset_name LINEMOD --database_dir LM_database --outpose_dir LM_yolo_pose`` 72 | 73 | Evaluation on the scanned objects of OnePose-LowTexture. 74 | - ``python inference.py --dataset_name LOWTEXTUREVideo --database_dir LTVideo_database --outpose_dir LTVideo_pose`` 75 | 76 | ## Training 77 | We utilize a subset (``gso_1M``) of the MegaPose dataset for training. 78 | Please download [``MegaPose/gso_1M``](https://www.paris.inria.fr/archive_ylabbeprojectsdata/megapose/webdatasets/) and [``MegaPose/google_scanned_objects.zip``](https://www.paris.inria.fr/archive_ylabbeprojectsdata/megapose/tars/) to the directory``dataspace``, and organize the data as 79 | ``` 80 | dataspace/ 81 | ├── MegaPose/ 82 | │ ├── webdatasets/gso_1M 83 | │ └── google_scanned_objects 84 | ... 85 | ``` 86 | execute the following script under the [``MegaPose``](https://github.com/megapose6d/megapose6d?tab=readme-ov-file) environment for preparing the training data. 87 | - ``python dataset/extract_megapose_to_BOP.py`` 88 | 89 | Then, train the network via 90 | - ``python training/training.py`` 91 | 92 | ## Acknowledgement 93 | - 1. The code is partially based on [DINOv2](https://github.com/facebookresearch/dinov2), [3D Gaussian Splatting](https://github.com/graphdeco-inria/gaussian-splatting?tab=readme-ov-file), [MegaPose](https://github.com/megapose6d/megapose6d), and [SC6D](https://github.com/dingdingcai/SC6D-pose). 94 | 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /assets/gspose_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingdingcai/GSPose/dc88a8965e2b48436371f297f9288fba8313ab83/assets/gspose_overview.png -------------------------------------------------------------------------------- /config/inference_cfg.py: -------------------------------------------------------------------------------- 1 | cosim_topk = -1 2 | refer_view_num = 8 3 | DINO_PATCH_SIZE = 14 4 | zoom_image_margin = 0 5 | zoom_image_scale = 224 6 | query_longside_scale = 672 7 | query_shortside_scale = query_longside_scale * 3 // 4 8 | 9 | coarse_threshold = 0.05 10 | coarse_bbox_padding = 1.25 11 | finer_threshold = 0.5 12 | finer_bbox_padding = 1.5 13 | enable_fine_detection = True 14 | 15 | save_reference_mask = True 16 | #### configure for 3D-GS-Refiner #### 17 | ROT_TOPK = 1 # single rotation proposal 18 | 19 | WARMUP = 10 20 | END_LR = 0 21 | START_LR = 5e-3 22 | MAX_STEPS = 400 23 | GS_RENDER_SIZE = 224 24 | EARLY_STOP_MIN_STEPS = 5 25 | EARLY_STOP_LOSS_GRAD_NORM = 1e-4 26 | 27 | USE_SSIM = True 28 | USE_MS_SSIM = True 29 | 30 | BINARIZE_MASK = False 31 | USE_YOLO_BBOX = True 32 | USE_ALLOCENTRIC = True 33 | APPLY_ZOOM_AND_CROP = True 34 | CC_INCLUDE_SUPMASK = False 35 | 36 | -------------------------------------------------------------------------------- /dataset/demo_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | import mediapy as media 6 | 7 | from transforms3d import affines, quaternions 8 | 9 | from misc_utils import gs_utils 10 | 11 | class OnePoseCap_Dataset(torch.utils.data.Dataset): 12 | def __init__(self, obj_data_dir, num_grid_points=4096, extract_RGB=False, use_binarized_mask=False, obj_database_dir=None): 13 | 14 | self.extract_RGB = extract_RGB 15 | self.obj_data_dir = obj_data_dir 16 | self.num_grid_points = num_grid_points 17 | self.obj_database_dir = obj_database_dir 18 | self.use_binarized_mask = use_binarized_mask 19 | 20 | self.arkit_box_path = os.path.join(self.obj_data_dir, 'Box.txt') 21 | self.arkit_pose_path = os.path.join(self.obj_data_dir, 'ARposes.txt') 22 | self.arkit_video_path = os.path.join(self.obj_data_dir, 'Frames.m4v') 23 | self.arkit_intrin_path = os.path.join(self.obj_data_dir, 'Frames.txt') 24 | 25 | #### read the ARKit pose info 26 | with open(self.arkit_pose_path, 'r') as pf: 27 | self.arkit_poses = [row.strip() for row in pf.readlines() if len(row) > 0 and row[0] != '#'] 28 | 29 | with open(self.arkit_intrin_path, 'r') as cf: 30 | self.arkit_camKs = [row.strip() for row in cf.readlines() if len(row) > 0 and row[0] != '#'] 31 | 32 | ### read the video 33 | if self.extract_RGB: 34 | RGB_dir = os.path.join(self.obj_data_dir, 'RGB') 35 | if not os.path.exists(RGB_dir): 36 | os.makedirs(RGB_dir) 37 | cap = cv2.VideoCapture(self.arkit_video_path) 38 | index = 0 39 | while True: 40 | ret, image = cap.read() 41 | if not ret: 42 | break 43 | cv2.imwrite(os.path.join(RGB_dir, f'{index}.png'), image) 44 | index += 1 45 | else: 46 | self.video_frames = media.read_video(self.arkit_video_path) # NxHxWx3 47 | 48 | 49 | assert(len(self.arkit_poses) == len(self.arkit_camKs)) 50 | 51 | #### preprocess the ARKit 3D object bounding box 52 | with open(self.arkit_box_path, 'r') as f: 53 | lines = f.readlines() 54 | box_data = [float(e) for e in lines[1].strip().split(',')] 55 | ex, ey, ez = box_data[3:6] 56 | self.obj_bbox3d = np.array([ 57 | [-ex, -ey, -ez], # Front-top-left corner 58 | [ex, -ey, -ez], # Front-top-right corner 59 | [ex, ey, -ez], # Front-bottom-right corner 60 | [-ex, ey, -ez], # Front-bottom-left corner 61 | [-ex, -ey, ez], # Back-top-left corner 62 | [ex, -ey, ez], # Back-top-right corner 63 | [ex, ey, ez], # Back-bottom-right corner 64 | [-ex, ey, ez], # Back-bottom-left corner 65 | ]) * 0.5 66 | obj_bbox3D_dims = np.array([ex, ey, ez], dtype=np.float32) 67 | grid_cube_size = (np.prod(obj_bbox3D_dims, axis=0) / self.num_grid_points)**(1/3) 68 | xnum, ynum, znum = np.ceil(obj_bbox3D_dims / grid_cube_size).astype(np.int64) 69 | xmin, ymin, zmin = self.obj_bbox3d.min(axis=0) 70 | xmax, ymax, zmax = self.obj_bbox3d.max(axis=0) 71 | zgrid, ygrid, xgrid = np.meshgrid(np.linspace(zmin, zmax, znum), 72 | np.linspace(ymin, ymax, ynum), 73 | np.linspace(xmin, xmax, xnum), 74 | indexing='ij') 75 | self.bbox3d_grid_points = np.stack([xgrid, ygrid, zgrid], axis=-1).reshape(-1, 3) 76 | self.bbox3d_diameter = np.linalg.norm(obj_bbox3D_dims) 77 | 78 | bbox3d_position = np.array(box_data[0:3], dtype=np.float32) 79 | bbox3D_rot_quat = np.array(box_data[6:10], dtype=np.float32) 80 | bbox3D_rot_mat = quaternions.quat2mat(bbox3D_rot_quat) 81 | T_O2W = affines.compose(bbox3d_position, bbox3D_rot_mat, np.ones(3)) # object-to-world 82 | 83 | self.camKs = list() 84 | self.poses = list() 85 | self.allo_poses = list() 86 | self.image_IDs = list() 87 | Xaxis_Rmat = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]], dtype=np.float32) 88 | 89 | for frame_idx, pose_info in enumerate(self.arkit_poses): 90 | camk_info = self.arkit_camKs[frame_idx] # [time, index, fx, fy, cx, cy] 91 | camk_dat = [float(c) for c in camk_info.split(',')] 92 | camk = np.eye(3) 93 | camk[0, 0] = camk_dat[-4] 94 | camk[1, 1] = camk_dat[-3] 95 | camk[0, 2] = camk_dat[-2] 96 | camk[1, 2] = camk_dat[-1] 97 | self.camKs.append(camk) 98 | 99 | pose_dat = [float(p) for p in pose_info.split(',')] 100 | bbox_pos = pose_dat[1:4] 101 | bbox_quat = pose_dat[4:] 102 | rot_mat = quaternions.quat2mat(bbox_quat) 103 | rot_mat = rot_mat @ Xaxis_Rmat.copy() # conversion (X-right, Y-up, Z-back) right-hand 104 | T_C2W = affines.compose(bbox_pos, rot_mat, np.ones(3)) # camera-to-world 105 | T_W2C = np.linalg.inv(T_C2W) # world-to-camera 106 | pose_RT = T_W2C @ T_O2W # object-to-camera 107 | 108 | allo_pose = pose_RT.copy() 109 | allo_pose[:3, :3] = gs_utils.egocentric_to_allocentric(allo_pose)[:3, :3] 110 | self.allo_poses.append(allo_pose) 111 | 112 | self.poses.append(pose_RT) 113 | self.image_IDs.append(frame_idx) 114 | 115 | def __len__(self): 116 | return len(self.poses) 117 | 118 | def __getitem__(self, idx): 119 | data_dict = dict() 120 | camK = self.camKs[idx] 121 | pose = self.poses[idx] 122 | allo_pose = self.allo_poses[idx] 123 | image_ID = self.image_IDs[idx] 124 | 125 | if self.extract_RGB: 126 | image = cv2.imread(os.path.join(self.obj_data_dir, 'RGB', f'{image_ID}.png')) 127 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0 128 | else: 129 | image = np.array(self.video_frames[idx]) / 255.0 130 | 131 | data_dict['image_ID'] = image_ID 132 | data_dict['camK'] = torch.as_tensor(camK, dtype=torch.float32) 133 | data_dict['pose'] = torch.as_tensor(pose, dtype=torch.float32) 134 | data_dict['image'] = torch.as_tensor(image, dtype=torch.float32) 135 | data_dict['allo_pose'] = torch.as_tensor(allo_pose, dtype=torch.float32) 136 | 137 | if self.obj_database_dir is not None: 138 | data_dict['coseg_mask_path'] = os.path.join(self.obj_database_dir, 'pred_coseg_mask', '{:06d}.png'.format(image_ID)) 139 | else: 140 | data_dict['coseg_mask_path'] = os.path.join(self.obj_data_dir, 'pred_coseg_mask', '{:06d}.png'.format(image_ID)) 141 | 142 | return data_dict 143 | 144 | def collate_fn(self, batch): 145 | """ 146 | batchify the data 147 | """ 148 | new_batch = dict() 149 | for each_dat in batch: 150 | for key, val in each_dat.items(): 151 | if key not in new_batch: 152 | new_batch[key] = list() 153 | new_batch[key].append(val) 154 | 155 | for key, val in new_batch.items(): 156 | new_batch[key] = torch.stack(val, dim=0) 157 | 158 | return new_batch 159 | -------------------------------------------------------------------------------- /dataset/parse_OnePoseCap_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import tqdm 4 | import numpy as np 5 | import os.path as osp 6 | from pathlib import Path 7 | from transforms3d import affines, quaternions 8 | 9 | 10 | def parse_box(box_path): 11 | with open(box_path, 'r') as f: 12 | lines = f.readlines() 13 | data = [float(e) for e in lines[1].strip().split(',')] 14 | position = data[:3] 15 | quaternion = data[6:] 16 | rot_mat = quaternions.quat2mat(quaternion) 17 | T_ow = affines.compose(position, rot_mat, np.ones(3)) 18 | return T_ow 19 | 20 | def get_bbox3d(box_path): 21 | assert Path(box_path).exists() 22 | with open(box_path, 'r') as f: 23 | lines = f.readlines() 24 | box_data = [float(e) for e in lines[1].strip().split(',')] 25 | ex, ey, ez = box_data[3: 6] 26 | bbox_3d = np.array([ 27 | [-ex, -ey, -ez], 28 | [ex, -ey, -ez], 29 | [ex, -ey, ez], 30 | [-ex, -ey, ez], 31 | [-ex, ey, -ez], 32 | [ ex, ey, -ez], 33 | [ ex, ey, ez], 34 | [-ex, ey, ez] 35 | ]) * 0.5 36 | bbox_3d_homo = np.concatenate([bbox_3d, np.ones((8, 1))], axis=1) 37 | return bbox_3d, bbox_3d_homo 38 | 39 | 40 | 41 | def data_process_anno(data_dir): 42 | arkit_box_path = osp.join(data_dir, 'Box.txt') 43 | arkit_pose_path = osp.join(data_dir, 'ARposes.txt') 44 | arkit_intrin_path = osp.join(data_dir, 'Frames.txt') 45 | 46 | out_pose_dir = osp.join(data_dir, 'poses') 47 | Path(out_pose_dir).mkdir(parents=True, exist_ok=True) 48 | out_intrin_path = osp.join(data_dir, 'intrinsics.txt') 49 | out_bbox3D_path = osp.join(osp.dirname(data_dir), 'box3d_corners.txt') 50 | 51 | ##### read the ARKit 3D bounding box and convert to box corners 52 | bbox_3d, bbox_3d_homo = get_bbox3d(arkit_box_path) 53 | np.savetxt(out_bbox3D_path, bbox_3d) 54 | 55 | ##### read the ARKit camera intrinsics 56 | with open(arkit_intrin_path, 'r') as f: 57 | lines = [l.strip() for l in f.readlines() if len(l) > 0 and l[0] != '#'] 58 | arkit_camk = np.array([[float(e) for e in l.split(',')] for l in lines]) 59 | fx, fy, cx, cy = np.average(arkit_camk, axis=0)[2:] 60 | with open(out_intrin_path, 'w') as f: 61 | f.write('fx: {0}\nfy: {1}\ncx: {2}\ncy: {3}'.format(fx, fy, cx, cy)) 62 | 63 | ##### read the ARKit camera poses 64 | T_O2W = parse_box(arkit_box_path) # 3D object bounding box is defined w.r.t. the world coordinate system 65 | 66 | with open(arkit_pose_path, 'r') as f: 67 | lines = [l.strip() for l in f.readlines()] 68 | index = 0 69 | for line in tqdm.tqdm(lines): 70 | if len(line) == 0 or line[0] == '#': 71 | continue 72 | 73 | eles = line.split(',') 74 | data = [float(e) for e in eles] 75 | 76 | position = data[1:4] 77 | quaternion = data[4:] 78 | rot_mat = quaternions.quat2mat(quaternion) 79 | rot_mat = rot_mat @ np.array([ 80 | [1, 0, 0], 81 | [0, -1, 0], 82 | [0, 0, -1]]) 83 | T_C2W = affines.compose(position, rot_mat, np.ones(3)) 84 | T_W2C = np.linalg.inv(T_C2W) 85 | T_O2C = T_W2C @ T_O2W 86 | 87 | pose_save_path = osp.join(out_pose_dir, '{}.txt'.format(index)) 88 | np.savetxt(pose_save_path, T_O2C) 89 | index += 1 90 | 91 | 92 | 93 | 94 | if __name__ == "__main__": 95 | args = parse_args() 96 | data_dir = args.scanned_object_path 97 | assert osp.exists(data_dir), f"Scanned object path:{data_dir} not exists!" 98 | seq_dirs = os.listdir(data_dir) 99 | 100 | for seq_dir in seq_dirs: 101 | if '-refer' in seq_dir: 102 | ################ Parse scanned reference sequence ################ 103 | print('=> Processing train sequence: ', seq_dir) 104 | video_path = osp.join(data_dir, seq_dir, 'Frames.m4v') 105 | print('=> parse video: ', video_path) 106 | data_process_anno(osp.join(data_dir, seq_dir), downsample_rate=1, hw=512) 107 | 108 | elif '-query' in seq_dir: 109 | ################ Parse scanned test sequence ################ 110 | print('=> Processing test sequence: ', seq_dir) 111 | data_process_test(osp.join(data_dir, seq_dir), downsample_rate=1) 112 | pass 113 | 114 | else: 115 | 116 | continue 117 | 118 | 119 | # python parse_scanned_data.py --scanned_object_path /home/dingding/Workspace/Others/OnePose_Plus_Plus/data/demo/teacan -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: gspose 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | - xformers 7 | dependencies: 8 | - python==3.8.5 9 | - plyfile==0.8.1 10 | - cudatoolkit==11.7 11 | - pip: 12 | - mmcv==1.7.1 13 | - tqdm==4.66.2 14 | - pillow==9.3.0 15 | - mediapy==1.1.4 16 | - open3d==0.16.0 17 | - trimesh==3.18.0 18 | - ninja==1.11.1.1 19 | - structlog==23.1.0 20 | - pycocotools==2.0.6 21 | - webdataset==0.2.48 22 | - transforms3d==0.4.1 23 | - tensorboard==2.11.2 24 | - scikit-image==0.19.3 25 | - opencv-python==4.8.0.76 26 | - pytorch-msssim==1.0.0 -------------------------------------------------------------------------------- /gaussian_object/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import json 14 | import random 15 | 16 | # import the customized modules 17 | from .arguments import ModelParams 18 | from .gaussian_model import GaussianModel 19 | from .dataset_readers import readObjectInfo 20 | from .utils.system_utils import searchForMaxIteration 21 | from .utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 22 | 23 | 24 | class Scene: 25 | 26 | gaussians : GaussianModel 27 | 28 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]): 29 | """b 30 | :param path: Path to colmap scene main folder. 31 | """ 32 | self.model_path = args.model_path 33 | self.loaded_iter = None 34 | self.gaussians = gaussians 35 | 36 | if load_iteration: 37 | if load_iteration == -1: 38 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 39 | else: 40 | self.loaded_iter = load_iteration 41 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 42 | 43 | self.train_cameras = {} 44 | self.test_cameras = {} 45 | 46 | scene_info = readObjectInfo(train_dataset=args.referloader, 47 | test_dataset=args.queryloader, 48 | model_path=args.model_path, 49 | zoom_scale=args.zoom_scale, margin=args.margin, 50 | random_points3D=args.random_points3D, 51 | num_points=args.num_points, 52 | ) 53 | 54 | # if os.path.exists(os.path.join(args.source_path, "sparse")): 55 | # scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval) 56 | # elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): 57 | # print("Found transforms_train.json file, assuming Blender data set!") 58 | # scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) 59 | # else: 60 | # assert False, "Could not recognize scene type!" 61 | 62 | 63 | # if not self.loaded_iter: 64 | # with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: 65 | # dest_file.write(src_file.read()) 66 | # json_cams = [] 67 | # camlist = [] 68 | # if scene_info.test_cameras: 69 | # camlist.extend(scene_info.test_cameras) 70 | # if scene_info.train_cameras: 71 | # camlist.extend(scene_info.train_cameras) 72 | # for id, cam in enumerate(camlist): 73 | # json_cams.append(camera_to_JSON(id, cam)) 74 | # with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: 75 | # json.dump(json_cams, file) 76 | 77 | if shuffle: 78 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 79 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 80 | 81 | self.cameras_extent = scene_info.nerf_normalization["radius"] 82 | 83 | for resolution_scale in resolution_scales: 84 | print("Loading Training Cameras") 85 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args) 86 | print("Loading Test Cameras") 87 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) 88 | 89 | if self.loaded_iter: 90 | self.gaussians.load_ply(os.path.join(self.model_path, 91 | "point_cloud", 92 | "iteration_" + str(self.loaded_iter), 93 | "point_cloud.ply")) 94 | else: 95 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) 96 | 97 | def save(self, iteration): 98 | # point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 99 | self.gaussians.save_ply(os.path.join(self.model_path, "3DGO_model.ply")) 100 | 101 | def getTrainCameras(self, scale=1.0): 102 | return self.train_cameras[scale] 103 | 104 | def getTestCameras(self, scale=1.0): 105 | return self.test_cameras[scale] -------------------------------------------------------------------------------- /gaussian_object/arguments.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import sys 14 | from argparse import ArgumentParser, Namespace 15 | 16 | class GroupParams: 17 | pass 18 | 19 | class ParamGroup: 20 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False): 21 | group = parser.add_argument_group(name) 22 | for key, value in vars(self).items(): 23 | shorthand = False 24 | if key.startswith("_"): 25 | shorthand = True 26 | key = key[1:] 27 | t = type(value) 28 | value = value if not fill_none else None 29 | if shorthand: 30 | if t == bool: 31 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") 32 | else: 33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) 34 | else: 35 | if t == bool: 36 | group.add_argument("--" + key, default=value, action="store_true") 37 | else: 38 | group.add_argument("--" + key, default=value, type=t) 39 | 40 | def extract(self, args): 41 | group = GroupParams() 42 | for arg in vars(args).items(): 43 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 44 | setattr(group, arg[0], arg[1]) 45 | return group 46 | 47 | class ModelParams(ParamGroup): 48 | def __init__(self, parser, sentinel=False): 49 | self.sh_degree = 3 50 | self._source_path = "" 51 | self._model_path = "" 52 | self._images = "images" 53 | self._resolution = -1 54 | self._white_background = False 55 | self.data_device = "cuda" 56 | self.eval = False 57 | 58 | self.margin = 0.0 59 | self.zoom_scale = 512 60 | self.num_points = 4096 61 | self.random_points3D = False 62 | self.referloader = None 63 | self.queryloader = None 64 | 65 | super().__init__(parser, "Loading Parameters", sentinel) 66 | 67 | def extract(self, args): 68 | g = super().extract(args) 69 | g.source_path = os.path.abspath(g.source_path) 70 | return g 71 | 72 | class PipelineParams(ParamGroup): 73 | def __init__(self, parser): 74 | self.convert_SHs_python = False 75 | self.compute_cov3D_python = False 76 | self.debug = False 77 | super().__init__(parser, "Pipeline Parameters") 78 | 79 | class OptimizationParams(ParamGroup): 80 | def __init__(self, parser): 81 | self.iterations = 30_000 82 | self.position_lr_init = 0.00016 83 | self.position_lr_final = 0.0000016 84 | self.position_lr_delay_mult = 0.01 85 | self.position_lr_max_steps = 30_000 86 | self.feature_lr = 0.0025 87 | self.opacity_lr = 0.05 88 | self.scaling_lr = 0.005 89 | self.rotation_lr = 0.001 90 | self.percent_dense = 0.01 91 | self.lambda_dssim = 0.2 92 | self.densification_interval = 100 93 | self.opacity_reset_interval = 3000 94 | self.densify_from_iter = 500 95 | self.densify_until_iter = 15_000 96 | self.densify_grad_threshold = 0.0002 97 | self.random_background = False 98 | super().__init__(parser, "Optimization Parameters") 99 | 100 | def get_combined_args(parser : ArgumentParser): 101 | cmdlne_string = sys.argv[1:] 102 | cfgfile_string = "Namespace()" 103 | args_cmdline = parser.parse_args(cmdlne_string) 104 | 105 | try: 106 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 107 | print("Looking for config file in", cfgfilepath) 108 | with open(cfgfilepath) as cfg_file: 109 | print("Config file found: {}".format(cfgfilepath)) 110 | cfgfile_string = cfg_file.read() 111 | except TypeError: 112 | print("Config file not found at") 113 | pass 114 | args_cfgfile = eval(cfgfile_string) 115 | 116 | merged_dict = vars(args_cfgfile).copy() 117 | for k,v in vars(args_cmdline).items(): 118 | if v != None: 119 | merged_dict[k] = v 120 | return Namespace(**merged_dict) 121 | -------------------------------------------------------------------------------- /gaussian_object/build_3DGaussianObject.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import sys 14 | import glob 15 | import uuid 16 | import torch 17 | from tqdm import tqdm 18 | from random import randint 19 | from argparse import ArgumentParser, Namespace 20 | try: 21 | from torch.utils.tensorboard import SummaryWriter 22 | TENSORBOARD_FOUND = True 23 | except ImportError: 24 | TENSORBOARD_FOUND = False 25 | 26 | from gaussian_object.utils.image_utils import psnr 27 | # from gaussian_object.utils.general_utils import safe_state 28 | # from gaussian_object.gaussian_renderer import network_gui 29 | from gaussian_object.gaussian_renderer import render 30 | 31 | # import the customized modules 32 | from gaussian_object import Scene 33 | from gaussian_object.loss_utils import l1_loss, ssim 34 | from gaussian_object.gaussian_model import GaussianModel 35 | from gaussian_object.arguments import ModelParams, PipelineParams, OptimizationParams 36 | 37 | 38 | # def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): 39 | def create_3D_Gaussian_object(dataset, opt, pipe, testing_iterations=[30_000], 40 | saving_iterations=[30_000], 41 | checkpoint_iterations=[], 42 | checkpoint=None, 43 | debug_from=-1, 44 | return_gaussian=True, 45 | ): 46 | first_iter = 0 47 | tb_writer = prepare_output_and_logger(dataset) 48 | gaussians = GaussianModel(dataset.sh_degree) 49 | scene = Scene(dataset, gaussians) 50 | gaussians.training_setup(opt) 51 | if checkpoint: 52 | (model_params, first_iter) = torch.load(checkpoint) 53 | gaussians.restore(model_params, opt) 54 | 55 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 56 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 57 | 58 | iter_start = torch.cuda.Event(enable_timing = True) 59 | iter_end = torch.cuda.Event(enable_timing = True) 60 | 61 | viewpoint_stack = None 62 | ema_loss_for_log = 0.0 63 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="3DGO modeling progress") 64 | first_iter += 1 65 | for iteration in range(first_iter, opt.iterations + 1): 66 | 67 | # if network_gui.conn == None: 68 | # network_gui.try_connect() 69 | # while network_gui.conn != None: 70 | # try: 71 | # net_image_bytes = None 72 | # custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive() 73 | # if custom_cam != None: 74 | # net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"] 75 | # net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) 76 | # network_gui.send(net_image_bytes, dataset.source_path) 77 | # if do_training and ((iteration < int(opt.iterations)) or not keep_alive): 78 | # break 79 | # except Exception as e: 80 | # network_gui.conn = None 81 | 82 | iter_start.record() 83 | 84 | gaussians.update_learning_rate(iteration) 85 | 86 | # Every 1000 its we increase the levels of SH up to a maximum degree 87 | if iteration % 1000 == 0: 88 | gaussians.oneupSHdegree() 89 | 90 | # Pick a random Camera 91 | if not viewpoint_stack: 92 | viewpoint_stack = scene.getTrainCameras().copy() 93 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) 94 | 95 | # Render 96 | if (iteration - 1) == debug_from: 97 | pipe.debug = True 98 | 99 | bg = torch.rand((3), device="cuda") if opt.random_background else background 100 | 101 | render_pkg = render(viewpoint_cam, gaussians, pipe, bg) 102 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] 103 | 104 | # Loss 105 | gt_image = viewpoint_cam.original_image.cuda() 106 | 107 | trunc_FG_mask = (gt_image.sum(dim=0, keepdim=True) > 0).type(torch.float32) 108 | 109 | # Ll1 = l1_loss(image, gt_image) 110 | # loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) 111 | 112 | Ll1 = (l1_loss(image, gt_image, size_average=True) * trunc_FG_mask).mean() 113 | ssim_score = (ssim(image, gt_image, size_average=True) * trunc_FG_mask).mean() 114 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim_score) 115 | 116 | loss.backward() 117 | iter_end.record() 118 | 119 | with torch.no_grad(): 120 | # Progress bar 121 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 122 | if iteration % 10 == 0: 123 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) 124 | progress_bar.update(10) 125 | if iteration == opt.iterations: 126 | progress_bar.close() 127 | 128 | # Log and save 129 | training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background)) 130 | if (iteration in saving_iterations): 131 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 132 | scene.save(iteration) 133 | 134 | # Densification 135 | if iteration < opt.densify_until_iter: 136 | # Keep track of max radii in image-space for pruning 137 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) 138 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) 139 | 140 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: 141 | size_threshold = 20 if iteration > opt.opacity_reset_interval else None 142 | gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold) 143 | 144 | if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): 145 | gaussians.reset_opacity() 146 | 147 | # Optimizer step 148 | if iteration < opt.iterations: 149 | gaussians.optimizer.step() 150 | gaussians.optimizer.zero_grad(set_to_none = True) 151 | 152 | if (iteration in checkpoint_iterations): 153 | print("\n[ITER {}] Saving Checkpoint".format(iteration)) 154 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") 155 | 156 | if return_gaussian: 157 | return gaussians 158 | 159 | def prepare_output_and_logger(args): 160 | if not args.model_path: 161 | if os.getenv('OAR_JOB_ID'): 162 | unique_str=os.getenv('OAR_JOB_ID') 163 | else: 164 | unique_str = str(uuid.uuid4()) 165 | args.model_path = os.path.join("./output/", unique_str[0:10]) 166 | 167 | # Set up output folder 168 | print("Output folder: {}".format(args.model_path)) 169 | logs_file = f'{args.model_path}/events.out.tfevents.*' 170 | if len(glob.glob(logs_file)) > 0: 171 | os.system(f"rm -r {logs_file}") # remove the old events 172 | 173 | os.makedirs(args.model_path, exist_ok = True) 174 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 175 | cfg_log_f.write(str(Namespace(**vars(args)))) 176 | 177 | # Create Tensorboard writer 178 | tb_writer = None 179 | if TENSORBOARD_FOUND: 180 | tb_writer = SummaryWriter(args.model_path) 181 | else: 182 | print("Tensorboard not available: not logging progress") 183 | return tb_writer 184 | 185 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs): 186 | if tb_writer: 187 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) 188 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) 189 | tb_writer.add_scalar('iter_time', elapsed, iteration) 190 | 191 | # Report test and samples of training set 192 | if iteration in testing_iterations: 193 | torch.cuda.empty_cache() 194 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, 195 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) 196 | 197 | for config in validation_configs: 198 | if config['cameras'] and len(config['cameras']) > 0: 199 | l1_test = 0.0 200 | psnr_test = 0.0 201 | for idx, viewpoint in enumerate(config['cameras']): 202 | image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0) 203 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 204 | if tb_writer and (idx < 5): 205 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration) 206 | if iteration == testing_iterations[0]: 207 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration) 208 | l1_test += l1_loss(image, gt_image).mean().double() 209 | psnr_test += psnr(image, gt_image).mean().double() 210 | psnr_test /= len(config['cameras']) 211 | l1_test /= len(config['cameras']) 212 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) 213 | if tb_writer: 214 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) 215 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) 216 | 217 | if tb_writer: 218 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) 219 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) 220 | torch.cuda.empty_cache() 221 | 222 | -------------------------------------------------------------------------------- /gaussian_object/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | import os 12 | import sys 13 | import torch 14 | import numpy as np 15 | from torch import nn 16 | 17 | # import the customized modules 18 | from gaussian_object.utils.graphics_utils import getWorld2View2, getProjectionMatrix 19 | 20 | 21 | class Camera(nn.Module): 22 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, 23 | image_name, uid, 24 | cx_offset, cy_offset, mask=None, 25 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" 26 | ): 27 | super(Camera, self).__init__() 28 | 29 | self.uid = uid 30 | self.colmap_id = colmap_id 31 | self.R = R 32 | self.T = T 33 | self.FoVx = FoVx 34 | self.FoVy = FoVy 35 | self.cx_offset = cx_offset 36 | self.cy_offset = cy_offset 37 | self.image_name = image_name 38 | 39 | try: 40 | self.data_device = torch.device(data_device) 41 | except Exception as e: 42 | print(e) 43 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) 44 | self.data_device = torch.device("cuda") 45 | 46 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 47 | self.image_width = self.original_image.shape[2] 48 | self.image_height = self.original_image.shape[1] 49 | 50 | if gt_alpha_mask is not None: 51 | self.original_image *= gt_alpha_mask.to(self.data_device) 52 | else: 53 | self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) 54 | 55 | self.zfar = 100.0 56 | self.znear = 0.01 57 | 58 | self.trans = trans 59 | self.scale = scale 60 | 61 | # self.world_view_transform = torch.tensor(getWorld2View2(self.R, self.T, self.trans, self.scale)).transpose(0, 1).cuda() 62 | # self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, 63 | # fovX=self.FoVx, fovY=self.FoVy, 64 | # cx_offset=self.cx_offset, 65 | # cy_offset=self.cy_offset, 66 | # ).transpose(0,1).cuda() 67 | # self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 68 | # self.camera_center = self.world_view_transform.inverse()[3, :3] 69 | 70 | @property 71 | def world_view_transform(self): 72 | return torch.tensor(getWorld2View2(self.R, self.T, self.trans, self.scale)).transpose(0, 1).cuda() 73 | 74 | @property 75 | def projection_matrix(self): 76 | return getProjectionMatrix(znear=self.znear, zfar=self.zfar, 77 | fovX=self.FoVx, fovY=self.FoVy, 78 | cx_offset=self.cx_offset, 79 | cy_offset=self.cy_offset, 80 | ).transpose(0,1).cuda() 81 | 82 | @property 83 | def full_proj_transform(self): 84 | return (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 85 | 86 | @property 87 | def camera_center(self): 88 | return self.world_view_transform.inverse()[3, :3] 89 | 90 | class MiniCam: 91 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): 92 | self.image_width = width 93 | self.image_height = height 94 | self.FoVy = fovy 95 | self.FoVx = fovx 96 | self.znear = znear 97 | self.zfar = zfar 98 | self.world_view_transform = world_view_transform 99 | self.full_proj_transform = full_proj_transform 100 | view_inv = torch.inverse(self.world_view_transform) 101 | self.camera_center = view_inv[3][:3] 102 | 103 | -------------------------------------------------------------------------------- /gaussian_object/dataset_readers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import cv2 14 | import sys 15 | import json 16 | import torch 17 | 18 | import numpy as np 19 | from PIL import Image 20 | from pathlib import Path 21 | from typing import NamedTuple 22 | # from plyfile import PlyData, PlyElement 23 | from pytorch3d import transforms as py3d_transform 24 | 25 | 26 | # import the customized modules 27 | from misc_utils import gs_utils 28 | from gaussian_object.utils.sh_utils import SH2RGB 29 | from gaussian_object.gaussian_model import BasicPointCloud 30 | from gaussian_object.utils.graphics_utils import getWorld2View2, focal2fov 31 | 32 | class CameraInfo(NamedTuple): 33 | uid: int 34 | R: np.array 35 | T: np.array 36 | FovY: np.array 37 | FovX: np.array 38 | image: np.array 39 | image_path: str 40 | image_name: str 41 | width: int 42 | height: int 43 | cx_offset: np.array = 0 44 | cy_offset: np.array = 0 45 | 46 | class SceneInfo(NamedTuple): 47 | point_cloud: BasicPointCloud 48 | train_cameras: list 49 | test_cameras: list 50 | nerf_normalization: dict 51 | ply_path: str 52 | 53 | def getNerfppNorm(cam_info): 54 | def get_center_and_diag(cam_centers): 55 | cam_centers = np.hstack(cam_centers) 56 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 57 | center = avg_cam_center 58 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 59 | diagonal = np.max(dist) 60 | return center.flatten(), diagonal 61 | 62 | cam_centers = [] 63 | 64 | for cam in cam_info: 65 | W2C = getWorld2View2(cam.R, cam.T) 66 | C2W = np.linalg.inv(W2C) 67 | cam_centers.append(C2W[:3, 3:4]) 68 | 69 | center, diagonal = get_center_and_diag(cam_centers) 70 | radius = diagonal * 1.1 71 | 72 | translate = -center 73 | 74 | return {"translate": translate, "radius": radius} 75 | 76 | # def fetchPly(path): 77 | # plydata = PlyData.read(path) 78 | # vertices = plydata['vertex'] 79 | # positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T 80 | # colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 81 | # normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T 82 | # return BasicPointCloud(points=positions, colors=colors, normals=normals) 83 | 84 | # def storePly(path, xyz, rgb): 85 | # # Define the dtype for the structured array 86 | # dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 87 | # ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), 88 | # ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 89 | 90 | # normals = np.zeros_like(xyz) 91 | 92 | # elements = np.empty(xyz.shape[0], dtype=dtype) 93 | # attributes = np.concatenate((xyz, normals, rgb), axis=1) 94 | # elements[:] = list(map(tuple, attributes)) 95 | 96 | # # Create the PlyData object and write to file 97 | # vertex_element = PlyElement.describe(elements, 'vertex') 98 | # ply_data = PlyData([vertex_element]) 99 | # ply_data.write(path) 100 | 101 | def readCameras(dataloader, zoom_scale=512, margin=0.0, frame_sample_interval=1): 102 | cam_infos = [] 103 | use_binarized_mask = dataloader.use_binarized_mask 104 | bbox3d_diameter = dataloader.bbox3d_diameter 105 | for frame_idx in range(len(dataloader)): 106 | if frame_idx % frame_sample_interval != 0: 107 | continue 108 | obj_data = dataloader[frame_idx] 109 | camK = np.array(obj_data['camK']) 110 | pose = np.array(obj_data['pose']) 111 | R = np.transpose(pose[:3,:3]) # R is stored transposed due to 'glm' in CUDA code 112 | T = pose[:3, 3] 113 | 114 | if 'image_path' not in obj_data: 115 | image_path = None 116 | image = (obj_data['image'] * 255).numpy() 117 | image = Image.fromarray(image.astype(np.uint8)) 118 | image_name = f'{frame_idx}.png' 119 | else: 120 | image_path = obj_data['image_path'] 121 | image = Image.open(image_path) 122 | image_name = os.path.basename(image_path) 123 | 124 | image = torch.from_numpy(np.array(image)) 125 | raw_height, raw_width = image.shape[:2] 126 | 127 | out = gs_utils.zoom_in_and_crop_with_offset(image, t=T, K=camK, 128 | radius=bbox3d_diameter/2, 129 | margin=margin, target_size=zoom_scale) 130 | image = out['zoom_image'].squeeze() 131 | height, width = image.shape[:2] 132 | 133 | # if 'coseg_mask_path' not in obj_data: 134 | # mask = np.ones((height, width, 1), dtype=np.float32) 135 | try: 136 | mask = Image.open(obj_data['coseg_mask_path']) 137 | mask = torch.from_numpy(np.array(mask, dtype=np.float32)) / 255.0 138 | mask = gs_utils.zoom_in_and_crop_with_offset( 139 | mask, t=T, K=camK, radius=bbox3d_diameter/2, 140 | margin=margin, target_size=zoom_scale 141 | )['zoom_image'].squeeze() 142 | if mask.dim() == 2: 143 | mask = mask[:, :, None] 144 | if use_binarized_mask: 145 | mask = mask.round() 146 | except Exception as e: 147 | print(e) 148 | mask = np.ones((height, width, 1), dtype=np.float32) 149 | 150 | image = (image * mask).type(torch.uint8).numpy() 151 | image = Image.fromarray(image.astype(np.uint8)) 152 | 153 | zoom_camk = out['zoom_camK'].squeeze().numpy() 154 | zoom_offset = out['zoom_offset'].squeeze().numpy() 155 | cx_offset = zoom_offset[0] 156 | cy_offset = zoom_offset[1] 157 | cam_fx = zoom_camk[0, 0] 158 | cam_fy = zoom_camk[1, 1] 159 | FovX = focal2fov(cam_fx, width) 160 | FovY = focal2fov(cam_fy, height) 161 | 162 | cam_info = CameraInfo(R=R, T=T, FovY=FovY, FovX=FovX, 163 | cx_offset=cx_offset, cy_offset=cy_offset, 164 | uid=frame_idx, image=image, 165 | image_path=image_path, image_name=image_name, 166 | width=width, height=height) 167 | cam_infos.append(cam_info) 168 | return cam_infos 169 | 170 | 171 | def readObjectInfo(train_dataset, test_dataset, model_path, num_points=4096, zoom_scale=512, margin=0.0, random_points3D=False): 172 | 173 | print(f"Reading {len(train_dataset)} training image ...") 174 | train_cam_infos = readCameras(train_dataset, zoom_scale=zoom_scale, margin=margin, frame_sample_interval=1) 175 | num_training_samples = len(train_cam_infos) 176 | print(f"{num_training_samples} training samples") 177 | print(f"-----------------------------------------") 178 | 179 | test_interval = len(test_dataset) // 3 180 | test_cam_infos = readCameras(test_dataset, zoom_scale=zoom_scale, margin=margin, frame_sample_interval=test_interval) 181 | num_test_samples = len(test_cam_infos) 182 | print(f"{num_test_samples} testing samples") 183 | print(f"----------------------------------------") 184 | 185 | nerf_normalization = getNerfppNorm(train_cam_infos) 186 | ply_path = os.path.join(model_path, "3DGS_points3d.ply") 187 | if not random_points3D: 188 | obj_bbox3D = train_dataset.obj_bbox3d 189 | # obj_bbox3D = np.loadtxt(os.path.join(path, 'corners.txt')) 190 | min_3D_corner = obj_bbox3D.min(axis=0) 191 | max_3D_corner = obj_bbox3D.max(axis=0) 192 | obj_bbox3D_dims = max_3D_corner - min_3D_corner 193 | grid_cube_size = (np.prod(obj_bbox3D_dims, axis=0) / num_points)**(1/3) 194 | 195 | xnum, ynum, znum = np.ceil(obj_bbox3D_dims / grid_cube_size).astype(np.int64) 196 | xmin, ymin, zmin = min_3D_corner 197 | xmax, ymax, zmax = max_3D_corner 198 | zgrid, ygrid, xgrid = np.meshgrid(np.linspace(zmin, zmax, znum), 199 | np.linspace(ymin, ymax, ynum), 200 | np.linspace(xmin, xmax, xnum), 201 | indexing='ij') 202 | xyz = np.stack([xgrid, ygrid, zgrid], axis=-1).reshape(-1, 3) 203 | num_pts = xyz.shape[0] 204 | shs = np.random.random((num_pts, 3)) / 255.0 205 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))) 206 | 207 | # storePly(ply_path, xyz, SH2RGB(shs) * 255) 208 | 209 | if random_points3D: 210 | # Since this data set has no colmap data, we start with random points 211 | # num_pts = 100_000 212 | print(f"Generating random point cloud ({num_points})...") 213 | 214 | # We create random points inside the bounds of the synthetic Blender scenes 215 | xyz = np.random.random((num_points, 3)) #* 2.6 - 1.3 216 | shs = np.random.random((num_points, 3)) #/ 255.0 217 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_points, 3))) 218 | 219 | # storePly(ply_path, xyz, SH2RGB(shs) * 255) 220 | 221 | # try: 222 | # pcd = fetchPly(ply_path) 223 | # except: 224 | # pcd = None 225 | 226 | object_info = SceneInfo(point_cloud=pcd, 227 | train_cameras=train_cam_infos, 228 | test_cameras=test_cam_infos, 229 | nerf_normalization=nerf_normalization, 230 | ply_path=ply_path) 231 | return object_info 232 | 233 | 234 | 235 | -------------------------------------------------------------------------------- /gaussian_object/gaussian_render.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | import math 12 | import torch 13 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 14 | 15 | from gaussian_object.utils.sh_utils import eval_sh 16 | from gaussian_object.gaussian_model import GaussianModel 17 | 18 | 19 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None): 20 | """ 21 | Render the scene. 22 | 23 | Background tensor (bg_color) must be on GPU! 24 | """ 25 | 26 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 27 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 28 | try: 29 | screenspace_points.retain_grad() 30 | except: 31 | pass 32 | 33 | # Set up rasterization configuration 34 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 35 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 36 | 37 | raster_settings = GaussianRasterizationSettings( 38 | image_height=int(viewpoint_camera.image_height), 39 | image_width=int(viewpoint_camera.image_width), 40 | tanfovx=tanfovx, 41 | tanfovy=tanfovy, 42 | bg=bg_color, 43 | scale_modifier=scaling_modifier, 44 | viewmatrix=viewpoint_camera.world_view_transform, 45 | projmatrix=viewpoint_camera.full_proj_transform, 46 | sh_degree=pc.active_sh_degree, 47 | campos=viewpoint_camera.camera_center, 48 | prefiltered=False, 49 | debug=pipe.debug 50 | ) 51 | 52 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 53 | 54 | means3D = pc.get_xyz 55 | means2D = screenspace_points 56 | opacity = pc.get_opacity 57 | 58 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 59 | # scaling / rotation by the rasterizer. 60 | scales = None 61 | rotations = None 62 | cov3D_precomp = None 63 | if pipe.compute_cov3D_python: 64 | cov3D_precomp = pc.get_covariance(scaling_modifier) 65 | else: 66 | scales = pc.get_scaling 67 | rotations = pc.get_rotation 68 | 69 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 70 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 71 | shs = None 72 | colors_precomp = None 73 | if override_color is None: 74 | if pipe.convert_SHs_python: 75 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) 76 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) 77 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) 78 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 79 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 80 | else: 81 | shs = pc.get_features 82 | else: 83 | colors_precomp = override_color 84 | 85 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 86 | rendered_image, radii = rasterizer( 87 | means3D = means3D, 88 | means2D = means2D, 89 | shs = shs, 90 | colors_precomp = colors_precomp, 91 | opacities = opacity, 92 | scales = scales, 93 | rotations = rotations, 94 | cov3D_precomp = cov3D_precomp) 95 | 96 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 97 | # They will be excluded from value updates used in the splitting criteria. 98 | return {"render": rendered_image, 99 | "viewspace_points": screenspace_points, 100 | "visibility_filter" : radii > 0, 101 | "radii": radii} 102 | -------------------------------------------------------------------------------- /gaussian_object/gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 15 | 16 | from gaussian_object.gaussian_model import GaussianModel 17 | from gaussian_object.utils.sh_utils import eval_sh 18 | 19 | 20 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None): 21 | """ 22 | Render the scene. 23 | 24 | Background tensor (bg_color) must be on GPU! 25 | """ 26 | 27 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 28 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 29 | try: 30 | screenspace_points.retain_grad() 31 | except: 32 | pass 33 | 34 | # Set up rasterization configuration 35 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 36 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 37 | 38 | raster_settings = GaussianRasterizationSettings( 39 | image_height=int(viewpoint_camera.image_height), 40 | image_width=int(viewpoint_camera.image_width), 41 | tanfovx=tanfovx, 42 | tanfovy=tanfovy, 43 | bg=bg_color, 44 | scale_modifier=scaling_modifier, 45 | viewmatrix=viewpoint_camera.world_view_transform, 46 | projmatrix=viewpoint_camera.full_proj_transform, 47 | sh_degree=pc.active_sh_degree, 48 | campos=viewpoint_camera.camera_center, 49 | prefiltered=False, 50 | debug=pipe.debug 51 | ) 52 | 53 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 54 | 55 | means3D = pc.get_xyz 56 | means2D = screenspace_points 57 | opacity = pc.get_opacity 58 | 59 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 60 | # scaling / rotation by the rasterizer. 61 | scales = None 62 | rotations = None 63 | cov3D_precomp = None 64 | if pipe.compute_cov3D_python: 65 | cov3D_precomp = pc.get_covariance(scaling_modifier) 66 | else: 67 | scales = pc.get_scaling 68 | rotations = pc.get_rotation 69 | 70 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 71 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 72 | shs = None 73 | colors_precomp = None 74 | if override_color is None: 75 | if pipe.convert_SHs_python: 76 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) 77 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) 78 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) 79 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 80 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 81 | else: 82 | shs = pc.get_features 83 | else: 84 | colors_precomp = override_color 85 | 86 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 87 | rendered_image, radii = rasterizer( 88 | means3D = means3D, 89 | means2D = means2D, 90 | shs = shs, 91 | colors_precomp = colors_precomp, 92 | opacities = opacity, 93 | scales = scales, 94 | rotations = rotations, 95 | cov3D_precomp = cov3D_precomp) 96 | 97 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 98 | # They will be excluded from value updates used in the splitting criteria. 99 | return {"render": rendered_image, 100 | "viewspace_points": screenspace_points, 101 | "visibility_filter" : radii > 0, 102 | "radii": radii} 103 | -------------------------------------------------------------------------------- /gaussian_object/gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import traceback 14 | import socket 15 | import json 16 | from gaussian_object.cameras import MiniCam 17 | 18 | host = "127.0.0.1" 19 | port = 6009 20 | 21 | conn = None 22 | addr = None 23 | 24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 25 | 26 | def init(wish_host, wish_port): 27 | global host, port, listener 28 | host = wish_host 29 | port = wish_port 30 | listener.bind((host, port)) 31 | listener.listen() 32 | listener.settimeout(0) 33 | 34 | def try_connect(): 35 | global conn, addr, listener 36 | try: 37 | conn, addr = listener.accept() 38 | print(f"\nConnected by {addr}") 39 | conn.settimeout(None) 40 | except Exception as inst: 41 | pass 42 | 43 | def read(): 44 | global conn 45 | messageLength = conn.recv(4) 46 | messageLength = int.from_bytes(messageLength, 'little') 47 | message = conn.recv(messageLength) 48 | return json.loads(message.decode("utf-8")) 49 | 50 | def send(message_bytes, verify): 51 | global conn 52 | if message_bytes != None: 53 | conn.sendall(message_bytes) 54 | conn.sendall(len(verify).to_bytes(4, 'little')) 55 | conn.sendall(bytes(verify, 'ascii')) 56 | 57 | def receive(): 58 | message = read() 59 | 60 | width = message["resolution_x"] 61 | height = message["resolution_y"] 62 | 63 | if width != 0 and height != 0: 64 | try: 65 | do_training = bool(message["train"]) 66 | fovy = message["fov_y"] 67 | fovx = message["fov_x"] 68 | znear = message["z_near"] 69 | zfar = message["z_far"] 70 | do_shs_python = bool(message["shs_python"]) 71 | do_rot_scale_python = bool(message["rot_scale_python"]) 72 | keep_alive = bool(message["keep_alive"]) 73 | scaling_modifier = message["scaling_modifier"] 74 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() 75 | world_view_transform[:,1] = -world_view_transform[:,1] 76 | world_view_transform[:,2] = -world_view_transform[:,2] 77 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() 78 | full_proj_transform[:,1] = -full_proj_transform[:,1] 79 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) 80 | except Exception as e: 81 | print("") 82 | traceback.print_exc() 83 | raise e 84 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier 85 | else: 86 | return None, None, None, None, None, None -------------------------------------------------------------------------------- /gaussian_object/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | def l1_loss(network_output, gt, size_average=True): 18 | loss = torch.abs((network_output - gt)) 19 | if size_average: 20 | return loss.mean() 21 | return loss 22 | 23 | def l2_loss(network_output, gt): 24 | return ((network_output - gt) ** 2).mean() 25 | 26 | def gaussian(window_size, sigma): 27 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 28 | return gauss / gauss.sum() 29 | 30 | def create_window(window_size, channel): 31 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 32 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 33 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 34 | return window 35 | 36 | def ssim(img1, img2, window_size=11, size_average=True): 37 | channel = img1.size(-3) 38 | window = create_window(window_size, channel) 39 | 40 | if img1.is_cuda: 41 | window = window.cuda(img1.get_device()) 42 | window = window.type_as(img1) 43 | 44 | return _ssim(img1, img2, window, window_size, channel, size_average) 45 | 46 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 47 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 48 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 49 | 50 | mu1_sq = mu1.pow(2) 51 | mu2_sq = mu2.pow(2) 52 | mu1_mu2 = mu1 * mu2 53 | 54 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 55 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 56 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 57 | 58 | C1 = 0.01 ** 2 59 | C2 = 0.03 ** 2 60 | 61 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 62 | 63 | if size_average: 64 | return ssim_map.mean() 65 | else: 66 | return ssim_map#.mean(1).mean(1).mean(1) 67 | 68 | -------------------------------------------------------------------------------- /gaussian_object/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /gaussian_object/utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | import numpy as np 12 | 13 | from gaussian_object.cameras import Camera 14 | from gaussian_object.utils.general_utils import PILtoTorch 15 | from gaussian_object.utils.graphics_utils import fov2focal 16 | 17 | WARNED = False 18 | 19 | def loadCam(args, id, cam_info, resolution_scale): 20 | orig_w, orig_h = cam_info.image.size 21 | 22 | if args.resolution in [1, 2, 4, 8]: 23 | resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution)) 24 | else: # should be a type that converts to float 25 | if args.resolution == -1: 26 | if orig_w > 1600: 27 | global WARNED 28 | if not WARNED: 29 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " 30 | "If this is not desired, please explicitly specify '--resolution/-r' as 1") 31 | WARNED = True 32 | global_down = orig_w / 1600 33 | else: 34 | global_down = 1 35 | else: 36 | global_down = orig_w / args.resolution 37 | 38 | scale = float(global_down) * float(resolution_scale) 39 | resolution = (int(orig_w / scale), int(orig_h / scale)) 40 | 41 | resized_image_rgb = PILtoTorch(cam_info.image, resolution) 42 | 43 | gt_image = resized_image_rgb[:3, ...] 44 | loaded_mask = None 45 | 46 | # if cam_info.mask is not None: 47 | # mask = PILtoTorch(cam_info.mask, resolution) 48 | # else: 49 | # mask = torch.ones_like(gt_image)[..., 0][None] 50 | 51 | if resized_image_rgb.shape[1] == 4: 52 | loaded_mask = resized_image_rgb[3:4, ...] 53 | 54 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 55 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, 56 | # mask=mask, 57 | cx_offset=cam_info.cx_offset, 58 | cy_offset=cam_info.cy_offset, 59 | image=gt_image, gt_alpha_mask=loaded_mask, 60 | image_name=cam_info.image_name, uid=id, data_device=args.data_device) 61 | 62 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 63 | camera_list = [] 64 | 65 | for id, c in enumerate(cam_infos): 66 | camera_list.append(loadCam(args, id, c, resolution_scale)) 67 | 68 | return camera_list 69 | 70 | def camera_to_JSON(id, camera : Camera): 71 | Rt = np.zeros((4, 4)) 72 | Rt[:3, :3] = camera.R.transpose() 73 | Rt[:3, 3] = camera.T 74 | Rt[3, 3] = 1.0 75 | 76 | W2C = np.linalg.inv(Rt) 77 | pos = W2C[:3, 3] 78 | rot = W2C[:3, :3] 79 | serializable_array_2d = [x.tolist() for x in rot] 80 | camera_entry = { 81 | 'id' : id, 82 | 'img_name' : camera.image_name, 83 | 'width' : camera.width, 84 | 'height' : camera.height, 85 | 'position': pos.tolist(), 86 | 'rotation': serializable_array_2d, 87 | 'fy' : fov2focal(camera.FovY, camera.height), 88 | 'fx' : fov2focal(camera.FovX, camera.width) 89 | } 90 | return camera_entry 91 | -------------------------------------------------------------------------------- /gaussian_object/utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | 18 | def inverse_sigmoid(x): 19 | return torch.log(x/(1-x)) 20 | 21 | def PILtoTorch(pil_image, resolution): 22 | resized_image_PIL = pil_image.resize(resolution) 23 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 24 | if len(resized_image.shape) == 3: 25 | return resized_image.permute(2, 0, 1) 26 | else: 27 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 28 | 29 | def get_expon_lr_func( 30 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 31 | ): 32 | """ 33 | Copied from Plenoxels 34 | 35 | Continuous learning rate decay function. Adapted from JaxNeRF 36 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 37 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 38 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 39 | function of lr_delay_mult, such that the initial learning rate is 40 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 41 | to the normal learning rate when steps>lr_delay_steps. 42 | :param conf: config subtree 'lr' or similar 43 | :param max_steps: int, the number of steps during optimization. 44 | :return HoF which takes step as input 45 | """ 46 | 47 | def helper(step): 48 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 49 | # Disable this parameter 50 | return 0.0 51 | if lr_delay_steps > 0: 52 | # A kind of reverse cosine decay. 53 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 54 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 55 | ) 56 | else: 57 | delay_rate = 1.0 58 | t = np.clip(step / max_steps, 0, 1) 59 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 60 | return delay_rate * log_lerp 61 | 62 | return helper 63 | 64 | def strip_lowerdiag(L): 65 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 66 | 67 | uncertainty[:, 0] = L[:, 0, 0] 68 | uncertainty[:, 1] = L[:, 0, 1] 69 | uncertainty[:, 2] = L[:, 0, 2] 70 | uncertainty[:, 3] = L[:, 1, 1] 71 | uncertainty[:, 4] = L[:, 1, 2] 72 | uncertainty[:, 5] = L[:, 2, 2] 73 | return uncertainty 74 | 75 | def strip_symmetric(sym): 76 | return strip_lowerdiag(sym) 77 | 78 | def build_rotation(r): 79 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) 80 | 81 | q = r / norm[:, None] 82 | 83 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 84 | 85 | r = q[:, 0] 86 | x = q[:, 1] 87 | y = q[:, 2] 88 | z = q[:, 3] 89 | 90 | R[:, 0, 0] = 1 - 2 * (y*y + z*z) 91 | R[:, 0, 1] = 2 * (x*y - r*z) 92 | R[:, 0, 2] = 2 * (x*z + r*y) 93 | R[:, 1, 0] = 2 * (x*y + r*z) 94 | R[:, 1, 1] = 1 - 2 * (x*x + z*z) 95 | R[:, 1, 2] = 2 * (y*z - r*x) 96 | R[:, 2, 0] = 2 * (x*z - r*y) 97 | R[:, 2, 1] = 2 * (y*z + r*x) 98 | R[:, 2, 2] = 1 - 2 * (x*x + y*y) 99 | return R 100 | 101 | def build_scaling_rotation(s, r): 102 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 103 | R = build_rotation(r) 104 | 105 | L[:,0,0] = s[:,0] 106 | L[:,1,1] = s[:,1] 107 | L[:,2,2] = s[:,2] 108 | 109 | L = R @ L 110 | return L 111 | 112 | def safe_state(silent): 113 | old_f = sys.stdout 114 | class F: 115 | def __init__(self, silent): 116 | self.silent = silent 117 | 118 | def write(self, x): 119 | if not self.silent: 120 | if x.endswith("\n"): 121 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 122 | else: 123 | old_f.write(x) 124 | 125 | def flush(self): 126 | old_f.flush() 127 | 128 | sys.stdout = F(silent) 129 | 130 | random.seed(0) 131 | np.random.seed(0) 132 | torch.manual_seed(0) 133 | torch.cuda.set_device(torch.device("cuda:0")) 134 | -------------------------------------------------------------------------------- /gaussian_object/utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | class BasicPointCloud(NamedTuple): 18 | points : np.array 19 | colors : np.array 20 | normals : np.array 21 | 22 | def geom_transform_points(points, transf_matrix): 23 | P, _ = points.shape 24 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 25 | points_hom = torch.cat([points, ones], dim=1) 26 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 27 | 28 | denom = points_out[..., 3:] + 0.0000001 29 | return (points_out[..., :3] / denom).squeeze(dim=0) 30 | 31 | def getWorld2View(R, t): 32 | Rt = np.zeros((4, 4)) 33 | Rt[:3, :3] = R.transpose() 34 | Rt[:3, 3] = t 35 | Rt[3, 3] = 1.0 36 | return np.float32(Rt) 37 | 38 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 39 | Rt = np.zeros((4, 4)) 40 | Rt[:3, :3] = R.transpose() 41 | Rt[:3, 3] = t 42 | Rt[3, 3] = 1.0 43 | 44 | C2W = np.linalg.inv(Rt) 45 | cam_center = C2W[:3, 3] 46 | cam_center = (cam_center + translate) * scale 47 | C2W[:3, 3] = cam_center 48 | Rt = np.linalg.inv(C2W) 49 | return np.float32(Rt) 50 | 51 | def getProjectionMatrix(znear, zfar, fovX, fovY, cx_offset=0, cy_offset=0): 52 | tanHalfFovY = math.tan((fovY / 2)) 53 | tanHalfFovX = math.tan((fovX / 2)) 54 | 55 | top = tanHalfFovY * znear 56 | bottom = -top 57 | right = tanHalfFovX * znear 58 | left = -right 59 | 60 | P = torch.zeros(4, 4) 61 | 62 | z_sign = 1.0 63 | 64 | P[0, 0] = 2.0 * znear / (right - left) 65 | P[1, 1] = 2.0 * znear / (top - bottom) 66 | P[0, 2] = (right + left) / (right - left) + cx_offset 67 | P[1, 2] = (top + bottom) / (top - bottom) + cy_offset 68 | P[3, 2] = z_sign 69 | P[2, 2] = z_sign * zfar / (zfar - znear) 70 | P[2, 3] = -(zfar * znear) / (zfar - znear) 71 | return P 72 | 73 | def fov2focal(fov, pixels): 74 | return pixels / (2 * math.tan(fov / 2)) 75 | 76 | def focal2fov(focal, pixels): 77 | return 2*math.atan(pixels/(2*focal)) -------------------------------------------------------------------------------- /gaussian_object/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | 14 | def mse(img1, img2): 15 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 16 | 17 | def psnr(img1, img2): 18 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 19 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 20 | -------------------------------------------------------------------------------- /gaussian_object/utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | def l1_loss(network_output, gt): 18 | return torch.abs((network_output - gt)).mean() 19 | 20 | def l2_loss(network_output, gt): 21 | return ((network_output - gt) ** 2).mean() 22 | 23 | def gaussian(window_size, sigma): 24 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 25 | return gauss / gauss.sum() 26 | 27 | def create_window(window_size, channel): 28 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 29 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 30 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 31 | return window 32 | 33 | def ssim(img1, img2, window_size=11, size_average=True): 34 | channel = img1.size(-3) 35 | window = create_window(window_size, channel) 36 | 37 | if img1.is_cuda: 38 | window = window.cuda(img1.get_device()) 39 | window = window.type_as(img1) 40 | 41 | return _ssim(img1, img2, window, window_size, channel, size_average) 42 | 43 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 44 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 45 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 46 | 47 | mu1_sq = mu1.pow(2) 48 | mu2_sq = mu2.pow(2) 49 | mu1_mu2 = mu1 * mu2 50 | 51 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 52 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 53 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 54 | 55 | C1 = 0.01 ** 2 56 | C2 = 0.03 ** 2 57 | 58 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 59 | 60 | if size_average: 61 | return ssim_map.mean() 62 | else: 63 | return ssim_map.mean(1).mean(1).mean(1) 64 | 65 | -------------------------------------------------------------------------------- /gaussian_object/utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /gaussian_object/utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | def mkdir_p(folder_path): 17 | # Creates a directory. equivalent to using mkdir -p on the command line 18 | try: 19 | makedirs(folder_path) 20 | except OSError as exc: # Python >2.5 21 | if exc.errno == EEXIST and path.isdir(folder_path): 22 | pass 23 | else: 24 | raise 25 | 26 | def searchForMaxIteration(folder): 27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 28 | return max(saved_iters) 29 | -------------------------------------------------------------------------------- /install_env.sh: -------------------------------------------------------------------------------- 1 | pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 -f https://download.pytorch.org/whl/torch_stable.html 2 | pip install "git+https://github.com/facebookresearch/pytorch3d.git" 3 | pip install xformers==0.0.22 4 | 5 | cd submodules/Connected_components_PyTorch 6 | python setup.py install 7 | 8 | cd ../diff-gaussian-rasterization 9 | python setup.py install 10 | 11 | cd ../simple-knn 12 | python setup.py install 13 | 14 | cd ../../model/curope 15 | python setup.py install 16 | cd ../.. 17 | -------------------------------------------------------------------------------- /misc_utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from math import exp 14 | import torch.nn.functional as F 15 | from torch.autograd import Variable 16 | 17 | def l1_loss(network_output, gt, size_average=True): 18 | loss = torch.abs((network_output - gt)) 19 | if size_average: 20 | return loss.mean() 21 | return loss 22 | 23 | def l2_loss(network_output, gt): 24 | return ((network_output - gt) ** 2).mean() 25 | 26 | def gaussian(window_size, sigma): 27 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 28 | return gauss / gauss.sum() 29 | 30 | def create_window(window_size, channel): 31 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 32 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 33 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 34 | return window 35 | 36 | def ssim(img1, img2, window_size=11, size_average=True): 37 | channel = img1.size(-3) 38 | window = create_window(window_size, channel) 39 | 40 | if img1.is_cuda: 41 | window = window.cuda(img1.get_device()) 42 | window = window.type_as(img1) 43 | 44 | return _ssim(img1, img2, window, window_size, channel, size_average) 45 | 46 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 47 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 48 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 49 | 50 | mu1_sq = mu1.pow(2) 51 | mu2_sq = mu2.pow(2) 52 | mu1_mu2 = mu1 * mu2 53 | 54 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 55 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 56 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 57 | 58 | C1 = 0.01 ** 2 59 | C2 = 0.03 ** 2 60 | 61 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 62 | 63 | if size_average: 64 | return ssim_map.mean() 65 | else: 66 | return ssim_map#.mean(1).mean(1).mean(1) 67 | 68 | -------------------------------------------------------------------------------- /misc_utils/metric_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def calc_pose_error(pred_RT, gt_RT, unit='cm'): 5 | pred_R = pred_RT[:3, :3] 6 | pred_t = pred_RT[:3, 3] 7 | gt_R = gt_RT[:3, :3] 8 | gt_t = gt_RT[:3, 3] 9 | R_err = np.arccos(np.clip(np.trace(pred_R @ gt_R.T) / 2 - 0.5, -1, 1.0)) / np.pi * 180 10 | t_err = np.linalg.norm(pred_t - gt_t) 11 | 12 | if unit == 'm': 13 | unit_factor = 1 14 | elif unit == 'cm': 15 | unit_factor = 100 16 | elif unit == 'mm': 17 | unit_factor = 1000 18 | else: 19 | raise ValueError('Invalid unit') 20 | 21 | t_err *= unit_factor 22 | return R_err, t_err 23 | 24 | def calc_add_metric(model_3D_pts, diameter, pose_pred, pose_target, percentage=0.1, return_error=False, syn=False, model_unit='m'): 25 | from scipy import spatial 26 | # Dim check: 27 | if pose_pred.shape[0] == 4: 28 | pose_pred = pose_pred[:3] 29 | if pose_target.shape[0] == 4: 30 | pose_target = pose_target[:3] 31 | 32 | diameter_thres = diameter * percentage 33 | model_pred = np.dot(model_3D_pts, pose_pred[:, :3].T) + pose_pred[:, 3] 34 | model_target = np.dot(model_3D_pts, pose_target[:, :3].T) + pose_target[:, 3] 35 | 36 | if syn: 37 | mean_dist_index = spatial.cKDTree(model_pred) 38 | mean_dist, _ = mean_dist_index.query(model_target, k=1) 39 | mean_dist = np.mean(mean_dist) 40 | else: 41 | mean_dist = np.mean(np.linalg.norm(model_pred - model_target, axis=-1)) 42 | 43 | if return_error: 44 | return mean_dist 45 | elif mean_dist < diameter_thres: 46 | return True 47 | else: 48 | return False 49 | 50 | def calc_projection_2d_error(model_3D_pts, pose_pred, pose_targets, K, pixels=5, return_error=True): 51 | def project(xyz, K, RT): 52 | """ 53 | NOTE: need to use original K 54 | xyz: [N, 3] 55 | K: [3, 3] 56 | RT: [3, 4] 57 | """ 58 | xyz = np.dot(xyz, RT[:, :3].T) + RT[:, 3:].T 59 | xyz = np.dot(xyz, K.T) 60 | xy = xyz[:, :2] / xyz[:, 2:] 61 | return xy 62 | 63 | # Dim check: 64 | if pose_pred.shape[0] == 4: 65 | pose_pred = pose_pred[:3] 66 | if pose_targets.shape[0] == 4: 67 | pose_targets = pose_targets[:3] 68 | 69 | model_2d_pred = project(model_3D_pts, K, pose_pred) # pose_pred: 3*4 70 | model_2d_targets = project(model_3D_pts, K, pose_targets) 71 | proj_mean_diff = np.mean(np.linalg.norm(model_2d_pred - model_2d_targets, axis=-1)) 72 | if return_error: 73 | return proj_mean_diff 74 | elif proj_mean_diff < pixels: 75 | return True 76 | else: 77 | return False 78 | 79 | def calc_bbox_IOU(pd_bbox, gt_bbox, iou_threshold=0.5, return_iou=False): 80 | px1, py1, px2, py2 = pd_bbox.squeeze() 81 | gx1, gy1, gx2, gy2 = gt_bbox.squeeze() 82 | inter_wid = np.maximum(np.minimum(px2, gx2) - np.maximum(px1, gx1), 0) 83 | inter_hei = np.maximum(np.minimum(py2, gy2) - np.maximum(py1, gy1), 0) 84 | inter_area = inter_wid * inter_hei 85 | outer_wid = np.maximum(px2, gx2) - np.minimum(px1, gx1) 86 | outer_hei = np.maximum(py2, gy2) - np.minimum(py1, gy1) 87 | outer_area = outer_wid * outer_hei 88 | iou = inter_area / outer_area 89 | if return_iou: 90 | return iou 91 | elif iou > iou_threshold: 92 | return True 93 | else: 94 | return False 95 | 96 | def aggregate_metrics(metrics, pose_thres=[1, 3, 5], proj2d_thres=5): 97 | """ Aggregate metrics for the whole dataset: 98 | (This method should be called once per dataset) 99 | """ 100 | R_errs = metrics["R_errs"] 101 | t_errs = metrics["t_errs"] 102 | 103 | agg_metric = {} 104 | for pose_threshold in pose_thres: 105 | agg_metric[f"{pose_threshold}˚@{pose_threshold}cm"] = np.mean( 106 | (np.array(R_errs) < pose_threshold) & (np.array(t_errs) < pose_threshold) 107 | ) 108 | agg_metric[f"{pose_threshold}cm"] = np.mean((np.array(t_errs) < pose_threshold)) 109 | agg_metric[f"{pose_threshold}˚"] = np.mean((np.array(R_errs) < pose_threshold)) 110 | 111 | if "ADD_metric" in metrics: 112 | ADD_metric = metrics['ADD_metric'] 113 | agg_metric["ADD"] = np.mean(ADD_metric) 114 | 115 | if "Proj2D" in metrics: 116 | proj2D_metric = metrics['Proj2D'] 117 | agg_metric[f"pix@{proj2d_thres}"] = np.mean(np.array(proj2D_metric) < proj2d_thres) 118 | 119 | return agg_metric 120 | 121 | -------------------------------------------------------------------------------- /misc_utils/warmup_lr.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | class CosineAnnealingWarmupRestarts(_LRScheduler): 6 | """ 7 | optimizer (Optimizer): Wrapped optimizer. 8 | first_cycle_steps (int): First cycle step size. 9 | cycle_mult(float): Cycle steps magnification. Default: -1. 10 | max_lr(float): First cycle's max learning rate. Default: 0.1. 11 | min_lr(float): Min learning rate. Default: 0.001. 12 | warmup_steps(int): Linear warmup step size. Default: 0. 13 | gamma(float): Decrease rate of max learning rate by cycle. Default: 1. 14 | last_epoch (int): The index of last epoch. Default: -1. 15 | """ 16 | 17 | def __init__(self, 18 | optimizer : torch.optim.Optimizer, 19 | first_cycle_steps : int, 20 | cycle_mult : float = 1., 21 | max_lr : float = 0.1, 22 | min_lr : float = 0.001, 23 | warmup_steps : int = 0, 24 | gamma : float = 1., 25 | last_epoch : int = -1 26 | ): 27 | assert warmup_steps < first_cycle_steps 28 | 29 | self.first_cycle_steps = first_cycle_steps # first cycle step size 30 | self.cycle_mult = cycle_mult # cycle steps magnification 31 | self.base_max_lr = max_lr # first max learning rate 32 | self.max_lr = max_lr # max learning rate in the current cycle 33 | self.min_lr = min_lr # min learning rate 34 | self.warmup_steps = warmup_steps # warmup step size 35 | self.gamma = gamma # decrease rate of max learning rate by cycle 36 | 37 | self.cur_cycle_steps = first_cycle_steps # first cycle step size 38 | self.cycle = 0 # cycle count 39 | self.step_in_cycle = last_epoch # step size of the current cycle 40 | 41 | super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch) 42 | 43 | # set learning rate min_lr 44 | self.init_lr() 45 | 46 | def init_lr(self): 47 | self.base_lrs = [] 48 | for param_group in self.optimizer.param_groups: 49 | param_group['lr'] = self.min_lr 50 | self.base_lrs.append(self.min_lr) 51 | 52 | def get_lr(self): 53 | if self.step_in_cycle == -1: 54 | return self.base_lrs 55 | elif self.step_in_cycle < self.warmup_steps: 56 | return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs] 57 | else: 58 | return [base_lr + (self.max_lr - base_lr) \ 59 | * (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \ 60 | / (self.cur_cycle_steps - self.warmup_steps))) / 2 61 | for base_lr in self.base_lrs] 62 | 63 | def step(self, epoch=None): 64 | if epoch is None: 65 | epoch = self.last_epoch + 1 66 | self.step_in_cycle = self.step_in_cycle + 1 67 | if self.step_in_cycle >= self.cur_cycle_steps: 68 | self.cycle += 1 69 | self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps 70 | self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps 71 | else: 72 | if epoch >= self.first_cycle_steps: 73 | if self.cycle_mult == 1.: 74 | self.step_in_cycle = epoch % self.first_cycle_steps 75 | self.cycle = epoch // self.first_cycle_steps 76 | else: 77 | n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult)) 78 | self.cycle = n 79 | self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1)) 80 | self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n) 81 | else: 82 | self.cur_cycle_steps = self.first_cycle_steps 83 | self.step_in_cycle = epoch 84 | 85 | self.max_lr = self.base_max_lr * (self.gamma**self.cycle) 86 | self.last_epoch = math.floor(epoch) 87 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 88 | param_group['lr'] = lr 89 | -------------------------------------------------------------------------------- /model/curope/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | from .curope2d import cuRoPE2D 5 | -------------------------------------------------------------------------------- /model/curope/curope.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (C) 2022-present Naver Corporation. All rights reserved. 3 | Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 4 | */ 5 | 6 | #include 7 | 8 | // forward declaration 9 | void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ); 10 | 11 | void rope_2d_cpu( torch::Tensor tokens, const torch::Tensor positions, const float base, const float fwd ) 12 | { 13 | const int B = tokens.size(0); 14 | const int N = tokens.size(1); 15 | const int H = tokens.size(2); 16 | const int D = tokens.size(3) / 4; 17 | 18 | auto tok = tokens.accessor(); 19 | auto pos = positions.accessor(); 20 | 21 | for (int b = 0; b < B; b++) { 22 | for (int x = 0; x < 2; x++) { // y and then x (2d) 23 | for (int n = 0; n < N; n++) { 24 | 25 | // grab the token position 26 | const int p = pos[b][n][x]; 27 | 28 | for (int h = 0; h < H; h++) { 29 | for (int d = 0; d < D; d++) { 30 | // grab the two values 31 | float u = tok[b][n][h][d+0+x*2*D]; 32 | float v = tok[b][n][h][d+D+x*2*D]; 33 | 34 | // grab the cos,sin 35 | const float inv_freq = fwd * p / powf(base, d/float(D)); 36 | float c = cosf(inv_freq); 37 | float s = sinf(inv_freq); 38 | 39 | // write the result 40 | tok[b][n][h][d+0+x*2*D] = u*c - v*s; 41 | tok[b][n][h][d+D+x*2*D] = v*c + u*s; 42 | } 43 | } 44 | } 45 | } 46 | } 47 | } 48 | 49 | void rope_2d( torch::Tensor tokens, // B,N,H,D 50 | const torch::Tensor positions, // B,N,2 51 | const float base, 52 | const float fwd ) 53 | { 54 | TORCH_CHECK(tokens.dim() == 4, "tokens must have 4 dimensions"); 55 | TORCH_CHECK(positions.dim() == 3, "positions must have 3 dimensions"); 56 | TORCH_CHECK(tokens.size(0) == positions.size(0), "batch size differs between tokens & positions"); 57 | TORCH_CHECK(tokens.size(1) == positions.size(1), "seq_length differs between tokens & positions"); 58 | TORCH_CHECK(positions.size(2) == 2, "positions.shape[2] must be equal to 2"); 59 | TORCH_CHECK(tokens.is_cuda() == positions.is_cuda(), "tokens and positions are not on the same device" ); 60 | 61 | if (tokens.is_cuda()) 62 | rope_2d_cuda( tokens, positions, base, fwd ); 63 | else 64 | rope_2d_cpu( tokens, positions, base, fwd ); 65 | } 66 | 67 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 68 | m.def("rope_2d", &rope_2d, "RoPE 2d forward/backward"); 69 | } 70 | -------------------------------------------------------------------------------- /model/curope/curope2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | import torch 5 | 6 | try: 7 | import curope as _kernels # run `python setup.py install` 8 | except ModuleNotFoundError: 9 | from . import curope as _kernels # run `python setup.py build_ext --inplace` 10 | 11 | 12 | class cuRoPE2D_func (torch.autograd.Function): 13 | 14 | @staticmethod 15 | def forward(ctx, tokens, positions, base, F0=1): 16 | ctx.save_for_backward(positions) 17 | ctx.saved_base = base 18 | ctx.saved_F0 = F0 19 | # tokens = tokens.clone() # uncomment this if inplace doesn't work 20 | _kernels.rope_2d( tokens, positions, base, F0 ) 21 | ctx.mark_dirty(tokens) 22 | return tokens 23 | 24 | @staticmethod 25 | def backward(ctx, grad_res): 26 | positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0 27 | _kernels.rope_2d( grad_res, positions, base, -F0 ) 28 | ctx.mark_dirty(grad_res) 29 | return grad_res, None, None, None 30 | 31 | 32 | class cuRoPE2D(torch.nn.Module): 33 | def __init__(self, freq=100.0, F0=1.0): 34 | super().__init__() 35 | self.base = freq 36 | self.F0 = F0 37 | 38 | def forward(self, tokens, positions): 39 | cuRoPE2D_func.apply( tokens.transpose(1,2), positions, self.base, self.F0 ) 40 | return tokens -------------------------------------------------------------------------------- /model/curope/kernels.cu: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (C) 2022-present Naver Corporation. All rights reserved. 3 | Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 4 | */ 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #define CHECK_CUDA(tensor) {\ 12 | TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \ 13 | TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); } 14 | void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));} 15 | 16 | 17 | template < typename scalar_t > 18 | __global__ void rope_2d_cuda_kernel( 19 | //scalar_t* __restrict__ tokens, 20 | torch::PackedTensorAccessor32 tokens, 21 | const int64_t* __restrict__ pos, 22 | const float base, 23 | const float fwd ) 24 | // const int N, const int H, const int D ) 25 | { 26 | // tokens shape = (B, N, H, D) 27 | const int N = tokens.size(1); 28 | const int H = tokens.size(2); 29 | const int D = tokens.size(3); 30 | 31 | // each block update a single token, for all heads 32 | // each thread takes care of a single output 33 | extern __shared__ float shared[]; 34 | float* shared_inv_freq = shared + D; 35 | 36 | const int b = blockIdx.x / N; 37 | const int n = blockIdx.x % N; 38 | 39 | const int Q = D / 4; 40 | // one token = [0..Q : Q..2Q : 2Q..3Q : 3Q..D] 41 | // u_Y v_Y u_X v_X 42 | 43 | // shared memory: first, compute inv_freq 44 | if (threadIdx.x < Q) 45 | shared_inv_freq[threadIdx.x] = fwd / powf(base, threadIdx.x/float(Q)); 46 | __syncthreads(); 47 | 48 | // start of X or Y part 49 | const int X = threadIdx.x < D/2 ? 0 : 1; 50 | const int m = (X*D/2) + (threadIdx.x % Q); // index of u_Y or u_X 51 | 52 | // grab the cos,sin appropriate for me 53 | const float freq = pos[blockIdx.x*2+X] * shared_inv_freq[threadIdx.x % Q]; 54 | const float cos = cosf(freq); 55 | const float sin = sinf(freq); 56 | /* 57 | float* shared_cos_sin = shared + D + D/4; 58 | if ((threadIdx.x % (D/2)) < Q) 59 | shared_cos_sin[m+0] = cosf(freq); 60 | else 61 | shared_cos_sin[m+Q] = sinf(freq); 62 | __syncthreads(); 63 | const float cos = shared_cos_sin[m+0]; 64 | const float sin = shared_cos_sin[m+Q]; 65 | */ 66 | 67 | for (int h = 0; h < H; h++) 68 | { 69 | // then, load all the token for this head in shared memory 70 | shared[threadIdx.x] = tokens[b][n][h][threadIdx.x]; 71 | __syncthreads(); 72 | 73 | const float u = shared[m]; 74 | const float v = shared[m+Q]; 75 | 76 | // write output 77 | if ((threadIdx.x % (D/2)) < Q) 78 | tokens[b][n][h][threadIdx.x] = u*cos - v*sin; 79 | else 80 | tokens[b][n][h][threadIdx.x] = v*cos + u*sin; 81 | } 82 | } 83 | 84 | void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ) 85 | { 86 | const int B = tokens.size(0); // batch size 87 | const int N = tokens.size(1); // sequence length 88 | const int H = tokens.size(2); // number of heads 89 | const int D = tokens.size(3); // dimension per head 90 | 91 | TORCH_CHECK(tokens.stride(3) == 1 && tokens.stride(2) == D, "tokens are not contiguous"); 92 | TORCH_CHECK(pos.is_contiguous(), "positions are not contiguous"); 93 | TORCH_CHECK(pos.size(0) == B && pos.size(1) == N && pos.size(2) == 2, "bad pos.shape"); 94 | TORCH_CHECK(D % 4 == 0, "token dim must be multiple of 4"); 95 | 96 | // one block for each layer, one thread per local-max 97 | const int THREADS_PER_BLOCK = D; 98 | const int N_BLOCKS = B * N; // each block takes care of H*D values 99 | const int SHARED_MEM = sizeof(float) * (D + D/4); 100 | 101 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), "rope_2d_cuda", ([&] { 102 | rope_2d_cuda_kernel <<>> ( 103 | //tokens.data_ptr(), 104 | tokens.packed_accessor32(), 105 | pos.data_ptr(), 106 | base, fwd); //, N, H, D ); 107 | })); 108 | } 109 | -------------------------------------------------------------------------------- /model/curope/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | from setuptools import setup 5 | from torch import cuda 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | # compile for all possible CUDA architectures 9 | all_cuda_archs = cuda.get_gencode_flags().replace('compute=','arch=').split() 10 | # alternatively, you can list cuda archs that you want, eg: 11 | # all_cuda_archs = [ 12 | # '-gencode', 'arch=compute_70,code=sm_70', 13 | # '-gencode', 'arch=compute_75,code=sm_75', 14 | # '-gencode', 'arch=compute_80,code=sm_80', 15 | # '-gencode', 'arch=compute_86,code=sm_86' 16 | # ] 17 | 18 | setup( 19 | name = 'curope', 20 | ext_modules = [ 21 | CUDAExtension( 22 | name='curope', 23 | sources=[ 24 | "curope.cpp", 25 | "kernels.cu", 26 | ], 27 | extra_compile_args = dict( 28 | nvcc=['-O3','--ptxas-options=-v',"--use_fast_math"]+all_cuda_archs, 29 | cxx=['-O3']) 30 | ) 31 | ], 32 | cmdclass = { 33 | 'build_ext': BuildExtension 34 | }) 35 | -------------------------------------------------------------------------------- /model/dino_layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # from .dino_head import DINOHead 8 | from .mlp import Mlp 9 | # from .patch_embed import PatchEmbed 10 | # from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 11 | # from .block import NestedTensorBlock 12 | # from .attention import Attention, MemEffAttention 13 | from .efficient_attention import * 14 | -------------------------------------------------------------------------------- /model/dino_layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 10 | 11 | import logging 12 | 13 | from torch import Tensor 14 | from torch import nn 15 | 16 | 17 | logger = logging.getLogger("dinov2") 18 | 19 | 20 | try: 21 | from xformers.ops import memory_efficient_attention, unbind, fmha 22 | 23 | XFORMERS_AVAILABLE = True 24 | except ImportError: 25 | logger.warning("xFormers not available") 26 | XFORMERS_AVAILABLE = False 27 | 28 | 29 | class Attention(nn.Module): 30 | def __init__( 31 | self, 32 | dim: int, 33 | num_heads: int = 8, 34 | qkv_bias: bool = False, 35 | proj_bias: bool = True, 36 | attn_drop: float = 0.0, 37 | proj_drop: float = 0.0, 38 | ) -> None: 39 | super().__init__() 40 | self.num_heads = num_heads 41 | head_dim = dim // num_heads 42 | self.scale = head_dim**-0.5 43 | 44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 45 | self.attn_drop = nn.Dropout(attn_drop) 46 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 47 | self.proj_drop = nn.Dropout(proj_drop) 48 | 49 | def forward(self, x: Tensor) -> Tensor: 50 | B, N, C = x.shape 51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 52 | 53 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 54 | attn = q @ k.transpose(-2, -1) 55 | 56 | attn = attn.softmax(dim=-1) 57 | attn = self.attn_drop(attn) 58 | 59 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 60 | x = self.proj(x) 61 | x = self.proj_drop(x) 62 | return x 63 | 64 | 65 | class MemEffAttention(Attention): 66 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 67 | if not XFORMERS_AVAILABLE: 68 | assert attn_bias is None, "xFormers is required for nested tensors usage" 69 | return super().forward(x) 70 | 71 | B, N, C = x.shape 72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 73 | 74 | q, k, v = unbind(qkv, 2) 75 | 76 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 77 | x = x.reshape([B, N, C]) 78 | 79 | x = self.proj(x) 80 | x = self.proj_drop(x) 81 | return x 82 | -------------------------------------------------------------------------------- /model/dino_layers/block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 10 | 11 | import logging 12 | from typing import Callable, List, Any, Tuple, Dict 13 | 14 | import torch 15 | from torch import nn, Tensor 16 | 17 | from .attention import Attention, MemEffAttention 18 | from .drop_path import DropPath 19 | from .layer_scale import LayerScale 20 | from .mlp import Mlp 21 | 22 | 23 | logger = logging.getLogger("dinov2") 24 | 25 | 26 | try: 27 | from xformers.ops import fmha 28 | from xformers.ops import scaled_index_add, index_select_cat 29 | 30 | XFORMERS_AVAILABLE = True 31 | except ImportError: 32 | logger.warning("xFormers not available") 33 | XFORMERS_AVAILABLE = False 34 | 35 | 36 | class Block(nn.Module): 37 | def __init__( 38 | self, 39 | dim: int, 40 | num_heads: int, 41 | mlp_ratio: float = 4.0, 42 | qkv_bias: bool = False, 43 | proj_bias: bool = True, 44 | ffn_bias: bool = True, 45 | drop: float = 0.0, 46 | attn_drop: float = 0.0, 47 | init_values=None, 48 | drop_path: float = 0.0, 49 | act_layer: Callable[..., nn.Module] = nn.GELU, 50 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm, 51 | attn_class: Callable[..., nn.Module] = Attention, 52 | ffn_layer: Callable[..., nn.Module] = Mlp, 53 | ) -> None: 54 | super().__init__() 55 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") 56 | self.norm1 = norm_layer(dim) 57 | self.attn = attn_class( 58 | dim, 59 | num_heads=num_heads, 60 | qkv_bias=qkv_bias, 61 | proj_bias=proj_bias, 62 | attn_drop=attn_drop, 63 | proj_drop=drop, 64 | ) 65 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 66 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 67 | 68 | self.norm2 = norm_layer(dim) 69 | mlp_hidden_dim = int(dim * mlp_ratio) 70 | self.mlp = ffn_layer( 71 | in_features=dim, 72 | hidden_features=mlp_hidden_dim, 73 | act_layer=act_layer, 74 | drop=drop, 75 | bias=ffn_bias, 76 | ) 77 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 78 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 79 | 80 | self.sample_drop_ratio = drop_path 81 | 82 | def forward(self, x: Tensor) -> Tensor: 83 | def attn_residual_func(x: Tensor) -> Tensor: 84 | return self.ls1(self.attn(self.norm1(x))) 85 | 86 | def ffn_residual_func(x: Tensor) -> Tensor: 87 | return self.ls2(self.mlp(self.norm2(x))) 88 | 89 | if self.training and self.sample_drop_ratio > 0.1: 90 | # the overhead is compensated only for a drop path rate larger than 0.1 91 | x = drop_add_residual_stochastic_depth( 92 | x, 93 | residual_func=attn_residual_func, 94 | sample_drop_ratio=self.sample_drop_ratio, 95 | ) 96 | x = drop_add_residual_stochastic_depth( 97 | x, 98 | residual_func=ffn_residual_func, 99 | sample_drop_ratio=self.sample_drop_ratio, 100 | ) 101 | elif self.training and self.sample_drop_ratio > 0.0: 102 | x = x + self.drop_path1(attn_residual_func(x)) 103 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 104 | else: 105 | x = x + attn_residual_func(x) 106 | x = x + ffn_residual_func(x) 107 | return x 108 | 109 | 110 | def drop_add_residual_stochastic_depth( 111 | x: Tensor, 112 | residual_func: Callable[[Tensor], Tensor], 113 | sample_drop_ratio: float = 0.0, 114 | ) -> Tensor: 115 | # 1) extract subset using permutation 116 | b, n, d = x.shape 117 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 118 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 119 | x_subset = x[brange] 120 | 121 | # 2) apply residual_func to get residual 122 | residual = residual_func(x_subset) 123 | 124 | x_flat = x.flatten(1) 125 | residual = residual.flatten(1) 126 | 127 | residual_scale_factor = b / sample_subset_size 128 | 129 | # 3) add the residual 130 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 131 | return x_plus_residual.view_as(x) 132 | 133 | 134 | def get_branges_scales(x, sample_drop_ratio=0.0): 135 | b, n, d = x.shape 136 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 137 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 138 | residual_scale_factor = b / sample_subset_size 139 | return brange, residual_scale_factor 140 | 141 | 142 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): 143 | if scaling_vector is None: 144 | x_flat = x.flatten(1) 145 | residual = residual.flatten(1) 146 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 147 | else: 148 | x_plus_residual = scaled_index_add( 149 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor 150 | ) 151 | return x_plus_residual 152 | 153 | 154 | attn_bias_cache: Dict[Tuple, Any] = {} 155 | 156 | 157 | def get_attn_bias_and_cat(x_list, branges=None): 158 | """ 159 | this will perform the index select, cat the tensors, and provide the attn_bias from cache 160 | """ 161 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] 162 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) 163 | if all_shapes not in attn_bias_cache.keys(): 164 | seqlens = [] 165 | for b, x in zip(batch_sizes, x_list): 166 | for _ in range(b): 167 | seqlens.append(x.shape[1]) 168 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) 169 | attn_bias._batch_sizes = batch_sizes 170 | attn_bias_cache[all_shapes] = attn_bias 171 | 172 | if branges is not None: 173 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) 174 | else: 175 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) 176 | cat_tensors = torch.cat(tensors_bs1, dim=1) 177 | 178 | return attn_bias_cache[all_shapes], cat_tensors 179 | 180 | 181 | def drop_add_residual_stochastic_depth_list( 182 | x_list: List[Tensor], 183 | residual_func: Callable[[Tensor, Any], Tensor], 184 | sample_drop_ratio: float = 0.0, 185 | scaling_vector=None, 186 | ) -> Tensor: 187 | # 1) generate random set of indices for dropping samples in the batch 188 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] 189 | branges = [s[0] for s in branges_scales] 190 | residual_scale_factors = [s[1] for s in branges_scales] 191 | 192 | # 2) get attention bias and index+concat the tensors 193 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) 194 | 195 | # 3) apply residual_func to get residual, and split the result 196 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore 197 | 198 | outputs = [] 199 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): 200 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) 201 | return outputs 202 | 203 | 204 | class NestedTensorBlock(Block): 205 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: 206 | """ 207 | x_list contains a list of tensors to nest together and run 208 | """ 209 | assert isinstance(self.attn, MemEffAttention) 210 | 211 | if self.training and self.sample_drop_ratio > 0.0: 212 | 213 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 214 | return self.attn(self.norm1(x), attn_bias=attn_bias) 215 | 216 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 217 | return self.mlp(self.norm2(x)) 218 | 219 | x_list = drop_add_residual_stochastic_depth_list( 220 | x_list, 221 | residual_func=attn_residual_func, 222 | sample_drop_ratio=self.sample_drop_ratio, 223 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, 224 | ) 225 | x_list = drop_add_residual_stochastic_depth_list( 226 | x_list, 227 | residual_func=ffn_residual_func, 228 | sample_drop_ratio=self.sample_drop_ratio, 229 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, 230 | ) 231 | return x_list 232 | else: 233 | 234 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 235 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) 236 | 237 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 238 | return self.ls2(self.mlp(self.norm2(x))) 239 | 240 | attn_bias, x = get_attn_bias_and_cat(x_list) 241 | x = x + attn_residual_func(x, attn_bias=attn_bias) 242 | x = x + ffn_residual_func(x) 243 | return attn_bias.split(x) 244 | 245 | def forward(self, x_or_x_list): 246 | if isinstance(x_or_x_list, Tensor): 247 | return super().forward(x_or_x_list) 248 | elif isinstance(x_or_x_list, list): 249 | assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" 250 | return self.forward_nested(x_or_x_list) 251 | else: 252 | raise AssertionError 253 | -------------------------------------------------------------------------------- /model/dino_layers/dino_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.init import trunc_normal_ 10 | from torch.nn.utils import weight_norm 11 | 12 | 13 | class DINOHead(nn.Module): 14 | def __init__( 15 | self, 16 | in_dim, 17 | out_dim, 18 | use_bn=False, 19 | nlayers=3, 20 | hidden_dim=2048, 21 | bottleneck_dim=256, 22 | mlp_bias=True, 23 | ): 24 | super().__init__() 25 | nlayers = max(nlayers, 1) 26 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) 27 | self.apply(self._init_weights) 28 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 29 | self.last_layer.weight_g.data.fill_(1) 30 | 31 | def _init_weights(self, m): 32 | if isinstance(m, nn.Linear): 33 | trunc_normal_(m.weight, std=0.02) 34 | if isinstance(m, nn.Linear) and m.bias is not None: 35 | nn.init.constant_(m.bias, 0) 36 | 37 | def forward(self, x): 38 | x = self.mlp(x) 39 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12 40 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) 41 | x = self.last_layer(x) 42 | return x 43 | 44 | 45 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): 46 | if nlayers == 1: 47 | return nn.Linear(in_dim, bottleneck_dim, bias=bias) 48 | else: 49 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] 50 | if use_bn: 51 | layers.append(nn.BatchNorm1d(hidden_dim)) 52 | layers.append(nn.GELU()) 53 | for _ in range(nlayers - 2): 54 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) 55 | if use_bn: 56 | layers.append(nn.BatchNorm1d(hidden_dim)) 57 | layers.append(nn.GELU()) 58 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) 59 | return nn.Sequential(*layers) 60 | -------------------------------------------------------------------------------- /model/dino_layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 10 | 11 | 12 | from torch import nn 13 | 14 | 15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 16 | if drop_prob == 0.0 or not training: 17 | return x 18 | keep_prob = 1 - drop_prob 19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 21 | if keep_prob > 0.0: 22 | random_tensor.div_(keep_prob) 23 | output = x * random_tensor 24 | return output 25 | 26 | 27 | class DropPath(nn.Module): 28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 29 | 30 | def __init__(self, drop_prob=None): 31 | super(DropPath, self).__init__() 32 | self.drop_prob = drop_prob 33 | 34 | def forward(self, x): 35 | return drop_path(x, self.drop_prob, self.training) 36 | -------------------------------------------------------------------------------- /model/dino_layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 8 | 9 | from typing import Union 10 | 11 | import torch 12 | from torch import Tensor 13 | from torch import nn 14 | 15 | 16 | class LayerScale(nn.Module): 17 | def __init__( 18 | self, 19 | dim: int, 20 | init_values: Union[float, Tensor] = 1e-5, 21 | inplace: bool = False, 22 | ) -> None: 23 | super().__init__() 24 | self.inplace = inplace 25 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 26 | 27 | def forward(self, x: Tensor) -> Tensor: 28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 29 | -------------------------------------------------------------------------------- /model/dino_layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 10 | 11 | 12 | from typing import Callable, Optional 13 | 14 | from torch import Tensor, nn 15 | 16 | 17 | class Mlp(nn.Module): 18 | def __init__( 19 | self, 20 | in_features: int, 21 | hidden_features: Optional[int] = None, 22 | out_features: Optional[int] = None, 23 | act_layer: Callable[..., nn.Module] = nn.GELU, 24 | drop: float = 0.0, 25 | bias: bool = True, 26 | ) -> None: 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x: Tensor) -> Tensor: 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | -------------------------------------------------------------------------------- /model/dino_layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 10 | 11 | from typing import Callable, Optional, Tuple, Union 12 | 13 | from torch import Tensor 14 | import torch.nn as nn 15 | 16 | 17 | def make_2tuple(x): 18 | if isinstance(x, tuple): 19 | assert len(x) == 2 20 | return x 21 | 22 | assert isinstance(x, int) 23 | return (x, x) 24 | 25 | 26 | class PatchEmbed(nn.Module): 27 | """ 28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 29 | 30 | Args: 31 | img_size: Image size. 32 | patch_size: Patch token size. 33 | in_chans: Number of input image channels. 34 | embed_dim: Number of linear projection output channels. 35 | norm_layer: Normalization layer. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | img_size: Union[int, Tuple[int, int]] = 224, 41 | patch_size: Union[int, Tuple[int, int]] = 16, 42 | in_chans: int = 3, 43 | embed_dim: int = 768, 44 | norm_layer: Optional[Callable] = None, 45 | flatten_embedding: bool = True, 46 | ) -> None: 47 | super().__init__() 48 | 49 | image_HW = make_2tuple(img_size) 50 | patch_HW = make_2tuple(patch_size) 51 | patch_grid_size = ( 52 | image_HW[0] // patch_HW[0], 53 | image_HW[1] // patch_HW[1], 54 | ) 55 | 56 | self.img_size = image_HW 57 | self.patch_size = patch_HW 58 | self.patches_resolution = patch_grid_size 59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 60 | 61 | self.in_chans = in_chans 62 | self.embed_dim = embed_dim 63 | 64 | self.flatten_embedding = flatten_embedding 65 | 66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 68 | 69 | def forward(self, x: Tensor) -> Tensor: 70 | _, _, H, W = x.shape 71 | patch_H, patch_W = self.patch_size 72 | 73 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 74 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 75 | 76 | x = self.proj(x) # B C H W 77 | H, W = x.size(2), x.size(3) 78 | x = x.flatten(2).transpose(1, 2) # B HW C 79 | x = self.norm(x) 80 | if not self.flatten_embedding: 81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 82 | return x 83 | 84 | def flops(self) -> float: 85 | Ho, Wo = self.patches_resolution 86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 87 | if self.norm is not None: 88 | flops += Ho * Wo * self.embed_dim 89 | return flops 90 | -------------------------------------------------------------------------------- /model/dino_layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Callable, Optional 8 | 9 | from torch import Tensor, nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class SwiGLUFFN(nn.Module): 14 | def __init__( 15 | self, 16 | in_features: int, 17 | hidden_features: Optional[int] = None, 18 | out_features: Optional[int] = None, 19 | act_layer: Callable[..., nn.Module] = None, 20 | drop: float = 0.0, 21 | bias: bool = True, 22 | ) -> None: 23 | super().__init__() 24 | out_features = out_features or in_features 25 | hidden_features = hidden_features or in_features 26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 28 | 29 | def forward(self, x: Tensor) -> Tensor: 30 | x12 = self.w12(x) 31 | x1, x2 = x12.chunk(2, dim=-1) 32 | hidden = F.silu(x1) * x2 33 | return self.w3(hidden) 34 | 35 | 36 | try: 37 | from xformers.ops import SwiGLU 38 | 39 | XFORMERS_AVAILABLE = True 40 | except ImportError: 41 | SwiGLU = SwiGLUFFN 42 | XFORMERS_AVAILABLE = False 43 | 44 | 45 | class SwiGLUFFNFused(SwiGLU): 46 | def __init__( 47 | self, 48 | in_features: int, 49 | hidden_features: Optional[int] = None, 50 | out_features: Optional[int] = None, 51 | act_layer: Callable[..., nn.Module] = None, 52 | drop: float = 0.0, 53 | bias: bool = True, 54 | ) -> None: 55 | out_features = out_features or in_features 56 | hidden_features = hidden_features or in_features 57 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 58 | super().__init__( 59 | in_features=in_features, 60 | hidden_features=hidden_features, 61 | out_features=out_features, 62 | bias=bias, 63 | ) 64 | -------------------------------------------------------------------------------- /model/generalized_mean_pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | class GeM3D(nn.Module): 6 | def __init__(self, p=1, eps=1e-6): 7 | super(GeM3D,self).__init__() 8 | self.p = nn.Parameter(torch.ones(1) * p) 9 | self.eps = eps 10 | 11 | def forward(self, x): 12 | return self.gem(x.float(), p=self.p, eps=self.eps) 13 | 14 | def gem(self, x, p=1, eps=1e-6): 15 | return F.avg_pool3d(x.clamp(min=eps).pow(p), (x.size(-3), x.size(-2), x.size(-1))).pow(1./p) 16 | 17 | def __repr__(self): 18 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' 19 | 20 | 21 | class GeM2D(nn.Module): 22 | def __init__(self, p=3, eps=1e-6): 23 | super(GeM2D,self).__init__() 24 | self.p = nn.Parameter(torch.ones(1) * p) 25 | self.eps = eps 26 | 27 | def forward(self, x): 28 | return self.gem(x.float(), p=self.p, eps=self.eps) 29 | 30 | def gem(self, x, p=1, eps=1e-6): 31 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p) 32 | 33 | def __repr__(self): 34 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' 35 | 36 | 37 | class GeM1D(nn.Module): 38 | def __init__(self, p=3, channel_last=True, eps=1e-6): 39 | super(GeM1D, self).__init__() 40 | self.p = nn.Parameter(torch.ones(1) * p) 41 | self.eps = eps 42 | self.channel_last = channel_last 43 | 44 | def forward(self, x): 45 | return self.gem(x.float(), p=self.p, eps=self.eps) 46 | 47 | def gem(self, x, p=1, eps=1e-6): 48 | """ 49 | x: (B, C, L) -> (B, C) 50 | """ 51 | if self.channel_last: 52 | x = x.permute(0, 2, 1) 53 | 54 | # return F.avg_pool1d(x.clamp(min=eps), x.size(-1)) 55 | return F.avg_pool1d(x.clamp(min=eps).pow(p), x.size(-1)).pow(1./p) 56 | 57 | def __repr__(self): 58 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' 59 | 60 | -------------------------------------------------------------------------------- /model/position_encoding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | try: 7 | from .curope import cuRoPE2D 8 | RoPE2D = cuRoPE2D 9 | except ImportError: 10 | print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead') 11 | 12 | class RoPE2D(torch.nn.Module): 13 | 14 | def __init__(self, freq=100.0, F0=1.0): 15 | super().__init__() 16 | self.base = freq 17 | self.F0 = F0 18 | self.cache = {} 19 | 20 | def get_cos_sin(self, D, seq_len, device, dtype): 21 | if (D,seq_len,device,dtype) not in self.cache: 22 | inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) 23 | t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) 24 | freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) 25 | freqs = torch.cat((freqs, freqs), dim=-1) 26 | cos = freqs.cos() # (Seq, Dim) 27 | sin = freqs.sin() 28 | self.cache[D,seq_len,device,dtype] = (cos,sin) 29 | return self.cache[D,seq_len,device,dtype] 30 | 31 | @staticmethod 32 | def rotate_half(x): 33 | x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] 34 | return torch.cat((-x2, x1), dim=-1) 35 | 36 | def apply_rope1d(self, tokens, pos1d, cos, sin): 37 | assert pos1d.ndim==2 38 | cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] 39 | sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] 40 | return (tokens * cos) + (self.rotate_half(tokens) * sin) 41 | 42 | def forward(self, tokens, positions): 43 | """ 44 | input: 45 | * tokens: batch_size x nheads x ntokens x dim 46 | * positions: batch_size x ntokens x 2 (y and x position of each token) 47 | output: 48 | * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim) 49 | """ 50 | assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two" 51 | D = tokens.size(3) // 2 52 | assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2 53 | cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype) 54 | # split features into two along the feature dimension, and apply rope1d on each half 55 | y, x = tokens.chunk(2, dim=-1) 56 | y = self.apply_rope1d(y, positions[:,:,0], cos, sin) 57 | x = self.apply_rope1d(x, positions[:,:,1], cos, sin) 58 | tokens = torch.cat((y, x), dim=-1) 59 | return tokens 60 | 61 | 62 | class PositionEncodingSine1D(nn.Module): 63 | def __init__(self, d_model, max_seq_len=20_000): 64 | super(PositionEncodingSine1D, self).__init__() 65 | self.d_model = d_model 66 | # Create positional encoding matrix 67 | pe = torch.zeros(max_seq_len, d_model) 68 | position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) 69 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 70 | pe[:, 0::2] = torch.sin(position * div_term) 71 | pe[:, 1::2] = torch.cos(position * div_term) 72 | pe = pe.unsqueeze(0) 73 | self.register_buffer('pe', pe) # 1 x max_seq_len x d_model 74 | 75 | def forward(self, x): 76 | """ 77 | x: BxLxC 78 | """ 79 | x = x + self.pe[:, :x.size(1)] 80 | return x 81 | 82 | 83 | # Position encoding for query image 84 | class PositionEncodingSine2D(nn.Module): 85 | """ 86 | This is a sinusoidal position encoding that generalized to 2-dimensional images 87 | """ 88 | 89 | def __init__(self, d_model, max_shape=(1280, 960)): 90 | """ 91 | Args: 92 | max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels 93 | """ 94 | super().__init__() 95 | 96 | max_shape = tuple(max_shape) 97 | 98 | pe = torch.zeros((d_model, *max_shape)) 99 | y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) 100 | x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) 101 | div_term = torch.exp( 102 | torch.arange(0, d_model // 2, 2).float() 103 | * (-math.log(10000.0) / d_model // 2) 104 | ) 105 | div_term = div_term[:, None, None] # [C//4, 1, 1] 106 | pe[0::4, :, :] = torch.sin(x_position * div_term) 107 | pe[1::4, :, :] = torch.cos(x_position * div_term) 108 | pe[2::4, :, :] = torch.sin(y_position * div_term) 109 | pe[3::4, :, :] = torch.cos(y_position * div_term) 110 | 111 | self.register_buffer("pe", pe.unsqueeze(0), persistent=False) # [1, C, H, W] 112 | 113 | def forward(self, x): 114 | """ 115 | Args: 116 | x: [N, C, H, W] 117 | """ 118 | return x + self.pe[:, :, :x.size(2), :x.size(3)] 119 | 120 | 121 | class PositionEncodingSine3D(nn.Module): 122 | """ 123 | This is a sinusoidal position encoding that generalized to 3-dimensional cubes 124 | """ 125 | 126 | def __init__(self, d_model, max_shape=(128, 128, 128)): 127 | """ 128 | Args: 129 | max_shape (tuple): for DxHxW cube 130 | """ 131 | super().__init__() 132 | 133 | assert(d_model % 6 == 0), "d_model must be divisible by 6 for 3D sinusoidal position encoding" 134 | max_shape = tuple(max_shape) 135 | 136 | pe = torch.zeros((d_model, *max_shape)) # CxDxHxW 137 | z_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) 138 | y_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) 139 | x_position = torch.ones(max_shape).cumsum(2).float().unsqueeze(0) 140 | div_term = torch.exp( 141 | torch.arange(0, d_model // 2, 3).float() 142 | * (-math.log(10000.0) / d_model // 2) 143 | ) 144 | div_term = div_term[:, None, None, None] # [C//6, 1, 1, 1] 145 | 146 | pe[0::6, :, :, :] = torch.sin(x_position * div_term) 147 | pe[1::6, :, :, :] = torch.cos(x_position * div_term) 148 | pe[2::6, :, :, :] = torch.sin(y_position * div_term) 149 | pe[3::6, :, :, :] = torch.cos(y_position * div_term) 150 | pe[4::6, :, :, :] = torch.sin(z_position * div_term) 151 | pe[5::6, :, :, :] = torch.cos(z_position * div_term) 152 | 153 | self.register_buffer("pe", pe.unsqueeze(0), persistent=False) # [1, C, D, H, W] 154 | 155 | def forward(self, x): 156 | """ 157 | Args: 158 | x: [N, C, D, H, W] 159 | """ 160 | assert(x.dim() == 5), "Input must be 5-dimensional" 161 | return x + self.pe[:, :, :x.size(-3), :x.size(-2), :x.size(-1)] 162 | 163 | # Position encoding for 3D points 164 | class PositionEncodingLinear3D(nn.Module): 165 | """ Joint encoding of visual appearance and location using MLPs """ 166 | 167 | def __init__(self, inp_dim, feature_dim, layers, norm_method="batchnorm"): 168 | super().__init__() 169 | self.encoder = self.MLP([inp_dim] + list(layers) + [feature_dim], norm_method) 170 | 171 | self.encoder 172 | nn.init.constant_(self.encoder[-1].bias, 0.0) 173 | 174 | def forward(self, kpts, descriptors): 175 | """ 176 | kpts: B*L*3 or B*L*4 177 | descriptors: B*C*L 178 | """ 179 | # inputs = kpts # B*L*3 180 | 181 | return descriptors + self.encoder(kpts).transpose(2, 1).expand_as(descriptors) # B*C*L 182 | 183 | def MLP(self, channels: list, norm_method="batchnorm"): 184 | """ Multi-layer perceptron""" 185 | n = len(channels) 186 | layers = [] 187 | for i in range(1, n): 188 | layers.append(nn.Linear(channels[i - 1], channels[i], bias=True)) 189 | if i < n - 1: 190 | if norm_method == "batchnorm": 191 | layers.append(nn.BatchNorm1d(channels[i])) 192 | elif norm_method == "layernorm": 193 | layers.append(nn.LayerNorm(channels[i])) 194 | elif norm_method == "instancenorm": 195 | layers.append(nn.InstanceNorm1d(channels[i])) 196 | else: 197 | raise NotImplementedError 198 | # layers.append(nn.GroupNorm(channels[i], channels[i])) # group norm 199 | layers.append(nn.ReLU()) 200 | return nn.Sequential(*layers) 201 | 202 | 203 | # Position encoding for 3D points 204 | class KeypointEncoding_linear(nn.Module): 205 | """ Joint encoding of visual appearance and location using MLPs """ 206 | 207 | def __init__(self, inp_dim, feature_dim, layers, norm_method="batchnorm"): 208 | super().__init__() 209 | self.encoder = self.MLP([inp_dim] + list(layers) + [feature_dim], norm_method) 210 | nn.init.constant_(self.encoder[-1].bias, 0.0) 211 | 212 | def forward(self, kpts, descriptors): 213 | """ 214 | kpts: B*L*3 or B*L*4 215 | descriptors: B*C*L 216 | """ 217 | # inputs = kpts # B*L*3 218 | return descriptors + self.encoder(kpts).transpose(2, 1) # B*C*L 219 | 220 | def MLP(self, channels: list, norm_method="batchnorm"): 221 | """ Multi-layer perceptron""" 222 | n = len(channels) 223 | layers = [] 224 | for i in range(1, n): 225 | layers.append(nn.Linear(channels[i - 1], channels[i], bias=True)) 226 | if i < n - 1: 227 | if norm_method == "batchnorm": 228 | layers.append(nn.BatchNorm1d(channels[i])) 229 | elif norm_method == "layernorm": 230 | layers.append(nn.LayerNorm(channels[i])) 231 | elif norm_method == "instancenorm": 232 | layers.append(nn.InstanceNorm1d(channels[i])) 233 | else: 234 | raise NotImplementedError 235 | # layers.append(nn.GroupNorm(channels[i], channels[i])) # group norm 236 | layers.append(nn.ReLU()) 237 | return nn.Sequential(*layers) -------------------------------------------------------------------------------- /three/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | from . import stats 3 | from . import quaternion 4 | from .rigid import * 5 | from .batchview import * 6 | from . import orientation 7 | from . import utils 8 | -------------------------------------------------------------------------------- /three/batchview.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | @torch.jit.script 5 | def bvmm(a, b): 6 | if a.shape[0] != b.shape[0]: 7 | raise ValueError("batch dimension must match") 8 | if a.shape[1] != b.shape[1]: 9 | raise ValueError("view dimension must match") 10 | 11 | nbatch, nview, nrow, ncol = a.shape 12 | a = a.view(-1, nrow, ncol) 13 | b = b.view(-1, nrow, ncol) 14 | out = torch.bmm(a, b) 15 | out = out.view(nbatch, nview, out.shape[1], out.shape[2]) 16 | return out 17 | 18 | 19 | def bv2b(x): 20 | if not x.is_contiguous(): 21 | return x.reshape(-1, *x.shape[2:]) 22 | return x.view(-1, *x.shape[2:]) 23 | 24 | 25 | def b2bv(x, num_view=-1, batch_size=-1): 26 | if num_view == -1 and batch_size == -1: 27 | raise ValueError('One of num_view or batch_size must be non-negative.') 28 | return x.view(batch_size, num_view, *x.shape[1:]) 29 | 30 | 31 | def vcat(tensors, batch_size): 32 | tensors = [b2bv(t, batch_size=batch_size) for t in tensors] 33 | return bv2b(torch.cat(tensors, dim=1)) 34 | 35 | 36 | def vsplit(tensor, sections): 37 | num_view = sum(sections) 38 | tensor = b2bv(tensor, num_view=num_view) 39 | return tuple(bv2b(t) for t in torch.split(tensor, sections, dim=1)) 40 | -------------------------------------------------------------------------------- /three/core.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | @torch.jit.script 5 | def acos_safe(t, eps: float = 1e-7): 6 | return torch.acos(torch.clamp(t, min=-1.0 + eps, max=1.0 - eps)) 7 | 8 | 9 | @torch.jit.script 10 | def ensure_batch_dim(tensor, num_dims: int): 11 | unsqueezed = False 12 | if len(tensor.shape) == num_dims: 13 | tensor = tensor.unsqueeze(0) 14 | unsqueezed = True 15 | 16 | return tensor, unsqueezed 17 | 18 | 19 | @torch.jit.script 20 | def normalize(vector, dim: int = -1): 21 | """ 22 | Normalizes the vector to a unit vector using the p-norm. 23 | Args: 24 | vector (tensor): the vector to normalize of size (*, 3) 25 | p (int): the norm order to use 26 | 27 | Returns: 28 | (tensor): A unit normalized vector of size (*, 3) 29 | """ 30 | return vector / torch.norm(vector, p=2.0, dim=dim, keepdim=True) 31 | 32 | 33 | @torch.jit.script 34 | def uniform(n: int, min_val: float, max_val: float): 35 | return (max_val - min_val) * torch.rand(n) + min_val 36 | 37 | 38 | def uniform_unit_vector(n): 39 | return normalize(torch.randn(n, 3), dim=1) 40 | 41 | 42 | def inner_product(a, b): 43 | return (a * b).sum(dim=-1) 44 | 45 | 46 | @torch.jit.script 47 | def homogenize(coords): 48 | ones = torch.ones_like(coords[..., 0, None]) 49 | return torch.cat((coords, ones), dim=-1) 50 | 51 | 52 | @torch.jit.script 53 | def dehomogenize(coords): 54 | return coords[..., :coords.size(-1) - 1] / coords[..., -1, None] 55 | 56 | 57 | def transform_coord_grid(grid, transform): 58 | if transform.size(0) != grid.size(0): 59 | raise ValueError('Batch dimensions must match.') 60 | 61 | out_shape = (*grid.shape[:-1], transform.size(1)) 62 | 63 | grid = homogenize(grid) 64 | coords = grid.view(grid.size(0), -1, grid.size(-1)) 65 | coords = transform @ coords.transpose(1, 2) 66 | coords = coords.transpose(1, 2) 67 | return dehomogenize(coords.view(*out_shape)) 68 | 69 | 70 | @torch.jit.script 71 | def transform_coords(coords, transform): 72 | coords, unsqueezed = ensure_batch_dim(coords, 2) 73 | 74 | coords = homogenize(coords) 75 | coords = transform @ coords.transpose(1, 2) 76 | coords = coords.transpose(1, 2) 77 | coords = dehomogenize(coords) 78 | if unsqueezed: 79 | coords = coords.squeeze(0) 80 | 81 | return coords 82 | 83 | 84 | @torch.jit.script 85 | def grid_to_coords(grid): 86 | return grid.view(grid.size(0), -1, grid.size(-1)) 87 | 88 | 89 | def spherical_to_cartesian(theta, phi, r=1.0): 90 | x = r * torch.cos(theta) * torch.sin(phi) 91 | y = r * torch.sin(theta) * torch.sin(phi) 92 | z = r * torch.cos(theta) 93 | return torch.stack((x, y, z), dim=-1) 94 | 95 | 96 | def points_bound(points): 97 | min_dim = torch.min(points, dim=0)[0] 98 | max_dim = torch.max(points, dim=0)[0] 99 | return torch.stack((min_dim, max_dim), dim=1) 100 | 101 | 102 | def points_radius(points): 103 | bounds = points_bound(points) 104 | centroid = bounds.mean(dim=1).unsqueeze(0) 105 | max_radius = torch.norm(points - centroid, dim=1).max() 106 | return max_radius 107 | 108 | 109 | def points_diameter(points): 110 | return 2* points_radius(points) 111 | 112 | 113 | def points_centroid(points): 114 | return points_bound(points).mean(dim=1) 115 | 116 | 117 | def points_bounding_size(points): 118 | bounds = points_bound(points) 119 | return torch.norm(bounds[:, 1] - bounds[:, 0]) 120 | -------------------------------------------------------------------------------- /three/imutils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from skimage import morphology 5 | 6 | 7 | def keep_largest_object(mask): 8 | labels, num_labels = morphology.label(mask, return_num=True) 9 | best_mask = None 10 | best_count = -1 11 | for i in range(1, num_labels + 1): 12 | cur_mask = (labels == i) 13 | cur_count = cur_mask.sum() 14 | if cur_count > best_count: 15 | best_mask = cur_mask 16 | best_count = cur_count 17 | 18 | if best_mask is None: 19 | return np.zeros_like(mask) 20 | 21 | return best_mask 22 | 23 | 24 | def mask_chroma(image, hue_min=(40, 65, 65), hue_max=(180, 255, 255), 25 | use_bgr=False): 26 | image_hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV if use_bgr else cv2.COLOR_RGB2HSV) 27 | hue_min = np.array(hue_min) 28 | hue_max = np.array(hue_max) 29 | mask = ~cv2.inRange(image_hsv, hue_min, hue_max) 30 | mask = morphology.binary_closing(mask, selem=morphology.disk(5)) 31 | return mask 32 | 33 | 34 | def grabcut(image, fg_init_mask, bg_init_mask=None): 35 | if image.dtype == np.float32 or image.dtype == np.double: 36 | image = (image * 255.0).astype(np.uint8) 37 | # Initialize mask based on sparse pointcloud. 38 | mask = np.full(image.shape[:2], fill_value=cv2.GC_PR_BGD, dtype=np.uint8) 39 | mask[fg_init_mask] = cv2.GC_PR_FGD 40 | if bg_init_mask is not None: 41 | mask[bg_init_mask] = cv2.GC_BGD 42 | 43 | # Perform grab cut. 44 | bg_model = np.zeros((1, 65), np.float64) 45 | fg_model = np.zeros((1, 65), np.float64) 46 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 47 | cv2.grabCut(image, mask, None, bg_model, fg_model, 3, cv2.GC_INIT_WITH_MASK) 48 | 49 | # Post process mask. 50 | out_mask = np.where((mask == 2) | (mask == 0), 0, 1).astype('uint8') 51 | out_mask = morphology.binary_closing(out_mask, selem=morphology.disk(5)) 52 | 53 | return out_mask 54 | 55 | 56 | def mean_color(image, mask): 57 | return (image * mask).sum(dim=(-2, -1)) / mask.sum(dim=(-2, -1)) 58 | 59 | 60 | def dilate(labels, iters, kernel_size): 61 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) 62 | morphed_labels = [] 63 | for label in labels: 64 | morphed_labels.append( 65 | cv2.dilate(label.squeeze(0).numpy(), kernel, iterations=iters)) 66 | return torch.tensor(np.stack(morphed_labels, axis=0), 67 | dtype=torch.float32).unsqueeze(1) 68 | 69 | 70 | def erode(labels, iters, kernel_size): 71 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) 72 | morphed_labels = [] 73 | for label in labels: 74 | morphed_labels.append( 75 | cv2.erode(label.squeeze(0).numpy(), kernel, iterations=iters)) 76 | return torch.tensor(np.stack(morphed_labels, axis=0), 77 | dtype=torch.float32).unsqueeze(1) 78 | -------------------------------------------------------------------------------- /three/meshutils.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import numpy as np 4 | from scipy import linalg 5 | 6 | import trimesh 7 | import trimesh.remesh 8 | from trimesh.visual.material import SimpleMaterial 9 | 10 | EPS = 10e-10 11 | 12 | 13 | def compute_vertex_normals(vertices, faces): 14 | normals = np.ones_like(vertices) 15 | triangles = vertices[faces] 16 | triangle_normals = np.cross(triangles[:, 1] - triangles[:, 0], 17 | triangles[:, 2] - triangles[:, 0]) 18 | triangle_normals /= (linalg.norm(triangle_normals, axis=1)[:, None] + EPS) 19 | normals[faces[:, 0]] += triangle_normals 20 | normals[faces[:, 1]] += triangle_normals 21 | normals[faces[:, 2]] += triangle_normals 22 | normals /= (linalg.norm(normals, axis=1)[:, None] + 0) 23 | 24 | return normals 25 | 26 | 27 | def are_trimesh_normals_corrupt(trimesh): 28 | corrupt_normals = linalg.norm(trimesh.vertex_normals, axis=1) == 0.0 29 | return corrupt_normals.sum() > 0 30 | 31 | 32 | def subdivide_mesh(mesh): 33 | attributes = {} 34 | if hasattr(mesh.visual, 'uv'): 35 | attributes = {'uv': mesh.visual.uv} 36 | vertices, faces, attributes = trimesh.remesh.subdivide( 37 | mesh.vertices, mesh.faces, attributes=attributes) 38 | mesh.vertices = vertices 39 | mesh.faces = faces 40 | if 'uv' in attributes: 41 | mesh.visual.uv = attributes['uv'] 42 | 43 | return mesh 44 | 45 | 46 | class Object3D(object): 47 | """Represents a graspable object.""" 48 | 49 | def __init__(self, path, load_materials=False): 50 | scene = trimesh.load(str(path)) 51 | if isinstance(scene, trimesh.Trimesh): 52 | scene = trimesh.Scene(scene) 53 | 54 | self.meshes: typing.List[trimesh.Trimesh] = list(scene.dump()) 55 | 56 | self.path = path 57 | self.scale = 1.0 58 | 59 | def to_scene(self): 60 | return trimesh.Scene(self.meshes) 61 | 62 | def are_normals_corrupt(self): 63 | for mesh in self.meshes: 64 | if are_trimesh_normals_corrupt(mesh): 65 | return True 66 | 67 | return False 68 | 69 | def recompute_normals(self): 70 | for mesh in self.meshes: 71 | mesh.vertex_normals = compute_vertex_normals(mesh.vertices, mesh.faces) 72 | 73 | return self 74 | 75 | def rescale(self, scale=1.0): 76 | """Set scale of object mesh. 77 | 78 | :param scale 79 | """ 80 | self.scale = scale 81 | for mesh in self.meshes: 82 | mesh.apply_scale(self.scale) 83 | 84 | return self 85 | 86 | def resize(self, size, ref='diameter'): 87 | """Set longest of all three lengths in Cartesian space. 88 | 89 | :param size 90 | """ 91 | if ref == 'diameter': 92 | ref_scale = self.bounding_diameter 93 | else: 94 | ref_scale = self.bounding_size 95 | 96 | self.scale = size / ref_scale 97 | for mesh in self.meshes: 98 | mesh.apply_scale(self.scale) 99 | 100 | return self 101 | 102 | @property 103 | def centroid(self): 104 | return self.bounds.mean(axis=0) 105 | 106 | @property 107 | def bounding_size(self): 108 | return max(self.extents) 109 | 110 | @property 111 | def bounding_diameter(self): 112 | centroid = self.bounds.mean(axis=0) 113 | max_radius = linalg.norm(self.vertices - centroid, axis=1).max() 114 | return max_radius * 2 115 | 116 | @property 117 | def bounding_radius(self): 118 | return self.bounding_diameter / 2.0 119 | 120 | @property 121 | def extents(self): 122 | min_dim = np.min(self.vertices, axis=0) 123 | max_dim = np.max(self.vertices, axis=0) 124 | return max_dim - min_dim 125 | 126 | @property 127 | def bounds(self): 128 | min_dim = np.min(self.vertices, axis=0) 129 | max_dim = np.max(self.vertices, axis=0) 130 | return np.stack((min_dim, max_dim), axis=0) 131 | 132 | def recenter(self, method='bounds'): 133 | if method == 'mean': 134 | # Center the mesh. 135 | vertex_mean = np.mean(self.vertices, 0) 136 | translation = -vertex_mean 137 | elif method == 'bounds': 138 | center = self.bounds.mean(axis=0) 139 | translation = -center 140 | else: 141 | raise ValueError(f"Unknown method {method!r}") 142 | 143 | for mesh in self.meshes: 144 | mesh.apply_translation(translation) 145 | 146 | return self 147 | 148 | @property 149 | def vertices(self): 150 | return np.concatenate([mesh.vertices for mesh in self.meshes]) 151 | -------------------------------------------------------------------------------- /three/orientation.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | # from latentfusion import three 6 | # from latentfusion.three import quaternion as q 7 | 8 | 9 | import three 10 | from three import quaternion as q 11 | 12 | 13 | def spiral_orbit(n, c=16): 14 | phi = torch.linspace(0, math.pi, n) 15 | theta = c * phi 16 | rot_quat = q.from_spherical(phi, theta) 17 | return rot_quat 18 | 19 | 20 | def _check_up(up, n): 21 | if not torch.is_tensor(up): 22 | up = torch.tensor(up, dtype=torch.float32) 23 | if len(up.shape) == 1: 24 | up = up.expand(n, -1) 25 | return three.normalize(up) 26 | 27 | 28 | def _is_ray_in_segment(ray, up, min_angle, max_angle): 29 | angle = torch.acos(three.inner_product(up, ray)) 30 | return (min_angle <= angle) & (angle <= max_angle) 31 | 32 | 33 | def sample_segment_rays(n, up, min_angle, max_angle): 34 | up = _check_up(up, n) 35 | 36 | # Sample random new 'up' orientation. 37 | rays = three.normalize(torch.randn(n, 3)) 38 | num_invalid = n 39 | while num_invalid > 0: 40 | valid = _is_ray_in_segment(rays, up, min_angle, max_angle) 41 | num_invalid = (~valid).sum().item() 42 | rays[~valid] = three.normalize(torch.randn(num_invalid, 3)) 43 | 44 | return three.normalize(rays) 45 | 46 | 47 | def sample_hemisphere_rays(n, up): 48 | """ 49 | Samples a ray in the upper hemisphere (defined by `up`). 50 | 51 | Implemented by sampling a uniform random ray on the unit sphere and reflecting 52 | the vector to be on the same side as the up vector. 53 | 54 | Args: 55 | n (int): number of rays to sample 56 | up (tensor, tuple): the up direction defining the hemisphere 57 | 58 | Returns: 59 | (tensor): `n` rays uniformly sampled on the hemisphere 60 | 61 | """ 62 | up = _check_up(up, n) 63 | 64 | # Sample random new 'up' orientation. 65 | rays = three.normalize(torch.randn(n, 3)) 66 | 67 | # Reflect to upper hemisphere. 68 | dot = (up * rays).sum(dim=-1) 69 | rays[dot < 0] = rays[dot < 0] - 2 * dot[dot < 0, None] * up[dot < 0] 70 | 71 | return rays 72 | 73 | 74 | def random_quat_from_ray(forward, up=None): 75 | """ 76 | Sample uniformly random quaternions that orients the camera forward direction. 77 | 78 | Args: 79 | forward: a vector representing the forward direction. 80 | 81 | Returns: 82 | 83 | """ 84 | n = forward.shape[0] 85 | if up is None: 86 | down = three.uniform_unit_vector(n) 87 | else: 88 | up = torch.tensor(up).unsqueeze(0).expand(n, 3) 89 | up = up + forward 90 | down = -up 91 | right = three.normalize(torch.cross(down, forward)) 92 | down = three.normalize(torch.cross(forward, right)) 93 | 94 | mat = torch.stack([right, down, forward], dim=1) 95 | 96 | return three.quaternion.mat_to_quat(mat) 97 | 98 | 99 | def sample_segment_quats(n, up, min_angle, max_angle): 100 | """ 101 | Sample a quaternion where the resulting `up` direction is constrained to a segment of the sphere. 102 | 103 | This is performed by first sampling a 'yaw' angle and then sampling a random 'up' direction in the segment. 104 | 105 | The sphere segment is defined as being [min_angle,max_angle] radians away from the 'up' direction. 106 | 107 | Args: 108 | n (int): number of rays to sample 109 | up (tensor, tuple): the up direction defining sphere segment 110 | min_angle (float): the min angle from the up direction defining the sphere segment 111 | max_angle (float): the max angle from the up direction defining the sphere segment 112 | 113 | Returns: 114 | (tensor): a batch of sampled quaternions 115 | """ 116 | up = _check_up(up, n) 117 | 118 | yaw_angle = torch.rand(n) * math.pi * 2.0 119 | yaw_quat = q.from_axis_angle(up, yaw_angle) 120 | 121 | rays = sample_segment_rays(n, up, min_angle, max_angle) 122 | 123 | pivot = torch.cross(up, rays) 124 | angles = torch.acos(three.inner_product(up, rays)) 125 | quat = q.from_axis_angle(pivot, angles) 126 | 127 | return q.qmul(quat, yaw_quat) 128 | 129 | 130 | def evenly_distributed_points(n: int, hemisphere=False, pole=(0.0, 0.0, 1.0)): 131 | """ 132 | Uses the sunflower method to sample points on a sphere that are 133 | roughly evenly distributed. 134 | 135 | Reference: 136 | https://stackoverflow.com/questions/9600801/evenly-distributing-n-points-on-a-sphere/44164075#44164075 137 | """ 138 | indices = torch.arange(0, n, dtype=torch.float32) + 0.5 139 | 140 | if hemisphere: 141 | phi = torch.acos(1 - 2 * indices / n / 2) 142 | else: 143 | phi = torch.acos(1 - 2 * indices / n) 144 | theta = math.pi * (1 + 5 ** 0.5) * indices 145 | 146 | points = torch.stack([ 147 | torch.cos(theta) * torch.sin(phi), 148 | torch.sin(theta) * torch.sin(phi), 149 | torch.cos(phi), 150 | ], dim=1) 151 | 152 | if hemisphere: 153 | default_pole = torch.tensor([(0.0, 0.0, 1.0)]).expand(n, 3) 154 | pole = torch.tensor([pole]).expand(n, 3) 155 | if (default_pole[0] + pole[0]).abs().sum() < 1e-5: 156 | # If the pole is the opposite side just flip. 157 | points = -points 158 | elif (default_pole[0] - pole[0]).abs().sum() < 1e-5: 159 | points = points 160 | else: 161 | # Otherwise take the cross product as the rotation axis. 162 | rot_axis = torch.cross(pole, default_pole) 163 | rot_angle = torch.acos(three.inner_product(pole, default_pole)) 164 | rot_quat = three.quaternion.from_axis_angle(rot_axis, rot_angle) 165 | points = three.quaternion.rotate_vector(rot_quat, points) 166 | 167 | return points 168 | 169 | 170 | def evenly_distributed_quats(n: int, hemisphere=False, hemisphere_pole=(0.0, 0.0, 1.0), 171 | upright=False, upright_up=(0.0, 0.0, 1.0)): 172 | rays = evenly_distributed_points(n, hemisphere, hemisphere_pole) 173 | return random_quat_from_ray(-rays, upright_up if upright else None) 174 | 175 | 176 | @torch.jit.script 177 | def disk_sample_quats(n: int, min_angle: float, max_tries: int = 64): 178 | 179 | quats = q.random(1) 180 | 181 | num_tries = 0 182 | while quats.shape[0] < n: 183 | new_quat = q.random(1) 184 | angles = q.angular_distance(quats, new_quat) 185 | if torch.all(angles >= min_angle) or num_tries > max_tries: 186 | quats = torch.cat((quats, new_quat), dim=0) 187 | num_tries = 0 188 | else: 189 | num_tries += 1 190 | 191 | return quats 192 | -------------------------------------------------------------------------------- /three/pytorch3d_rendering.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def batchfy_mv_meshes(meshes, num_views, batch_size=1): 4 | from pytorch3d import structures as py3d_struct 5 | mv_meshes = list() 6 | for mesh in meshes.split([1 for _ in range(batch_size)]): 7 | mv_meshes.append(mesh.extend(num_views)) 8 | mv_meshes = py3d_struct.join_meshes_as_batch(mv_meshes) 9 | return mv_meshes 10 | 11 | def generate_mesh_model_from_3Dbbox(bbox3D): 12 | from pytorch3d import structures as py3d_struct 13 | bbox3D = bbox3D.squeeze() 14 | assert(bbox3D.dim() == 2 and bbox3D.shape[1] == 3), bbox3D.shape 15 | if len(bbox3D) == 2: # 2x3 16 | xmin, ymin, zmin = bbox3D[0] 17 | xmax, ymax, zmax = bbox3D[1] 18 | bbox3D_corners = torch.tensor([ 19 | [xmin, ymin, zmin], 20 | [xmax, ymin, zmin], 21 | [xmax, ymax, zmin], 22 | [xmin, ymax, zmin], 23 | [xmin, ymin, zmax], 24 | [xmax, ymin, zmax], 25 | [xmax, ymax, zmax], 26 | [xmin, ymax, zmax] 27 | ]) 28 | elif len(bbox3D) == 8: # 8x3 29 | bbox3D_corners = bbox3D 30 | else: 31 | print(bbox3D) 32 | raise NotImplementedError 33 | 34 | assert(bbox3D_corners.dim() == 2 and len(bbox3D_corners) == 8), bbox3D_corners.shape 35 | bbox3D_faces = torch.tensor([ 36 | [0, 1, 2, 3], 37 | [4, 5, 6, 7], 38 | [0, 1, 5, 4], 39 | [1, 2, 6, 5], 40 | [2, 3, 7, 6], 41 | [3, 0, 4, 7],]) 42 | # Convert quadrilateral faces to triangles 43 | bbox3D_faces = torch.cat([ 44 | bbox3D_faces[:, [0, 1, 2]], 45 | bbox3D_faces[:, [2, 3, 0]] 46 | ], dim=0) 47 | 48 | # Create the mesh 49 | bbox3D_mesh = py3d_struct.Meshes(verts=[bbox3D_corners], faces=[bbox3D_faces]) 50 | return bbox3D_mesh 51 | 52 | def compute_ROI_camera_intrinsic(camK, img_size, camera_dist=None, scaled_T=None, 53 | bbox_center=None, bbox_scale=None): 54 | """ 55 | camK: Bx3x3 56 | scaled_T: Bx3 57 | bbox_center: Bx2 58 | bbox_scale: B 59 | camera_dist: B 60 | """ 61 | squeeze = False 62 | if not isinstance(camK, torch.Tensor): 63 | camK = torch.tensor(camK) 64 | 65 | if scaled_T is not None and not isinstance(scaled_T, torch.Tensor): 66 | scaled_T = torch.tensor(scaled_T) 67 | if bbox_center is not None and not isinstance(bbox_center, torch.Tensor): 68 | bbox_center = torch.tensor(bbox_center) 69 | if bbox_scale is not None and not isinstance(bbox_scale, torch.Tensor): 70 | bbox_scale = torch.tensor(bbox_scale) 71 | 72 | device = camK.device 73 | assert(scaled_T is not None or bbox_center is not None) 74 | if scaled_T is not None: 75 | assert(camera_dist is not None) 76 | if scaled_T.dim() == 1: 77 | squeeze = True 78 | scaled_T = scaled_T[None, :] 79 | assert(scaled_T.dim() == 2), scaled_T.shape 80 | obj_2D_center = torch.einsum('bij,bj->bi', camK, scaled_T.to(device)) 81 | bbox_center = obj_2D_center[:, :2] / obj_2D_center[:, 2:3] # 82 | bbox_scale = camera_dist * img_size / scaled_T[:, 2].to(device) 83 | # print(bbox_center, bbox_scale) 84 | elif bbox_center is not None: 85 | assert(bbox_scale is not None) 86 | 87 | if bbox_center.dim() == 1: 88 | bbox_center = bbox_center[None, :] 89 | squeeze = True 90 | if bbox_scale.dim() == 0: 91 | bbox_scale = bbox_scale[None] 92 | if camK.dim() == 2: 93 | camK = camK[None, :, :] 94 | squeeze = True 95 | 96 | assert(bbox_center.dim() == 2 and bbox_scale.dim() == 1), (bbox_center.shape, bbox_scale.shape) 97 | 98 | bbox_x1y1 = bbox_center - bbox_scale[:, None] / 2 99 | bbox_rescaling_factor = img_size / bbox_scale 100 | T_cam2roi = torch.eye(3)[None, :, :].repeat(len(bbox_rescaling_factor), 1, 1).to(device) 101 | T_cam2roi[:, :2, 2] = -bbox_x1y1.to(device) 102 | T_cam2roi[:, :2, :] *= bbox_rescaling_factor[:, None, None].to(device) 103 | new_camK = torch.einsum('bij,bjk->bik', T_cam2roi, camK) 104 | if squeeze: 105 | new_camK = new_camK.squeeze(0) 106 | return new_camK 107 | 108 | def generate_3D_coordinate_map_from_depth(depth, camK, obj_RT): 109 | if depth.squeeze().dim() == 2: 110 | depth = depth.squeeze()[None, None, :, :] 111 | elif depth.squeeze().dim() == 3: 112 | depth = depth.squeeze()[:, None, :, :] 113 | 114 | if obj_RT.dim() == 2: 115 | obj_RT = obj_RT[None, :, :] 116 | if camK.dim() == 2: 117 | camK = camK[None, :, :] 118 | 119 | if len(camK) != len(obj_RT): 120 | camK = camK.repeat(len(obj_RT), 1, 1) 121 | 122 | assert(depth.size(0) == obj_RT.size(0)) 123 | assert(len(depth) == len(obj_RT)), (depth.shape, obj_RT.shape) 124 | 125 | device = depth.device 126 | im_hei, im_wid = depth.shape[-2:] 127 | YY, XX = torch.meshgrid(torch.arange(im_hei), torch.arange(im_wid), indexing='ij') 128 | XYZ_map = torch.stack([XX, YY, torch.ones_like(XX)], dim=0).to(device) # 3xHxW 129 | XYZ_map = XYZ_map[None, :, :, :] * depth # 1x3xHxW, Bx1xHxW 130 | 131 | XYZ_map = torch.einsum('bij,bjhw->bihw', torch.inverse(camK).to(device), XYZ_map) 132 | Rs = obj_RT[:, :3, :3].to(device) 133 | Ts = obj_RT[:, :3, 3].to(device) 134 | 135 | XYZ_map = torch.einsum('bij,bjhw->bihw', torch.inverse(Rs), XYZ_map - Ts[:, :, None, None]) 136 | 137 | return XYZ_map 138 | 139 | def render_depth_from_mesh_model(mesh, obj_RT, camK, img_hei, img_wid, return_coordinate_map=False): 140 | """ 141 | Pytorch3D: K_4x4 = [ 142 | [fx, 0, px, 0], 143 | [0, fy, py, 0], 144 | [0, 0, 0, 1], 145 | [0, 0, 1, 0], 146 | ] 147 | """ 148 | from pytorch3d import renderer as py3d_renderer 149 | from pytorch3d import transforms as py3d_transform 150 | from pytorch3d.transforms import euler_angles_to_matrix 151 | 152 | device = obj_RT.device 153 | if obj_RT.dim() == 2: 154 | obj_RT = obj_RT[None, :, :] 155 | 156 | if camK.dim() == 2: 157 | camK = camK[None, :, :] 158 | 159 | if len(mesh) != len(obj_RT): 160 | mesh = mesh.extend(len(obj_RT)) 161 | 162 | assert(len(mesh) == len(obj_RT)), (len(mesh), obj_RT.shape) 163 | 164 | Rz_mat = torch.eye(4).to(device) 165 | Rz_mat[:3, :3] = py3d_transform.euler_angles_to_matrix(torch.as_tensor([0, 0, torch.pi]), 'XYZ') 166 | py3d_RT = torch.einsum('ij,bjk->bik', Rz_mat, obj_RT) 167 | cam_R = py3d_RT[:, :3, :3].transpose(-2, -1) 168 | cam_T = py3d_RT[:, :3, 3] 169 | 170 | fxfy = torch.stack([camK[:, 0, 0], camK[:, 1, 1]], dim=1) 171 | pxpy = torch.stack([camK[:, 0, 2], camK[:, 1, 2]], dim=1) 172 | cameras = py3d_renderer.PerspectiveCameras(R=cam_R, 173 | T=cam_T, 174 | image_size=((img_hei, img_wid),), 175 | focal_length=fxfy, 176 | principal_point=pxpy, 177 | in_ndc=False, 178 | device=device) 179 | # Define rasterizer settings 180 | raster_settings = py3d_renderer.RasterizationSettings( 181 | image_size=(img_hei, img_wid), 182 | blur_radius=0.0, 183 | faces_per_pixel=1, 184 | bin_size=0, 185 | ) 186 | 187 | rasterizer = py3d_renderer.MeshRasterizer( 188 | cameras=cameras, 189 | raster_settings=raster_settings, 190 | ) 191 | 192 | fragments = rasterizer(mesh.to(device)) 193 | depth_map = fragments.zbuf[..., 0].unsqueeze(1) # Bx1xHxW 194 | depth_mask = torch.zeros_like(depth_map) 195 | depth_mask[depth_map>0] = 1 196 | depth_map *= depth_mask 197 | 198 | if return_coordinate_map: 199 | XYZ_map = generate_3D_coordinate_map_from_depth(depth_map, camK, obj_RT) 200 | return depth_map, depth_mask, XYZ_map * depth_mask 201 | 202 | return depth_map, depth_mask 203 | 204 | -------------------------------------------------------------------------------- /three/quaternion.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from . import core 6 | 7 | __all__ = ['normalize', 'quat_to_mat'] 8 | 9 | 10 | def identity(n: int, device: str = 'cpu'): 11 | return torch.tensor((1.0, 0.0, 0.0, 0.0), device=device).view(1, 4).expand(n, 4) 12 | 13 | 14 | def normalize(quaternion: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: 15 | r"""Normalizes a quaternion. 16 | The quaternion should be in (x, y, z, w) format. 17 | 18 | Args: 19 | quaternion (torch.Tensor): a tensor containing a quaternion to be 20 | normalized. The tensor can be of shape :math:`(*, 4)`. 21 | eps (Optional[bool]): small value to avoid division by zero. 22 | Default: 1e-12. 23 | 24 | Return: 25 | torch.Tensor: the normalized quaternion of shape :math:`(*, 4)`. 26 | 27 | """ 28 | if not isinstance(quaternion, torch.Tensor): 29 | raise TypeError("Input type is not a torch.Tensor. Got {}".format( 30 | type(quaternion))) 31 | 32 | if not quaternion.shape[-1] == 4: 33 | raise ValueError( 34 | "Input must be a tensor of shape (*, 4). Got {}".format( 35 | quaternion.shape)) 36 | return F.normalize(quaternion, p=2.0, dim=-1, eps=eps) 37 | 38 | 39 | def quat_to_mat(quaternion: torch.Tensor) -> torch.Tensor: 40 | """ 41 | Converts a quaternion to a rotation matrix. 42 | The quaternion should be in (w, x, y, z) format. 43 | Adapted from: 44 | https://github.com/kornia/kornia/blob/d729d7c4357ca73e4915a42285a0771bca4436ce/kornia/geometry/conversions.py#L235 45 | Args: 46 | quaternion (torch.Tensor): a tensor containing a quaternion to be 47 | converted. The tensor can be of shape :math:`(*, 4)`. 48 | Return: 49 | torch.Tensor: the rotation matrix of shape :math:`(*, 3, 3)`. 50 | Example: 51 | >>> quaternion = torch.tensor([0., 0., 1., 0.]) 52 | >>> quat_to_mat(quaternion) 53 | tensor([[[-1., 0., 0.], 54 | [ 0., -1., 0.], 55 | [ 0., 0., 1.]]]) 56 | """ 57 | quaternion, unsqueezed = core.ensure_batch_dim(quaternion, 1) 58 | 59 | if not quaternion.shape[-1] == 4: 60 | raise ValueError( 61 | "Input must be a tensor of shape (*, 4). Got {}".format( 62 | quaternion.shape)) 63 | # normalize the input quaternion 64 | quaternion_norm = normalize(quaternion) 65 | 66 | # unpack the normalized quaternion components 67 | w, x, y, z = torch.chunk(quaternion_norm, chunks=4, dim=-1) 68 | 69 | # compute the actual conversion 70 | tx: torch.Tensor = 2.0 * x 71 | ty: torch.Tensor = 2.0 * y 72 | tz: torch.Tensor = 2.0 * z 73 | twx: torch.Tensor = tx * w 74 | twy: torch.Tensor = ty * w 75 | twz: torch.Tensor = tz * w 76 | txx: torch.Tensor = tx * x 77 | txy: torch.Tensor = ty * x 78 | txz: torch.Tensor = tz * x 79 | tyy: torch.Tensor = ty * y 80 | tyz: torch.Tensor = tz * y 81 | tzz: torch.Tensor = tz * z 82 | one: torch.Tensor = torch.tensor(1.) 83 | 84 | matrix: torch.Tensor = torch.stack([ 85 | one - (tyy + tzz), txy - twz, txz + twy, 86 | txy + twz, one - (txx + tzz), tyz - twx, 87 | txz - twy, tyz + twx, one - (txx + tyy) 88 | ], dim=-1).view(-1, 3, 3) 89 | 90 | if unsqueezed: 91 | matrix = matrix.squeeze(0) 92 | 93 | return matrix 94 | 95 | 96 | def mat_to_quat(rotation_matrix: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: 97 | """ 98 | Convert 3x3 rotation matrix to 4d quaternion vector. 99 | The quaternion vector has components in (w, x, y, z) format. 100 | Adapted From: 101 | https://github.com/kornia/kornia/blob/d729d7c4357ca73e4915a42285a0771bca4436ce/kornia/geometry/conversions.py#L235 102 | Args: 103 | rotation_matrix (torch.Tensor): the rotation matrix to convert. 104 | eps (float): small value to avoid zero division. Default: 1e-8. 105 | Return: 106 | torch.Tensor: the rotation in quaternion. 107 | Shape: 108 | - Input: :math:`(*, 3, 3)` 109 | - Output: :math:`(*, 4)` 110 | """ 111 | rotation_matrix, unsqueezed = core.ensure_batch_dim(rotation_matrix, 2) 112 | 113 | if not isinstance(rotation_matrix, torch.Tensor): 114 | raise TypeError("Input type is not a torch.Tensor. Got {}".format( 115 | type(rotation_matrix))) 116 | 117 | if not rotation_matrix.shape[-2:] == (3, 3): 118 | raise ValueError( 119 | "Input size must be a (*, 3, 3) tensor. Got {}".format( 120 | rotation_matrix.shape)) 121 | 122 | def safe_zero_division(numerator: torch.Tensor, 123 | denominator: torch.Tensor) -> torch.Tensor: 124 | eps = torch.finfo(numerator.dtype).tiny 125 | return numerator / torch.clamp(denominator, min=eps) 126 | 127 | if not rotation_matrix.is_contiguous(): 128 | rotation_matrix_vec: torch.Tensor = rotation_matrix.reshape( 129 | *rotation_matrix.shape[:-2], 9) 130 | else: 131 | rotation_matrix_vec: torch.Tensor = rotation_matrix.view( 132 | *rotation_matrix.shape[:-2], 9) 133 | 134 | m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.chunk( 135 | rotation_matrix_vec, chunks=9, dim=-1) 136 | 137 | trace: torch.Tensor = m00 + m11 + m22 138 | 139 | def trace_positive_cond(): 140 | sq = torch.sqrt(trace + 1.0) * 2. # sq = 4 * qw. 141 | qw = 0.25 * sq 142 | qx = safe_zero_division(m21 - m12, sq) 143 | qy = safe_zero_division(m02 - m20, sq) 144 | qz = safe_zero_division(m10 - m01, sq) 145 | return torch.cat([qw, qx, qy, qz], dim=-1) 146 | 147 | def cond_1(): 148 | sq = torch.sqrt(1.0 + m00 - m11 - m22 + eps) * 2. # sq = 4 * qx. 149 | qw = safe_zero_division(m21 - m12, sq) 150 | qx = 0.25 * sq 151 | qy = safe_zero_division(m01 + m10, sq) 152 | qz = safe_zero_division(m02 + m20, sq) 153 | return torch.cat([qw, qx, qy, qz], dim=-1) 154 | 155 | def cond_2(): 156 | sq = torch.sqrt(1.0 + m11 - m00 - m22 + eps) * 2. # sq = 4 * qy. 157 | qw = safe_zero_division(m02 - m20, sq) 158 | qx = safe_zero_division(m01 + m10, sq) 159 | qy = 0.25 * sq 160 | qz = safe_zero_division(m12 + m21, sq) 161 | return torch.cat([qw, qx, qy, qz], dim=-1) 162 | 163 | def cond_3(): 164 | sq = torch.sqrt(1.0 + m22 - m00 - m11 + eps) * 2. # sq = 4 * qz. 165 | qw = safe_zero_division(m10 - m01, sq) 166 | qx = safe_zero_division(m02 + m20, sq) 167 | qy = safe_zero_division(m12 + m21, sq) 168 | qz = 0.25 * sq 169 | return torch.cat([qw, qx, qy, qz], dim=-1) 170 | 171 | where_2 = torch.where(m11 > m22, cond_2(), cond_3()) 172 | where_1 = torch.where((m00 > m11) & (m00 > m22), cond_1(), where_2) 173 | 174 | quaternion: torch.Tensor = torch.where( 175 | trace > 0., trace_positive_cond(), where_1) 176 | 177 | if unsqueezed: 178 | quaternion = quaternion.squeeze(0) 179 | 180 | return quaternion 181 | 182 | 183 | @torch.jit.script 184 | def random(k: int = 1, device: str = 'cpu'): 185 | """Return uniform random unit quaternion. 186 | rand: array like or None 187 | Three independent random variables that are uniformly distributed 188 | between 0 and 1. 189 | 190 | """ 191 | rand = torch.rand(k, 3, device=device) 192 | r1 = torch.sqrt(1.0 - rand[:, 0]) 193 | r2 = torch.sqrt(rand[:, 0]) 194 | pi2 = math.pi * 2.0 195 | t1 = pi2 * rand[:, 1] 196 | t2 = pi2 * rand[:, 2] 197 | 198 | return torch.stack([ 199 | torch.cos(t2) * r2, 200 | torch.sin(t1) * r1, 201 | torch.cos(t1) * r1, 202 | torch.sin(t2) * r2 203 | ], dim=1) 204 | 205 | 206 | def qmul(q1, q2): 207 | """ 208 | Quaternion multiplication. 209 | 210 | Use the Hamilton product to perform quaternion multiplication. 211 | 212 | References: 213 | http://en.wikipedia.org/wiki/Quaternions#Hamilton_product 214 | https://github.com/matthew-brett/transforms3d/blob/master/transforms3d/quaternions.py 215 | """ 216 | assert q1.shape[-1] == 4 217 | assert q2.shape[-1] == 4 218 | 219 | ham_prod = torch.bmm(q2.view(-1, 4, 1), q1.view(-1, 1, 4)) 220 | 221 | w = ham_prod[:, 0, 0] - ham_prod[:, 1, 1] - ham_prod[:, 2, 2] - ham_prod[:, 3, 3] 222 | x = ham_prod[:, 0, 1] + ham_prod[:, 1, 0] - ham_prod[:, 2, 3] + ham_prod[:, 3, 2] 223 | y = ham_prod[:, 0, 2] + ham_prod[:, 1, 3] + ham_prod[:, 2, 0] - ham_prod[:, 3, 1] 224 | z = ham_prod[:, 0, 3] - ham_prod[:, 1, 2] + ham_prod[:, 2, 1] + ham_prod[:, 3, 0] 225 | 226 | return torch.stack((w, x, y, z), dim=1).view(q1.shape) 227 | 228 | 229 | def rotate_vector(quat, vector): 230 | """ 231 | References: 232 | https://github.com/matthew-brett/transforms3d/blob/master/transforms3d/quaternions.py#L419 233 | """ 234 | assert quat.shape[-1] == 4 235 | assert vector.shape[-1] == 3 236 | assert quat.shape[:-1] == vector.shape[:-1] 237 | 238 | original_shape = list(vector.shape) 239 | quat = quat.view(-1, 4) 240 | vector = vector.view(-1, 3) 241 | 242 | pure_quat = quat[:, 1:] 243 | uv = torch.cross(pure_quat, vector, dim=1) 244 | uuv = torch.cross(pure_quat, uv, dim=1) 245 | return (vector + 2 * (quat[:, :1] * uv + uuv)).view(original_shape) 246 | 247 | 248 | def from_spherical(theta, phi, r=1.0): 249 | x = torch.cos(theta) * torch.sin(phi) 250 | y = torch.sin(theta) * torch.sin(phi) 251 | z = r * torch.cos(phi) 252 | w = torch.zeros_like(x) 253 | 254 | return torch.stack((w, x, y, z), dim=-1) 255 | 256 | 257 | def from_axis_angle(axis, angle): 258 | """ 259 | Compute a quaternion from the axis angle representation. 260 | 261 | Reference: 262 | https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation 263 | Args: 264 | axis: axis to rotate about 265 | angle: angle to rotate by 266 | 267 | Returns: 268 | Tensor of shape (*, 4) representing a quaternion. 269 | """ 270 | if torch.is_tensor(axis) and isinstance(angle, float): 271 | angle = torch.tensor(angle, dtype=axis.dtype, device=axis.device) 272 | angle = angle.expand(axis.shape[0]) 273 | 274 | axis = axis / torch.norm(axis, dim=-1, keepdim=True) 275 | 276 | c = torch.cos(angle / 2.0) 277 | s = torch.sin(angle / 2.0) 278 | 279 | w = c 280 | x = s * axis[..., 0] 281 | y = s * axis[..., 1] 282 | z = s * axis[..., 2] 283 | 284 | return torch.stack((w, x, y, z), dim=-1) 285 | 286 | 287 | def qexp(q, eps=1e-8): 288 | """ 289 | Computes the quaternion exponent. 290 | 291 | Reference: 292 | https://en.wikipedia.org/wiki/Quaternion#Exponential,_logarithm,_and_power_functions 293 | Args: 294 | q (tensor): the quaternion to compute the exponent of 295 | Returns: 296 | (tensor): Tensor of shape (*, 4) representing exp(q) 297 | """ 298 | 299 | if q.shape[1] == 4: 300 | # Let q = (s; v). 301 | s, v = torch.split(q, (1, 3), dim=-1) 302 | else: 303 | s = torch.zeros_like(q[:, :1]) 304 | v = q 305 | 306 | theta = torch.norm(v, dim=-1, keepdim=True) 307 | exp_s = torch.exp(s) 308 | w = torch.cos(theta) 309 | xyz = 1.0 / theta.clamp(min=eps) * torch.sin(theta) * v 310 | 311 | return exp_s * torch.cat((w, xyz), dim=-1) 312 | 313 | 314 | def qlog(q, eps=1e-8): 315 | """ 316 | Computes the quaternion logarithm. 317 | 318 | Reference: 319 | https://en.wikipedia.org/wiki/Quaternion#Exponential,_logarithm,_and_power_functions 320 | https://users.aalto.fi/~ssarkka/pub/quat.pdf 321 | Args: 322 | q (tensor): the quaternion to compute the logarithm of 323 | Returns: 324 | (tensor): Tensor of shape (*, 4) representing ln(q) 325 | """ 326 | 327 | mag = torch.norm(q, dim=-1, keepdim=True) 328 | # Let q = (s; v). 329 | s, v = torch.split(q, (1, 3), dim=-1) 330 | w = torch.log(mag) 331 | xyz = (v / torch.norm(v, dim=-1, keepdim=True).clamp(min=eps) 332 | * core.acos_safe(s / mag.clamp(min=eps))) 333 | 334 | return torch.cat((w, xyz), dim=-1) 335 | 336 | 337 | def qdelta(n, std, device=None): 338 | omega = torch.cat((torch.zeros(n, 1, device=device), 339 | torch.randn(n, 3, device=device)), dim=-1) 340 | delta_q = qexp(std / 2.0 * omega) 341 | return delta_q 342 | 343 | 344 | def perturb(q, std): 345 | """ 346 | Perturbs the unit quaternion `q`. 347 | 348 | References: 349 | https://math.stackexchange.com/questions/2992016/how-to-linearize-quaternions 350 | http://asrl.utias.utoronto.ca/~tdb/bib/barfoot_aa10_appendix.pdf 351 | https://math.stackexchange.com/questions/473736/small-angular-displacements-with-a-quaternion-representation 352 | 353 | Args: 354 | q (tensor): the quaternion to perturb (the mean of the perturbation) 355 | std (float): the stadnard deviation of the perturbation 356 | 357 | Returns: 358 | (tensor): Tensor of shape (*, 4), the perturbed quaternion 359 | """ 360 | q, unsqueezed = core.ensure_batch_dim(q, num_dims=1) 361 | 362 | n = q.shape[0] 363 | delta_q = qdelta(n, std, device=q.device) 364 | q_out = qmul(delta_q, q) 365 | 366 | if unsqueezed: 367 | q_out = q_out.squeeze(0) 368 | 369 | return q_out 370 | 371 | 372 | def angular_distance(q1, q2, eps: float = 1e-7): 373 | q1 = normalize(q1) 374 | q2 = normalize(q2) 375 | dot = q1 @ q2.t() 376 | dist = 2 * core.acos_safe(dot.abs(), eps=eps) 377 | return dist 378 | -------------------------------------------------------------------------------- /three/rigid.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | # from latentfusion import three 7 | # from latentfusion.three import ensure_batch_dim 8 | 9 | import three 10 | from three import quaternion as q 11 | from three import ensure_batch_dim 12 | 13 | def intrinsic_to_3x4(matrix): 14 | matrix, unsqueezed = ensure_batch_dim(matrix, num_dims=2) 15 | 16 | zeros = torch.zeros(1, 3, 1, 17 | dtype=matrix.dtype, 18 | device=matrix.device 19 | ).expand(matrix.shape[0], -1, -1) 20 | mat = torch.cat((matrix, zeros), dim=-1) 21 | 22 | if unsqueezed: 23 | mat = mat.squeeze(0) 24 | 25 | return mat 26 | 27 | 28 | @torch.jit.script 29 | def matrix_3x3_to_4x4(matrix): 30 | matrix, unsqueezed = ensure_batch_dim(matrix, num_dims=2) 31 | 32 | mat = F.pad(matrix, [0, 1, 0, 1]) 33 | mat[:, -1, -1] = 1.0 34 | 35 | if unsqueezed: 36 | mat = mat.squeeze(0) 37 | 38 | return mat 39 | 40 | 41 | @torch.jit.script 42 | def rotation_to_4x4(matrix): 43 | return matrix_3x3_to_4x4(matrix) 44 | 45 | 46 | @torch.jit.script 47 | def translation_to_4x4(translation): 48 | translation, unsqueezed = ensure_batch_dim(translation, num_dims=1) 49 | 50 | eye = torch.eye(4, device=translation.device) 51 | mat = F.pad(translation.unsqueeze(2), [3, 0, 0, 1]) + eye 52 | 53 | if unsqueezed: 54 | mat = mat.squeeze(0) 55 | 56 | return mat 57 | 58 | 59 | def translate_matrix(matrix, offset): 60 | matrix, unsqueezed = ensure_batch_dim(matrix, num_dims=2) 61 | 62 | out = inverse_transform(matrix) 63 | out[:, :3, 3] += offset 64 | out = inverse_transform(out) 65 | 66 | if unsqueezed: 67 | out = out.squeeze(0) 68 | 69 | return out 70 | 71 | 72 | def scale_matrix(matrix, scale): 73 | matrix, unsqueezed = ensure_batch_dim(matrix, num_dims=2) 74 | 75 | out = inverse_transform(matrix) 76 | out[:, :3, 3] *= scale 77 | out = inverse_transform(out) 78 | 79 | if unsqueezed: 80 | out = out.squeeze(0) 81 | 82 | return out 83 | 84 | 85 | def decompose(matrix): 86 | matrix, unsqueezed = ensure_batch_dim(matrix, num_dims=2) 87 | 88 | # Extract rotation matrix. 89 | origin = (torch.tensor([0.0, 0.0, 0.0, 1.0], device=matrix.device, dtype=matrix.dtype) 90 | .unsqueeze(1) 91 | .unsqueeze(0)) 92 | origin = origin.expand(matrix.size(0), -1, -1) 93 | R = torch.cat((matrix[:, :, :3], origin), dim=-1) 94 | 95 | # Extract translation matrix. 96 | eye = torch.eye(4, 3, device=matrix.device).unsqueeze(0).expand(matrix.size(0), -1, -1) 97 | T = torch.cat((eye, matrix[:, :, 3].unsqueeze(-1)), dim=-1) 98 | 99 | if unsqueezed: 100 | R = R.squeeze(0) 101 | T = T.squeeze(0) 102 | 103 | return R, T 104 | 105 | 106 | def inverse_transform(matrix): 107 | matrix, unsqueezed = ensure_batch_dim(matrix, num_dims=2) 108 | 109 | R, T = decompose(matrix) 110 | R_inv = R.transpose(1, 2) 111 | t = T[:, :4, 3].unsqueeze(2) 112 | t_inv = (R_inv @ t)[:, :3].squeeze(2) 113 | 114 | out = torch.zeros_like(matrix) 115 | out[:, :3, :3] = R_inv[:, :3, :3] 116 | out[:, :3, 3] = -t_inv 117 | out[:, 3, 3] = 1 118 | 119 | if unsqueezed: 120 | out = out.squeeze(0) 121 | 122 | return out 123 | 124 | 125 | def extrinsic_to_position(extrinsic): 126 | extrinsic, unsqueezed = ensure_batch_dim(extrinsic, num_dims=2) 127 | 128 | rot_mat, trans_mat = decompose(extrinsic) 129 | position = rot_mat.transpose(2, 1) @ trans_mat[:, :, 3, None] 130 | position = three.dehomogenize(position.squeeze(-1)) 131 | 132 | if unsqueezed: 133 | position = position.squeeze(0) 134 | return position 135 | 136 | 137 | @torch.jit.script 138 | def random_translation(n: int, 139 | x_bound: Tuple[float, float], 140 | y_bound: Tuple[float, float], 141 | z_bound: Tuple[float, float]): 142 | trans_x = three.uniform(n, *x_bound) 143 | trans_y = three.uniform(n, *y_bound) 144 | trans_z = three.uniform(n, *z_bound) 145 | translation = torch.stack((trans_x, trans_y, trans_z), dim=-1) 146 | return translation 147 | 148 | 149 | @torch.jit.script 150 | def to_extrinsic_matrix(translation, quaternion): 151 | rot_mat = three.quaternion.quat_to_mat(quaternion) 152 | rot_mat = rotation_to_4x4(rot_mat) 153 | trans_mat = translation_to_4x4(translation) 154 | extrinsic = trans_mat @ rot_mat 155 | return extrinsic 156 | 157 | 158 | def extrinsic_to_quat(extrinsic): 159 | rot_mat, _ = decompose(extrinsic) 160 | rot_mat = rot_mat[..., :3, :3] 161 | return three.quaternion.mat_to_quat(rot_mat) 162 | -------------------------------------------------------------------------------- /three/stats.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mad(tensor, dim=0): 5 | median, _ = tensor.median(dim=dim) 6 | return torch.median(torch.abs(tensor - median), dim=dim)[0] 7 | 8 | 9 | def mask_outliers_mad(data, m=2.0): 10 | median = data.median() 11 | mad = torch.median(torch.abs(data - median)) 12 | mask = torch.abs(data - median) / mad < m 13 | return mask 14 | 15 | 16 | def reject_outliers_mad(data, m=2.0): 17 | return data[mask_outliers_mad(data, m)] 18 | 19 | 20 | def mask_outliers(data, m=2.0): 21 | mean = data.mean() 22 | std = torch.std(data) 23 | mask = torch.abs(data - mean) / std < m 24 | return mask 25 | 26 | 27 | def reject_outliers(data, m=2.0): 28 | return data[mask_outliers(data, m)] 29 | 30 | 31 | def robust_mean(data, m=2.0): 32 | return torch.mean(reject_outliers(data, m)) 33 | 34 | 35 | def robust_mean_mad(data, m=2.0): 36 | return torch.mean(reject_outliers_mad(data, m)) 37 | -------------------------------------------------------------------------------- /three/torchutils.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import abc 4 | import copy 5 | import random 6 | from contextlib import contextmanager 7 | from functools import partial 8 | from itertools import chain 9 | 10 | import structlog 11 | import torch 12 | import torch.autograd 13 | from torch import nn 14 | from torch.nn.parallel.scatter_gather import scatter_kwargs 15 | from torch.utils.data import Sampler, DataLoader 16 | 17 | logger = structlog.get_logger(__name__) 18 | 19 | 20 | def dict_to(d, device): 21 | return {k: v.to(device) for k, v in d.items()} 22 | 23 | 24 | def deparameterize_module(module): 25 | module = module.clone() 26 | for name, parameter in module.named_parameters(): 27 | parameter = parameter.detach() 28 | setattr(module, name, parameter) 29 | return module 30 | 31 | 32 | @contextmanager 33 | def manual_seed(seed): 34 | torch_state = torch.get_rng_state() 35 | torch.manual_seed(seed) 36 | yield 37 | torch.set_rng_state(torch_state) 38 | 39 | 40 | def module_device(module): 41 | return next(module.parameters()).device 42 | 43 | 44 | def save_checkpoint(save_dir, name, state): 45 | if not save_dir.exists(): 46 | logger.info("creating directory", path=save_dir) 47 | save_dir.mkdir(parents=True) 48 | 49 | path = save_dir / f'{name}.pth' 50 | logger.info("saving checkpoint", name=name, path=path) 51 | 52 | with path.open('wb') as f: 53 | torch.save(state, f) 54 | 55 | 56 | def save_if_better(save_dir, state, meters, key, bigger_is_better=False): 57 | best_key = f'best-{key}'.replace('/', '-') 58 | if best_key not in state: 59 | state[best_key] = -1 if bigger_is_better else float('inf') 60 | 61 | if bigger_is_better: 62 | better = meters[key].mean >= state[best_key] 63 | else: 64 | better = meters[key].mean <= state[best_key] 65 | 66 | if better: 67 | state[best_key] = meters[key].mean 68 | save_checkpoint(save_dir, best_key, state) 69 | 70 | 71 | class DeterministicShuffledSampler(Sampler): 72 | """Shuffles the dataset once and then samples deterministically.""" 73 | 74 | def __init__(self, data_source, replacement=False, num_samples=None): 75 | super().__init__(data_source) 76 | 77 | self.data_source = data_source 78 | self.replacement = replacement 79 | self.num_samples = num_samples 80 | 81 | if self.num_samples is not None and replacement is False: 82 | raise ValueError( 83 | "With replacement=False, num_samples should not be specified, " 84 | "since a random permute will be performed.") 85 | 86 | if self.num_samples is None: 87 | self.num_samples = len(self.data_source) 88 | 89 | if not isinstance(self.num_samples, int) or self.num_samples <= 0: 90 | raise ValueError("num_samples should be a positive integeral " 91 | "value, but got num_samples={}".format( 92 | self.num_samples)) 93 | if not isinstance(self.replacement, bool): 94 | raise ValueError("replacement should be a boolean value, but got " 95 | "replacement={}".format(self.replacement)) 96 | 97 | n = len(self.data_source) 98 | if self.replacement: 99 | self.permutation = torch.randint( 100 | high=n, size=(self.num_samples,), dtype=torch.int64).tolist() 101 | else: 102 | self.permutation = torch.randperm(n).tolist() 103 | 104 | def __iter__(self): 105 | return iter(self.permutation) 106 | 107 | def __len__(self): 108 | return len(self.data_source) 109 | 110 | 111 | class Scatterable(abc.ABC): 112 | """ 113 | A mixin to make an object scatterable across GPUs for data-parallelism. 114 | 115 | The object should be serialized to a dictionary in the `to_kwargs` method. Each entry should be 116 | a scatterable tensor object. 117 | 118 | The `from_kwargs` method should be able to take the dictionary format and reconstruct the original 119 | class object. 120 | """ 121 | 122 | @classmethod 123 | @abc.abstractmethod 124 | def to_kwargs(self): 125 | pass 126 | 127 | @classmethod 128 | @abc.abstractmethod 129 | def from_kwargs(cls, kwargs): 130 | pass 131 | 132 | 133 | class MyDataParallel(nn.DataParallel): 134 | """ 135 | A Scatterable-aware data parallel class. 136 | """ 137 | 138 | def scatter(self, inputs, kwargs, device_ids): 139 | _inputs = [] 140 | _kwargs = {} 141 | input_constructors = {} 142 | kwargs_constructors = {} 143 | for i, item in enumerate(inputs): 144 | if isinstance(item, Scatterable): 145 | input_constructors[i] = item.from_kwargs 146 | _inputs.append(item.to_kwargs()) 147 | else: 148 | input_constructors[i] = lambda x: x 149 | _inputs.append(item) 150 | 151 | for key, item in kwargs.items(): 152 | if isinstance(item, Scatterable): 153 | kwargs_constructors[key] = item.from_kwargs 154 | _kwargs[key] = item.to_kwargs() 155 | else: 156 | kwargs_constructors[key] = lambda x: x 157 | _kwargs[key] = item 158 | 159 | _inputs, _kwargs = scatter_kwargs(_inputs, _kwargs, device_ids, dim=self.dim) 160 | 161 | _inputs = [ 162 | [input_constructors[i](item) for i, item in enumerate(_input)] 163 | for _input in _inputs 164 | ] 165 | _kwargs = [ 166 | {k: kwargs_constructors[k](item) for k, item in _kwarg.items()} 167 | for _kwarg in _kwargs 168 | ] 169 | 170 | return _inputs, _kwargs 171 | 172 | 173 | class ListSampler(Sampler): 174 | r"""Samples given elements sequentially, always in the same order. 175 | Arguments: 176 | data_source (Dataset): dataset to sample from 177 | indices (iterable): list of indices to sample 178 | """ 179 | 180 | def __init__(self, data_source, indices): 181 | super().__init__(data_source) 182 | if indices is None: 183 | self.indices = list(range(len(data_source))) 184 | else: 185 | self.indices = indices 186 | 187 | def __iter__(self): 188 | return iter(self.indices) 189 | 190 | def __len__(self): 191 | return len(self.indices) 192 | 193 | 194 | class ShuffledSubsetSampler(Sampler): 195 | r"""Samples given elements sequentially, always in the same order. 196 | Arguments: 197 | data_source (Dataset): dataset to sample from 198 | indices (iterable): list of indices to sample 199 | """ 200 | 201 | def __init__(self, data_source, indices): 202 | super().__init__(data_source) 203 | self.indices = copy.deepcopy(indices) 204 | random.shuffle(self.indices) 205 | 206 | def __iter__(self): 207 | return iter(self.indices) 208 | 209 | def __len__(self): 210 | return len(self.indices) 211 | 212 | 213 | class IndexedDataLoader(DataLoader): 214 | 215 | def __init__(self, dataset, batch_size, num_workers=4, indices=None, drop_last=False, 216 | pin_memory=False, shuffle=False): 217 | if shuffle: 218 | sampler = ShuffledSubsetSampler(dataset, indices) 219 | else: 220 | sampler = ListSampler(dataset, indices) 221 | 222 | super().__init__( 223 | dataset, 224 | batch_size=batch_size, 225 | num_workers=num_workers, 226 | sampler=sampler, 227 | drop_last=drop_last, 228 | pin_memory=pin_memory) 229 | 230 | 231 | class SequentialDataLoader(IndexedDataLoader): 232 | 233 | def __init__(self, *args, **kwargs): 234 | super().__init__(*args, **kwargs, shuffle=False) 235 | 236 | 237 | class _InfiniteSampler(object): 238 | """ Sampler that repeats forever. 239 | 240 | Hack to force dataloader to use same workers. 241 | 242 | Args: 243 | sampler (Sampler) 244 | """ 245 | 246 | def __init__(self, sampler): 247 | self.sampler = sampler 248 | 249 | def __len__(self): 250 | return len(self.sampler) 251 | 252 | def __iter__(self): 253 | while True: 254 | yield from iter(self.sampler) 255 | 256 | 257 | class WorkerPreservingDataLoader(DataLoader): 258 | """ 259 | Hack to force dataloader to use same workers. 260 | """ 261 | 262 | def __init__(self, *args, **kwargs): 263 | super().__init__(*args, **kwargs) 264 | self.batch_sampler = _InfiniteSampler(self.batch_sampler) 265 | self.iterator = super().__iter__() 266 | 267 | def __iter__(self): 268 | for i in range(len(self)): 269 | yield next(self.iterator) 270 | 271 | 272 | @contextmanager 273 | def profile(): 274 | with torch.autograd.profiler.profile(use_cuda=True) as prof: 275 | yield 276 | print(prof.key_averages().table(sort_by="self_cpu_time_total")) 277 | 278 | 279 | @contextmanager 280 | def measure_time(s: str): 281 | torch.cuda.synchronize() 282 | start = time.time() 283 | yield 284 | torch.cuda.synchronize() 285 | print(f"[{s}] elapsed: {time.time() - start:.02f}") 286 | 287 | -------------------------------------------------------------------------------- /three/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def farthest_points(data, n_clusters: int, dist_func, return_center_indexes=False, 5 | return_distances=False, verbose=False): 6 | """ 7 | Performs farthest point sampling on data points. 8 | 9 | Args: 10 | data (torch.tensor): data points. 11 | n_clusters (int): number of clusters. 12 | dist_func (Callable): distance function that is used to compare two data points. 13 | return_center_indexes (bool): if True, returns the indexes of the center of clusters. 14 | return_distances (bool): if True, return distances of each point from centers. 15 | 16 | Returns clusters, [centers, distances]: 17 | clusters (torch.tensor): the cluster index for each element in data. 18 | centers (torch.tensor): the integer index of each center. 19 | distances (torch.tensor): closest distances of each point to any of the cluster centers. 20 | """ 21 | if n_clusters >= data.shape[0]: 22 | if return_center_indexes: 23 | return (torch.arange(data.shape[0], dtype=torch.long), 24 | torch.arange(data.shape[0], dtype=torch.long)) 25 | 26 | return torch.arange(data.shape[0], dtype=torch.long) 27 | 28 | clusters = torch.full((data.shape[0],), fill_value=-1, dtype=torch.long) 29 | distances = torch.full((data.shape[0],), fill_value=1e7, dtype=torch.float32) 30 | centers = torch.zeros(n_clusters, dtype=torch.long) 31 | for i in range(n_clusters): 32 | center_idx = torch.argmax(distances) 33 | centers[i] = center_idx 34 | 35 | broadcasted_data = data[center_idx].unsqueeze(0).expand(data.shape[0], -1) 36 | new_distances = dist_func(broadcasted_data, data) 37 | distances = torch.min(distances, new_distances) 38 | clusters[distances == new_distances] = i 39 | if verbose: 40 | print('farthest points max distance : {}'.format(torch.max(distances))) 41 | 42 | if return_center_indexes: 43 | if return_distances: 44 | return clusters, centers, distances 45 | return clusters, centers 46 | 47 | return clusters 48 | -------------------------------------------------------------------------------- /training/training.py: -------------------------------------------------------------------------------- 1 | import os 2 | gpu_id = 0 3 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 4 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 5 | 6 | import sys 7 | import time 8 | import glob 9 | import torch 10 | import shutil 11 | import numpy as np 12 | from torch import optim 13 | import matplotlib.pyplot as plt 14 | 15 | from torch.cuda.amp import autocast, GradScaler 16 | from torch.utils.tensorboard import SummaryWriter 17 | 18 | PROJ_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # ./../../ 19 | sys.path.append(PROJ_ROOT) 20 | 21 | from dataset import misc 22 | from misc_utils import warmup_lr 23 | from model.network import model_arch as ModelNet 24 | from dataset.megapose_dataset import MegaPose_Dataset as Dataset 25 | device = torch.device('cuda:0') 26 | 27 | def batchify_cuda_device(data_dict, batch_size, flatten_multiview=True, use_cuda=True): 28 | for key, val in data_dict.items(): 29 | for sub_key, sub_val in val.items(): 30 | if use_cuda: 31 | try: 32 | data_dict[key][sub_key] = sub_val.cuda(non_blocking=True) 33 | except: 34 | pass 35 | if flatten_multiview: 36 | try: 37 | if data_dict[key][sub_key].shape[0] == batch_size: 38 | data_dict[key][sub_key] = data_dict[key][sub_key].flatten(0, 1) 39 | except: 40 | pass 41 | 42 | img_size = 224 43 | batch_size = 2 44 | que_view_num = 4 45 | refer_view_num = 8 46 | random_view_num = 24 # 8 + 24 = 32 47 | nnb_Rmat_threshold = 30 48 | num_train_iters = 100_000 49 | 50 | DATA_DIR = os.path.join(PROJ_ROOT, 'dataspace', 'MegaPose') 51 | 52 | dataset = Dataset(data_dir=DATA_DIR, 53 | query_view_num=que_view_num, 54 | refer_view_num=refer_view_num, 55 | rand_view_num=random_view_num, 56 | nnb_Rmat_threshold=nnb_Rmat_threshold, 57 | ) 58 | 59 | print('num_objects: ', len(dataset.selected_objIDs)) 60 | 61 | model_net = ModelNet().to(device) 62 | CKPT_ROOT = os.path.join(PROJ_ROOT, 'checkpoints') 63 | checkpoints = os.path.join(CKPT_ROOT, 'checkpoints') 64 | tb_dir = os.path.join(checkpoints, 'tb') 65 | tb_old = tb_dir.replace('tb', 'tb_old') 66 | if os.path.exists(tb_old): 67 | shutil.rmtree(tb_old) 68 | if not os.path.exists(tb_dir): 69 | os.makedirs(tb_dir) 70 | shutil.move(tb_dir, tb_old) 71 | tb_writer = SummaryWriter(tb_dir) 72 | 73 | data_loader = torch.utils.data.DataLoader(dataset, 74 | shuffle=True, 75 | num_workers=8, 76 | batch_size=batch_size, 77 | collate_fn=dataset.collate_fn, 78 | pin_memory=False, drop_last=False) 79 | 80 | END_LR = 1e-6 81 | START_LR = 1e-4 82 | max_steps = num_train_iters 83 | 84 | iter_steps = 0 85 | TB_SKIP_STEPS = 5 86 | short_log_interval = 100 87 | long_log_interval = 1_000 88 | checkpoint_interval = 10_000 89 | enable_FP16_training = True 90 | LOSS_WEIGHTS = { 91 | 'rm_loss': 1.0, 92 | 'cm_loss': 10.0, 93 | 'qm_loss': 10.0, 94 | 'Remb_loss': 1.0, 95 | } 96 | 97 | optimizer = optim.AdamW(model_net.parameters(), lr=START_LR) 98 | lr_scheduler = warmup_lr.CosineAnnealingWarmupRestarts(optimizer, max_steps, max_lr=START_LR, min_lr=END_LR) 99 | 100 | losses_dict = {} 101 | model_net.train() 102 | scaler = GradScaler() 103 | start_timer = time.time() 104 | data_iterator = iter(data_loader) 105 | 106 | print('total training max_steps: {}'.format(max_steps)) 107 | print('enable_FP16_training: ', enable_FP16_training) 108 | for iter_steps in range(1, max_steps+1): 109 | lr_scheduler.step() 110 | optimizer.zero_grad() 111 | try: 112 | batch_data = next(data_iterator) 113 | except: 114 | data_iterator = iter(data_loader) # reinitialize the iterator 115 | batch_data = next(data_iterator) 116 | 117 | batchify_cuda_device(batch_data, batch_size=batch_size, flatten_multiview=True, use_cuda=True) 118 | scaler_curr_scale = 1.0 119 | loss = 0 120 | with autocast(enable_FP16_training): 121 | net_outputs = model_net(batch_data) 122 | for ls_name, ls_wgh in LOSS_WEIGHTS.items(): 123 | loss += net_outputs.get(ls_name, 0) * LOSS_WEIGHTS[ls_name] 124 | assert (not torch.isnan(loss).any()) 125 | 126 | scaler.scale(loss).backward() 127 | 128 | with torch.no_grad(): 129 | scaler.unscale_(optimizer) 130 | grad_norm = torch.nn.utils.clip_grad_norm_(model_net.parameters(), 1.0) 131 | scaler.step(optimizer) 132 | scaler.update() 133 | scaler_curr_scale = scaler.state_dict()['scale'] 134 | 135 | if 'ls' not in losses_dict: 136 | losses_dict['ls'] = list() 137 | losses_dict['ls'].append(loss.item()) 138 | 139 | for key_, val_ in net_outputs.items(): 140 | if 'loss' in key_: 141 | if key_ not in losses_dict: 142 | losses_dict[key_] = list() 143 | ls = val_.item() 144 | if key_ in LOSS_WEIGHTS: 145 | ls *= LOSS_WEIGHTS[key_] 146 | losses_dict[key_].append(ls) 147 | 148 | if (iter_steps > TB_SKIP_STEPS) and (iter_steps % short_log_interval == 0): 149 | tb_writer.add_scalar("Other/lr", optimizer.param_groups[0]['lr'], iter_steps) 150 | 151 | for idx, (key_, val_) in enumerate(losses_dict.items()): 152 | tb_writer.add_scalar(f"Loss/{idx}_{key_}", val_[-1], iter_steps) 153 | 154 | if ((iter_steps > 5 and iter_steps < 2000 and iter_steps % short_log_interval == 0) 155 | or iter_steps % long_log_interval == 0): 156 | 157 | curr_lr = optimizer.param_groups[0]['lr'] 158 | time_stamp = time.strftime('%d-%H:%M:%S', time.localtime()) 159 | logging_str = "{:.1f}k".format(iter_steps/1000) 160 | 161 | for key_, val_ in losses_dict.items(): 162 | dis_str = key_.split('_')[0] 163 | logging_str += ', {}:{:.4f}'.format(dis_str, np.mean(val_[-2000:])) 164 | 165 | logging_str += ', {}'.format(time_stamp) 166 | logging_str += ', {:.1f}'.format(scaler_curr_scale) 167 | logging_str += ', {:.6f}'.format(curr_lr) 168 | 169 | print(logging_str) 170 | 171 | vis_num_views = np.minimum(8, refer_view_num) 172 | fig, ax = plt.subplots(4, vis_num_views+1, figsize=(12, 5), 173 | gridspec_kw={'width_ratios': [1.5] + [1 for _ in range(vis_num_views)]} 174 | ) 175 | 176 | rgb_que_image = batch_data['query_dict']['rescaled_image'][0].detach().cpu().permute(1, 2, 0).squeeze().float() 177 | gt_que_full_mask = batch_data['query_dict']['rescaled_mask'][0].detach().cpu().permute(1, 2, 0).squeeze().float() 178 | pd_que_full_mask = net_outputs['que_full_pd_mask'][0].detach().cpu().permute(1, 2, 0).squeeze().float() 179 | rgb_path = batch_data['query_dict']['rgb_path'][0].split('train_pbr/')[-1] 180 | ax[0, 0].imshow(rgb_que_image) 181 | ax[0, 0].set_title(rgb_path, fontsize=10) 182 | ax[1, 0].imshow(gt_que_full_mask) 183 | ax[2, 0].imshow(pd_que_full_mask) 184 | ax[3, 0].imshow((gt_que_full_mask - pd_que_full_mask)) 185 | ax[0, 0].axis(False) 186 | ax[1, 0].axis(False) 187 | ax[2, 0].axis(False) 188 | ax[3, 0].axis(False) 189 | 190 | rgb_que_image = batch_data['query_dict']['dzi_image'][0].detach().cpu().permute(1, 2, 0).squeeze().float() 191 | gt_que_que_mask = batch_data['query_dict']['dzi_mask'][0].detach().cpu().permute(1, 2, 0).squeeze().float() 192 | pd_que_que_mask = net_outputs['que_pd_mask'][0].detach().cpu().permute(1, 2, 0).squeeze().float() 193 | ax[0, 1].imshow(rgb_que_image) 194 | ax[1, 1].imshow(gt_que_que_mask) 195 | ax[2, 1].imshow(pd_que_que_mask) 196 | ax[3, 1].imshow((gt_que_que_mask - pd_que_que_mask)) 197 | ax[0, 1].axis(False) 198 | ax[1, 1].axis(False) 199 | ax[2, 1].axis(False) 200 | ax[3, 1].axis(False) 201 | 202 | for vix in range(vis_num_views-1): 203 | vjx = vix + 2 204 | rgb_ref_image = batch_data['refer_dict']['zoom_image'][vix].detach().cpu().permute(1, 2, 0).squeeze().float() 205 | gt_ref_mask = batch_data['refer_dict']['zoom_mask'][vix].detach().cpu().permute(1, 2, 0).squeeze().float() 206 | pd_ref_mask = net_outputs['ref_pd_mask'][vix].detach().cpu().permute(1, 2, 0).squeeze().float() 207 | ax[0, vjx].imshow(rgb_ref_image) 208 | ax[1, vjx].imshow(gt_ref_mask) 209 | ax[2, vjx].imshow(pd_ref_mask) 210 | ax[3, vjx].imshow((gt_ref_mask - pd_ref_mask)) 211 | ax[0, vjx].axis(False) 212 | ax[1, vjx].axis(False) 213 | ax[2, vjx].axis(False) 214 | ax[3, vjx].axis(False) 215 | plt.tight_layout() 216 | tb_writer.add_figure('visulize_refer', fig, iter_steps) 217 | fig.clear() 218 | 219 | Remb_logit = net_outputs['Remb_logit'][0].detach().cpu() 220 | delta_Rdeg = net_outputs['delta_Rdeg'][0].detach().cpu().float() 221 | delta_Rdeg = torch.acos(torch.clamp(delta_Rdeg, min=-1.0, max=1.0)) / torch.pi * 180 222 | rank_Rdegs, rank_Rinds = torch.topk(delta_Rdeg, dim=0, k=delta_Rdeg.shape[0], largest=False) 223 | 224 | fig, ax = plt.subplots(1, 1) 225 | ax.plot(rank_Rdegs, Remb_logit[rank_Rinds]) 226 | ax.grid() 227 | tb_writer.add_figure('Rotation probability distribution', fig, iter_steps) 228 | fig.clear() 229 | 230 | if iter_steps % checkpoint_interval == 0: 231 | if not os.path.exists(checkpoints): 232 | os.makedirs(checkpoints) 233 | time_stamp = time.strftime('%m%d_%H%M%S', time.localtime()) 234 | 235 | ckpt_name = 'model_{}_{}.pth'.format(iter_steps, time_stamp) 236 | ckpt_file = os.path.join(checkpoints, ckpt_name) 237 | try: 238 | torch.save(model_net.module.state_dict(), ckpt_file) 239 | except: 240 | torch.save(model_net.state_dict(), ckpt_file) 241 | 242 | # try: 243 | # state = { 244 | # 'model_net': model_net.module.state_dict(), 245 | # 'optimizer': optimizer.state_dict(), 246 | # 'lr_scheduler': lr_scheduler.state_dict(), 247 | # 'scaler': scaler.state_dict(), 248 | # 'iter_steps': iter_steps, 249 | # } 250 | # torch.save(state, ckpt_file) 251 | # except: 252 | # state = { 253 | # 'model_net': model_net.state_dict(), 254 | # 'optimizer': optimizer.state_dict(), 255 | # 'lr_scheduler': lr_scheduler.state_dict(), 256 | # 'scaler': scaler.state_dict(), 257 | # 'iter_steps': iter_steps, 258 | # } 259 | # torch.save(state, ckpt_file) 260 | 261 | print('saving to ', ckpt_file) 262 | 263 | 264 | --------------------------------------------------------------------------------