├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── assets
└── gspose_overview.png
├── config
└── inference_cfg.py
├── dataset
├── demo_dataset.py
├── extract_megapose_to_BOP.py
├── inference_datasets.py
├── megapose_dataset.py
├── misc.py
└── parse_OnePoseCap_data.py
├── environment.yml
├── gaussian_object
├── __init__.py
├── arguments.py
├── build_3DGaussianObject.py
├── cameras.py
├── dataset_readers.py
├── gaussian_model.py
├── gaussian_render.py
├── gaussian_renderer
│ ├── __init__.py
│ └── network_gui.py
├── loss_utils.py
├── sh_utils.py
└── utils
│ ├── camera_utils.py
│ ├── general_utils.py
│ ├── graphics_utils.py
│ ├── image_utils.py
│ ├── loss_utils.py
│ ├── sh_utils.py
│ └── system_utils.py
├── inference.py
├── install_env.sh
├── misc_utils
├── gs_utils.py
├── loss_utils.py
├── metric_utils.py
└── warmup_lr.py
├── model
├── blocks.py
├── curope
│ ├── __init__.py
│ ├── curope.cpp
│ ├── curope2d.py
│ ├── kernels.cu
│ └── setup.py
├── dino_layers
│ ├── __init__.py
│ ├── attention.py
│ ├── block.py
│ ├── dino_head.py
│ ├── drop_path.py
│ ├── efficient_attention.py
│ ├── layer_scale.py
│ ├── mlp.py
│ ├── patch_embed.py
│ └── swiglu_ffn.py
├── generalized_mean_pooling.py
├── network.py
└── position_encoding.py
├── notebook
└── Demo_Example_with_GS-Pose.ipynb
├── three
├── __init__.py
├── batchview.py
├── core.py
├── imutils.py
├── meshutils.py
├── orientation.py
├── pytorch3d_rendering.py
├── quaternion.py
├── rigid.py
├── stats.py
├── torchutils.py
└── utils.py
└── training
└── training.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 |
163 | *database
164 | database*
165 | demo_data
166 | *.pth
167 | *.mp4
168 | *.MP4
169 | checkpoints
170 | dataspace
171 | notebook/*.png
172 | notebook/*.json
173 | dataspace
174 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "submodules/Connected_components_PyTorch"]
2 | path = submodules/Connected_components_PyTorch
3 | url = https://github.com/zsef123/Connected_components_PyTorch.git
4 | [submodule "submodules/simple-knn"]
5 | path = submodules/simple-knn
6 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git
7 | [submodule "submodules/diff-gaussian-rasterization"]
8 | path = submodules/diff-gaussian-rasterization
9 | url = https://github.com/graphdeco-inria/diff-gaussian-rasterization
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Dingding
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GS-Pose: Generalizable Segmentation-Based 6D Object Pose Estimation With 3D Gaussian Splatting
2 | - [[Project Page](https://dingdingcai.github.io/gs-pose)]
3 | - [[Paper](https://arxiv.org/abs/2403.10683)]
4 |
5 |
6 |
7 |
8 | ``` Bash
9 | @inproceedings{cai_2024_GSPose,
10 | author = {Cai, Dingding and Heikkil\"a, Janne and Rahtu, Esa},
11 | title = {GS-Pose: Generalizable Segmentation-Based 6D Object Pose Estimation With 3D Gaussian Splatting},
12 | journal = {arXiv preprint arXiv:2403.10683},
13 | year = {2024},
14 | }
15 | ```
16 |
17 | ## Setup
18 | Please start by installing [Miniconda3](https://conda.io/projects/conda/en/latest/user-guide/install/linux.html).
19 | This repository contains submodules, and the default environment can be installed as below.
20 |
21 | ``` Bash
22 | git clone git@github.com:dingdingcai/GSPose.git --recursive
23 | cd GSPose
24 | conda env create -f environment.yml
25 | conda activate gspose
26 |
27 | bash install_env.sh
28 | ```
29 |
30 | ## Pre-trained Model
31 | Download the [pretrained weights](https://drive.google.com/file/d/1VgOAemCrEeW_nT6qQ3R12oz_3UZmQILy/view?usp=sharing) and store it as ``checkpoints/model_wights.pth``.
32 |
33 |
34 |
35 | ## Demo Example
36 | An example of using GS-Pose for both pose estimation and tracking is provided in [``notebook``](./notebook/Demo_Example_with_GS-Pose.ipynb).
37 |
38 |
39 | ## Datasets
40 | Our evaluation is conducted on the LINEMOD and OnePose-LowTexture datasets.
41 | - For comparison with Gen6D, download [``LINEMOD_Gen6D``](https://connecthkuhk-my.sharepoint.com/:f:/g/personal/yuanly_connect_hku_hk/EkWESLayIVdEov4YlVrRShQBkOVTJwgK0bjF7chFg2GrBg?e=Y8UpXu).
42 | - For comparion with OnePose++, download [``lm``](https://bop.felk.cvut.cz/datasets) and the YOLOv5 detection results [``lm_yolo_detection``](https://zjueducn-my.sharepoint.com/:u:/g/personal/12121064_zju_edu_cn/EdodUdKGwHpCuvw3Cio5DYoBTntYLQuc7vNg9DkytWuJAQ?e=sAXp4B).
43 | - Download the [OnePose-LowTexture](https://github.com/zju3dv/OnePose_Plus_Plus/blob/main/doc/dataset_document.md
44 | ) dataset and store it under the directory ``onepose_dataset``.
45 |
46 |
47 | All datasets are organised under the ``dataspace`` directory, as below,
48 | ```
49 | dataspace/
50 | ├── LINEMOD_Gen6D
51 | │
52 | ├── bop_dataset/
53 | │ ├── lm
54 | │ └── lm_yolo_detection
55 | │
56 | ├── onepose_dataset/
57 | │ ├── scanned_model
58 | │ └── lowtexture_test_data
59 | │
60 | └── README.md
61 | ```
62 |
63 | ## Evaluation
64 | Evaluation on the subset of LINEMOD (comparison with Gen6D, Cas6D, etc.).
65 | - ``python inference.py --dataset_name LINEMOD_SUBSET --database_dir LMSubSet_database --outpose_dir LMSubSet_pose``
66 |
67 | Evaluation on all objects of LINEMOD using the built-in detector.
68 | - ``python inference.py --dataset_name LINEMOD --database_dir LM_database --outpose_dir LM_pose``
69 |
70 | Evaluation on all objects of LINEMOD using the YOLOv5 detection (comparison with OnePose/OnePose++).
71 | - ``python inference.py --dataset_name LINEMOD --database_dir LM_database --outpose_dir LM_yolo_pose``
72 |
73 | Evaluation on the scanned objects of OnePose-LowTexture.
74 | - ``python inference.py --dataset_name LOWTEXTUREVideo --database_dir LTVideo_database --outpose_dir LTVideo_pose``
75 |
76 | ## Training
77 | We utilize a subset (``gso_1M``) of the MegaPose dataset for training.
78 | Please download [``MegaPose/gso_1M``](https://www.paris.inria.fr/archive_ylabbeprojectsdata/megapose/webdatasets/) and [``MegaPose/google_scanned_objects.zip``](https://www.paris.inria.fr/archive_ylabbeprojectsdata/megapose/tars/) to the directory``dataspace``, and organize the data as
79 | ```
80 | dataspace/
81 | ├── MegaPose/
82 | │ ├── webdatasets/gso_1M
83 | │ └── google_scanned_objects
84 | ...
85 | ```
86 | execute the following script under the [``MegaPose``](https://github.com/megapose6d/megapose6d?tab=readme-ov-file) environment for preparing the training data.
87 | - ``python dataset/extract_megapose_to_BOP.py``
88 |
89 | Then, train the network via
90 | - ``python training/training.py``
91 |
92 | ## Acknowledgement
93 | - 1. The code is partially based on [DINOv2](https://github.com/facebookresearch/dinov2), [3D Gaussian Splatting](https://github.com/graphdeco-inria/gaussian-splatting?tab=readme-ov-file), [MegaPose](https://github.com/megapose6d/megapose6d), and [SC6D](https://github.com/dingdingcai/SC6D-pose).
94 |
95 |
96 |
97 |
98 |
--------------------------------------------------------------------------------
/assets/gspose_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dingdingcai/GSPose/dc88a8965e2b48436371f297f9288fba8313ab83/assets/gspose_overview.png
--------------------------------------------------------------------------------
/config/inference_cfg.py:
--------------------------------------------------------------------------------
1 | cosim_topk = -1
2 | refer_view_num = 8
3 | DINO_PATCH_SIZE = 14
4 | zoom_image_margin = 0
5 | zoom_image_scale = 224
6 | query_longside_scale = 672
7 | query_shortside_scale = query_longside_scale * 3 // 4
8 |
9 | coarse_threshold = 0.05
10 | coarse_bbox_padding = 1.25
11 | finer_threshold = 0.5
12 | finer_bbox_padding = 1.5
13 | enable_fine_detection = True
14 |
15 | save_reference_mask = True
16 | #### configure for 3D-GS-Refiner ####
17 | ROT_TOPK = 1 # single rotation proposal
18 |
19 | WARMUP = 10
20 | END_LR = 0
21 | START_LR = 5e-3
22 | MAX_STEPS = 400
23 | GS_RENDER_SIZE = 224
24 | EARLY_STOP_MIN_STEPS = 5
25 | EARLY_STOP_LOSS_GRAD_NORM = 1e-4
26 |
27 | USE_SSIM = True
28 | USE_MS_SSIM = True
29 |
30 | BINARIZE_MASK = False
31 | USE_YOLO_BBOX = True
32 | USE_ALLOCENTRIC = True
33 | APPLY_ZOOM_AND_CROP = True
34 | CC_INCLUDE_SUPMASK = False
35 |
36 |
--------------------------------------------------------------------------------
/dataset/demo_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import numpy as np
5 | import mediapy as media
6 |
7 | from transforms3d import affines, quaternions
8 |
9 | from misc_utils import gs_utils
10 |
11 | class OnePoseCap_Dataset(torch.utils.data.Dataset):
12 | def __init__(self, obj_data_dir, num_grid_points=4096, extract_RGB=False, use_binarized_mask=False, obj_database_dir=None):
13 |
14 | self.extract_RGB = extract_RGB
15 | self.obj_data_dir = obj_data_dir
16 | self.num_grid_points = num_grid_points
17 | self.obj_database_dir = obj_database_dir
18 | self.use_binarized_mask = use_binarized_mask
19 |
20 | self.arkit_box_path = os.path.join(self.obj_data_dir, 'Box.txt')
21 | self.arkit_pose_path = os.path.join(self.obj_data_dir, 'ARposes.txt')
22 | self.arkit_video_path = os.path.join(self.obj_data_dir, 'Frames.m4v')
23 | self.arkit_intrin_path = os.path.join(self.obj_data_dir, 'Frames.txt')
24 |
25 | #### read the ARKit pose info
26 | with open(self.arkit_pose_path, 'r') as pf:
27 | self.arkit_poses = [row.strip() for row in pf.readlines() if len(row) > 0 and row[0] != '#']
28 |
29 | with open(self.arkit_intrin_path, 'r') as cf:
30 | self.arkit_camKs = [row.strip() for row in cf.readlines() if len(row) > 0 and row[0] != '#']
31 |
32 | ### read the video
33 | if self.extract_RGB:
34 | RGB_dir = os.path.join(self.obj_data_dir, 'RGB')
35 | if not os.path.exists(RGB_dir):
36 | os.makedirs(RGB_dir)
37 | cap = cv2.VideoCapture(self.arkit_video_path)
38 | index = 0
39 | while True:
40 | ret, image = cap.read()
41 | if not ret:
42 | break
43 | cv2.imwrite(os.path.join(RGB_dir, f'{index}.png'), image)
44 | index += 1
45 | else:
46 | self.video_frames = media.read_video(self.arkit_video_path) # NxHxWx3
47 |
48 |
49 | assert(len(self.arkit_poses) == len(self.arkit_camKs))
50 |
51 | #### preprocess the ARKit 3D object bounding box
52 | with open(self.arkit_box_path, 'r') as f:
53 | lines = f.readlines()
54 | box_data = [float(e) for e in lines[1].strip().split(',')]
55 | ex, ey, ez = box_data[3:6]
56 | self.obj_bbox3d = np.array([
57 | [-ex, -ey, -ez], # Front-top-left corner
58 | [ex, -ey, -ez], # Front-top-right corner
59 | [ex, ey, -ez], # Front-bottom-right corner
60 | [-ex, ey, -ez], # Front-bottom-left corner
61 | [-ex, -ey, ez], # Back-top-left corner
62 | [ex, -ey, ez], # Back-top-right corner
63 | [ex, ey, ez], # Back-bottom-right corner
64 | [-ex, ey, ez], # Back-bottom-left corner
65 | ]) * 0.5
66 | obj_bbox3D_dims = np.array([ex, ey, ez], dtype=np.float32)
67 | grid_cube_size = (np.prod(obj_bbox3D_dims, axis=0) / self.num_grid_points)**(1/3)
68 | xnum, ynum, znum = np.ceil(obj_bbox3D_dims / grid_cube_size).astype(np.int64)
69 | xmin, ymin, zmin = self.obj_bbox3d.min(axis=0)
70 | xmax, ymax, zmax = self.obj_bbox3d.max(axis=0)
71 | zgrid, ygrid, xgrid = np.meshgrid(np.linspace(zmin, zmax, znum),
72 | np.linspace(ymin, ymax, ynum),
73 | np.linspace(xmin, xmax, xnum),
74 | indexing='ij')
75 | self.bbox3d_grid_points = np.stack([xgrid, ygrid, zgrid], axis=-1).reshape(-1, 3)
76 | self.bbox3d_diameter = np.linalg.norm(obj_bbox3D_dims)
77 |
78 | bbox3d_position = np.array(box_data[0:3], dtype=np.float32)
79 | bbox3D_rot_quat = np.array(box_data[6:10], dtype=np.float32)
80 | bbox3D_rot_mat = quaternions.quat2mat(bbox3D_rot_quat)
81 | T_O2W = affines.compose(bbox3d_position, bbox3D_rot_mat, np.ones(3)) # object-to-world
82 |
83 | self.camKs = list()
84 | self.poses = list()
85 | self.allo_poses = list()
86 | self.image_IDs = list()
87 | Xaxis_Rmat = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]], dtype=np.float32)
88 |
89 | for frame_idx, pose_info in enumerate(self.arkit_poses):
90 | camk_info = self.arkit_camKs[frame_idx] # [time, index, fx, fy, cx, cy]
91 | camk_dat = [float(c) for c in camk_info.split(',')]
92 | camk = np.eye(3)
93 | camk[0, 0] = camk_dat[-4]
94 | camk[1, 1] = camk_dat[-3]
95 | camk[0, 2] = camk_dat[-2]
96 | camk[1, 2] = camk_dat[-1]
97 | self.camKs.append(camk)
98 |
99 | pose_dat = [float(p) for p in pose_info.split(',')]
100 | bbox_pos = pose_dat[1:4]
101 | bbox_quat = pose_dat[4:]
102 | rot_mat = quaternions.quat2mat(bbox_quat)
103 | rot_mat = rot_mat @ Xaxis_Rmat.copy() # conversion (X-right, Y-up, Z-back) right-hand
104 | T_C2W = affines.compose(bbox_pos, rot_mat, np.ones(3)) # camera-to-world
105 | T_W2C = np.linalg.inv(T_C2W) # world-to-camera
106 | pose_RT = T_W2C @ T_O2W # object-to-camera
107 |
108 | allo_pose = pose_RT.copy()
109 | allo_pose[:3, :3] = gs_utils.egocentric_to_allocentric(allo_pose)[:3, :3]
110 | self.allo_poses.append(allo_pose)
111 |
112 | self.poses.append(pose_RT)
113 | self.image_IDs.append(frame_idx)
114 |
115 | def __len__(self):
116 | return len(self.poses)
117 |
118 | def __getitem__(self, idx):
119 | data_dict = dict()
120 | camK = self.camKs[idx]
121 | pose = self.poses[idx]
122 | allo_pose = self.allo_poses[idx]
123 | image_ID = self.image_IDs[idx]
124 |
125 | if self.extract_RGB:
126 | image = cv2.imread(os.path.join(self.obj_data_dir, 'RGB', f'{image_ID}.png'))
127 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
128 | else:
129 | image = np.array(self.video_frames[idx]) / 255.0
130 |
131 | data_dict['image_ID'] = image_ID
132 | data_dict['camK'] = torch.as_tensor(camK, dtype=torch.float32)
133 | data_dict['pose'] = torch.as_tensor(pose, dtype=torch.float32)
134 | data_dict['image'] = torch.as_tensor(image, dtype=torch.float32)
135 | data_dict['allo_pose'] = torch.as_tensor(allo_pose, dtype=torch.float32)
136 |
137 | if self.obj_database_dir is not None:
138 | data_dict['coseg_mask_path'] = os.path.join(self.obj_database_dir, 'pred_coseg_mask', '{:06d}.png'.format(image_ID))
139 | else:
140 | data_dict['coseg_mask_path'] = os.path.join(self.obj_data_dir, 'pred_coseg_mask', '{:06d}.png'.format(image_ID))
141 |
142 | return data_dict
143 |
144 | def collate_fn(self, batch):
145 | """
146 | batchify the data
147 | """
148 | new_batch = dict()
149 | for each_dat in batch:
150 | for key, val in each_dat.items():
151 | if key not in new_batch:
152 | new_batch[key] = list()
153 | new_batch[key].append(val)
154 |
155 | for key, val in new_batch.items():
156 | new_batch[key] = torch.stack(val, dim=0)
157 |
158 | return new_batch
159 |
--------------------------------------------------------------------------------
/dataset/parse_OnePoseCap_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import tqdm
4 | import numpy as np
5 | import os.path as osp
6 | from pathlib import Path
7 | from transforms3d import affines, quaternions
8 |
9 |
10 | def parse_box(box_path):
11 | with open(box_path, 'r') as f:
12 | lines = f.readlines()
13 | data = [float(e) for e in lines[1].strip().split(',')]
14 | position = data[:3]
15 | quaternion = data[6:]
16 | rot_mat = quaternions.quat2mat(quaternion)
17 | T_ow = affines.compose(position, rot_mat, np.ones(3))
18 | return T_ow
19 |
20 | def get_bbox3d(box_path):
21 | assert Path(box_path).exists()
22 | with open(box_path, 'r') as f:
23 | lines = f.readlines()
24 | box_data = [float(e) for e in lines[1].strip().split(',')]
25 | ex, ey, ez = box_data[3: 6]
26 | bbox_3d = np.array([
27 | [-ex, -ey, -ez],
28 | [ex, -ey, -ez],
29 | [ex, -ey, ez],
30 | [-ex, -ey, ez],
31 | [-ex, ey, -ez],
32 | [ ex, ey, -ez],
33 | [ ex, ey, ez],
34 | [-ex, ey, ez]
35 | ]) * 0.5
36 | bbox_3d_homo = np.concatenate([bbox_3d, np.ones((8, 1))], axis=1)
37 | return bbox_3d, bbox_3d_homo
38 |
39 |
40 |
41 | def data_process_anno(data_dir):
42 | arkit_box_path = osp.join(data_dir, 'Box.txt')
43 | arkit_pose_path = osp.join(data_dir, 'ARposes.txt')
44 | arkit_intrin_path = osp.join(data_dir, 'Frames.txt')
45 |
46 | out_pose_dir = osp.join(data_dir, 'poses')
47 | Path(out_pose_dir).mkdir(parents=True, exist_ok=True)
48 | out_intrin_path = osp.join(data_dir, 'intrinsics.txt')
49 | out_bbox3D_path = osp.join(osp.dirname(data_dir), 'box3d_corners.txt')
50 |
51 | ##### read the ARKit 3D bounding box and convert to box corners
52 | bbox_3d, bbox_3d_homo = get_bbox3d(arkit_box_path)
53 | np.savetxt(out_bbox3D_path, bbox_3d)
54 |
55 | ##### read the ARKit camera intrinsics
56 | with open(arkit_intrin_path, 'r') as f:
57 | lines = [l.strip() for l in f.readlines() if len(l) > 0 and l[0] != '#']
58 | arkit_camk = np.array([[float(e) for e in l.split(',')] for l in lines])
59 | fx, fy, cx, cy = np.average(arkit_camk, axis=0)[2:]
60 | with open(out_intrin_path, 'w') as f:
61 | f.write('fx: {0}\nfy: {1}\ncx: {2}\ncy: {3}'.format(fx, fy, cx, cy))
62 |
63 | ##### read the ARKit camera poses
64 | T_O2W = parse_box(arkit_box_path) # 3D object bounding box is defined w.r.t. the world coordinate system
65 |
66 | with open(arkit_pose_path, 'r') as f:
67 | lines = [l.strip() for l in f.readlines()]
68 | index = 0
69 | for line in tqdm.tqdm(lines):
70 | if len(line) == 0 or line[0] == '#':
71 | continue
72 |
73 | eles = line.split(',')
74 | data = [float(e) for e in eles]
75 |
76 | position = data[1:4]
77 | quaternion = data[4:]
78 | rot_mat = quaternions.quat2mat(quaternion)
79 | rot_mat = rot_mat @ np.array([
80 | [1, 0, 0],
81 | [0, -1, 0],
82 | [0, 0, -1]])
83 | T_C2W = affines.compose(position, rot_mat, np.ones(3))
84 | T_W2C = np.linalg.inv(T_C2W)
85 | T_O2C = T_W2C @ T_O2W
86 |
87 | pose_save_path = osp.join(out_pose_dir, '{}.txt'.format(index))
88 | np.savetxt(pose_save_path, T_O2C)
89 | index += 1
90 |
91 |
92 |
93 |
94 | if __name__ == "__main__":
95 | args = parse_args()
96 | data_dir = args.scanned_object_path
97 | assert osp.exists(data_dir), f"Scanned object path:{data_dir} not exists!"
98 | seq_dirs = os.listdir(data_dir)
99 |
100 | for seq_dir in seq_dirs:
101 | if '-refer' in seq_dir:
102 | ################ Parse scanned reference sequence ################
103 | print('=> Processing train sequence: ', seq_dir)
104 | video_path = osp.join(data_dir, seq_dir, 'Frames.m4v')
105 | print('=> parse video: ', video_path)
106 | data_process_anno(osp.join(data_dir, seq_dir), downsample_rate=1, hw=512)
107 |
108 | elif '-query' in seq_dir:
109 | ################ Parse scanned test sequence ################
110 | print('=> Processing test sequence: ', seq_dir)
111 | data_process_test(osp.join(data_dir, seq_dir), downsample_rate=1)
112 | pass
113 |
114 | else:
115 |
116 | continue
117 |
118 |
119 | # python parse_scanned_data.py --scanned_object_path /home/dingding/Workspace/Others/OnePose_Plus_Plus/data/demo/teacan
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: gspose
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | - xformers
7 | dependencies:
8 | - python==3.8.5
9 | - plyfile==0.8.1
10 | - cudatoolkit==11.7
11 | - pip:
12 | - mmcv==1.7.1
13 | - tqdm==4.66.2
14 | - pillow==9.3.0
15 | - mediapy==1.1.4
16 | - open3d==0.16.0
17 | - trimesh==3.18.0
18 | - ninja==1.11.1.1
19 | - structlog==23.1.0
20 | - pycocotools==2.0.6
21 | - webdataset==0.2.48
22 | - transforms3d==0.4.1
23 | - tensorboard==2.11.2
24 | - scikit-image==0.19.3
25 | - opencv-python==4.8.0.76
26 | - pytorch-msssim==1.0.0
--------------------------------------------------------------------------------
/gaussian_object/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import json
14 | import random
15 |
16 | # import the customized modules
17 | from .arguments import ModelParams
18 | from .gaussian_model import GaussianModel
19 | from .dataset_readers import readObjectInfo
20 | from .utils.system_utils import searchForMaxIteration
21 | from .utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
22 |
23 |
24 | class Scene:
25 |
26 | gaussians : GaussianModel
27 |
28 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
29 | """b
30 | :param path: Path to colmap scene main folder.
31 | """
32 | self.model_path = args.model_path
33 | self.loaded_iter = None
34 | self.gaussians = gaussians
35 |
36 | if load_iteration:
37 | if load_iteration == -1:
38 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
39 | else:
40 | self.loaded_iter = load_iteration
41 | print("Loading trained model at iteration {}".format(self.loaded_iter))
42 |
43 | self.train_cameras = {}
44 | self.test_cameras = {}
45 |
46 | scene_info = readObjectInfo(train_dataset=args.referloader,
47 | test_dataset=args.queryloader,
48 | model_path=args.model_path,
49 | zoom_scale=args.zoom_scale, margin=args.margin,
50 | random_points3D=args.random_points3D,
51 | num_points=args.num_points,
52 | )
53 |
54 | # if os.path.exists(os.path.join(args.source_path, "sparse")):
55 | # scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval)
56 | # elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
57 | # print("Found transforms_train.json file, assuming Blender data set!")
58 | # scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
59 | # else:
60 | # assert False, "Could not recognize scene type!"
61 |
62 |
63 | # if not self.loaded_iter:
64 | # with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
65 | # dest_file.write(src_file.read())
66 | # json_cams = []
67 | # camlist = []
68 | # if scene_info.test_cameras:
69 | # camlist.extend(scene_info.test_cameras)
70 | # if scene_info.train_cameras:
71 | # camlist.extend(scene_info.train_cameras)
72 | # for id, cam in enumerate(camlist):
73 | # json_cams.append(camera_to_JSON(id, cam))
74 | # with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
75 | # json.dump(json_cams, file)
76 |
77 | if shuffle:
78 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
79 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
80 |
81 | self.cameras_extent = scene_info.nerf_normalization["radius"]
82 |
83 | for resolution_scale in resolution_scales:
84 | print("Loading Training Cameras")
85 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
86 | print("Loading Test Cameras")
87 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
88 |
89 | if self.loaded_iter:
90 | self.gaussians.load_ply(os.path.join(self.model_path,
91 | "point_cloud",
92 | "iteration_" + str(self.loaded_iter),
93 | "point_cloud.ply"))
94 | else:
95 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
96 |
97 | def save(self, iteration):
98 | # point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
99 | self.gaussians.save_ply(os.path.join(self.model_path, "3DGO_model.ply"))
100 |
101 | def getTrainCameras(self, scale=1.0):
102 | return self.train_cameras[scale]
103 |
104 | def getTestCameras(self, scale=1.0):
105 | return self.test_cameras[scale]
--------------------------------------------------------------------------------
/gaussian_object/arguments.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import sys
14 | from argparse import ArgumentParser, Namespace
15 |
16 | class GroupParams:
17 | pass
18 |
19 | class ParamGroup:
20 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False):
21 | group = parser.add_argument_group(name)
22 | for key, value in vars(self).items():
23 | shorthand = False
24 | if key.startswith("_"):
25 | shorthand = True
26 | key = key[1:]
27 | t = type(value)
28 | value = value if not fill_none else None
29 | if shorthand:
30 | if t == bool:
31 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
32 | else:
33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
34 | else:
35 | if t == bool:
36 | group.add_argument("--" + key, default=value, action="store_true")
37 | else:
38 | group.add_argument("--" + key, default=value, type=t)
39 |
40 | def extract(self, args):
41 | group = GroupParams()
42 | for arg in vars(args).items():
43 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
44 | setattr(group, arg[0], arg[1])
45 | return group
46 |
47 | class ModelParams(ParamGroup):
48 | def __init__(self, parser, sentinel=False):
49 | self.sh_degree = 3
50 | self._source_path = ""
51 | self._model_path = ""
52 | self._images = "images"
53 | self._resolution = -1
54 | self._white_background = False
55 | self.data_device = "cuda"
56 | self.eval = False
57 |
58 | self.margin = 0.0
59 | self.zoom_scale = 512
60 | self.num_points = 4096
61 | self.random_points3D = False
62 | self.referloader = None
63 | self.queryloader = None
64 |
65 | super().__init__(parser, "Loading Parameters", sentinel)
66 |
67 | def extract(self, args):
68 | g = super().extract(args)
69 | g.source_path = os.path.abspath(g.source_path)
70 | return g
71 |
72 | class PipelineParams(ParamGroup):
73 | def __init__(self, parser):
74 | self.convert_SHs_python = False
75 | self.compute_cov3D_python = False
76 | self.debug = False
77 | super().__init__(parser, "Pipeline Parameters")
78 |
79 | class OptimizationParams(ParamGroup):
80 | def __init__(self, parser):
81 | self.iterations = 30_000
82 | self.position_lr_init = 0.00016
83 | self.position_lr_final = 0.0000016
84 | self.position_lr_delay_mult = 0.01
85 | self.position_lr_max_steps = 30_000
86 | self.feature_lr = 0.0025
87 | self.opacity_lr = 0.05
88 | self.scaling_lr = 0.005
89 | self.rotation_lr = 0.001
90 | self.percent_dense = 0.01
91 | self.lambda_dssim = 0.2
92 | self.densification_interval = 100
93 | self.opacity_reset_interval = 3000
94 | self.densify_from_iter = 500
95 | self.densify_until_iter = 15_000
96 | self.densify_grad_threshold = 0.0002
97 | self.random_background = False
98 | super().__init__(parser, "Optimization Parameters")
99 |
100 | def get_combined_args(parser : ArgumentParser):
101 | cmdlne_string = sys.argv[1:]
102 | cfgfile_string = "Namespace()"
103 | args_cmdline = parser.parse_args(cmdlne_string)
104 |
105 | try:
106 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
107 | print("Looking for config file in", cfgfilepath)
108 | with open(cfgfilepath) as cfg_file:
109 | print("Config file found: {}".format(cfgfilepath))
110 | cfgfile_string = cfg_file.read()
111 | except TypeError:
112 | print("Config file not found at")
113 | pass
114 | args_cfgfile = eval(cfgfile_string)
115 |
116 | merged_dict = vars(args_cfgfile).copy()
117 | for k,v in vars(args_cmdline).items():
118 | if v != None:
119 | merged_dict[k] = v
120 | return Namespace(**merged_dict)
121 |
--------------------------------------------------------------------------------
/gaussian_object/build_3DGaussianObject.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import sys
14 | import glob
15 | import uuid
16 | import torch
17 | from tqdm import tqdm
18 | from random import randint
19 | from argparse import ArgumentParser, Namespace
20 | try:
21 | from torch.utils.tensorboard import SummaryWriter
22 | TENSORBOARD_FOUND = True
23 | except ImportError:
24 | TENSORBOARD_FOUND = False
25 |
26 | from gaussian_object.utils.image_utils import psnr
27 | # from gaussian_object.utils.general_utils import safe_state
28 | # from gaussian_object.gaussian_renderer import network_gui
29 | from gaussian_object.gaussian_renderer import render
30 |
31 | # import the customized modules
32 | from gaussian_object import Scene
33 | from gaussian_object.loss_utils import l1_loss, ssim
34 | from gaussian_object.gaussian_model import GaussianModel
35 | from gaussian_object.arguments import ModelParams, PipelineParams, OptimizationParams
36 |
37 |
38 | # def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
39 | def create_3D_Gaussian_object(dataset, opt, pipe, testing_iterations=[30_000],
40 | saving_iterations=[30_000],
41 | checkpoint_iterations=[],
42 | checkpoint=None,
43 | debug_from=-1,
44 | return_gaussian=True,
45 | ):
46 | first_iter = 0
47 | tb_writer = prepare_output_and_logger(dataset)
48 | gaussians = GaussianModel(dataset.sh_degree)
49 | scene = Scene(dataset, gaussians)
50 | gaussians.training_setup(opt)
51 | if checkpoint:
52 | (model_params, first_iter) = torch.load(checkpoint)
53 | gaussians.restore(model_params, opt)
54 |
55 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
56 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
57 |
58 | iter_start = torch.cuda.Event(enable_timing = True)
59 | iter_end = torch.cuda.Event(enable_timing = True)
60 |
61 | viewpoint_stack = None
62 | ema_loss_for_log = 0.0
63 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="3DGO modeling progress")
64 | first_iter += 1
65 | for iteration in range(first_iter, opt.iterations + 1):
66 |
67 | # if network_gui.conn == None:
68 | # network_gui.try_connect()
69 | # while network_gui.conn != None:
70 | # try:
71 | # net_image_bytes = None
72 | # custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive()
73 | # if custom_cam != None:
74 | # net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"]
75 | # net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy())
76 | # network_gui.send(net_image_bytes, dataset.source_path)
77 | # if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
78 | # break
79 | # except Exception as e:
80 | # network_gui.conn = None
81 |
82 | iter_start.record()
83 |
84 | gaussians.update_learning_rate(iteration)
85 |
86 | # Every 1000 its we increase the levels of SH up to a maximum degree
87 | if iteration % 1000 == 0:
88 | gaussians.oneupSHdegree()
89 |
90 | # Pick a random Camera
91 | if not viewpoint_stack:
92 | viewpoint_stack = scene.getTrainCameras().copy()
93 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
94 |
95 | # Render
96 | if (iteration - 1) == debug_from:
97 | pipe.debug = True
98 |
99 | bg = torch.rand((3), device="cuda") if opt.random_background else background
100 |
101 | render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
102 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
103 |
104 | # Loss
105 | gt_image = viewpoint_cam.original_image.cuda()
106 |
107 | trunc_FG_mask = (gt_image.sum(dim=0, keepdim=True) > 0).type(torch.float32)
108 |
109 | # Ll1 = l1_loss(image, gt_image)
110 | # loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
111 |
112 | Ll1 = (l1_loss(image, gt_image, size_average=True) * trunc_FG_mask).mean()
113 | ssim_score = (ssim(image, gt_image, size_average=True) * trunc_FG_mask).mean()
114 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim_score)
115 |
116 | loss.backward()
117 | iter_end.record()
118 |
119 | with torch.no_grad():
120 | # Progress bar
121 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
122 | if iteration % 10 == 0:
123 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
124 | progress_bar.update(10)
125 | if iteration == opt.iterations:
126 | progress_bar.close()
127 |
128 | # Log and save
129 | training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
130 | if (iteration in saving_iterations):
131 | print("\n[ITER {}] Saving Gaussians".format(iteration))
132 | scene.save(iteration)
133 |
134 | # Densification
135 | if iteration < opt.densify_until_iter:
136 | # Keep track of max radii in image-space for pruning
137 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
138 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
139 |
140 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
141 | size_threshold = 20 if iteration > opt.opacity_reset_interval else None
142 | gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
143 |
144 | if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
145 | gaussians.reset_opacity()
146 |
147 | # Optimizer step
148 | if iteration < opt.iterations:
149 | gaussians.optimizer.step()
150 | gaussians.optimizer.zero_grad(set_to_none = True)
151 |
152 | if (iteration in checkpoint_iterations):
153 | print("\n[ITER {}] Saving Checkpoint".format(iteration))
154 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
155 |
156 | if return_gaussian:
157 | return gaussians
158 |
159 | def prepare_output_and_logger(args):
160 | if not args.model_path:
161 | if os.getenv('OAR_JOB_ID'):
162 | unique_str=os.getenv('OAR_JOB_ID')
163 | else:
164 | unique_str = str(uuid.uuid4())
165 | args.model_path = os.path.join("./output/", unique_str[0:10])
166 |
167 | # Set up output folder
168 | print("Output folder: {}".format(args.model_path))
169 | logs_file = f'{args.model_path}/events.out.tfevents.*'
170 | if len(glob.glob(logs_file)) > 0:
171 | os.system(f"rm -r {logs_file}") # remove the old events
172 |
173 | os.makedirs(args.model_path, exist_ok = True)
174 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f:
175 | cfg_log_f.write(str(Namespace(**vars(args))))
176 |
177 | # Create Tensorboard writer
178 | tb_writer = None
179 | if TENSORBOARD_FOUND:
180 | tb_writer = SummaryWriter(args.model_path)
181 | else:
182 | print("Tensorboard not available: not logging progress")
183 | return tb_writer
184 |
185 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs):
186 | if tb_writer:
187 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
188 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
189 | tb_writer.add_scalar('iter_time', elapsed, iteration)
190 |
191 | # Report test and samples of training set
192 | if iteration in testing_iterations:
193 | torch.cuda.empty_cache()
194 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()},
195 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]})
196 |
197 | for config in validation_configs:
198 | if config['cameras'] and len(config['cameras']) > 0:
199 | l1_test = 0.0
200 | psnr_test = 0.0
201 | for idx, viewpoint in enumerate(config['cameras']):
202 | image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0)
203 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
204 | if tb_writer and (idx < 5):
205 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration)
206 | if iteration == testing_iterations[0]:
207 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration)
208 | l1_test += l1_loss(image, gt_image).mean().double()
209 | psnr_test += psnr(image, gt_image).mean().double()
210 | psnr_test /= len(config['cameras'])
211 | l1_test /= len(config['cameras'])
212 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test))
213 | if tb_writer:
214 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration)
215 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration)
216 |
217 | if tb_writer:
218 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration)
219 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration)
220 | torch.cuda.empty_cache()
221 |
222 |
--------------------------------------------------------------------------------
/gaussian_object/cameras.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 | import os
12 | import sys
13 | import torch
14 | import numpy as np
15 | from torch import nn
16 |
17 | # import the customized modules
18 | from gaussian_object.utils.graphics_utils import getWorld2View2, getProjectionMatrix
19 |
20 |
21 | class Camera(nn.Module):
22 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
23 | image_name, uid,
24 | cx_offset, cy_offset, mask=None,
25 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
26 | ):
27 | super(Camera, self).__init__()
28 |
29 | self.uid = uid
30 | self.colmap_id = colmap_id
31 | self.R = R
32 | self.T = T
33 | self.FoVx = FoVx
34 | self.FoVy = FoVy
35 | self.cx_offset = cx_offset
36 | self.cy_offset = cy_offset
37 | self.image_name = image_name
38 |
39 | try:
40 | self.data_device = torch.device(data_device)
41 | except Exception as e:
42 | print(e)
43 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
44 | self.data_device = torch.device("cuda")
45 |
46 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
47 | self.image_width = self.original_image.shape[2]
48 | self.image_height = self.original_image.shape[1]
49 |
50 | if gt_alpha_mask is not None:
51 | self.original_image *= gt_alpha_mask.to(self.data_device)
52 | else:
53 | self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
54 |
55 | self.zfar = 100.0
56 | self.znear = 0.01
57 |
58 | self.trans = trans
59 | self.scale = scale
60 |
61 | # self.world_view_transform = torch.tensor(getWorld2View2(self.R, self.T, self.trans, self.scale)).transpose(0, 1).cuda()
62 | # self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar,
63 | # fovX=self.FoVx, fovY=self.FoVy,
64 | # cx_offset=self.cx_offset,
65 | # cy_offset=self.cy_offset,
66 | # ).transpose(0,1).cuda()
67 | # self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
68 | # self.camera_center = self.world_view_transform.inverse()[3, :3]
69 |
70 | @property
71 | def world_view_transform(self):
72 | return torch.tensor(getWorld2View2(self.R, self.T, self.trans, self.scale)).transpose(0, 1).cuda()
73 |
74 | @property
75 | def projection_matrix(self):
76 | return getProjectionMatrix(znear=self.znear, zfar=self.zfar,
77 | fovX=self.FoVx, fovY=self.FoVy,
78 | cx_offset=self.cx_offset,
79 | cy_offset=self.cy_offset,
80 | ).transpose(0,1).cuda()
81 |
82 | @property
83 | def full_proj_transform(self):
84 | return (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
85 |
86 | @property
87 | def camera_center(self):
88 | return self.world_view_transform.inverse()[3, :3]
89 |
90 | class MiniCam:
91 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
92 | self.image_width = width
93 | self.image_height = height
94 | self.FoVy = fovy
95 | self.FoVx = fovx
96 | self.znear = znear
97 | self.zfar = zfar
98 | self.world_view_transform = world_view_transform
99 | self.full_proj_transform = full_proj_transform
100 | view_inv = torch.inverse(self.world_view_transform)
101 | self.camera_center = view_inv[3][:3]
102 |
103 |
--------------------------------------------------------------------------------
/gaussian_object/dataset_readers.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import cv2
14 | import sys
15 | import json
16 | import torch
17 |
18 | import numpy as np
19 | from PIL import Image
20 | from pathlib import Path
21 | from typing import NamedTuple
22 | # from plyfile import PlyData, PlyElement
23 | from pytorch3d import transforms as py3d_transform
24 |
25 |
26 | # import the customized modules
27 | from misc_utils import gs_utils
28 | from gaussian_object.utils.sh_utils import SH2RGB
29 | from gaussian_object.gaussian_model import BasicPointCloud
30 | from gaussian_object.utils.graphics_utils import getWorld2View2, focal2fov
31 |
32 | class CameraInfo(NamedTuple):
33 | uid: int
34 | R: np.array
35 | T: np.array
36 | FovY: np.array
37 | FovX: np.array
38 | image: np.array
39 | image_path: str
40 | image_name: str
41 | width: int
42 | height: int
43 | cx_offset: np.array = 0
44 | cy_offset: np.array = 0
45 |
46 | class SceneInfo(NamedTuple):
47 | point_cloud: BasicPointCloud
48 | train_cameras: list
49 | test_cameras: list
50 | nerf_normalization: dict
51 | ply_path: str
52 |
53 | def getNerfppNorm(cam_info):
54 | def get_center_and_diag(cam_centers):
55 | cam_centers = np.hstack(cam_centers)
56 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
57 | center = avg_cam_center
58 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
59 | diagonal = np.max(dist)
60 | return center.flatten(), diagonal
61 |
62 | cam_centers = []
63 |
64 | for cam in cam_info:
65 | W2C = getWorld2View2(cam.R, cam.T)
66 | C2W = np.linalg.inv(W2C)
67 | cam_centers.append(C2W[:3, 3:4])
68 |
69 | center, diagonal = get_center_and_diag(cam_centers)
70 | radius = diagonal * 1.1
71 |
72 | translate = -center
73 |
74 | return {"translate": translate, "radius": radius}
75 |
76 | # def fetchPly(path):
77 | # plydata = PlyData.read(path)
78 | # vertices = plydata['vertex']
79 | # positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
80 | # colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
81 | # normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
82 | # return BasicPointCloud(points=positions, colors=colors, normals=normals)
83 |
84 | # def storePly(path, xyz, rgb):
85 | # # Define the dtype for the structured array
86 | # dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
87 | # ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
88 | # ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
89 |
90 | # normals = np.zeros_like(xyz)
91 |
92 | # elements = np.empty(xyz.shape[0], dtype=dtype)
93 | # attributes = np.concatenate((xyz, normals, rgb), axis=1)
94 | # elements[:] = list(map(tuple, attributes))
95 |
96 | # # Create the PlyData object and write to file
97 | # vertex_element = PlyElement.describe(elements, 'vertex')
98 | # ply_data = PlyData([vertex_element])
99 | # ply_data.write(path)
100 |
101 | def readCameras(dataloader, zoom_scale=512, margin=0.0, frame_sample_interval=1):
102 | cam_infos = []
103 | use_binarized_mask = dataloader.use_binarized_mask
104 | bbox3d_diameter = dataloader.bbox3d_diameter
105 | for frame_idx in range(len(dataloader)):
106 | if frame_idx % frame_sample_interval != 0:
107 | continue
108 | obj_data = dataloader[frame_idx]
109 | camK = np.array(obj_data['camK'])
110 | pose = np.array(obj_data['pose'])
111 | R = np.transpose(pose[:3,:3]) # R is stored transposed due to 'glm' in CUDA code
112 | T = pose[:3, 3]
113 |
114 | if 'image_path' not in obj_data:
115 | image_path = None
116 | image = (obj_data['image'] * 255).numpy()
117 | image = Image.fromarray(image.astype(np.uint8))
118 | image_name = f'{frame_idx}.png'
119 | else:
120 | image_path = obj_data['image_path']
121 | image = Image.open(image_path)
122 | image_name = os.path.basename(image_path)
123 |
124 | image = torch.from_numpy(np.array(image))
125 | raw_height, raw_width = image.shape[:2]
126 |
127 | out = gs_utils.zoom_in_and_crop_with_offset(image, t=T, K=camK,
128 | radius=bbox3d_diameter/2,
129 | margin=margin, target_size=zoom_scale)
130 | image = out['zoom_image'].squeeze()
131 | height, width = image.shape[:2]
132 |
133 | # if 'coseg_mask_path' not in obj_data:
134 | # mask = np.ones((height, width, 1), dtype=np.float32)
135 | try:
136 | mask = Image.open(obj_data['coseg_mask_path'])
137 | mask = torch.from_numpy(np.array(mask, dtype=np.float32)) / 255.0
138 | mask = gs_utils.zoom_in_and_crop_with_offset(
139 | mask, t=T, K=camK, radius=bbox3d_diameter/2,
140 | margin=margin, target_size=zoom_scale
141 | )['zoom_image'].squeeze()
142 | if mask.dim() == 2:
143 | mask = mask[:, :, None]
144 | if use_binarized_mask:
145 | mask = mask.round()
146 | except Exception as e:
147 | print(e)
148 | mask = np.ones((height, width, 1), dtype=np.float32)
149 |
150 | image = (image * mask).type(torch.uint8).numpy()
151 | image = Image.fromarray(image.astype(np.uint8))
152 |
153 | zoom_camk = out['zoom_camK'].squeeze().numpy()
154 | zoom_offset = out['zoom_offset'].squeeze().numpy()
155 | cx_offset = zoom_offset[0]
156 | cy_offset = zoom_offset[1]
157 | cam_fx = zoom_camk[0, 0]
158 | cam_fy = zoom_camk[1, 1]
159 | FovX = focal2fov(cam_fx, width)
160 | FovY = focal2fov(cam_fy, height)
161 |
162 | cam_info = CameraInfo(R=R, T=T, FovY=FovY, FovX=FovX,
163 | cx_offset=cx_offset, cy_offset=cy_offset,
164 | uid=frame_idx, image=image,
165 | image_path=image_path, image_name=image_name,
166 | width=width, height=height)
167 | cam_infos.append(cam_info)
168 | return cam_infos
169 |
170 |
171 | def readObjectInfo(train_dataset, test_dataset, model_path, num_points=4096, zoom_scale=512, margin=0.0, random_points3D=False):
172 |
173 | print(f"Reading {len(train_dataset)} training image ...")
174 | train_cam_infos = readCameras(train_dataset, zoom_scale=zoom_scale, margin=margin, frame_sample_interval=1)
175 | num_training_samples = len(train_cam_infos)
176 | print(f"{num_training_samples} training samples")
177 | print(f"-----------------------------------------")
178 |
179 | test_interval = len(test_dataset) // 3
180 | test_cam_infos = readCameras(test_dataset, zoom_scale=zoom_scale, margin=margin, frame_sample_interval=test_interval)
181 | num_test_samples = len(test_cam_infos)
182 | print(f"{num_test_samples} testing samples")
183 | print(f"----------------------------------------")
184 |
185 | nerf_normalization = getNerfppNorm(train_cam_infos)
186 | ply_path = os.path.join(model_path, "3DGS_points3d.ply")
187 | if not random_points3D:
188 | obj_bbox3D = train_dataset.obj_bbox3d
189 | # obj_bbox3D = np.loadtxt(os.path.join(path, 'corners.txt'))
190 | min_3D_corner = obj_bbox3D.min(axis=0)
191 | max_3D_corner = obj_bbox3D.max(axis=0)
192 | obj_bbox3D_dims = max_3D_corner - min_3D_corner
193 | grid_cube_size = (np.prod(obj_bbox3D_dims, axis=0) / num_points)**(1/3)
194 |
195 | xnum, ynum, znum = np.ceil(obj_bbox3D_dims / grid_cube_size).astype(np.int64)
196 | xmin, ymin, zmin = min_3D_corner
197 | xmax, ymax, zmax = max_3D_corner
198 | zgrid, ygrid, xgrid = np.meshgrid(np.linspace(zmin, zmax, znum),
199 | np.linspace(ymin, ymax, ynum),
200 | np.linspace(xmin, xmax, xnum),
201 | indexing='ij')
202 | xyz = np.stack([xgrid, ygrid, zgrid], axis=-1).reshape(-1, 3)
203 | num_pts = xyz.shape[0]
204 | shs = np.random.random((num_pts, 3)) / 255.0
205 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))
206 |
207 | # storePly(ply_path, xyz, SH2RGB(shs) * 255)
208 |
209 | if random_points3D:
210 | # Since this data set has no colmap data, we start with random points
211 | # num_pts = 100_000
212 | print(f"Generating random point cloud ({num_points})...")
213 |
214 | # We create random points inside the bounds of the synthetic Blender scenes
215 | xyz = np.random.random((num_points, 3)) #* 2.6 - 1.3
216 | shs = np.random.random((num_points, 3)) #/ 255.0
217 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_points, 3)))
218 |
219 | # storePly(ply_path, xyz, SH2RGB(shs) * 255)
220 |
221 | # try:
222 | # pcd = fetchPly(ply_path)
223 | # except:
224 | # pcd = None
225 |
226 | object_info = SceneInfo(point_cloud=pcd,
227 | train_cameras=train_cam_infos,
228 | test_cameras=test_cam_infos,
229 | nerf_normalization=nerf_normalization,
230 | ply_path=ply_path)
231 | return object_info
232 |
233 |
234 |
235 |
--------------------------------------------------------------------------------
/gaussian_object/gaussian_render.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 | import math
12 | import torch
13 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
14 |
15 | from gaussian_object.utils.sh_utils import eval_sh
16 | from gaussian_object.gaussian_model import GaussianModel
17 |
18 |
19 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
20 | """
21 | Render the scene.
22 |
23 | Background tensor (bg_color) must be on GPU!
24 | """
25 |
26 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
27 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
28 | try:
29 | screenspace_points.retain_grad()
30 | except:
31 | pass
32 |
33 | # Set up rasterization configuration
34 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
35 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
36 |
37 | raster_settings = GaussianRasterizationSettings(
38 | image_height=int(viewpoint_camera.image_height),
39 | image_width=int(viewpoint_camera.image_width),
40 | tanfovx=tanfovx,
41 | tanfovy=tanfovy,
42 | bg=bg_color,
43 | scale_modifier=scaling_modifier,
44 | viewmatrix=viewpoint_camera.world_view_transform,
45 | projmatrix=viewpoint_camera.full_proj_transform,
46 | sh_degree=pc.active_sh_degree,
47 | campos=viewpoint_camera.camera_center,
48 | prefiltered=False,
49 | debug=pipe.debug
50 | )
51 |
52 | rasterizer = GaussianRasterizer(raster_settings=raster_settings)
53 |
54 | means3D = pc.get_xyz
55 | means2D = screenspace_points
56 | opacity = pc.get_opacity
57 |
58 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
59 | # scaling / rotation by the rasterizer.
60 | scales = None
61 | rotations = None
62 | cov3D_precomp = None
63 | if pipe.compute_cov3D_python:
64 | cov3D_precomp = pc.get_covariance(scaling_modifier)
65 | else:
66 | scales = pc.get_scaling
67 | rotations = pc.get_rotation
68 |
69 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
70 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
71 | shs = None
72 | colors_precomp = None
73 | if override_color is None:
74 | if pipe.convert_SHs_python:
75 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
76 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
77 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
78 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
79 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
80 | else:
81 | shs = pc.get_features
82 | else:
83 | colors_precomp = override_color
84 |
85 | # Rasterize visible Gaussians to image, obtain their radii (on screen).
86 | rendered_image, radii = rasterizer(
87 | means3D = means3D,
88 | means2D = means2D,
89 | shs = shs,
90 | colors_precomp = colors_precomp,
91 | opacities = opacity,
92 | scales = scales,
93 | rotations = rotations,
94 | cov3D_precomp = cov3D_precomp)
95 |
96 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
97 | # They will be excluded from value updates used in the splitting criteria.
98 | return {"render": rendered_image,
99 | "viewspace_points": screenspace_points,
100 | "visibility_filter" : radii > 0,
101 | "radii": radii}
102 |
--------------------------------------------------------------------------------
/gaussian_object/gaussian_renderer/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import math
14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
15 |
16 | from gaussian_object.gaussian_model import GaussianModel
17 | from gaussian_object.utils.sh_utils import eval_sh
18 |
19 |
20 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
21 | """
22 | Render the scene.
23 |
24 | Background tensor (bg_color) must be on GPU!
25 | """
26 |
27 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
28 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
29 | try:
30 | screenspace_points.retain_grad()
31 | except:
32 | pass
33 |
34 | # Set up rasterization configuration
35 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
36 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
37 |
38 | raster_settings = GaussianRasterizationSettings(
39 | image_height=int(viewpoint_camera.image_height),
40 | image_width=int(viewpoint_camera.image_width),
41 | tanfovx=tanfovx,
42 | tanfovy=tanfovy,
43 | bg=bg_color,
44 | scale_modifier=scaling_modifier,
45 | viewmatrix=viewpoint_camera.world_view_transform,
46 | projmatrix=viewpoint_camera.full_proj_transform,
47 | sh_degree=pc.active_sh_degree,
48 | campos=viewpoint_camera.camera_center,
49 | prefiltered=False,
50 | debug=pipe.debug
51 | )
52 |
53 | rasterizer = GaussianRasterizer(raster_settings=raster_settings)
54 |
55 | means3D = pc.get_xyz
56 | means2D = screenspace_points
57 | opacity = pc.get_opacity
58 |
59 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
60 | # scaling / rotation by the rasterizer.
61 | scales = None
62 | rotations = None
63 | cov3D_precomp = None
64 | if pipe.compute_cov3D_python:
65 | cov3D_precomp = pc.get_covariance(scaling_modifier)
66 | else:
67 | scales = pc.get_scaling
68 | rotations = pc.get_rotation
69 |
70 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
71 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
72 | shs = None
73 | colors_precomp = None
74 | if override_color is None:
75 | if pipe.convert_SHs_python:
76 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
77 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
78 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
79 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
80 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
81 | else:
82 | shs = pc.get_features
83 | else:
84 | colors_precomp = override_color
85 |
86 | # Rasterize visible Gaussians to image, obtain their radii (on screen).
87 | rendered_image, radii = rasterizer(
88 | means3D = means3D,
89 | means2D = means2D,
90 | shs = shs,
91 | colors_precomp = colors_precomp,
92 | opacities = opacity,
93 | scales = scales,
94 | rotations = rotations,
95 | cov3D_precomp = cov3D_precomp)
96 |
97 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
98 | # They will be excluded from value updates used in the splitting criteria.
99 | return {"render": rendered_image,
100 | "viewspace_points": screenspace_points,
101 | "visibility_filter" : radii > 0,
102 | "radii": radii}
103 |
--------------------------------------------------------------------------------
/gaussian_object/gaussian_renderer/network_gui.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import traceback
14 | import socket
15 | import json
16 | from gaussian_object.cameras import MiniCam
17 |
18 | host = "127.0.0.1"
19 | port = 6009
20 |
21 | conn = None
22 | addr = None
23 |
24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
25 |
26 | def init(wish_host, wish_port):
27 | global host, port, listener
28 | host = wish_host
29 | port = wish_port
30 | listener.bind((host, port))
31 | listener.listen()
32 | listener.settimeout(0)
33 |
34 | def try_connect():
35 | global conn, addr, listener
36 | try:
37 | conn, addr = listener.accept()
38 | print(f"\nConnected by {addr}")
39 | conn.settimeout(None)
40 | except Exception as inst:
41 | pass
42 |
43 | def read():
44 | global conn
45 | messageLength = conn.recv(4)
46 | messageLength = int.from_bytes(messageLength, 'little')
47 | message = conn.recv(messageLength)
48 | return json.loads(message.decode("utf-8"))
49 |
50 | def send(message_bytes, verify):
51 | global conn
52 | if message_bytes != None:
53 | conn.sendall(message_bytes)
54 | conn.sendall(len(verify).to_bytes(4, 'little'))
55 | conn.sendall(bytes(verify, 'ascii'))
56 |
57 | def receive():
58 | message = read()
59 |
60 | width = message["resolution_x"]
61 | height = message["resolution_y"]
62 |
63 | if width != 0 and height != 0:
64 | try:
65 | do_training = bool(message["train"])
66 | fovy = message["fov_y"]
67 | fovx = message["fov_x"]
68 | znear = message["z_near"]
69 | zfar = message["z_far"]
70 | do_shs_python = bool(message["shs_python"])
71 | do_rot_scale_python = bool(message["rot_scale_python"])
72 | keep_alive = bool(message["keep_alive"])
73 | scaling_modifier = message["scaling_modifier"]
74 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda()
75 | world_view_transform[:,1] = -world_view_transform[:,1]
76 | world_view_transform[:,2] = -world_view_transform[:,2]
77 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda()
78 | full_proj_transform[:,1] = -full_proj_transform[:,1]
79 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)
80 | except Exception as e:
81 | print("")
82 | traceback.print_exc()
83 | raise e
84 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier
85 | else:
86 | return None, None, None, None, None, None
--------------------------------------------------------------------------------
/gaussian_object/loss_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import torch.nn.functional as F
14 | from torch.autograd import Variable
15 | from math import exp
16 |
17 | def l1_loss(network_output, gt, size_average=True):
18 | loss = torch.abs((network_output - gt))
19 | if size_average:
20 | return loss.mean()
21 | return loss
22 |
23 | def l2_loss(network_output, gt):
24 | return ((network_output - gt) ** 2).mean()
25 |
26 | def gaussian(window_size, sigma):
27 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
28 | return gauss / gauss.sum()
29 |
30 | def create_window(window_size, channel):
31 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
32 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
33 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
34 | return window
35 |
36 | def ssim(img1, img2, window_size=11, size_average=True):
37 | channel = img1.size(-3)
38 | window = create_window(window_size, channel)
39 |
40 | if img1.is_cuda:
41 | window = window.cuda(img1.get_device())
42 | window = window.type_as(img1)
43 |
44 | return _ssim(img1, img2, window, window_size, channel, size_average)
45 |
46 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
47 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
48 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
49 |
50 | mu1_sq = mu1.pow(2)
51 | mu2_sq = mu2.pow(2)
52 | mu1_mu2 = mu1 * mu2
53 |
54 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
55 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
56 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
57 |
58 | C1 = 0.01 ** 2
59 | C2 = 0.03 ** 2
60 |
61 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
62 |
63 | if size_average:
64 | return ssim_map.mean()
65 | else:
66 | return ssim_map#.mean(1).mean(1).mean(1)
67 |
68 |
--------------------------------------------------------------------------------
/gaussian_object/sh_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 |
24 | import torch
25 |
26 | C0 = 0.28209479177387814
27 | C1 = 0.4886025119029199
28 | C2 = [
29 | 1.0925484305920792,
30 | -1.0925484305920792,
31 | 0.31539156525252005,
32 | -1.0925484305920792,
33 | 0.5462742152960396
34 | ]
35 | C3 = [
36 | -0.5900435899266435,
37 | 2.890611442640554,
38 | -0.4570457994644658,
39 | 0.3731763325901154,
40 | -0.4570457994644658,
41 | 1.445305721320277,
42 | -0.5900435899266435
43 | ]
44 | C4 = [
45 | 2.5033429417967046,
46 | -1.7701307697799304,
47 | 0.9461746957575601,
48 | -0.6690465435572892,
49 | 0.10578554691520431,
50 | -0.6690465435572892,
51 | 0.47308734787878004,
52 | -1.7701307697799304,
53 | 0.6258357354491761,
54 | ]
55 |
56 |
57 | def eval_sh(deg, sh, dirs):
58 | """
59 | Evaluate spherical harmonics at unit directions
60 | using hardcoded SH polynomials.
61 | Works with torch/np/jnp.
62 | ... Can be 0 or more batch dimensions.
63 | Args:
64 | deg: int SH deg. Currently, 0-3 supported
65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
66 | dirs: jnp.ndarray unit directions [..., 3]
67 | Returns:
68 | [..., C]
69 | """
70 | assert deg <= 4 and deg >= 0
71 | coeff = (deg + 1) ** 2
72 | assert sh.shape[-1] >= coeff
73 |
74 | result = C0 * sh[..., 0]
75 | if deg > 0:
76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
77 | result = (result -
78 | C1 * y * sh[..., 1] +
79 | C1 * z * sh[..., 2] -
80 | C1 * x * sh[..., 3])
81 |
82 | if deg > 1:
83 | xx, yy, zz = x * x, y * y, z * z
84 | xy, yz, xz = x * y, y * z, x * z
85 | result = (result +
86 | C2[0] * xy * sh[..., 4] +
87 | C2[1] * yz * sh[..., 5] +
88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
89 | C2[3] * xz * sh[..., 7] +
90 | C2[4] * (xx - yy) * sh[..., 8])
91 |
92 | if deg > 2:
93 | result = (result +
94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] +
95 | C3[1] * xy * z * sh[..., 10] +
96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
99 | C3[5] * z * (xx - yy) * sh[..., 14] +
100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15])
101 |
102 | if deg > 3:
103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
112 | return result
113 |
114 | def RGB2SH(rgb):
115 | return (rgb - 0.5) / C0
116 |
117 | def SH2RGB(sh):
118 | return sh * C0 + 0.5
--------------------------------------------------------------------------------
/gaussian_object/utils/camera_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 | import numpy as np
12 |
13 | from gaussian_object.cameras import Camera
14 | from gaussian_object.utils.general_utils import PILtoTorch
15 | from gaussian_object.utils.graphics_utils import fov2focal
16 |
17 | WARNED = False
18 |
19 | def loadCam(args, id, cam_info, resolution_scale):
20 | orig_w, orig_h = cam_info.image.size
21 |
22 | if args.resolution in [1, 2, 4, 8]:
23 | resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
24 | else: # should be a type that converts to float
25 | if args.resolution == -1:
26 | if orig_w > 1600:
27 | global WARNED
28 | if not WARNED:
29 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n "
30 | "If this is not desired, please explicitly specify '--resolution/-r' as 1")
31 | WARNED = True
32 | global_down = orig_w / 1600
33 | else:
34 | global_down = 1
35 | else:
36 | global_down = orig_w / args.resolution
37 |
38 | scale = float(global_down) * float(resolution_scale)
39 | resolution = (int(orig_w / scale), int(orig_h / scale))
40 |
41 | resized_image_rgb = PILtoTorch(cam_info.image, resolution)
42 |
43 | gt_image = resized_image_rgb[:3, ...]
44 | loaded_mask = None
45 |
46 | # if cam_info.mask is not None:
47 | # mask = PILtoTorch(cam_info.mask, resolution)
48 | # else:
49 | # mask = torch.ones_like(gt_image)[..., 0][None]
50 |
51 | if resized_image_rgb.shape[1] == 4:
52 | loaded_mask = resized_image_rgb[3:4, ...]
53 |
54 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
55 | FoVx=cam_info.FovX, FoVy=cam_info.FovY,
56 | # mask=mask,
57 | cx_offset=cam_info.cx_offset,
58 | cy_offset=cam_info.cy_offset,
59 | image=gt_image, gt_alpha_mask=loaded_mask,
60 | image_name=cam_info.image_name, uid=id, data_device=args.data_device)
61 |
62 | def cameraList_from_camInfos(cam_infos, resolution_scale, args):
63 | camera_list = []
64 |
65 | for id, c in enumerate(cam_infos):
66 | camera_list.append(loadCam(args, id, c, resolution_scale))
67 |
68 | return camera_list
69 |
70 | def camera_to_JSON(id, camera : Camera):
71 | Rt = np.zeros((4, 4))
72 | Rt[:3, :3] = camera.R.transpose()
73 | Rt[:3, 3] = camera.T
74 | Rt[3, 3] = 1.0
75 |
76 | W2C = np.linalg.inv(Rt)
77 | pos = W2C[:3, 3]
78 | rot = W2C[:3, :3]
79 | serializable_array_2d = [x.tolist() for x in rot]
80 | camera_entry = {
81 | 'id' : id,
82 | 'img_name' : camera.image_name,
83 | 'width' : camera.width,
84 | 'height' : camera.height,
85 | 'position': pos.tolist(),
86 | 'rotation': serializable_array_2d,
87 | 'fy' : fov2focal(camera.FovY, camera.height),
88 | 'fx' : fov2focal(camera.FovX, camera.width)
89 | }
90 | return camera_entry
91 |
--------------------------------------------------------------------------------
/gaussian_object/utils/general_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import sys
14 | from datetime import datetime
15 | import numpy as np
16 | import random
17 |
18 | def inverse_sigmoid(x):
19 | return torch.log(x/(1-x))
20 |
21 | def PILtoTorch(pil_image, resolution):
22 | resized_image_PIL = pil_image.resize(resolution)
23 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
24 | if len(resized_image.shape) == 3:
25 | return resized_image.permute(2, 0, 1)
26 | else:
27 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
28 |
29 | def get_expon_lr_func(
30 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
31 | ):
32 | """
33 | Copied from Plenoxels
34 |
35 | Continuous learning rate decay function. Adapted from JaxNeRF
36 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
37 | is log-linearly interpolated elsewhere (equivalent to exponential decay).
38 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth
39 | function of lr_delay_mult, such that the initial learning rate is
40 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back
41 | to the normal learning rate when steps>lr_delay_steps.
42 | :param conf: config subtree 'lr' or similar
43 | :param max_steps: int, the number of steps during optimization.
44 | :return HoF which takes step as input
45 | """
46 |
47 | def helper(step):
48 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
49 | # Disable this parameter
50 | return 0.0
51 | if lr_delay_steps > 0:
52 | # A kind of reverse cosine decay.
53 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
54 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
55 | )
56 | else:
57 | delay_rate = 1.0
58 | t = np.clip(step / max_steps, 0, 1)
59 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
60 | return delay_rate * log_lerp
61 |
62 | return helper
63 |
64 | def strip_lowerdiag(L):
65 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
66 |
67 | uncertainty[:, 0] = L[:, 0, 0]
68 | uncertainty[:, 1] = L[:, 0, 1]
69 | uncertainty[:, 2] = L[:, 0, 2]
70 | uncertainty[:, 3] = L[:, 1, 1]
71 | uncertainty[:, 4] = L[:, 1, 2]
72 | uncertainty[:, 5] = L[:, 2, 2]
73 | return uncertainty
74 |
75 | def strip_symmetric(sym):
76 | return strip_lowerdiag(sym)
77 |
78 | def build_rotation(r):
79 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
80 |
81 | q = r / norm[:, None]
82 |
83 | R = torch.zeros((q.size(0), 3, 3), device='cuda')
84 |
85 | r = q[:, 0]
86 | x = q[:, 1]
87 | y = q[:, 2]
88 | z = q[:, 3]
89 |
90 | R[:, 0, 0] = 1 - 2 * (y*y + z*z)
91 | R[:, 0, 1] = 2 * (x*y - r*z)
92 | R[:, 0, 2] = 2 * (x*z + r*y)
93 | R[:, 1, 0] = 2 * (x*y + r*z)
94 | R[:, 1, 1] = 1 - 2 * (x*x + z*z)
95 | R[:, 1, 2] = 2 * (y*z - r*x)
96 | R[:, 2, 0] = 2 * (x*z - r*y)
97 | R[:, 2, 1] = 2 * (y*z + r*x)
98 | R[:, 2, 2] = 1 - 2 * (x*x + y*y)
99 | return R
100 |
101 | def build_scaling_rotation(s, r):
102 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
103 | R = build_rotation(r)
104 |
105 | L[:,0,0] = s[:,0]
106 | L[:,1,1] = s[:,1]
107 | L[:,2,2] = s[:,2]
108 |
109 | L = R @ L
110 | return L
111 |
112 | def safe_state(silent):
113 | old_f = sys.stdout
114 | class F:
115 | def __init__(self, silent):
116 | self.silent = silent
117 |
118 | def write(self, x):
119 | if not self.silent:
120 | if x.endswith("\n"):
121 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
122 | else:
123 | old_f.write(x)
124 |
125 | def flush(self):
126 | old_f.flush()
127 |
128 | sys.stdout = F(silent)
129 |
130 | random.seed(0)
131 | np.random.seed(0)
132 | torch.manual_seed(0)
133 | torch.cuda.set_device(torch.device("cuda:0"))
134 |
--------------------------------------------------------------------------------
/gaussian_object/utils/graphics_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import math
14 | import numpy as np
15 | from typing import NamedTuple
16 |
17 | class BasicPointCloud(NamedTuple):
18 | points : np.array
19 | colors : np.array
20 | normals : np.array
21 |
22 | def geom_transform_points(points, transf_matrix):
23 | P, _ = points.shape
24 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
25 | points_hom = torch.cat([points, ones], dim=1)
26 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
27 |
28 | denom = points_out[..., 3:] + 0.0000001
29 | return (points_out[..., :3] / denom).squeeze(dim=0)
30 |
31 | def getWorld2View(R, t):
32 | Rt = np.zeros((4, 4))
33 | Rt[:3, :3] = R.transpose()
34 | Rt[:3, 3] = t
35 | Rt[3, 3] = 1.0
36 | return np.float32(Rt)
37 |
38 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
39 | Rt = np.zeros((4, 4))
40 | Rt[:3, :3] = R.transpose()
41 | Rt[:3, 3] = t
42 | Rt[3, 3] = 1.0
43 |
44 | C2W = np.linalg.inv(Rt)
45 | cam_center = C2W[:3, 3]
46 | cam_center = (cam_center + translate) * scale
47 | C2W[:3, 3] = cam_center
48 | Rt = np.linalg.inv(C2W)
49 | return np.float32(Rt)
50 |
51 | def getProjectionMatrix(znear, zfar, fovX, fovY, cx_offset=0, cy_offset=0):
52 | tanHalfFovY = math.tan((fovY / 2))
53 | tanHalfFovX = math.tan((fovX / 2))
54 |
55 | top = tanHalfFovY * znear
56 | bottom = -top
57 | right = tanHalfFovX * znear
58 | left = -right
59 |
60 | P = torch.zeros(4, 4)
61 |
62 | z_sign = 1.0
63 |
64 | P[0, 0] = 2.0 * znear / (right - left)
65 | P[1, 1] = 2.0 * znear / (top - bottom)
66 | P[0, 2] = (right + left) / (right - left) + cx_offset
67 | P[1, 2] = (top + bottom) / (top - bottom) + cy_offset
68 | P[3, 2] = z_sign
69 | P[2, 2] = z_sign * zfar / (zfar - znear)
70 | P[2, 3] = -(zfar * znear) / (zfar - znear)
71 | return P
72 |
73 | def fov2focal(fov, pixels):
74 | return pixels / (2 * math.tan(fov / 2))
75 |
76 | def focal2fov(focal, pixels):
77 | return 2*math.atan(pixels/(2*focal))
--------------------------------------------------------------------------------
/gaussian_object/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 |
14 | def mse(img1, img2):
15 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
16 |
17 | def psnr(img1, img2):
18 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
19 | return 20 * torch.log10(1.0 / torch.sqrt(mse))
20 |
--------------------------------------------------------------------------------
/gaussian_object/utils/loss_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import torch.nn.functional as F
14 | from torch.autograd import Variable
15 | from math import exp
16 |
17 | def l1_loss(network_output, gt):
18 | return torch.abs((network_output - gt)).mean()
19 |
20 | def l2_loss(network_output, gt):
21 | return ((network_output - gt) ** 2).mean()
22 |
23 | def gaussian(window_size, sigma):
24 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
25 | return gauss / gauss.sum()
26 |
27 | def create_window(window_size, channel):
28 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
29 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
30 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
31 | return window
32 |
33 | def ssim(img1, img2, window_size=11, size_average=True):
34 | channel = img1.size(-3)
35 | window = create_window(window_size, channel)
36 |
37 | if img1.is_cuda:
38 | window = window.cuda(img1.get_device())
39 | window = window.type_as(img1)
40 |
41 | return _ssim(img1, img2, window, window_size, channel, size_average)
42 |
43 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
44 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
45 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
46 |
47 | mu1_sq = mu1.pow(2)
48 | mu2_sq = mu2.pow(2)
49 | mu1_mu2 = mu1 * mu2
50 |
51 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
52 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
53 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
54 |
55 | C1 = 0.01 ** 2
56 | C2 = 0.03 ** 2
57 |
58 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
59 |
60 | if size_average:
61 | return ssim_map.mean()
62 | else:
63 | return ssim_map.mean(1).mean(1).mean(1)
64 |
65 |
--------------------------------------------------------------------------------
/gaussian_object/utils/sh_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 |
24 | import torch
25 |
26 | C0 = 0.28209479177387814
27 | C1 = 0.4886025119029199
28 | C2 = [
29 | 1.0925484305920792,
30 | -1.0925484305920792,
31 | 0.31539156525252005,
32 | -1.0925484305920792,
33 | 0.5462742152960396
34 | ]
35 | C3 = [
36 | -0.5900435899266435,
37 | 2.890611442640554,
38 | -0.4570457994644658,
39 | 0.3731763325901154,
40 | -0.4570457994644658,
41 | 1.445305721320277,
42 | -0.5900435899266435
43 | ]
44 | C4 = [
45 | 2.5033429417967046,
46 | -1.7701307697799304,
47 | 0.9461746957575601,
48 | -0.6690465435572892,
49 | 0.10578554691520431,
50 | -0.6690465435572892,
51 | 0.47308734787878004,
52 | -1.7701307697799304,
53 | 0.6258357354491761,
54 | ]
55 |
56 |
57 | def eval_sh(deg, sh, dirs):
58 | """
59 | Evaluate spherical harmonics at unit directions
60 | using hardcoded SH polynomials.
61 | Works with torch/np/jnp.
62 | ... Can be 0 or more batch dimensions.
63 | Args:
64 | deg: int SH deg. Currently, 0-3 supported
65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
66 | dirs: jnp.ndarray unit directions [..., 3]
67 | Returns:
68 | [..., C]
69 | """
70 | assert deg <= 4 and deg >= 0
71 | coeff = (deg + 1) ** 2
72 | assert sh.shape[-1] >= coeff
73 |
74 | result = C0 * sh[..., 0]
75 | if deg > 0:
76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
77 | result = (result -
78 | C1 * y * sh[..., 1] +
79 | C1 * z * sh[..., 2] -
80 | C1 * x * sh[..., 3])
81 |
82 | if deg > 1:
83 | xx, yy, zz = x * x, y * y, z * z
84 | xy, yz, xz = x * y, y * z, x * z
85 | result = (result +
86 | C2[0] * xy * sh[..., 4] +
87 | C2[1] * yz * sh[..., 5] +
88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
89 | C2[3] * xz * sh[..., 7] +
90 | C2[4] * (xx - yy) * sh[..., 8])
91 |
92 | if deg > 2:
93 | result = (result +
94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] +
95 | C3[1] * xy * z * sh[..., 10] +
96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
99 | C3[5] * z * (xx - yy) * sh[..., 14] +
100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15])
101 |
102 | if deg > 3:
103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
112 | return result
113 |
114 | def RGB2SH(rgb):
115 | return (rgb - 0.5) / C0
116 |
117 | def SH2RGB(sh):
118 | return sh * C0 + 0.5
--------------------------------------------------------------------------------
/gaussian_object/utils/system_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from errno import EEXIST
13 | from os import makedirs, path
14 | import os
15 |
16 | def mkdir_p(folder_path):
17 | # Creates a directory. equivalent to using mkdir -p on the command line
18 | try:
19 | makedirs(folder_path)
20 | except OSError as exc: # Python >2.5
21 | if exc.errno == EEXIST and path.isdir(folder_path):
22 | pass
23 | else:
24 | raise
25 |
26 | def searchForMaxIteration(folder):
27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]
28 | return max(saved_iters)
29 |
--------------------------------------------------------------------------------
/install_env.sh:
--------------------------------------------------------------------------------
1 | pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 -f https://download.pytorch.org/whl/torch_stable.html
2 | pip install "git+https://github.com/facebookresearch/pytorch3d.git"
3 | pip install xformers==0.0.22
4 |
5 | cd submodules/Connected_components_PyTorch
6 | python setup.py install
7 |
8 | cd ../diff-gaussian-rasterization
9 | python setup.py install
10 |
11 | cd ../simple-knn
12 | python setup.py install
13 |
14 | cd ../../model/curope
15 | python setup.py install
16 | cd ../..
17 |
--------------------------------------------------------------------------------
/misc_utils/loss_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | from math import exp
14 | import torch.nn.functional as F
15 | from torch.autograd import Variable
16 |
17 | def l1_loss(network_output, gt, size_average=True):
18 | loss = torch.abs((network_output - gt))
19 | if size_average:
20 | return loss.mean()
21 | return loss
22 |
23 | def l2_loss(network_output, gt):
24 | return ((network_output - gt) ** 2).mean()
25 |
26 | def gaussian(window_size, sigma):
27 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
28 | return gauss / gauss.sum()
29 |
30 | def create_window(window_size, channel):
31 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
32 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
33 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
34 | return window
35 |
36 | def ssim(img1, img2, window_size=11, size_average=True):
37 | channel = img1.size(-3)
38 | window = create_window(window_size, channel)
39 |
40 | if img1.is_cuda:
41 | window = window.cuda(img1.get_device())
42 | window = window.type_as(img1)
43 |
44 | return _ssim(img1, img2, window, window_size, channel, size_average)
45 |
46 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
47 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
48 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
49 |
50 | mu1_sq = mu1.pow(2)
51 | mu2_sq = mu2.pow(2)
52 | mu1_mu2 = mu1 * mu2
53 |
54 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
55 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
56 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
57 |
58 | C1 = 0.01 ** 2
59 | C2 = 0.03 ** 2
60 |
61 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
62 |
63 | if size_average:
64 | return ssim_map.mean()
65 | else:
66 | return ssim_map#.mean(1).mean(1).mean(1)
67 |
68 |
--------------------------------------------------------------------------------
/misc_utils/metric_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def calc_pose_error(pred_RT, gt_RT, unit='cm'):
5 | pred_R = pred_RT[:3, :3]
6 | pred_t = pred_RT[:3, 3]
7 | gt_R = gt_RT[:3, :3]
8 | gt_t = gt_RT[:3, 3]
9 | R_err = np.arccos(np.clip(np.trace(pred_R @ gt_R.T) / 2 - 0.5, -1, 1.0)) / np.pi * 180
10 | t_err = np.linalg.norm(pred_t - gt_t)
11 |
12 | if unit == 'm':
13 | unit_factor = 1
14 | elif unit == 'cm':
15 | unit_factor = 100
16 | elif unit == 'mm':
17 | unit_factor = 1000
18 | else:
19 | raise ValueError('Invalid unit')
20 |
21 | t_err *= unit_factor
22 | return R_err, t_err
23 |
24 | def calc_add_metric(model_3D_pts, diameter, pose_pred, pose_target, percentage=0.1, return_error=False, syn=False, model_unit='m'):
25 | from scipy import spatial
26 | # Dim check:
27 | if pose_pred.shape[0] == 4:
28 | pose_pred = pose_pred[:3]
29 | if pose_target.shape[0] == 4:
30 | pose_target = pose_target[:3]
31 |
32 | diameter_thres = diameter * percentage
33 | model_pred = np.dot(model_3D_pts, pose_pred[:, :3].T) + pose_pred[:, 3]
34 | model_target = np.dot(model_3D_pts, pose_target[:, :3].T) + pose_target[:, 3]
35 |
36 | if syn:
37 | mean_dist_index = spatial.cKDTree(model_pred)
38 | mean_dist, _ = mean_dist_index.query(model_target, k=1)
39 | mean_dist = np.mean(mean_dist)
40 | else:
41 | mean_dist = np.mean(np.linalg.norm(model_pred - model_target, axis=-1))
42 |
43 | if return_error:
44 | return mean_dist
45 | elif mean_dist < diameter_thres:
46 | return True
47 | else:
48 | return False
49 |
50 | def calc_projection_2d_error(model_3D_pts, pose_pred, pose_targets, K, pixels=5, return_error=True):
51 | def project(xyz, K, RT):
52 | """
53 | NOTE: need to use original K
54 | xyz: [N, 3]
55 | K: [3, 3]
56 | RT: [3, 4]
57 | """
58 | xyz = np.dot(xyz, RT[:, :3].T) + RT[:, 3:].T
59 | xyz = np.dot(xyz, K.T)
60 | xy = xyz[:, :2] / xyz[:, 2:]
61 | return xy
62 |
63 | # Dim check:
64 | if pose_pred.shape[0] == 4:
65 | pose_pred = pose_pred[:3]
66 | if pose_targets.shape[0] == 4:
67 | pose_targets = pose_targets[:3]
68 |
69 | model_2d_pred = project(model_3D_pts, K, pose_pred) # pose_pred: 3*4
70 | model_2d_targets = project(model_3D_pts, K, pose_targets)
71 | proj_mean_diff = np.mean(np.linalg.norm(model_2d_pred - model_2d_targets, axis=-1))
72 | if return_error:
73 | return proj_mean_diff
74 | elif proj_mean_diff < pixels:
75 | return True
76 | else:
77 | return False
78 |
79 | def calc_bbox_IOU(pd_bbox, gt_bbox, iou_threshold=0.5, return_iou=False):
80 | px1, py1, px2, py2 = pd_bbox.squeeze()
81 | gx1, gy1, gx2, gy2 = gt_bbox.squeeze()
82 | inter_wid = np.maximum(np.minimum(px2, gx2) - np.maximum(px1, gx1), 0)
83 | inter_hei = np.maximum(np.minimum(py2, gy2) - np.maximum(py1, gy1), 0)
84 | inter_area = inter_wid * inter_hei
85 | outer_wid = np.maximum(px2, gx2) - np.minimum(px1, gx1)
86 | outer_hei = np.maximum(py2, gy2) - np.minimum(py1, gy1)
87 | outer_area = outer_wid * outer_hei
88 | iou = inter_area / outer_area
89 | if return_iou:
90 | return iou
91 | elif iou > iou_threshold:
92 | return True
93 | else:
94 | return False
95 |
96 | def aggregate_metrics(metrics, pose_thres=[1, 3, 5], proj2d_thres=5):
97 | """ Aggregate metrics for the whole dataset:
98 | (This method should be called once per dataset)
99 | """
100 | R_errs = metrics["R_errs"]
101 | t_errs = metrics["t_errs"]
102 |
103 | agg_metric = {}
104 | for pose_threshold in pose_thres:
105 | agg_metric[f"{pose_threshold}˚@{pose_threshold}cm"] = np.mean(
106 | (np.array(R_errs) < pose_threshold) & (np.array(t_errs) < pose_threshold)
107 | )
108 | agg_metric[f"{pose_threshold}cm"] = np.mean((np.array(t_errs) < pose_threshold))
109 | agg_metric[f"{pose_threshold}˚"] = np.mean((np.array(R_errs) < pose_threshold))
110 |
111 | if "ADD_metric" in metrics:
112 | ADD_metric = metrics['ADD_metric']
113 | agg_metric["ADD"] = np.mean(ADD_metric)
114 |
115 | if "Proj2D" in metrics:
116 | proj2D_metric = metrics['Proj2D']
117 | agg_metric[f"pix@{proj2d_thres}"] = np.mean(np.array(proj2D_metric) < proj2d_thres)
118 |
119 | return agg_metric
120 |
121 |
--------------------------------------------------------------------------------
/misc_utils/warmup_lr.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.lr_scheduler import _LRScheduler
4 |
5 | class CosineAnnealingWarmupRestarts(_LRScheduler):
6 | """
7 | optimizer (Optimizer): Wrapped optimizer.
8 | first_cycle_steps (int): First cycle step size.
9 | cycle_mult(float): Cycle steps magnification. Default: -1.
10 | max_lr(float): First cycle's max learning rate. Default: 0.1.
11 | min_lr(float): Min learning rate. Default: 0.001.
12 | warmup_steps(int): Linear warmup step size. Default: 0.
13 | gamma(float): Decrease rate of max learning rate by cycle. Default: 1.
14 | last_epoch (int): The index of last epoch. Default: -1.
15 | """
16 |
17 | def __init__(self,
18 | optimizer : torch.optim.Optimizer,
19 | first_cycle_steps : int,
20 | cycle_mult : float = 1.,
21 | max_lr : float = 0.1,
22 | min_lr : float = 0.001,
23 | warmup_steps : int = 0,
24 | gamma : float = 1.,
25 | last_epoch : int = -1
26 | ):
27 | assert warmup_steps < first_cycle_steps
28 |
29 | self.first_cycle_steps = first_cycle_steps # first cycle step size
30 | self.cycle_mult = cycle_mult # cycle steps magnification
31 | self.base_max_lr = max_lr # first max learning rate
32 | self.max_lr = max_lr # max learning rate in the current cycle
33 | self.min_lr = min_lr # min learning rate
34 | self.warmup_steps = warmup_steps # warmup step size
35 | self.gamma = gamma # decrease rate of max learning rate by cycle
36 |
37 | self.cur_cycle_steps = first_cycle_steps # first cycle step size
38 | self.cycle = 0 # cycle count
39 | self.step_in_cycle = last_epoch # step size of the current cycle
40 |
41 | super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch)
42 |
43 | # set learning rate min_lr
44 | self.init_lr()
45 |
46 | def init_lr(self):
47 | self.base_lrs = []
48 | for param_group in self.optimizer.param_groups:
49 | param_group['lr'] = self.min_lr
50 | self.base_lrs.append(self.min_lr)
51 |
52 | def get_lr(self):
53 | if self.step_in_cycle == -1:
54 | return self.base_lrs
55 | elif self.step_in_cycle < self.warmup_steps:
56 | return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs]
57 | else:
58 | return [base_lr + (self.max_lr - base_lr) \
59 | * (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \
60 | / (self.cur_cycle_steps - self.warmup_steps))) / 2
61 | for base_lr in self.base_lrs]
62 |
63 | def step(self, epoch=None):
64 | if epoch is None:
65 | epoch = self.last_epoch + 1
66 | self.step_in_cycle = self.step_in_cycle + 1
67 | if self.step_in_cycle >= self.cur_cycle_steps:
68 | self.cycle += 1
69 | self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps
70 | self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps
71 | else:
72 | if epoch >= self.first_cycle_steps:
73 | if self.cycle_mult == 1.:
74 | self.step_in_cycle = epoch % self.first_cycle_steps
75 | self.cycle = epoch // self.first_cycle_steps
76 | else:
77 | n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
78 | self.cycle = n
79 | self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1))
80 | self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n)
81 | else:
82 | self.cur_cycle_steps = self.first_cycle_steps
83 | self.step_in_cycle = epoch
84 |
85 | self.max_lr = self.base_max_lr * (self.gamma**self.cycle)
86 | self.last_epoch = math.floor(epoch)
87 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
88 | param_group['lr'] = lr
89 |
--------------------------------------------------------------------------------
/model/curope/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
4 | from .curope2d import cuRoPE2D
5 |
--------------------------------------------------------------------------------
/model/curope/curope.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright (C) 2022-present Naver Corporation. All rights reserved.
3 | Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4 | */
5 |
6 | #include
7 |
8 | // forward declaration
9 | void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd );
10 |
11 | void rope_2d_cpu( torch::Tensor tokens, const torch::Tensor positions, const float base, const float fwd )
12 | {
13 | const int B = tokens.size(0);
14 | const int N = tokens.size(1);
15 | const int H = tokens.size(2);
16 | const int D = tokens.size(3) / 4;
17 |
18 | auto tok = tokens.accessor();
19 | auto pos = positions.accessor();
20 |
21 | for (int b = 0; b < B; b++) {
22 | for (int x = 0; x < 2; x++) { // y and then x (2d)
23 | for (int n = 0; n < N; n++) {
24 |
25 | // grab the token position
26 | const int p = pos[b][n][x];
27 |
28 | for (int h = 0; h < H; h++) {
29 | for (int d = 0; d < D; d++) {
30 | // grab the two values
31 | float u = tok[b][n][h][d+0+x*2*D];
32 | float v = tok[b][n][h][d+D+x*2*D];
33 |
34 | // grab the cos,sin
35 | const float inv_freq = fwd * p / powf(base, d/float(D));
36 | float c = cosf(inv_freq);
37 | float s = sinf(inv_freq);
38 |
39 | // write the result
40 | tok[b][n][h][d+0+x*2*D] = u*c - v*s;
41 | tok[b][n][h][d+D+x*2*D] = v*c + u*s;
42 | }
43 | }
44 | }
45 | }
46 | }
47 | }
48 |
49 | void rope_2d( torch::Tensor tokens, // B,N,H,D
50 | const torch::Tensor positions, // B,N,2
51 | const float base,
52 | const float fwd )
53 | {
54 | TORCH_CHECK(tokens.dim() == 4, "tokens must have 4 dimensions");
55 | TORCH_CHECK(positions.dim() == 3, "positions must have 3 dimensions");
56 | TORCH_CHECK(tokens.size(0) == positions.size(0), "batch size differs between tokens & positions");
57 | TORCH_CHECK(tokens.size(1) == positions.size(1), "seq_length differs between tokens & positions");
58 | TORCH_CHECK(positions.size(2) == 2, "positions.shape[2] must be equal to 2");
59 | TORCH_CHECK(tokens.is_cuda() == positions.is_cuda(), "tokens and positions are not on the same device" );
60 |
61 | if (tokens.is_cuda())
62 | rope_2d_cuda( tokens, positions, base, fwd );
63 | else
64 | rope_2d_cpu( tokens, positions, base, fwd );
65 | }
66 |
67 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
68 | m.def("rope_2d", &rope_2d, "RoPE 2d forward/backward");
69 | }
70 |
--------------------------------------------------------------------------------
/model/curope/curope2d.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
4 | import torch
5 |
6 | try:
7 | import curope as _kernels # run `python setup.py install`
8 | except ModuleNotFoundError:
9 | from . import curope as _kernels # run `python setup.py build_ext --inplace`
10 |
11 |
12 | class cuRoPE2D_func (torch.autograd.Function):
13 |
14 | @staticmethod
15 | def forward(ctx, tokens, positions, base, F0=1):
16 | ctx.save_for_backward(positions)
17 | ctx.saved_base = base
18 | ctx.saved_F0 = F0
19 | # tokens = tokens.clone() # uncomment this if inplace doesn't work
20 | _kernels.rope_2d( tokens, positions, base, F0 )
21 | ctx.mark_dirty(tokens)
22 | return tokens
23 |
24 | @staticmethod
25 | def backward(ctx, grad_res):
26 | positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0
27 | _kernels.rope_2d( grad_res, positions, base, -F0 )
28 | ctx.mark_dirty(grad_res)
29 | return grad_res, None, None, None
30 |
31 |
32 | class cuRoPE2D(torch.nn.Module):
33 | def __init__(self, freq=100.0, F0=1.0):
34 | super().__init__()
35 | self.base = freq
36 | self.F0 = F0
37 |
38 | def forward(self, tokens, positions):
39 | cuRoPE2D_func.apply( tokens.transpose(1,2), positions, self.base, self.F0 )
40 | return tokens
--------------------------------------------------------------------------------
/model/curope/kernels.cu:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright (C) 2022-present Naver Corporation. All rights reserved.
3 | Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4 | */
5 |
6 | #include
7 | #include
8 | #include
9 | #include
10 |
11 | #define CHECK_CUDA(tensor) {\
12 | TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \
13 | TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); }
14 | void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));}
15 |
16 |
17 | template < typename scalar_t >
18 | __global__ void rope_2d_cuda_kernel(
19 | //scalar_t* __restrict__ tokens,
20 | torch::PackedTensorAccessor32 tokens,
21 | const int64_t* __restrict__ pos,
22 | const float base,
23 | const float fwd )
24 | // const int N, const int H, const int D )
25 | {
26 | // tokens shape = (B, N, H, D)
27 | const int N = tokens.size(1);
28 | const int H = tokens.size(2);
29 | const int D = tokens.size(3);
30 |
31 | // each block update a single token, for all heads
32 | // each thread takes care of a single output
33 | extern __shared__ float shared[];
34 | float* shared_inv_freq = shared + D;
35 |
36 | const int b = blockIdx.x / N;
37 | const int n = blockIdx.x % N;
38 |
39 | const int Q = D / 4;
40 | // one token = [0..Q : Q..2Q : 2Q..3Q : 3Q..D]
41 | // u_Y v_Y u_X v_X
42 |
43 | // shared memory: first, compute inv_freq
44 | if (threadIdx.x < Q)
45 | shared_inv_freq[threadIdx.x] = fwd / powf(base, threadIdx.x/float(Q));
46 | __syncthreads();
47 |
48 | // start of X or Y part
49 | const int X = threadIdx.x < D/2 ? 0 : 1;
50 | const int m = (X*D/2) + (threadIdx.x % Q); // index of u_Y or u_X
51 |
52 | // grab the cos,sin appropriate for me
53 | const float freq = pos[blockIdx.x*2+X] * shared_inv_freq[threadIdx.x % Q];
54 | const float cos = cosf(freq);
55 | const float sin = sinf(freq);
56 | /*
57 | float* shared_cos_sin = shared + D + D/4;
58 | if ((threadIdx.x % (D/2)) < Q)
59 | shared_cos_sin[m+0] = cosf(freq);
60 | else
61 | shared_cos_sin[m+Q] = sinf(freq);
62 | __syncthreads();
63 | const float cos = shared_cos_sin[m+0];
64 | const float sin = shared_cos_sin[m+Q];
65 | */
66 |
67 | for (int h = 0; h < H; h++)
68 | {
69 | // then, load all the token for this head in shared memory
70 | shared[threadIdx.x] = tokens[b][n][h][threadIdx.x];
71 | __syncthreads();
72 |
73 | const float u = shared[m];
74 | const float v = shared[m+Q];
75 |
76 | // write output
77 | if ((threadIdx.x % (D/2)) < Q)
78 | tokens[b][n][h][threadIdx.x] = u*cos - v*sin;
79 | else
80 | tokens[b][n][h][threadIdx.x] = v*cos + u*sin;
81 | }
82 | }
83 |
84 | void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd )
85 | {
86 | const int B = tokens.size(0); // batch size
87 | const int N = tokens.size(1); // sequence length
88 | const int H = tokens.size(2); // number of heads
89 | const int D = tokens.size(3); // dimension per head
90 |
91 | TORCH_CHECK(tokens.stride(3) == 1 && tokens.stride(2) == D, "tokens are not contiguous");
92 | TORCH_CHECK(pos.is_contiguous(), "positions are not contiguous");
93 | TORCH_CHECK(pos.size(0) == B && pos.size(1) == N && pos.size(2) == 2, "bad pos.shape");
94 | TORCH_CHECK(D % 4 == 0, "token dim must be multiple of 4");
95 |
96 | // one block for each layer, one thread per local-max
97 | const int THREADS_PER_BLOCK = D;
98 | const int N_BLOCKS = B * N; // each block takes care of H*D values
99 | const int SHARED_MEM = sizeof(float) * (D + D/4);
100 |
101 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), "rope_2d_cuda", ([&] {
102 | rope_2d_cuda_kernel <<>> (
103 | //tokens.data_ptr(),
104 | tokens.packed_accessor32(),
105 | pos.data_ptr(),
106 | base, fwd); //, N, H, D );
107 | }));
108 | }
109 |
--------------------------------------------------------------------------------
/model/curope/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
4 | from setuptools import setup
5 | from torch import cuda
6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7 |
8 | # compile for all possible CUDA architectures
9 | all_cuda_archs = cuda.get_gencode_flags().replace('compute=','arch=').split()
10 | # alternatively, you can list cuda archs that you want, eg:
11 | # all_cuda_archs = [
12 | # '-gencode', 'arch=compute_70,code=sm_70',
13 | # '-gencode', 'arch=compute_75,code=sm_75',
14 | # '-gencode', 'arch=compute_80,code=sm_80',
15 | # '-gencode', 'arch=compute_86,code=sm_86'
16 | # ]
17 |
18 | setup(
19 | name = 'curope',
20 | ext_modules = [
21 | CUDAExtension(
22 | name='curope',
23 | sources=[
24 | "curope.cpp",
25 | "kernels.cu",
26 | ],
27 | extra_compile_args = dict(
28 | nvcc=['-O3','--ptxas-options=-v',"--use_fast_math"]+all_cuda_archs,
29 | cxx=['-O3'])
30 | )
31 | ],
32 | cmdclass = {
33 | 'build_ext': BuildExtension
34 | })
35 |
--------------------------------------------------------------------------------
/model/dino_layers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # from .dino_head import DINOHead
8 | from .mlp import Mlp
9 | # from .patch_embed import PatchEmbed
10 | # from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
11 | # from .block import NestedTensorBlock
12 | # from .attention import Attention, MemEffAttention
13 | from .efficient_attention import *
14 |
--------------------------------------------------------------------------------
/model/dino_layers/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10 |
11 | import logging
12 |
13 | from torch import Tensor
14 | from torch import nn
15 |
16 |
17 | logger = logging.getLogger("dinov2")
18 |
19 |
20 | try:
21 | from xformers.ops import memory_efficient_attention, unbind, fmha
22 |
23 | XFORMERS_AVAILABLE = True
24 | except ImportError:
25 | logger.warning("xFormers not available")
26 | XFORMERS_AVAILABLE = False
27 |
28 |
29 | class Attention(nn.Module):
30 | def __init__(
31 | self,
32 | dim: int,
33 | num_heads: int = 8,
34 | qkv_bias: bool = False,
35 | proj_bias: bool = True,
36 | attn_drop: float = 0.0,
37 | proj_drop: float = 0.0,
38 | ) -> None:
39 | super().__init__()
40 | self.num_heads = num_heads
41 | head_dim = dim // num_heads
42 | self.scale = head_dim**-0.5
43 |
44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45 | self.attn_drop = nn.Dropout(attn_drop)
46 | self.proj = nn.Linear(dim, dim, bias=proj_bias)
47 | self.proj_drop = nn.Dropout(proj_drop)
48 |
49 | def forward(self, x: Tensor) -> Tensor:
50 | B, N, C = x.shape
51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
52 |
53 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
54 | attn = q @ k.transpose(-2, -1)
55 |
56 | attn = attn.softmax(dim=-1)
57 | attn = self.attn_drop(attn)
58 |
59 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
60 | x = self.proj(x)
61 | x = self.proj_drop(x)
62 | return x
63 |
64 |
65 | class MemEffAttention(Attention):
66 | def forward(self, x: Tensor, attn_bias=None) -> Tensor:
67 | if not XFORMERS_AVAILABLE:
68 | assert attn_bias is None, "xFormers is required for nested tensors usage"
69 | return super().forward(x)
70 |
71 | B, N, C = x.shape
72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
73 |
74 | q, k, v = unbind(qkv, 2)
75 |
76 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
77 | x = x.reshape([B, N, C])
78 |
79 | x = self.proj(x)
80 | x = self.proj_drop(x)
81 | return x
82 |
--------------------------------------------------------------------------------
/model/dino_layers/block.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10 |
11 | import logging
12 | from typing import Callable, List, Any, Tuple, Dict
13 |
14 | import torch
15 | from torch import nn, Tensor
16 |
17 | from .attention import Attention, MemEffAttention
18 | from .drop_path import DropPath
19 | from .layer_scale import LayerScale
20 | from .mlp import Mlp
21 |
22 |
23 | logger = logging.getLogger("dinov2")
24 |
25 |
26 | try:
27 | from xformers.ops import fmha
28 | from xformers.ops import scaled_index_add, index_select_cat
29 |
30 | XFORMERS_AVAILABLE = True
31 | except ImportError:
32 | logger.warning("xFormers not available")
33 | XFORMERS_AVAILABLE = False
34 |
35 |
36 | class Block(nn.Module):
37 | def __init__(
38 | self,
39 | dim: int,
40 | num_heads: int,
41 | mlp_ratio: float = 4.0,
42 | qkv_bias: bool = False,
43 | proj_bias: bool = True,
44 | ffn_bias: bool = True,
45 | drop: float = 0.0,
46 | attn_drop: float = 0.0,
47 | init_values=None,
48 | drop_path: float = 0.0,
49 | act_layer: Callable[..., nn.Module] = nn.GELU,
50 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
51 | attn_class: Callable[..., nn.Module] = Attention,
52 | ffn_layer: Callable[..., nn.Module] = Mlp,
53 | ) -> None:
54 | super().__init__()
55 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
56 | self.norm1 = norm_layer(dim)
57 | self.attn = attn_class(
58 | dim,
59 | num_heads=num_heads,
60 | qkv_bias=qkv_bias,
61 | proj_bias=proj_bias,
62 | attn_drop=attn_drop,
63 | proj_drop=drop,
64 | )
65 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
66 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
67 |
68 | self.norm2 = norm_layer(dim)
69 | mlp_hidden_dim = int(dim * mlp_ratio)
70 | self.mlp = ffn_layer(
71 | in_features=dim,
72 | hidden_features=mlp_hidden_dim,
73 | act_layer=act_layer,
74 | drop=drop,
75 | bias=ffn_bias,
76 | )
77 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
78 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
79 |
80 | self.sample_drop_ratio = drop_path
81 |
82 | def forward(self, x: Tensor) -> Tensor:
83 | def attn_residual_func(x: Tensor) -> Tensor:
84 | return self.ls1(self.attn(self.norm1(x)))
85 |
86 | def ffn_residual_func(x: Tensor) -> Tensor:
87 | return self.ls2(self.mlp(self.norm2(x)))
88 |
89 | if self.training and self.sample_drop_ratio > 0.1:
90 | # the overhead is compensated only for a drop path rate larger than 0.1
91 | x = drop_add_residual_stochastic_depth(
92 | x,
93 | residual_func=attn_residual_func,
94 | sample_drop_ratio=self.sample_drop_ratio,
95 | )
96 | x = drop_add_residual_stochastic_depth(
97 | x,
98 | residual_func=ffn_residual_func,
99 | sample_drop_ratio=self.sample_drop_ratio,
100 | )
101 | elif self.training and self.sample_drop_ratio > 0.0:
102 | x = x + self.drop_path1(attn_residual_func(x))
103 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
104 | else:
105 | x = x + attn_residual_func(x)
106 | x = x + ffn_residual_func(x)
107 | return x
108 |
109 |
110 | def drop_add_residual_stochastic_depth(
111 | x: Tensor,
112 | residual_func: Callable[[Tensor], Tensor],
113 | sample_drop_ratio: float = 0.0,
114 | ) -> Tensor:
115 | # 1) extract subset using permutation
116 | b, n, d = x.shape
117 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
118 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
119 | x_subset = x[brange]
120 |
121 | # 2) apply residual_func to get residual
122 | residual = residual_func(x_subset)
123 |
124 | x_flat = x.flatten(1)
125 | residual = residual.flatten(1)
126 |
127 | residual_scale_factor = b / sample_subset_size
128 |
129 | # 3) add the residual
130 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
131 | return x_plus_residual.view_as(x)
132 |
133 |
134 | def get_branges_scales(x, sample_drop_ratio=0.0):
135 | b, n, d = x.shape
136 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
137 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
138 | residual_scale_factor = b / sample_subset_size
139 | return brange, residual_scale_factor
140 |
141 |
142 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
143 | if scaling_vector is None:
144 | x_flat = x.flatten(1)
145 | residual = residual.flatten(1)
146 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
147 | else:
148 | x_plus_residual = scaled_index_add(
149 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
150 | )
151 | return x_plus_residual
152 |
153 |
154 | attn_bias_cache: Dict[Tuple, Any] = {}
155 |
156 |
157 | def get_attn_bias_and_cat(x_list, branges=None):
158 | """
159 | this will perform the index select, cat the tensors, and provide the attn_bias from cache
160 | """
161 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
162 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
163 | if all_shapes not in attn_bias_cache.keys():
164 | seqlens = []
165 | for b, x in zip(batch_sizes, x_list):
166 | for _ in range(b):
167 | seqlens.append(x.shape[1])
168 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
169 | attn_bias._batch_sizes = batch_sizes
170 | attn_bias_cache[all_shapes] = attn_bias
171 |
172 | if branges is not None:
173 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
174 | else:
175 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
176 | cat_tensors = torch.cat(tensors_bs1, dim=1)
177 |
178 | return attn_bias_cache[all_shapes], cat_tensors
179 |
180 |
181 | def drop_add_residual_stochastic_depth_list(
182 | x_list: List[Tensor],
183 | residual_func: Callable[[Tensor, Any], Tensor],
184 | sample_drop_ratio: float = 0.0,
185 | scaling_vector=None,
186 | ) -> Tensor:
187 | # 1) generate random set of indices for dropping samples in the batch
188 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
189 | branges = [s[0] for s in branges_scales]
190 | residual_scale_factors = [s[1] for s in branges_scales]
191 |
192 | # 2) get attention bias and index+concat the tensors
193 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
194 |
195 | # 3) apply residual_func to get residual, and split the result
196 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
197 |
198 | outputs = []
199 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
200 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
201 | return outputs
202 |
203 |
204 | class NestedTensorBlock(Block):
205 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
206 | """
207 | x_list contains a list of tensors to nest together and run
208 | """
209 | assert isinstance(self.attn, MemEffAttention)
210 |
211 | if self.training and self.sample_drop_ratio > 0.0:
212 |
213 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
214 | return self.attn(self.norm1(x), attn_bias=attn_bias)
215 |
216 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
217 | return self.mlp(self.norm2(x))
218 |
219 | x_list = drop_add_residual_stochastic_depth_list(
220 | x_list,
221 | residual_func=attn_residual_func,
222 | sample_drop_ratio=self.sample_drop_ratio,
223 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
224 | )
225 | x_list = drop_add_residual_stochastic_depth_list(
226 | x_list,
227 | residual_func=ffn_residual_func,
228 | sample_drop_ratio=self.sample_drop_ratio,
229 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
230 | )
231 | return x_list
232 | else:
233 |
234 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
235 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
236 |
237 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
238 | return self.ls2(self.mlp(self.norm2(x)))
239 |
240 | attn_bias, x = get_attn_bias_and_cat(x_list)
241 | x = x + attn_residual_func(x, attn_bias=attn_bias)
242 | x = x + ffn_residual_func(x)
243 | return attn_bias.split(x)
244 |
245 | def forward(self, x_or_x_list):
246 | if isinstance(x_or_x_list, Tensor):
247 | return super().forward(x_or_x_list)
248 | elif isinstance(x_or_x_list, list):
249 | assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
250 | return self.forward_nested(x_or_x_list)
251 | else:
252 | raise AssertionError
253 |
--------------------------------------------------------------------------------
/model/dino_layers/dino_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn.init import trunc_normal_
10 | from torch.nn.utils import weight_norm
11 |
12 |
13 | class DINOHead(nn.Module):
14 | def __init__(
15 | self,
16 | in_dim,
17 | out_dim,
18 | use_bn=False,
19 | nlayers=3,
20 | hidden_dim=2048,
21 | bottleneck_dim=256,
22 | mlp_bias=True,
23 | ):
24 | super().__init__()
25 | nlayers = max(nlayers, 1)
26 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
27 | self.apply(self._init_weights)
28 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
29 | self.last_layer.weight_g.data.fill_(1)
30 |
31 | def _init_weights(self, m):
32 | if isinstance(m, nn.Linear):
33 | trunc_normal_(m.weight, std=0.02)
34 | if isinstance(m, nn.Linear) and m.bias is not None:
35 | nn.init.constant_(m.bias, 0)
36 |
37 | def forward(self, x):
38 | x = self.mlp(x)
39 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12
40 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
41 | x = self.last_layer(x)
42 | return x
43 |
44 |
45 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
46 | if nlayers == 1:
47 | return nn.Linear(in_dim, bottleneck_dim, bias=bias)
48 | else:
49 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
50 | if use_bn:
51 | layers.append(nn.BatchNorm1d(hidden_dim))
52 | layers.append(nn.GELU())
53 | for _ in range(nlayers - 2):
54 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
55 | if use_bn:
56 | layers.append(nn.BatchNorm1d(hidden_dim))
57 | layers.append(nn.GELU())
58 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
59 | return nn.Sequential(*layers)
60 |
--------------------------------------------------------------------------------
/model/dino_layers/drop_path.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
10 |
11 |
12 | from torch import nn
13 |
14 |
15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16 | if drop_prob == 0.0 or not training:
17 | return x
18 | keep_prob = 1 - drop_prob
19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
21 | if keep_prob > 0.0:
22 | random_tensor.div_(keep_prob)
23 | output = x * random_tensor
24 | return output
25 |
26 |
27 | class DropPath(nn.Module):
28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
29 |
30 | def __init__(self, drop_prob=None):
31 | super(DropPath, self).__init__()
32 | self.drop_prob = drop_prob
33 |
34 | def forward(self, x):
35 | return drop_path(x, self.drop_prob, self.training)
36 |
--------------------------------------------------------------------------------
/model/dino_layers/layer_scale.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
8 |
9 | from typing import Union
10 |
11 | import torch
12 | from torch import Tensor
13 | from torch import nn
14 |
15 |
16 | class LayerScale(nn.Module):
17 | def __init__(
18 | self,
19 | dim: int,
20 | init_values: Union[float, Tensor] = 1e-5,
21 | inplace: bool = False,
22 | ) -> None:
23 | super().__init__()
24 | self.inplace = inplace
25 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
26 |
27 | def forward(self, x: Tensor) -> Tensor:
28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma
29 |
--------------------------------------------------------------------------------
/model/dino_layers/mlp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
10 |
11 |
12 | from typing import Callable, Optional
13 |
14 | from torch import Tensor, nn
15 |
16 |
17 | class Mlp(nn.Module):
18 | def __init__(
19 | self,
20 | in_features: int,
21 | hidden_features: Optional[int] = None,
22 | out_features: Optional[int] = None,
23 | act_layer: Callable[..., nn.Module] = nn.GELU,
24 | drop: float = 0.0,
25 | bias: bool = True,
26 | ) -> None:
27 | super().__init__()
28 | out_features = out_features or in_features
29 | hidden_features = hidden_features or in_features
30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
31 | self.act = act_layer()
32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
33 | self.drop = nn.Dropout(drop)
34 |
35 | def forward(self, x: Tensor) -> Tensor:
36 | x = self.fc1(x)
37 | x = self.act(x)
38 | x = self.drop(x)
39 | x = self.fc2(x)
40 | x = self.drop(x)
41 | return x
42 |
--------------------------------------------------------------------------------
/model/dino_layers/patch_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10 |
11 | from typing import Callable, Optional, Tuple, Union
12 |
13 | from torch import Tensor
14 | import torch.nn as nn
15 |
16 |
17 | def make_2tuple(x):
18 | if isinstance(x, tuple):
19 | assert len(x) == 2
20 | return x
21 |
22 | assert isinstance(x, int)
23 | return (x, x)
24 |
25 |
26 | class PatchEmbed(nn.Module):
27 | """
28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
29 |
30 | Args:
31 | img_size: Image size.
32 | patch_size: Patch token size.
33 | in_chans: Number of input image channels.
34 | embed_dim: Number of linear projection output channels.
35 | norm_layer: Normalization layer.
36 | """
37 |
38 | def __init__(
39 | self,
40 | img_size: Union[int, Tuple[int, int]] = 224,
41 | patch_size: Union[int, Tuple[int, int]] = 16,
42 | in_chans: int = 3,
43 | embed_dim: int = 768,
44 | norm_layer: Optional[Callable] = None,
45 | flatten_embedding: bool = True,
46 | ) -> None:
47 | super().__init__()
48 |
49 | image_HW = make_2tuple(img_size)
50 | patch_HW = make_2tuple(patch_size)
51 | patch_grid_size = (
52 | image_HW[0] // patch_HW[0],
53 | image_HW[1] // patch_HW[1],
54 | )
55 |
56 | self.img_size = image_HW
57 | self.patch_size = patch_HW
58 | self.patches_resolution = patch_grid_size
59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1]
60 |
61 | self.in_chans = in_chans
62 | self.embed_dim = embed_dim
63 |
64 | self.flatten_embedding = flatten_embedding
65 |
66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
68 |
69 | def forward(self, x: Tensor) -> Tensor:
70 | _, _, H, W = x.shape
71 | patch_H, patch_W = self.patch_size
72 |
73 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
74 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
75 |
76 | x = self.proj(x) # B C H W
77 | H, W = x.size(2), x.size(3)
78 | x = x.flatten(2).transpose(1, 2) # B HW C
79 | x = self.norm(x)
80 | if not self.flatten_embedding:
81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C
82 | return x
83 |
84 | def flops(self) -> float:
85 | Ho, Wo = self.patches_resolution
86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
87 | if self.norm is not None:
88 | flops += Ho * Wo * self.embed_dim
89 | return flops
90 |
--------------------------------------------------------------------------------
/model/dino_layers/swiglu_ffn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Callable, Optional
8 |
9 | from torch import Tensor, nn
10 | import torch.nn.functional as F
11 |
12 |
13 | class SwiGLUFFN(nn.Module):
14 | def __init__(
15 | self,
16 | in_features: int,
17 | hidden_features: Optional[int] = None,
18 | out_features: Optional[int] = None,
19 | act_layer: Callable[..., nn.Module] = None,
20 | drop: float = 0.0,
21 | bias: bool = True,
22 | ) -> None:
23 | super().__init__()
24 | out_features = out_features or in_features
25 | hidden_features = hidden_features or in_features
26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
28 |
29 | def forward(self, x: Tensor) -> Tensor:
30 | x12 = self.w12(x)
31 | x1, x2 = x12.chunk(2, dim=-1)
32 | hidden = F.silu(x1) * x2
33 | return self.w3(hidden)
34 |
35 |
36 | try:
37 | from xformers.ops import SwiGLU
38 |
39 | XFORMERS_AVAILABLE = True
40 | except ImportError:
41 | SwiGLU = SwiGLUFFN
42 | XFORMERS_AVAILABLE = False
43 |
44 |
45 | class SwiGLUFFNFused(SwiGLU):
46 | def __init__(
47 | self,
48 | in_features: int,
49 | hidden_features: Optional[int] = None,
50 | out_features: Optional[int] = None,
51 | act_layer: Callable[..., nn.Module] = None,
52 | drop: float = 0.0,
53 | bias: bool = True,
54 | ) -> None:
55 | out_features = out_features or in_features
56 | hidden_features = hidden_features or in_features
57 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
58 | super().__init__(
59 | in_features=in_features,
60 | hidden_features=hidden_features,
61 | out_features=out_features,
62 | bias=bias,
63 | )
64 |
--------------------------------------------------------------------------------
/model/generalized_mean_pooling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | class GeM3D(nn.Module):
6 | def __init__(self, p=1, eps=1e-6):
7 | super(GeM3D,self).__init__()
8 | self.p = nn.Parameter(torch.ones(1) * p)
9 | self.eps = eps
10 |
11 | def forward(self, x):
12 | return self.gem(x.float(), p=self.p, eps=self.eps)
13 |
14 | def gem(self, x, p=1, eps=1e-6):
15 | return F.avg_pool3d(x.clamp(min=eps).pow(p), (x.size(-3), x.size(-2), x.size(-1))).pow(1./p)
16 |
17 | def __repr__(self):
18 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'
19 |
20 |
21 | class GeM2D(nn.Module):
22 | def __init__(self, p=3, eps=1e-6):
23 | super(GeM2D,self).__init__()
24 | self.p = nn.Parameter(torch.ones(1) * p)
25 | self.eps = eps
26 |
27 | def forward(self, x):
28 | return self.gem(x.float(), p=self.p, eps=self.eps)
29 |
30 | def gem(self, x, p=1, eps=1e-6):
31 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
32 |
33 | def __repr__(self):
34 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'
35 |
36 |
37 | class GeM1D(nn.Module):
38 | def __init__(self, p=3, channel_last=True, eps=1e-6):
39 | super(GeM1D, self).__init__()
40 | self.p = nn.Parameter(torch.ones(1) * p)
41 | self.eps = eps
42 | self.channel_last = channel_last
43 |
44 | def forward(self, x):
45 | return self.gem(x.float(), p=self.p, eps=self.eps)
46 |
47 | def gem(self, x, p=1, eps=1e-6):
48 | """
49 | x: (B, C, L) -> (B, C)
50 | """
51 | if self.channel_last:
52 | x = x.permute(0, 2, 1)
53 |
54 | # return F.avg_pool1d(x.clamp(min=eps), x.size(-1))
55 | return F.avg_pool1d(x.clamp(min=eps).pow(p), x.size(-1)).pow(1./p)
56 |
57 | def __repr__(self):
58 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'
59 |
60 |
--------------------------------------------------------------------------------
/model/position_encoding.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 |
6 | try:
7 | from .curope import cuRoPE2D
8 | RoPE2D = cuRoPE2D
9 | except ImportError:
10 | print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead')
11 |
12 | class RoPE2D(torch.nn.Module):
13 |
14 | def __init__(self, freq=100.0, F0=1.0):
15 | super().__init__()
16 | self.base = freq
17 | self.F0 = F0
18 | self.cache = {}
19 |
20 | def get_cos_sin(self, D, seq_len, device, dtype):
21 | if (D,seq_len,device,dtype) not in self.cache:
22 | inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
23 | t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
24 | freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
25 | freqs = torch.cat((freqs, freqs), dim=-1)
26 | cos = freqs.cos() # (Seq, Dim)
27 | sin = freqs.sin()
28 | self.cache[D,seq_len,device,dtype] = (cos,sin)
29 | return self.cache[D,seq_len,device,dtype]
30 |
31 | @staticmethod
32 | def rotate_half(x):
33 | x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
34 | return torch.cat((-x2, x1), dim=-1)
35 |
36 | def apply_rope1d(self, tokens, pos1d, cos, sin):
37 | assert pos1d.ndim==2
38 | cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
39 | sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
40 | return (tokens * cos) + (self.rotate_half(tokens) * sin)
41 |
42 | def forward(self, tokens, positions):
43 | """
44 | input:
45 | * tokens: batch_size x nheads x ntokens x dim
46 | * positions: batch_size x ntokens x 2 (y and x position of each token)
47 | output:
48 | * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
49 | """
50 | assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
51 | D = tokens.size(3) // 2
52 | assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
53 | cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype)
54 | # split features into two along the feature dimension, and apply rope1d on each half
55 | y, x = tokens.chunk(2, dim=-1)
56 | y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
57 | x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
58 | tokens = torch.cat((y, x), dim=-1)
59 | return tokens
60 |
61 |
62 | class PositionEncodingSine1D(nn.Module):
63 | def __init__(self, d_model, max_seq_len=20_000):
64 | super(PositionEncodingSine1D, self).__init__()
65 | self.d_model = d_model
66 | # Create positional encoding matrix
67 | pe = torch.zeros(max_seq_len, d_model)
68 | position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
69 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
70 | pe[:, 0::2] = torch.sin(position * div_term)
71 | pe[:, 1::2] = torch.cos(position * div_term)
72 | pe = pe.unsqueeze(0)
73 | self.register_buffer('pe', pe) # 1 x max_seq_len x d_model
74 |
75 | def forward(self, x):
76 | """
77 | x: BxLxC
78 | """
79 | x = x + self.pe[:, :x.size(1)]
80 | return x
81 |
82 |
83 | # Position encoding for query image
84 | class PositionEncodingSine2D(nn.Module):
85 | """
86 | This is a sinusoidal position encoding that generalized to 2-dimensional images
87 | """
88 |
89 | def __init__(self, d_model, max_shape=(1280, 960)):
90 | """
91 | Args:
92 | max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
93 | """
94 | super().__init__()
95 |
96 | max_shape = tuple(max_shape)
97 |
98 | pe = torch.zeros((d_model, *max_shape))
99 | y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
100 | x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
101 | div_term = torch.exp(
102 | torch.arange(0, d_model // 2, 2).float()
103 | * (-math.log(10000.0) / d_model // 2)
104 | )
105 | div_term = div_term[:, None, None] # [C//4, 1, 1]
106 | pe[0::4, :, :] = torch.sin(x_position * div_term)
107 | pe[1::4, :, :] = torch.cos(x_position * div_term)
108 | pe[2::4, :, :] = torch.sin(y_position * div_term)
109 | pe[3::4, :, :] = torch.cos(y_position * div_term)
110 |
111 | self.register_buffer("pe", pe.unsqueeze(0), persistent=False) # [1, C, H, W]
112 |
113 | def forward(self, x):
114 | """
115 | Args:
116 | x: [N, C, H, W]
117 | """
118 | return x + self.pe[:, :, :x.size(2), :x.size(3)]
119 |
120 |
121 | class PositionEncodingSine3D(nn.Module):
122 | """
123 | This is a sinusoidal position encoding that generalized to 3-dimensional cubes
124 | """
125 |
126 | def __init__(self, d_model, max_shape=(128, 128, 128)):
127 | """
128 | Args:
129 | max_shape (tuple): for DxHxW cube
130 | """
131 | super().__init__()
132 |
133 | assert(d_model % 6 == 0), "d_model must be divisible by 6 for 3D sinusoidal position encoding"
134 | max_shape = tuple(max_shape)
135 |
136 | pe = torch.zeros((d_model, *max_shape)) # CxDxHxW
137 | z_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
138 | y_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
139 | x_position = torch.ones(max_shape).cumsum(2).float().unsqueeze(0)
140 | div_term = torch.exp(
141 | torch.arange(0, d_model // 2, 3).float()
142 | * (-math.log(10000.0) / d_model // 2)
143 | )
144 | div_term = div_term[:, None, None, None] # [C//6, 1, 1, 1]
145 |
146 | pe[0::6, :, :, :] = torch.sin(x_position * div_term)
147 | pe[1::6, :, :, :] = torch.cos(x_position * div_term)
148 | pe[2::6, :, :, :] = torch.sin(y_position * div_term)
149 | pe[3::6, :, :, :] = torch.cos(y_position * div_term)
150 | pe[4::6, :, :, :] = torch.sin(z_position * div_term)
151 | pe[5::6, :, :, :] = torch.cos(z_position * div_term)
152 |
153 | self.register_buffer("pe", pe.unsqueeze(0), persistent=False) # [1, C, D, H, W]
154 |
155 | def forward(self, x):
156 | """
157 | Args:
158 | x: [N, C, D, H, W]
159 | """
160 | assert(x.dim() == 5), "Input must be 5-dimensional"
161 | return x + self.pe[:, :, :x.size(-3), :x.size(-2), :x.size(-1)]
162 |
163 | # Position encoding for 3D points
164 | class PositionEncodingLinear3D(nn.Module):
165 | """ Joint encoding of visual appearance and location using MLPs """
166 |
167 | def __init__(self, inp_dim, feature_dim, layers, norm_method="batchnorm"):
168 | super().__init__()
169 | self.encoder = self.MLP([inp_dim] + list(layers) + [feature_dim], norm_method)
170 |
171 | self.encoder
172 | nn.init.constant_(self.encoder[-1].bias, 0.0)
173 |
174 | def forward(self, kpts, descriptors):
175 | """
176 | kpts: B*L*3 or B*L*4
177 | descriptors: B*C*L
178 | """
179 | # inputs = kpts # B*L*3
180 |
181 | return descriptors + self.encoder(kpts).transpose(2, 1).expand_as(descriptors) # B*C*L
182 |
183 | def MLP(self, channels: list, norm_method="batchnorm"):
184 | """ Multi-layer perceptron"""
185 | n = len(channels)
186 | layers = []
187 | for i in range(1, n):
188 | layers.append(nn.Linear(channels[i - 1], channels[i], bias=True))
189 | if i < n - 1:
190 | if norm_method == "batchnorm":
191 | layers.append(nn.BatchNorm1d(channels[i]))
192 | elif norm_method == "layernorm":
193 | layers.append(nn.LayerNorm(channels[i]))
194 | elif norm_method == "instancenorm":
195 | layers.append(nn.InstanceNorm1d(channels[i]))
196 | else:
197 | raise NotImplementedError
198 | # layers.append(nn.GroupNorm(channels[i], channels[i])) # group norm
199 | layers.append(nn.ReLU())
200 | return nn.Sequential(*layers)
201 |
202 |
203 | # Position encoding for 3D points
204 | class KeypointEncoding_linear(nn.Module):
205 | """ Joint encoding of visual appearance and location using MLPs """
206 |
207 | def __init__(self, inp_dim, feature_dim, layers, norm_method="batchnorm"):
208 | super().__init__()
209 | self.encoder = self.MLP([inp_dim] + list(layers) + [feature_dim], norm_method)
210 | nn.init.constant_(self.encoder[-1].bias, 0.0)
211 |
212 | def forward(self, kpts, descriptors):
213 | """
214 | kpts: B*L*3 or B*L*4
215 | descriptors: B*C*L
216 | """
217 | # inputs = kpts # B*L*3
218 | return descriptors + self.encoder(kpts).transpose(2, 1) # B*C*L
219 |
220 | def MLP(self, channels: list, norm_method="batchnorm"):
221 | """ Multi-layer perceptron"""
222 | n = len(channels)
223 | layers = []
224 | for i in range(1, n):
225 | layers.append(nn.Linear(channels[i - 1], channels[i], bias=True))
226 | if i < n - 1:
227 | if norm_method == "batchnorm":
228 | layers.append(nn.BatchNorm1d(channels[i]))
229 | elif norm_method == "layernorm":
230 | layers.append(nn.LayerNorm(channels[i]))
231 | elif norm_method == "instancenorm":
232 | layers.append(nn.InstanceNorm1d(channels[i]))
233 | else:
234 | raise NotImplementedError
235 | # layers.append(nn.GroupNorm(channels[i], channels[i])) # group norm
236 | layers.append(nn.ReLU())
237 | return nn.Sequential(*layers)
--------------------------------------------------------------------------------
/three/__init__.py:
--------------------------------------------------------------------------------
1 | from .core import *
2 | from . import stats
3 | from . import quaternion
4 | from .rigid import *
5 | from .batchview import *
6 | from . import orientation
7 | from . import utils
8 |
--------------------------------------------------------------------------------
/three/batchview.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | @torch.jit.script
5 | def bvmm(a, b):
6 | if a.shape[0] != b.shape[0]:
7 | raise ValueError("batch dimension must match")
8 | if a.shape[1] != b.shape[1]:
9 | raise ValueError("view dimension must match")
10 |
11 | nbatch, nview, nrow, ncol = a.shape
12 | a = a.view(-1, nrow, ncol)
13 | b = b.view(-1, nrow, ncol)
14 | out = torch.bmm(a, b)
15 | out = out.view(nbatch, nview, out.shape[1], out.shape[2])
16 | return out
17 |
18 |
19 | def bv2b(x):
20 | if not x.is_contiguous():
21 | return x.reshape(-1, *x.shape[2:])
22 | return x.view(-1, *x.shape[2:])
23 |
24 |
25 | def b2bv(x, num_view=-1, batch_size=-1):
26 | if num_view == -1 and batch_size == -1:
27 | raise ValueError('One of num_view or batch_size must be non-negative.')
28 | return x.view(batch_size, num_view, *x.shape[1:])
29 |
30 |
31 | def vcat(tensors, batch_size):
32 | tensors = [b2bv(t, batch_size=batch_size) for t in tensors]
33 | return bv2b(torch.cat(tensors, dim=1))
34 |
35 |
36 | def vsplit(tensor, sections):
37 | num_view = sum(sections)
38 | tensor = b2bv(tensor, num_view=num_view)
39 | return tuple(bv2b(t) for t in torch.split(tensor, sections, dim=1))
40 |
--------------------------------------------------------------------------------
/three/core.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | @torch.jit.script
5 | def acos_safe(t, eps: float = 1e-7):
6 | return torch.acos(torch.clamp(t, min=-1.0 + eps, max=1.0 - eps))
7 |
8 |
9 | @torch.jit.script
10 | def ensure_batch_dim(tensor, num_dims: int):
11 | unsqueezed = False
12 | if len(tensor.shape) == num_dims:
13 | tensor = tensor.unsqueeze(0)
14 | unsqueezed = True
15 |
16 | return tensor, unsqueezed
17 |
18 |
19 | @torch.jit.script
20 | def normalize(vector, dim: int = -1):
21 | """
22 | Normalizes the vector to a unit vector using the p-norm.
23 | Args:
24 | vector (tensor): the vector to normalize of size (*, 3)
25 | p (int): the norm order to use
26 |
27 | Returns:
28 | (tensor): A unit normalized vector of size (*, 3)
29 | """
30 | return vector / torch.norm(vector, p=2.0, dim=dim, keepdim=True)
31 |
32 |
33 | @torch.jit.script
34 | def uniform(n: int, min_val: float, max_val: float):
35 | return (max_val - min_val) * torch.rand(n) + min_val
36 |
37 |
38 | def uniform_unit_vector(n):
39 | return normalize(torch.randn(n, 3), dim=1)
40 |
41 |
42 | def inner_product(a, b):
43 | return (a * b).sum(dim=-1)
44 |
45 |
46 | @torch.jit.script
47 | def homogenize(coords):
48 | ones = torch.ones_like(coords[..., 0, None])
49 | return torch.cat((coords, ones), dim=-1)
50 |
51 |
52 | @torch.jit.script
53 | def dehomogenize(coords):
54 | return coords[..., :coords.size(-1) - 1] / coords[..., -1, None]
55 |
56 |
57 | def transform_coord_grid(grid, transform):
58 | if transform.size(0) != grid.size(0):
59 | raise ValueError('Batch dimensions must match.')
60 |
61 | out_shape = (*grid.shape[:-1], transform.size(1))
62 |
63 | grid = homogenize(grid)
64 | coords = grid.view(grid.size(0), -1, grid.size(-1))
65 | coords = transform @ coords.transpose(1, 2)
66 | coords = coords.transpose(1, 2)
67 | return dehomogenize(coords.view(*out_shape))
68 |
69 |
70 | @torch.jit.script
71 | def transform_coords(coords, transform):
72 | coords, unsqueezed = ensure_batch_dim(coords, 2)
73 |
74 | coords = homogenize(coords)
75 | coords = transform @ coords.transpose(1, 2)
76 | coords = coords.transpose(1, 2)
77 | coords = dehomogenize(coords)
78 | if unsqueezed:
79 | coords = coords.squeeze(0)
80 |
81 | return coords
82 |
83 |
84 | @torch.jit.script
85 | def grid_to_coords(grid):
86 | return grid.view(grid.size(0), -1, grid.size(-1))
87 |
88 |
89 | def spherical_to_cartesian(theta, phi, r=1.0):
90 | x = r * torch.cos(theta) * torch.sin(phi)
91 | y = r * torch.sin(theta) * torch.sin(phi)
92 | z = r * torch.cos(theta)
93 | return torch.stack((x, y, z), dim=-1)
94 |
95 |
96 | def points_bound(points):
97 | min_dim = torch.min(points, dim=0)[0]
98 | max_dim = torch.max(points, dim=0)[0]
99 | return torch.stack((min_dim, max_dim), dim=1)
100 |
101 |
102 | def points_radius(points):
103 | bounds = points_bound(points)
104 | centroid = bounds.mean(dim=1).unsqueeze(0)
105 | max_radius = torch.norm(points - centroid, dim=1).max()
106 | return max_radius
107 |
108 |
109 | def points_diameter(points):
110 | return 2* points_radius(points)
111 |
112 |
113 | def points_centroid(points):
114 | return points_bound(points).mean(dim=1)
115 |
116 |
117 | def points_bounding_size(points):
118 | bounds = points_bound(points)
119 | return torch.norm(bounds[:, 1] - bounds[:, 0])
120 |
--------------------------------------------------------------------------------
/three/imutils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | from skimage import morphology
5 |
6 |
7 | def keep_largest_object(mask):
8 | labels, num_labels = morphology.label(mask, return_num=True)
9 | best_mask = None
10 | best_count = -1
11 | for i in range(1, num_labels + 1):
12 | cur_mask = (labels == i)
13 | cur_count = cur_mask.sum()
14 | if cur_count > best_count:
15 | best_mask = cur_mask
16 | best_count = cur_count
17 |
18 | if best_mask is None:
19 | return np.zeros_like(mask)
20 |
21 | return best_mask
22 |
23 |
24 | def mask_chroma(image, hue_min=(40, 65, 65), hue_max=(180, 255, 255),
25 | use_bgr=False):
26 | image_hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV if use_bgr else cv2.COLOR_RGB2HSV)
27 | hue_min = np.array(hue_min)
28 | hue_max = np.array(hue_max)
29 | mask = ~cv2.inRange(image_hsv, hue_min, hue_max)
30 | mask = morphology.binary_closing(mask, selem=morphology.disk(5))
31 | return mask
32 |
33 |
34 | def grabcut(image, fg_init_mask, bg_init_mask=None):
35 | if image.dtype == np.float32 or image.dtype == np.double:
36 | image = (image * 255.0).astype(np.uint8)
37 | # Initialize mask based on sparse pointcloud.
38 | mask = np.full(image.shape[:2], fill_value=cv2.GC_PR_BGD, dtype=np.uint8)
39 | mask[fg_init_mask] = cv2.GC_PR_FGD
40 | if bg_init_mask is not None:
41 | mask[bg_init_mask] = cv2.GC_BGD
42 |
43 | # Perform grab cut.
44 | bg_model = np.zeros((1, 65), np.float64)
45 | fg_model = np.zeros((1, 65), np.float64)
46 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
47 | cv2.grabCut(image, mask, None, bg_model, fg_model, 3, cv2.GC_INIT_WITH_MASK)
48 |
49 | # Post process mask.
50 | out_mask = np.where((mask == 2) | (mask == 0), 0, 1).astype('uint8')
51 | out_mask = morphology.binary_closing(out_mask, selem=morphology.disk(5))
52 |
53 | return out_mask
54 |
55 |
56 | def mean_color(image, mask):
57 | return (image * mask).sum(dim=(-2, -1)) / mask.sum(dim=(-2, -1))
58 |
59 |
60 | def dilate(labels, iters, kernel_size):
61 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
62 | morphed_labels = []
63 | for label in labels:
64 | morphed_labels.append(
65 | cv2.dilate(label.squeeze(0).numpy(), kernel, iterations=iters))
66 | return torch.tensor(np.stack(morphed_labels, axis=0),
67 | dtype=torch.float32).unsqueeze(1)
68 |
69 |
70 | def erode(labels, iters, kernel_size):
71 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
72 | morphed_labels = []
73 | for label in labels:
74 | morphed_labels.append(
75 | cv2.erode(label.squeeze(0).numpy(), kernel, iterations=iters))
76 | return torch.tensor(np.stack(morphed_labels, axis=0),
77 | dtype=torch.float32).unsqueeze(1)
78 |
--------------------------------------------------------------------------------
/three/meshutils.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | import numpy as np
4 | from scipy import linalg
5 |
6 | import trimesh
7 | import trimesh.remesh
8 | from trimesh.visual.material import SimpleMaterial
9 |
10 | EPS = 10e-10
11 |
12 |
13 | def compute_vertex_normals(vertices, faces):
14 | normals = np.ones_like(vertices)
15 | triangles = vertices[faces]
16 | triangle_normals = np.cross(triangles[:, 1] - triangles[:, 0],
17 | triangles[:, 2] - triangles[:, 0])
18 | triangle_normals /= (linalg.norm(triangle_normals, axis=1)[:, None] + EPS)
19 | normals[faces[:, 0]] += triangle_normals
20 | normals[faces[:, 1]] += triangle_normals
21 | normals[faces[:, 2]] += triangle_normals
22 | normals /= (linalg.norm(normals, axis=1)[:, None] + 0)
23 |
24 | return normals
25 |
26 |
27 | def are_trimesh_normals_corrupt(trimesh):
28 | corrupt_normals = linalg.norm(trimesh.vertex_normals, axis=1) == 0.0
29 | return corrupt_normals.sum() > 0
30 |
31 |
32 | def subdivide_mesh(mesh):
33 | attributes = {}
34 | if hasattr(mesh.visual, 'uv'):
35 | attributes = {'uv': mesh.visual.uv}
36 | vertices, faces, attributes = trimesh.remesh.subdivide(
37 | mesh.vertices, mesh.faces, attributes=attributes)
38 | mesh.vertices = vertices
39 | mesh.faces = faces
40 | if 'uv' in attributes:
41 | mesh.visual.uv = attributes['uv']
42 |
43 | return mesh
44 |
45 |
46 | class Object3D(object):
47 | """Represents a graspable object."""
48 |
49 | def __init__(self, path, load_materials=False):
50 | scene = trimesh.load(str(path))
51 | if isinstance(scene, trimesh.Trimesh):
52 | scene = trimesh.Scene(scene)
53 |
54 | self.meshes: typing.List[trimesh.Trimesh] = list(scene.dump())
55 |
56 | self.path = path
57 | self.scale = 1.0
58 |
59 | def to_scene(self):
60 | return trimesh.Scene(self.meshes)
61 |
62 | def are_normals_corrupt(self):
63 | for mesh in self.meshes:
64 | if are_trimesh_normals_corrupt(mesh):
65 | return True
66 |
67 | return False
68 |
69 | def recompute_normals(self):
70 | for mesh in self.meshes:
71 | mesh.vertex_normals = compute_vertex_normals(mesh.vertices, mesh.faces)
72 |
73 | return self
74 |
75 | def rescale(self, scale=1.0):
76 | """Set scale of object mesh.
77 |
78 | :param scale
79 | """
80 | self.scale = scale
81 | for mesh in self.meshes:
82 | mesh.apply_scale(self.scale)
83 |
84 | return self
85 |
86 | def resize(self, size, ref='diameter'):
87 | """Set longest of all three lengths in Cartesian space.
88 |
89 | :param size
90 | """
91 | if ref == 'diameter':
92 | ref_scale = self.bounding_diameter
93 | else:
94 | ref_scale = self.bounding_size
95 |
96 | self.scale = size / ref_scale
97 | for mesh in self.meshes:
98 | mesh.apply_scale(self.scale)
99 |
100 | return self
101 |
102 | @property
103 | def centroid(self):
104 | return self.bounds.mean(axis=0)
105 |
106 | @property
107 | def bounding_size(self):
108 | return max(self.extents)
109 |
110 | @property
111 | def bounding_diameter(self):
112 | centroid = self.bounds.mean(axis=0)
113 | max_radius = linalg.norm(self.vertices - centroid, axis=1).max()
114 | return max_radius * 2
115 |
116 | @property
117 | def bounding_radius(self):
118 | return self.bounding_diameter / 2.0
119 |
120 | @property
121 | def extents(self):
122 | min_dim = np.min(self.vertices, axis=0)
123 | max_dim = np.max(self.vertices, axis=0)
124 | return max_dim - min_dim
125 |
126 | @property
127 | def bounds(self):
128 | min_dim = np.min(self.vertices, axis=0)
129 | max_dim = np.max(self.vertices, axis=0)
130 | return np.stack((min_dim, max_dim), axis=0)
131 |
132 | def recenter(self, method='bounds'):
133 | if method == 'mean':
134 | # Center the mesh.
135 | vertex_mean = np.mean(self.vertices, 0)
136 | translation = -vertex_mean
137 | elif method == 'bounds':
138 | center = self.bounds.mean(axis=0)
139 | translation = -center
140 | else:
141 | raise ValueError(f"Unknown method {method!r}")
142 |
143 | for mesh in self.meshes:
144 | mesh.apply_translation(translation)
145 |
146 | return self
147 |
148 | @property
149 | def vertices(self):
150 | return np.concatenate([mesh.vertices for mesh in self.meshes])
151 |
--------------------------------------------------------------------------------
/three/orientation.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 |
5 | # from latentfusion import three
6 | # from latentfusion.three import quaternion as q
7 |
8 |
9 | import three
10 | from three import quaternion as q
11 |
12 |
13 | def spiral_orbit(n, c=16):
14 | phi = torch.linspace(0, math.pi, n)
15 | theta = c * phi
16 | rot_quat = q.from_spherical(phi, theta)
17 | return rot_quat
18 |
19 |
20 | def _check_up(up, n):
21 | if not torch.is_tensor(up):
22 | up = torch.tensor(up, dtype=torch.float32)
23 | if len(up.shape) == 1:
24 | up = up.expand(n, -1)
25 | return three.normalize(up)
26 |
27 |
28 | def _is_ray_in_segment(ray, up, min_angle, max_angle):
29 | angle = torch.acos(three.inner_product(up, ray))
30 | return (min_angle <= angle) & (angle <= max_angle)
31 |
32 |
33 | def sample_segment_rays(n, up, min_angle, max_angle):
34 | up = _check_up(up, n)
35 |
36 | # Sample random new 'up' orientation.
37 | rays = three.normalize(torch.randn(n, 3))
38 | num_invalid = n
39 | while num_invalid > 0:
40 | valid = _is_ray_in_segment(rays, up, min_angle, max_angle)
41 | num_invalid = (~valid).sum().item()
42 | rays[~valid] = three.normalize(torch.randn(num_invalid, 3))
43 |
44 | return three.normalize(rays)
45 |
46 |
47 | def sample_hemisphere_rays(n, up):
48 | """
49 | Samples a ray in the upper hemisphere (defined by `up`).
50 |
51 | Implemented by sampling a uniform random ray on the unit sphere and reflecting
52 | the vector to be on the same side as the up vector.
53 |
54 | Args:
55 | n (int): number of rays to sample
56 | up (tensor, tuple): the up direction defining the hemisphere
57 |
58 | Returns:
59 | (tensor): `n` rays uniformly sampled on the hemisphere
60 |
61 | """
62 | up = _check_up(up, n)
63 |
64 | # Sample random new 'up' orientation.
65 | rays = three.normalize(torch.randn(n, 3))
66 |
67 | # Reflect to upper hemisphere.
68 | dot = (up * rays).sum(dim=-1)
69 | rays[dot < 0] = rays[dot < 0] - 2 * dot[dot < 0, None] * up[dot < 0]
70 |
71 | return rays
72 |
73 |
74 | def random_quat_from_ray(forward, up=None):
75 | """
76 | Sample uniformly random quaternions that orients the camera forward direction.
77 |
78 | Args:
79 | forward: a vector representing the forward direction.
80 |
81 | Returns:
82 |
83 | """
84 | n = forward.shape[0]
85 | if up is None:
86 | down = three.uniform_unit_vector(n)
87 | else:
88 | up = torch.tensor(up).unsqueeze(0).expand(n, 3)
89 | up = up + forward
90 | down = -up
91 | right = three.normalize(torch.cross(down, forward))
92 | down = three.normalize(torch.cross(forward, right))
93 |
94 | mat = torch.stack([right, down, forward], dim=1)
95 |
96 | return three.quaternion.mat_to_quat(mat)
97 |
98 |
99 | def sample_segment_quats(n, up, min_angle, max_angle):
100 | """
101 | Sample a quaternion where the resulting `up` direction is constrained to a segment of the sphere.
102 |
103 | This is performed by first sampling a 'yaw' angle and then sampling a random 'up' direction in the segment.
104 |
105 | The sphere segment is defined as being [min_angle,max_angle] radians away from the 'up' direction.
106 |
107 | Args:
108 | n (int): number of rays to sample
109 | up (tensor, tuple): the up direction defining sphere segment
110 | min_angle (float): the min angle from the up direction defining the sphere segment
111 | max_angle (float): the max angle from the up direction defining the sphere segment
112 |
113 | Returns:
114 | (tensor): a batch of sampled quaternions
115 | """
116 | up = _check_up(up, n)
117 |
118 | yaw_angle = torch.rand(n) * math.pi * 2.0
119 | yaw_quat = q.from_axis_angle(up, yaw_angle)
120 |
121 | rays = sample_segment_rays(n, up, min_angle, max_angle)
122 |
123 | pivot = torch.cross(up, rays)
124 | angles = torch.acos(three.inner_product(up, rays))
125 | quat = q.from_axis_angle(pivot, angles)
126 |
127 | return q.qmul(quat, yaw_quat)
128 |
129 |
130 | def evenly_distributed_points(n: int, hemisphere=False, pole=(0.0, 0.0, 1.0)):
131 | """
132 | Uses the sunflower method to sample points on a sphere that are
133 | roughly evenly distributed.
134 |
135 | Reference:
136 | https://stackoverflow.com/questions/9600801/evenly-distributing-n-points-on-a-sphere/44164075#44164075
137 | """
138 | indices = torch.arange(0, n, dtype=torch.float32) + 0.5
139 |
140 | if hemisphere:
141 | phi = torch.acos(1 - 2 * indices / n / 2)
142 | else:
143 | phi = torch.acos(1 - 2 * indices / n)
144 | theta = math.pi * (1 + 5 ** 0.5) * indices
145 |
146 | points = torch.stack([
147 | torch.cos(theta) * torch.sin(phi),
148 | torch.sin(theta) * torch.sin(phi),
149 | torch.cos(phi),
150 | ], dim=1)
151 |
152 | if hemisphere:
153 | default_pole = torch.tensor([(0.0, 0.0, 1.0)]).expand(n, 3)
154 | pole = torch.tensor([pole]).expand(n, 3)
155 | if (default_pole[0] + pole[0]).abs().sum() < 1e-5:
156 | # If the pole is the opposite side just flip.
157 | points = -points
158 | elif (default_pole[0] - pole[0]).abs().sum() < 1e-5:
159 | points = points
160 | else:
161 | # Otherwise take the cross product as the rotation axis.
162 | rot_axis = torch.cross(pole, default_pole)
163 | rot_angle = torch.acos(three.inner_product(pole, default_pole))
164 | rot_quat = three.quaternion.from_axis_angle(rot_axis, rot_angle)
165 | points = three.quaternion.rotate_vector(rot_quat, points)
166 |
167 | return points
168 |
169 |
170 | def evenly_distributed_quats(n: int, hemisphere=False, hemisphere_pole=(0.0, 0.0, 1.0),
171 | upright=False, upright_up=(0.0, 0.0, 1.0)):
172 | rays = evenly_distributed_points(n, hemisphere, hemisphere_pole)
173 | return random_quat_from_ray(-rays, upright_up if upright else None)
174 |
175 |
176 | @torch.jit.script
177 | def disk_sample_quats(n: int, min_angle: float, max_tries: int = 64):
178 |
179 | quats = q.random(1)
180 |
181 | num_tries = 0
182 | while quats.shape[0] < n:
183 | new_quat = q.random(1)
184 | angles = q.angular_distance(quats, new_quat)
185 | if torch.all(angles >= min_angle) or num_tries > max_tries:
186 | quats = torch.cat((quats, new_quat), dim=0)
187 | num_tries = 0
188 | else:
189 | num_tries += 1
190 |
191 | return quats
192 |
--------------------------------------------------------------------------------
/three/pytorch3d_rendering.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def batchfy_mv_meshes(meshes, num_views, batch_size=1):
4 | from pytorch3d import structures as py3d_struct
5 | mv_meshes = list()
6 | for mesh in meshes.split([1 for _ in range(batch_size)]):
7 | mv_meshes.append(mesh.extend(num_views))
8 | mv_meshes = py3d_struct.join_meshes_as_batch(mv_meshes)
9 | return mv_meshes
10 |
11 | def generate_mesh_model_from_3Dbbox(bbox3D):
12 | from pytorch3d import structures as py3d_struct
13 | bbox3D = bbox3D.squeeze()
14 | assert(bbox3D.dim() == 2 and bbox3D.shape[1] == 3), bbox3D.shape
15 | if len(bbox3D) == 2: # 2x3
16 | xmin, ymin, zmin = bbox3D[0]
17 | xmax, ymax, zmax = bbox3D[1]
18 | bbox3D_corners = torch.tensor([
19 | [xmin, ymin, zmin],
20 | [xmax, ymin, zmin],
21 | [xmax, ymax, zmin],
22 | [xmin, ymax, zmin],
23 | [xmin, ymin, zmax],
24 | [xmax, ymin, zmax],
25 | [xmax, ymax, zmax],
26 | [xmin, ymax, zmax]
27 | ])
28 | elif len(bbox3D) == 8: # 8x3
29 | bbox3D_corners = bbox3D
30 | else:
31 | print(bbox3D)
32 | raise NotImplementedError
33 |
34 | assert(bbox3D_corners.dim() == 2 and len(bbox3D_corners) == 8), bbox3D_corners.shape
35 | bbox3D_faces = torch.tensor([
36 | [0, 1, 2, 3],
37 | [4, 5, 6, 7],
38 | [0, 1, 5, 4],
39 | [1, 2, 6, 5],
40 | [2, 3, 7, 6],
41 | [3, 0, 4, 7],])
42 | # Convert quadrilateral faces to triangles
43 | bbox3D_faces = torch.cat([
44 | bbox3D_faces[:, [0, 1, 2]],
45 | bbox3D_faces[:, [2, 3, 0]]
46 | ], dim=0)
47 |
48 | # Create the mesh
49 | bbox3D_mesh = py3d_struct.Meshes(verts=[bbox3D_corners], faces=[bbox3D_faces])
50 | return bbox3D_mesh
51 |
52 | def compute_ROI_camera_intrinsic(camK, img_size, camera_dist=None, scaled_T=None,
53 | bbox_center=None, bbox_scale=None):
54 | """
55 | camK: Bx3x3
56 | scaled_T: Bx3
57 | bbox_center: Bx2
58 | bbox_scale: B
59 | camera_dist: B
60 | """
61 | squeeze = False
62 | if not isinstance(camK, torch.Tensor):
63 | camK = torch.tensor(camK)
64 |
65 | if scaled_T is not None and not isinstance(scaled_T, torch.Tensor):
66 | scaled_T = torch.tensor(scaled_T)
67 | if bbox_center is not None and not isinstance(bbox_center, torch.Tensor):
68 | bbox_center = torch.tensor(bbox_center)
69 | if bbox_scale is not None and not isinstance(bbox_scale, torch.Tensor):
70 | bbox_scale = torch.tensor(bbox_scale)
71 |
72 | device = camK.device
73 | assert(scaled_T is not None or bbox_center is not None)
74 | if scaled_T is not None:
75 | assert(camera_dist is not None)
76 | if scaled_T.dim() == 1:
77 | squeeze = True
78 | scaled_T = scaled_T[None, :]
79 | assert(scaled_T.dim() == 2), scaled_T.shape
80 | obj_2D_center = torch.einsum('bij,bj->bi', camK, scaled_T.to(device))
81 | bbox_center = obj_2D_center[:, :2] / obj_2D_center[:, 2:3] #
82 | bbox_scale = camera_dist * img_size / scaled_T[:, 2].to(device)
83 | # print(bbox_center, bbox_scale)
84 | elif bbox_center is not None:
85 | assert(bbox_scale is not None)
86 |
87 | if bbox_center.dim() == 1:
88 | bbox_center = bbox_center[None, :]
89 | squeeze = True
90 | if bbox_scale.dim() == 0:
91 | bbox_scale = bbox_scale[None]
92 | if camK.dim() == 2:
93 | camK = camK[None, :, :]
94 | squeeze = True
95 |
96 | assert(bbox_center.dim() == 2 and bbox_scale.dim() == 1), (bbox_center.shape, bbox_scale.shape)
97 |
98 | bbox_x1y1 = bbox_center - bbox_scale[:, None] / 2
99 | bbox_rescaling_factor = img_size / bbox_scale
100 | T_cam2roi = torch.eye(3)[None, :, :].repeat(len(bbox_rescaling_factor), 1, 1).to(device)
101 | T_cam2roi[:, :2, 2] = -bbox_x1y1.to(device)
102 | T_cam2roi[:, :2, :] *= bbox_rescaling_factor[:, None, None].to(device)
103 | new_camK = torch.einsum('bij,bjk->bik', T_cam2roi, camK)
104 | if squeeze:
105 | new_camK = new_camK.squeeze(0)
106 | return new_camK
107 |
108 | def generate_3D_coordinate_map_from_depth(depth, camK, obj_RT):
109 | if depth.squeeze().dim() == 2:
110 | depth = depth.squeeze()[None, None, :, :]
111 | elif depth.squeeze().dim() == 3:
112 | depth = depth.squeeze()[:, None, :, :]
113 |
114 | if obj_RT.dim() == 2:
115 | obj_RT = obj_RT[None, :, :]
116 | if camK.dim() == 2:
117 | camK = camK[None, :, :]
118 |
119 | if len(camK) != len(obj_RT):
120 | camK = camK.repeat(len(obj_RT), 1, 1)
121 |
122 | assert(depth.size(0) == obj_RT.size(0))
123 | assert(len(depth) == len(obj_RT)), (depth.shape, obj_RT.shape)
124 |
125 | device = depth.device
126 | im_hei, im_wid = depth.shape[-2:]
127 | YY, XX = torch.meshgrid(torch.arange(im_hei), torch.arange(im_wid), indexing='ij')
128 | XYZ_map = torch.stack([XX, YY, torch.ones_like(XX)], dim=0).to(device) # 3xHxW
129 | XYZ_map = XYZ_map[None, :, :, :] * depth # 1x3xHxW, Bx1xHxW
130 |
131 | XYZ_map = torch.einsum('bij,bjhw->bihw', torch.inverse(camK).to(device), XYZ_map)
132 | Rs = obj_RT[:, :3, :3].to(device)
133 | Ts = obj_RT[:, :3, 3].to(device)
134 |
135 | XYZ_map = torch.einsum('bij,bjhw->bihw', torch.inverse(Rs), XYZ_map - Ts[:, :, None, None])
136 |
137 | return XYZ_map
138 |
139 | def render_depth_from_mesh_model(mesh, obj_RT, camK, img_hei, img_wid, return_coordinate_map=False):
140 | """
141 | Pytorch3D: K_4x4 = [
142 | [fx, 0, px, 0],
143 | [0, fy, py, 0],
144 | [0, 0, 0, 1],
145 | [0, 0, 1, 0],
146 | ]
147 | """
148 | from pytorch3d import renderer as py3d_renderer
149 | from pytorch3d import transforms as py3d_transform
150 | from pytorch3d.transforms import euler_angles_to_matrix
151 |
152 | device = obj_RT.device
153 | if obj_RT.dim() == 2:
154 | obj_RT = obj_RT[None, :, :]
155 |
156 | if camK.dim() == 2:
157 | camK = camK[None, :, :]
158 |
159 | if len(mesh) != len(obj_RT):
160 | mesh = mesh.extend(len(obj_RT))
161 |
162 | assert(len(mesh) == len(obj_RT)), (len(mesh), obj_RT.shape)
163 |
164 | Rz_mat = torch.eye(4).to(device)
165 | Rz_mat[:3, :3] = py3d_transform.euler_angles_to_matrix(torch.as_tensor([0, 0, torch.pi]), 'XYZ')
166 | py3d_RT = torch.einsum('ij,bjk->bik', Rz_mat, obj_RT)
167 | cam_R = py3d_RT[:, :3, :3].transpose(-2, -1)
168 | cam_T = py3d_RT[:, :3, 3]
169 |
170 | fxfy = torch.stack([camK[:, 0, 0], camK[:, 1, 1]], dim=1)
171 | pxpy = torch.stack([camK[:, 0, 2], camK[:, 1, 2]], dim=1)
172 | cameras = py3d_renderer.PerspectiveCameras(R=cam_R,
173 | T=cam_T,
174 | image_size=((img_hei, img_wid),),
175 | focal_length=fxfy,
176 | principal_point=pxpy,
177 | in_ndc=False,
178 | device=device)
179 | # Define rasterizer settings
180 | raster_settings = py3d_renderer.RasterizationSettings(
181 | image_size=(img_hei, img_wid),
182 | blur_radius=0.0,
183 | faces_per_pixel=1,
184 | bin_size=0,
185 | )
186 |
187 | rasterizer = py3d_renderer.MeshRasterizer(
188 | cameras=cameras,
189 | raster_settings=raster_settings,
190 | )
191 |
192 | fragments = rasterizer(mesh.to(device))
193 | depth_map = fragments.zbuf[..., 0].unsqueeze(1) # Bx1xHxW
194 | depth_mask = torch.zeros_like(depth_map)
195 | depth_mask[depth_map>0] = 1
196 | depth_map *= depth_mask
197 |
198 | if return_coordinate_map:
199 | XYZ_map = generate_3D_coordinate_map_from_depth(depth_map, camK, obj_RT)
200 | return depth_map, depth_mask, XYZ_map * depth_mask
201 |
202 | return depth_map, depth_mask
203 |
204 |
--------------------------------------------------------------------------------
/three/quaternion.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch.nn import functional as F
5 | from . import core
6 |
7 | __all__ = ['normalize', 'quat_to_mat']
8 |
9 |
10 | def identity(n: int, device: str = 'cpu'):
11 | return torch.tensor((1.0, 0.0, 0.0, 0.0), device=device).view(1, 4).expand(n, 4)
12 |
13 |
14 | def normalize(quaternion: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
15 | r"""Normalizes a quaternion.
16 | The quaternion should be in (x, y, z, w) format.
17 |
18 | Args:
19 | quaternion (torch.Tensor): a tensor containing a quaternion to be
20 | normalized. The tensor can be of shape :math:`(*, 4)`.
21 | eps (Optional[bool]): small value to avoid division by zero.
22 | Default: 1e-12.
23 |
24 | Return:
25 | torch.Tensor: the normalized quaternion of shape :math:`(*, 4)`.
26 |
27 | """
28 | if not isinstance(quaternion, torch.Tensor):
29 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(
30 | type(quaternion)))
31 |
32 | if not quaternion.shape[-1] == 4:
33 | raise ValueError(
34 | "Input must be a tensor of shape (*, 4). Got {}".format(
35 | quaternion.shape))
36 | return F.normalize(quaternion, p=2.0, dim=-1, eps=eps)
37 |
38 |
39 | def quat_to_mat(quaternion: torch.Tensor) -> torch.Tensor:
40 | """
41 | Converts a quaternion to a rotation matrix.
42 | The quaternion should be in (w, x, y, z) format.
43 | Adapted from:
44 | https://github.com/kornia/kornia/blob/d729d7c4357ca73e4915a42285a0771bca4436ce/kornia/geometry/conversions.py#L235
45 | Args:
46 | quaternion (torch.Tensor): a tensor containing a quaternion to be
47 | converted. The tensor can be of shape :math:`(*, 4)`.
48 | Return:
49 | torch.Tensor: the rotation matrix of shape :math:`(*, 3, 3)`.
50 | Example:
51 | >>> quaternion = torch.tensor([0., 0., 1., 0.])
52 | >>> quat_to_mat(quaternion)
53 | tensor([[[-1., 0., 0.],
54 | [ 0., -1., 0.],
55 | [ 0., 0., 1.]]])
56 | """
57 | quaternion, unsqueezed = core.ensure_batch_dim(quaternion, 1)
58 |
59 | if not quaternion.shape[-1] == 4:
60 | raise ValueError(
61 | "Input must be a tensor of shape (*, 4). Got {}".format(
62 | quaternion.shape))
63 | # normalize the input quaternion
64 | quaternion_norm = normalize(quaternion)
65 |
66 | # unpack the normalized quaternion components
67 | w, x, y, z = torch.chunk(quaternion_norm, chunks=4, dim=-1)
68 |
69 | # compute the actual conversion
70 | tx: torch.Tensor = 2.0 * x
71 | ty: torch.Tensor = 2.0 * y
72 | tz: torch.Tensor = 2.0 * z
73 | twx: torch.Tensor = tx * w
74 | twy: torch.Tensor = ty * w
75 | twz: torch.Tensor = tz * w
76 | txx: torch.Tensor = tx * x
77 | txy: torch.Tensor = ty * x
78 | txz: torch.Tensor = tz * x
79 | tyy: torch.Tensor = ty * y
80 | tyz: torch.Tensor = tz * y
81 | tzz: torch.Tensor = tz * z
82 | one: torch.Tensor = torch.tensor(1.)
83 |
84 | matrix: torch.Tensor = torch.stack([
85 | one - (tyy + tzz), txy - twz, txz + twy,
86 | txy + twz, one - (txx + tzz), tyz - twx,
87 | txz - twy, tyz + twx, one - (txx + tyy)
88 | ], dim=-1).view(-1, 3, 3)
89 |
90 | if unsqueezed:
91 | matrix = matrix.squeeze(0)
92 |
93 | return matrix
94 |
95 |
96 | def mat_to_quat(rotation_matrix: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
97 | """
98 | Convert 3x3 rotation matrix to 4d quaternion vector.
99 | The quaternion vector has components in (w, x, y, z) format.
100 | Adapted From:
101 | https://github.com/kornia/kornia/blob/d729d7c4357ca73e4915a42285a0771bca4436ce/kornia/geometry/conversions.py#L235
102 | Args:
103 | rotation_matrix (torch.Tensor): the rotation matrix to convert.
104 | eps (float): small value to avoid zero division. Default: 1e-8.
105 | Return:
106 | torch.Tensor: the rotation in quaternion.
107 | Shape:
108 | - Input: :math:`(*, 3, 3)`
109 | - Output: :math:`(*, 4)`
110 | """
111 | rotation_matrix, unsqueezed = core.ensure_batch_dim(rotation_matrix, 2)
112 |
113 | if not isinstance(rotation_matrix, torch.Tensor):
114 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(
115 | type(rotation_matrix)))
116 |
117 | if not rotation_matrix.shape[-2:] == (3, 3):
118 | raise ValueError(
119 | "Input size must be a (*, 3, 3) tensor. Got {}".format(
120 | rotation_matrix.shape))
121 |
122 | def safe_zero_division(numerator: torch.Tensor,
123 | denominator: torch.Tensor) -> torch.Tensor:
124 | eps = torch.finfo(numerator.dtype).tiny
125 | return numerator / torch.clamp(denominator, min=eps)
126 |
127 | if not rotation_matrix.is_contiguous():
128 | rotation_matrix_vec: torch.Tensor = rotation_matrix.reshape(
129 | *rotation_matrix.shape[:-2], 9)
130 | else:
131 | rotation_matrix_vec: torch.Tensor = rotation_matrix.view(
132 | *rotation_matrix.shape[:-2], 9)
133 |
134 | m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.chunk(
135 | rotation_matrix_vec, chunks=9, dim=-1)
136 |
137 | trace: torch.Tensor = m00 + m11 + m22
138 |
139 | def trace_positive_cond():
140 | sq = torch.sqrt(trace + 1.0) * 2. # sq = 4 * qw.
141 | qw = 0.25 * sq
142 | qx = safe_zero_division(m21 - m12, sq)
143 | qy = safe_zero_division(m02 - m20, sq)
144 | qz = safe_zero_division(m10 - m01, sq)
145 | return torch.cat([qw, qx, qy, qz], dim=-1)
146 |
147 | def cond_1():
148 | sq = torch.sqrt(1.0 + m00 - m11 - m22 + eps) * 2. # sq = 4 * qx.
149 | qw = safe_zero_division(m21 - m12, sq)
150 | qx = 0.25 * sq
151 | qy = safe_zero_division(m01 + m10, sq)
152 | qz = safe_zero_division(m02 + m20, sq)
153 | return torch.cat([qw, qx, qy, qz], dim=-1)
154 |
155 | def cond_2():
156 | sq = torch.sqrt(1.0 + m11 - m00 - m22 + eps) * 2. # sq = 4 * qy.
157 | qw = safe_zero_division(m02 - m20, sq)
158 | qx = safe_zero_division(m01 + m10, sq)
159 | qy = 0.25 * sq
160 | qz = safe_zero_division(m12 + m21, sq)
161 | return torch.cat([qw, qx, qy, qz], dim=-1)
162 |
163 | def cond_3():
164 | sq = torch.sqrt(1.0 + m22 - m00 - m11 + eps) * 2. # sq = 4 * qz.
165 | qw = safe_zero_division(m10 - m01, sq)
166 | qx = safe_zero_division(m02 + m20, sq)
167 | qy = safe_zero_division(m12 + m21, sq)
168 | qz = 0.25 * sq
169 | return torch.cat([qw, qx, qy, qz], dim=-1)
170 |
171 | where_2 = torch.where(m11 > m22, cond_2(), cond_3())
172 | where_1 = torch.where((m00 > m11) & (m00 > m22), cond_1(), where_2)
173 |
174 | quaternion: torch.Tensor = torch.where(
175 | trace > 0., trace_positive_cond(), where_1)
176 |
177 | if unsqueezed:
178 | quaternion = quaternion.squeeze(0)
179 |
180 | return quaternion
181 |
182 |
183 | @torch.jit.script
184 | def random(k: int = 1, device: str = 'cpu'):
185 | """Return uniform random unit quaternion.
186 | rand: array like or None
187 | Three independent random variables that are uniformly distributed
188 | between 0 and 1.
189 |
190 | """
191 | rand = torch.rand(k, 3, device=device)
192 | r1 = torch.sqrt(1.0 - rand[:, 0])
193 | r2 = torch.sqrt(rand[:, 0])
194 | pi2 = math.pi * 2.0
195 | t1 = pi2 * rand[:, 1]
196 | t2 = pi2 * rand[:, 2]
197 |
198 | return torch.stack([
199 | torch.cos(t2) * r2,
200 | torch.sin(t1) * r1,
201 | torch.cos(t1) * r1,
202 | torch.sin(t2) * r2
203 | ], dim=1)
204 |
205 |
206 | def qmul(q1, q2):
207 | """
208 | Quaternion multiplication.
209 |
210 | Use the Hamilton product to perform quaternion multiplication.
211 |
212 | References:
213 | http://en.wikipedia.org/wiki/Quaternions#Hamilton_product
214 | https://github.com/matthew-brett/transforms3d/blob/master/transforms3d/quaternions.py
215 | """
216 | assert q1.shape[-1] == 4
217 | assert q2.shape[-1] == 4
218 |
219 | ham_prod = torch.bmm(q2.view(-1, 4, 1), q1.view(-1, 1, 4))
220 |
221 | w = ham_prod[:, 0, 0] - ham_prod[:, 1, 1] - ham_prod[:, 2, 2] - ham_prod[:, 3, 3]
222 | x = ham_prod[:, 0, 1] + ham_prod[:, 1, 0] - ham_prod[:, 2, 3] + ham_prod[:, 3, 2]
223 | y = ham_prod[:, 0, 2] + ham_prod[:, 1, 3] + ham_prod[:, 2, 0] - ham_prod[:, 3, 1]
224 | z = ham_prod[:, 0, 3] - ham_prod[:, 1, 2] + ham_prod[:, 2, 1] + ham_prod[:, 3, 0]
225 |
226 | return torch.stack((w, x, y, z), dim=1).view(q1.shape)
227 |
228 |
229 | def rotate_vector(quat, vector):
230 | """
231 | References:
232 | https://github.com/matthew-brett/transforms3d/blob/master/transforms3d/quaternions.py#L419
233 | """
234 | assert quat.shape[-1] == 4
235 | assert vector.shape[-1] == 3
236 | assert quat.shape[:-1] == vector.shape[:-1]
237 |
238 | original_shape = list(vector.shape)
239 | quat = quat.view(-1, 4)
240 | vector = vector.view(-1, 3)
241 |
242 | pure_quat = quat[:, 1:]
243 | uv = torch.cross(pure_quat, vector, dim=1)
244 | uuv = torch.cross(pure_quat, uv, dim=1)
245 | return (vector + 2 * (quat[:, :1] * uv + uuv)).view(original_shape)
246 |
247 |
248 | def from_spherical(theta, phi, r=1.0):
249 | x = torch.cos(theta) * torch.sin(phi)
250 | y = torch.sin(theta) * torch.sin(phi)
251 | z = r * torch.cos(phi)
252 | w = torch.zeros_like(x)
253 |
254 | return torch.stack((w, x, y, z), dim=-1)
255 |
256 |
257 | def from_axis_angle(axis, angle):
258 | """
259 | Compute a quaternion from the axis angle representation.
260 |
261 | Reference:
262 | https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation
263 | Args:
264 | axis: axis to rotate about
265 | angle: angle to rotate by
266 |
267 | Returns:
268 | Tensor of shape (*, 4) representing a quaternion.
269 | """
270 | if torch.is_tensor(axis) and isinstance(angle, float):
271 | angle = torch.tensor(angle, dtype=axis.dtype, device=axis.device)
272 | angle = angle.expand(axis.shape[0])
273 |
274 | axis = axis / torch.norm(axis, dim=-1, keepdim=True)
275 |
276 | c = torch.cos(angle / 2.0)
277 | s = torch.sin(angle / 2.0)
278 |
279 | w = c
280 | x = s * axis[..., 0]
281 | y = s * axis[..., 1]
282 | z = s * axis[..., 2]
283 |
284 | return torch.stack((w, x, y, z), dim=-1)
285 |
286 |
287 | def qexp(q, eps=1e-8):
288 | """
289 | Computes the quaternion exponent.
290 |
291 | Reference:
292 | https://en.wikipedia.org/wiki/Quaternion#Exponential,_logarithm,_and_power_functions
293 | Args:
294 | q (tensor): the quaternion to compute the exponent of
295 | Returns:
296 | (tensor): Tensor of shape (*, 4) representing exp(q)
297 | """
298 |
299 | if q.shape[1] == 4:
300 | # Let q = (s; v).
301 | s, v = torch.split(q, (1, 3), dim=-1)
302 | else:
303 | s = torch.zeros_like(q[:, :1])
304 | v = q
305 |
306 | theta = torch.norm(v, dim=-1, keepdim=True)
307 | exp_s = torch.exp(s)
308 | w = torch.cos(theta)
309 | xyz = 1.0 / theta.clamp(min=eps) * torch.sin(theta) * v
310 |
311 | return exp_s * torch.cat((w, xyz), dim=-1)
312 |
313 |
314 | def qlog(q, eps=1e-8):
315 | """
316 | Computes the quaternion logarithm.
317 |
318 | Reference:
319 | https://en.wikipedia.org/wiki/Quaternion#Exponential,_logarithm,_and_power_functions
320 | https://users.aalto.fi/~ssarkka/pub/quat.pdf
321 | Args:
322 | q (tensor): the quaternion to compute the logarithm of
323 | Returns:
324 | (tensor): Tensor of shape (*, 4) representing ln(q)
325 | """
326 |
327 | mag = torch.norm(q, dim=-1, keepdim=True)
328 | # Let q = (s; v).
329 | s, v = torch.split(q, (1, 3), dim=-1)
330 | w = torch.log(mag)
331 | xyz = (v / torch.norm(v, dim=-1, keepdim=True).clamp(min=eps)
332 | * core.acos_safe(s / mag.clamp(min=eps)))
333 |
334 | return torch.cat((w, xyz), dim=-1)
335 |
336 |
337 | def qdelta(n, std, device=None):
338 | omega = torch.cat((torch.zeros(n, 1, device=device),
339 | torch.randn(n, 3, device=device)), dim=-1)
340 | delta_q = qexp(std / 2.0 * omega)
341 | return delta_q
342 |
343 |
344 | def perturb(q, std):
345 | """
346 | Perturbs the unit quaternion `q`.
347 |
348 | References:
349 | https://math.stackexchange.com/questions/2992016/how-to-linearize-quaternions
350 | http://asrl.utias.utoronto.ca/~tdb/bib/barfoot_aa10_appendix.pdf
351 | https://math.stackexchange.com/questions/473736/small-angular-displacements-with-a-quaternion-representation
352 |
353 | Args:
354 | q (tensor): the quaternion to perturb (the mean of the perturbation)
355 | std (float): the stadnard deviation of the perturbation
356 |
357 | Returns:
358 | (tensor): Tensor of shape (*, 4), the perturbed quaternion
359 | """
360 | q, unsqueezed = core.ensure_batch_dim(q, num_dims=1)
361 |
362 | n = q.shape[0]
363 | delta_q = qdelta(n, std, device=q.device)
364 | q_out = qmul(delta_q, q)
365 |
366 | if unsqueezed:
367 | q_out = q_out.squeeze(0)
368 |
369 | return q_out
370 |
371 |
372 | def angular_distance(q1, q2, eps: float = 1e-7):
373 | q1 = normalize(q1)
374 | q2 = normalize(q2)
375 | dot = q1 @ q2.t()
376 | dist = 2 * core.acos_safe(dot.abs(), eps=eps)
377 | return dist
378 |
--------------------------------------------------------------------------------
/three/rigid.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import torch
4 | from torch.nn import functional as F
5 |
6 | # from latentfusion import three
7 | # from latentfusion.three import ensure_batch_dim
8 |
9 | import three
10 | from three import quaternion as q
11 | from three import ensure_batch_dim
12 |
13 | def intrinsic_to_3x4(matrix):
14 | matrix, unsqueezed = ensure_batch_dim(matrix, num_dims=2)
15 |
16 | zeros = torch.zeros(1, 3, 1,
17 | dtype=matrix.dtype,
18 | device=matrix.device
19 | ).expand(matrix.shape[0], -1, -1)
20 | mat = torch.cat((matrix, zeros), dim=-1)
21 |
22 | if unsqueezed:
23 | mat = mat.squeeze(0)
24 |
25 | return mat
26 |
27 |
28 | @torch.jit.script
29 | def matrix_3x3_to_4x4(matrix):
30 | matrix, unsqueezed = ensure_batch_dim(matrix, num_dims=2)
31 |
32 | mat = F.pad(matrix, [0, 1, 0, 1])
33 | mat[:, -1, -1] = 1.0
34 |
35 | if unsqueezed:
36 | mat = mat.squeeze(0)
37 |
38 | return mat
39 |
40 |
41 | @torch.jit.script
42 | def rotation_to_4x4(matrix):
43 | return matrix_3x3_to_4x4(matrix)
44 |
45 |
46 | @torch.jit.script
47 | def translation_to_4x4(translation):
48 | translation, unsqueezed = ensure_batch_dim(translation, num_dims=1)
49 |
50 | eye = torch.eye(4, device=translation.device)
51 | mat = F.pad(translation.unsqueeze(2), [3, 0, 0, 1]) + eye
52 |
53 | if unsqueezed:
54 | mat = mat.squeeze(0)
55 |
56 | return mat
57 |
58 |
59 | def translate_matrix(matrix, offset):
60 | matrix, unsqueezed = ensure_batch_dim(matrix, num_dims=2)
61 |
62 | out = inverse_transform(matrix)
63 | out[:, :3, 3] += offset
64 | out = inverse_transform(out)
65 |
66 | if unsqueezed:
67 | out = out.squeeze(0)
68 |
69 | return out
70 |
71 |
72 | def scale_matrix(matrix, scale):
73 | matrix, unsqueezed = ensure_batch_dim(matrix, num_dims=2)
74 |
75 | out = inverse_transform(matrix)
76 | out[:, :3, 3] *= scale
77 | out = inverse_transform(out)
78 |
79 | if unsqueezed:
80 | out = out.squeeze(0)
81 |
82 | return out
83 |
84 |
85 | def decompose(matrix):
86 | matrix, unsqueezed = ensure_batch_dim(matrix, num_dims=2)
87 |
88 | # Extract rotation matrix.
89 | origin = (torch.tensor([0.0, 0.0, 0.0, 1.0], device=matrix.device, dtype=matrix.dtype)
90 | .unsqueeze(1)
91 | .unsqueeze(0))
92 | origin = origin.expand(matrix.size(0), -1, -1)
93 | R = torch.cat((matrix[:, :, :3], origin), dim=-1)
94 |
95 | # Extract translation matrix.
96 | eye = torch.eye(4, 3, device=matrix.device).unsqueeze(0).expand(matrix.size(0), -1, -1)
97 | T = torch.cat((eye, matrix[:, :, 3].unsqueeze(-1)), dim=-1)
98 |
99 | if unsqueezed:
100 | R = R.squeeze(0)
101 | T = T.squeeze(0)
102 |
103 | return R, T
104 |
105 |
106 | def inverse_transform(matrix):
107 | matrix, unsqueezed = ensure_batch_dim(matrix, num_dims=2)
108 |
109 | R, T = decompose(matrix)
110 | R_inv = R.transpose(1, 2)
111 | t = T[:, :4, 3].unsqueeze(2)
112 | t_inv = (R_inv @ t)[:, :3].squeeze(2)
113 |
114 | out = torch.zeros_like(matrix)
115 | out[:, :3, :3] = R_inv[:, :3, :3]
116 | out[:, :3, 3] = -t_inv
117 | out[:, 3, 3] = 1
118 |
119 | if unsqueezed:
120 | out = out.squeeze(0)
121 |
122 | return out
123 |
124 |
125 | def extrinsic_to_position(extrinsic):
126 | extrinsic, unsqueezed = ensure_batch_dim(extrinsic, num_dims=2)
127 |
128 | rot_mat, trans_mat = decompose(extrinsic)
129 | position = rot_mat.transpose(2, 1) @ trans_mat[:, :, 3, None]
130 | position = three.dehomogenize(position.squeeze(-1))
131 |
132 | if unsqueezed:
133 | position = position.squeeze(0)
134 | return position
135 |
136 |
137 | @torch.jit.script
138 | def random_translation(n: int,
139 | x_bound: Tuple[float, float],
140 | y_bound: Tuple[float, float],
141 | z_bound: Tuple[float, float]):
142 | trans_x = three.uniform(n, *x_bound)
143 | trans_y = three.uniform(n, *y_bound)
144 | trans_z = three.uniform(n, *z_bound)
145 | translation = torch.stack((trans_x, trans_y, trans_z), dim=-1)
146 | return translation
147 |
148 |
149 | @torch.jit.script
150 | def to_extrinsic_matrix(translation, quaternion):
151 | rot_mat = three.quaternion.quat_to_mat(quaternion)
152 | rot_mat = rotation_to_4x4(rot_mat)
153 | trans_mat = translation_to_4x4(translation)
154 | extrinsic = trans_mat @ rot_mat
155 | return extrinsic
156 |
157 |
158 | def extrinsic_to_quat(extrinsic):
159 | rot_mat, _ = decompose(extrinsic)
160 | rot_mat = rot_mat[..., :3, :3]
161 | return three.quaternion.mat_to_quat(rot_mat)
162 |
--------------------------------------------------------------------------------
/three/stats.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def mad(tensor, dim=0):
5 | median, _ = tensor.median(dim=dim)
6 | return torch.median(torch.abs(tensor - median), dim=dim)[0]
7 |
8 |
9 | def mask_outliers_mad(data, m=2.0):
10 | median = data.median()
11 | mad = torch.median(torch.abs(data - median))
12 | mask = torch.abs(data - median) / mad < m
13 | return mask
14 |
15 |
16 | def reject_outliers_mad(data, m=2.0):
17 | return data[mask_outliers_mad(data, m)]
18 |
19 |
20 | def mask_outliers(data, m=2.0):
21 | mean = data.mean()
22 | std = torch.std(data)
23 | mask = torch.abs(data - mean) / std < m
24 | return mask
25 |
26 |
27 | def reject_outliers(data, m=2.0):
28 | return data[mask_outliers(data, m)]
29 |
30 |
31 | def robust_mean(data, m=2.0):
32 | return torch.mean(reject_outliers(data, m))
33 |
34 |
35 | def robust_mean_mad(data, m=2.0):
36 | return torch.mean(reject_outliers_mad(data, m))
37 |
--------------------------------------------------------------------------------
/three/torchutils.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import abc
4 | import copy
5 | import random
6 | from contextlib import contextmanager
7 | from functools import partial
8 | from itertools import chain
9 |
10 | import structlog
11 | import torch
12 | import torch.autograd
13 | from torch import nn
14 | from torch.nn.parallel.scatter_gather import scatter_kwargs
15 | from torch.utils.data import Sampler, DataLoader
16 |
17 | logger = structlog.get_logger(__name__)
18 |
19 |
20 | def dict_to(d, device):
21 | return {k: v.to(device) for k, v in d.items()}
22 |
23 |
24 | def deparameterize_module(module):
25 | module = module.clone()
26 | for name, parameter in module.named_parameters():
27 | parameter = parameter.detach()
28 | setattr(module, name, parameter)
29 | return module
30 |
31 |
32 | @contextmanager
33 | def manual_seed(seed):
34 | torch_state = torch.get_rng_state()
35 | torch.manual_seed(seed)
36 | yield
37 | torch.set_rng_state(torch_state)
38 |
39 |
40 | def module_device(module):
41 | return next(module.parameters()).device
42 |
43 |
44 | def save_checkpoint(save_dir, name, state):
45 | if not save_dir.exists():
46 | logger.info("creating directory", path=save_dir)
47 | save_dir.mkdir(parents=True)
48 |
49 | path = save_dir / f'{name}.pth'
50 | logger.info("saving checkpoint", name=name, path=path)
51 |
52 | with path.open('wb') as f:
53 | torch.save(state, f)
54 |
55 |
56 | def save_if_better(save_dir, state, meters, key, bigger_is_better=False):
57 | best_key = f'best-{key}'.replace('/', '-')
58 | if best_key not in state:
59 | state[best_key] = -1 if bigger_is_better else float('inf')
60 |
61 | if bigger_is_better:
62 | better = meters[key].mean >= state[best_key]
63 | else:
64 | better = meters[key].mean <= state[best_key]
65 |
66 | if better:
67 | state[best_key] = meters[key].mean
68 | save_checkpoint(save_dir, best_key, state)
69 |
70 |
71 | class DeterministicShuffledSampler(Sampler):
72 | """Shuffles the dataset once and then samples deterministically."""
73 |
74 | def __init__(self, data_source, replacement=False, num_samples=None):
75 | super().__init__(data_source)
76 |
77 | self.data_source = data_source
78 | self.replacement = replacement
79 | self.num_samples = num_samples
80 |
81 | if self.num_samples is not None and replacement is False:
82 | raise ValueError(
83 | "With replacement=False, num_samples should not be specified, "
84 | "since a random permute will be performed.")
85 |
86 | if self.num_samples is None:
87 | self.num_samples = len(self.data_source)
88 |
89 | if not isinstance(self.num_samples, int) or self.num_samples <= 0:
90 | raise ValueError("num_samples should be a positive integeral "
91 | "value, but got num_samples={}".format(
92 | self.num_samples))
93 | if not isinstance(self.replacement, bool):
94 | raise ValueError("replacement should be a boolean value, but got "
95 | "replacement={}".format(self.replacement))
96 |
97 | n = len(self.data_source)
98 | if self.replacement:
99 | self.permutation = torch.randint(
100 | high=n, size=(self.num_samples,), dtype=torch.int64).tolist()
101 | else:
102 | self.permutation = torch.randperm(n).tolist()
103 |
104 | def __iter__(self):
105 | return iter(self.permutation)
106 |
107 | def __len__(self):
108 | return len(self.data_source)
109 |
110 |
111 | class Scatterable(abc.ABC):
112 | """
113 | A mixin to make an object scatterable across GPUs for data-parallelism.
114 |
115 | The object should be serialized to a dictionary in the `to_kwargs` method. Each entry should be
116 | a scatterable tensor object.
117 |
118 | The `from_kwargs` method should be able to take the dictionary format and reconstruct the original
119 | class object.
120 | """
121 |
122 | @classmethod
123 | @abc.abstractmethod
124 | def to_kwargs(self):
125 | pass
126 |
127 | @classmethod
128 | @abc.abstractmethod
129 | def from_kwargs(cls, kwargs):
130 | pass
131 |
132 |
133 | class MyDataParallel(nn.DataParallel):
134 | """
135 | A Scatterable-aware data parallel class.
136 | """
137 |
138 | def scatter(self, inputs, kwargs, device_ids):
139 | _inputs = []
140 | _kwargs = {}
141 | input_constructors = {}
142 | kwargs_constructors = {}
143 | for i, item in enumerate(inputs):
144 | if isinstance(item, Scatterable):
145 | input_constructors[i] = item.from_kwargs
146 | _inputs.append(item.to_kwargs())
147 | else:
148 | input_constructors[i] = lambda x: x
149 | _inputs.append(item)
150 |
151 | for key, item in kwargs.items():
152 | if isinstance(item, Scatterable):
153 | kwargs_constructors[key] = item.from_kwargs
154 | _kwargs[key] = item.to_kwargs()
155 | else:
156 | kwargs_constructors[key] = lambda x: x
157 | _kwargs[key] = item
158 |
159 | _inputs, _kwargs = scatter_kwargs(_inputs, _kwargs, device_ids, dim=self.dim)
160 |
161 | _inputs = [
162 | [input_constructors[i](item) for i, item in enumerate(_input)]
163 | for _input in _inputs
164 | ]
165 | _kwargs = [
166 | {k: kwargs_constructors[k](item) for k, item in _kwarg.items()}
167 | for _kwarg in _kwargs
168 | ]
169 |
170 | return _inputs, _kwargs
171 |
172 |
173 | class ListSampler(Sampler):
174 | r"""Samples given elements sequentially, always in the same order.
175 | Arguments:
176 | data_source (Dataset): dataset to sample from
177 | indices (iterable): list of indices to sample
178 | """
179 |
180 | def __init__(self, data_source, indices):
181 | super().__init__(data_source)
182 | if indices is None:
183 | self.indices = list(range(len(data_source)))
184 | else:
185 | self.indices = indices
186 |
187 | def __iter__(self):
188 | return iter(self.indices)
189 |
190 | def __len__(self):
191 | return len(self.indices)
192 |
193 |
194 | class ShuffledSubsetSampler(Sampler):
195 | r"""Samples given elements sequentially, always in the same order.
196 | Arguments:
197 | data_source (Dataset): dataset to sample from
198 | indices (iterable): list of indices to sample
199 | """
200 |
201 | def __init__(self, data_source, indices):
202 | super().__init__(data_source)
203 | self.indices = copy.deepcopy(indices)
204 | random.shuffle(self.indices)
205 |
206 | def __iter__(self):
207 | return iter(self.indices)
208 |
209 | def __len__(self):
210 | return len(self.indices)
211 |
212 |
213 | class IndexedDataLoader(DataLoader):
214 |
215 | def __init__(self, dataset, batch_size, num_workers=4, indices=None, drop_last=False,
216 | pin_memory=False, shuffle=False):
217 | if shuffle:
218 | sampler = ShuffledSubsetSampler(dataset, indices)
219 | else:
220 | sampler = ListSampler(dataset, indices)
221 |
222 | super().__init__(
223 | dataset,
224 | batch_size=batch_size,
225 | num_workers=num_workers,
226 | sampler=sampler,
227 | drop_last=drop_last,
228 | pin_memory=pin_memory)
229 |
230 |
231 | class SequentialDataLoader(IndexedDataLoader):
232 |
233 | def __init__(self, *args, **kwargs):
234 | super().__init__(*args, **kwargs, shuffle=False)
235 |
236 |
237 | class _InfiniteSampler(object):
238 | """ Sampler that repeats forever.
239 |
240 | Hack to force dataloader to use same workers.
241 |
242 | Args:
243 | sampler (Sampler)
244 | """
245 |
246 | def __init__(self, sampler):
247 | self.sampler = sampler
248 |
249 | def __len__(self):
250 | return len(self.sampler)
251 |
252 | def __iter__(self):
253 | while True:
254 | yield from iter(self.sampler)
255 |
256 |
257 | class WorkerPreservingDataLoader(DataLoader):
258 | """
259 | Hack to force dataloader to use same workers.
260 | """
261 |
262 | def __init__(self, *args, **kwargs):
263 | super().__init__(*args, **kwargs)
264 | self.batch_sampler = _InfiniteSampler(self.batch_sampler)
265 | self.iterator = super().__iter__()
266 |
267 | def __iter__(self):
268 | for i in range(len(self)):
269 | yield next(self.iterator)
270 |
271 |
272 | @contextmanager
273 | def profile():
274 | with torch.autograd.profiler.profile(use_cuda=True) as prof:
275 | yield
276 | print(prof.key_averages().table(sort_by="self_cpu_time_total"))
277 |
278 |
279 | @contextmanager
280 | def measure_time(s: str):
281 | torch.cuda.synchronize()
282 | start = time.time()
283 | yield
284 | torch.cuda.synchronize()
285 | print(f"[{s}] elapsed: {time.time() - start:.02f}")
286 |
287 |
--------------------------------------------------------------------------------
/three/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def farthest_points(data, n_clusters: int, dist_func, return_center_indexes=False,
5 | return_distances=False, verbose=False):
6 | """
7 | Performs farthest point sampling on data points.
8 |
9 | Args:
10 | data (torch.tensor): data points.
11 | n_clusters (int): number of clusters.
12 | dist_func (Callable): distance function that is used to compare two data points.
13 | return_center_indexes (bool): if True, returns the indexes of the center of clusters.
14 | return_distances (bool): if True, return distances of each point from centers.
15 |
16 | Returns clusters, [centers, distances]:
17 | clusters (torch.tensor): the cluster index for each element in data.
18 | centers (torch.tensor): the integer index of each center.
19 | distances (torch.tensor): closest distances of each point to any of the cluster centers.
20 | """
21 | if n_clusters >= data.shape[0]:
22 | if return_center_indexes:
23 | return (torch.arange(data.shape[0], dtype=torch.long),
24 | torch.arange(data.shape[0], dtype=torch.long))
25 |
26 | return torch.arange(data.shape[0], dtype=torch.long)
27 |
28 | clusters = torch.full((data.shape[0],), fill_value=-1, dtype=torch.long)
29 | distances = torch.full((data.shape[0],), fill_value=1e7, dtype=torch.float32)
30 | centers = torch.zeros(n_clusters, dtype=torch.long)
31 | for i in range(n_clusters):
32 | center_idx = torch.argmax(distances)
33 | centers[i] = center_idx
34 |
35 | broadcasted_data = data[center_idx].unsqueeze(0).expand(data.shape[0], -1)
36 | new_distances = dist_func(broadcasted_data, data)
37 | distances = torch.min(distances, new_distances)
38 | clusters[distances == new_distances] = i
39 | if verbose:
40 | print('farthest points max distance : {}'.format(torch.max(distances)))
41 |
42 | if return_center_indexes:
43 | if return_distances:
44 | return clusters, centers, distances
45 | return clusters, centers
46 |
47 | return clusters
48 |
--------------------------------------------------------------------------------
/training/training.py:
--------------------------------------------------------------------------------
1 | import os
2 | gpu_id = 0
3 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152
4 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
5 |
6 | import sys
7 | import time
8 | import glob
9 | import torch
10 | import shutil
11 | import numpy as np
12 | from torch import optim
13 | import matplotlib.pyplot as plt
14 |
15 | from torch.cuda.amp import autocast, GradScaler
16 | from torch.utils.tensorboard import SummaryWriter
17 |
18 | PROJ_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # ./../../
19 | sys.path.append(PROJ_ROOT)
20 |
21 | from dataset import misc
22 | from misc_utils import warmup_lr
23 | from model.network import model_arch as ModelNet
24 | from dataset.megapose_dataset import MegaPose_Dataset as Dataset
25 | device = torch.device('cuda:0')
26 |
27 | def batchify_cuda_device(data_dict, batch_size, flatten_multiview=True, use_cuda=True):
28 | for key, val in data_dict.items():
29 | for sub_key, sub_val in val.items():
30 | if use_cuda:
31 | try:
32 | data_dict[key][sub_key] = sub_val.cuda(non_blocking=True)
33 | except:
34 | pass
35 | if flatten_multiview:
36 | try:
37 | if data_dict[key][sub_key].shape[0] == batch_size:
38 | data_dict[key][sub_key] = data_dict[key][sub_key].flatten(0, 1)
39 | except:
40 | pass
41 |
42 | img_size = 224
43 | batch_size = 2
44 | que_view_num = 4
45 | refer_view_num = 8
46 | random_view_num = 24 # 8 + 24 = 32
47 | nnb_Rmat_threshold = 30
48 | num_train_iters = 100_000
49 |
50 | DATA_DIR = os.path.join(PROJ_ROOT, 'dataspace', 'MegaPose')
51 |
52 | dataset = Dataset(data_dir=DATA_DIR,
53 | query_view_num=que_view_num,
54 | refer_view_num=refer_view_num,
55 | rand_view_num=random_view_num,
56 | nnb_Rmat_threshold=nnb_Rmat_threshold,
57 | )
58 |
59 | print('num_objects: ', len(dataset.selected_objIDs))
60 |
61 | model_net = ModelNet().to(device)
62 | CKPT_ROOT = os.path.join(PROJ_ROOT, 'checkpoints')
63 | checkpoints = os.path.join(CKPT_ROOT, 'checkpoints')
64 | tb_dir = os.path.join(checkpoints, 'tb')
65 | tb_old = tb_dir.replace('tb', 'tb_old')
66 | if os.path.exists(tb_old):
67 | shutil.rmtree(tb_old)
68 | if not os.path.exists(tb_dir):
69 | os.makedirs(tb_dir)
70 | shutil.move(tb_dir, tb_old)
71 | tb_writer = SummaryWriter(tb_dir)
72 |
73 | data_loader = torch.utils.data.DataLoader(dataset,
74 | shuffle=True,
75 | num_workers=8,
76 | batch_size=batch_size,
77 | collate_fn=dataset.collate_fn,
78 | pin_memory=False, drop_last=False)
79 |
80 | END_LR = 1e-6
81 | START_LR = 1e-4
82 | max_steps = num_train_iters
83 |
84 | iter_steps = 0
85 | TB_SKIP_STEPS = 5
86 | short_log_interval = 100
87 | long_log_interval = 1_000
88 | checkpoint_interval = 10_000
89 | enable_FP16_training = True
90 | LOSS_WEIGHTS = {
91 | 'rm_loss': 1.0,
92 | 'cm_loss': 10.0,
93 | 'qm_loss': 10.0,
94 | 'Remb_loss': 1.0,
95 | }
96 |
97 | optimizer = optim.AdamW(model_net.parameters(), lr=START_LR)
98 | lr_scheduler = warmup_lr.CosineAnnealingWarmupRestarts(optimizer, max_steps, max_lr=START_LR, min_lr=END_LR)
99 |
100 | losses_dict = {}
101 | model_net.train()
102 | scaler = GradScaler()
103 | start_timer = time.time()
104 | data_iterator = iter(data_loader)
105 |
106 | print('total training max_steps: {}'.format(max_steps))
107 | print('enable_FP16_training: ', enable_FP16_training)
108 | for iter_steps in range(1, max_steps+1):
109 | lr_scheduler.step()
110 | optimizer.zero_grad()
111 | try:
112 | batch_data = next(data_iterator)
113 | except:
114 | data_iterator = iter(data_loader) # reinitialize the iterator
115 | batch_data = next(data_iterator)
116 |
117 | batchify_cuda_device(batch_data, batch_size=batch_size, flatten_multiview=True, use_cuda=True)
118 | scaler_curr_scale = 1.0
119 | loss = 0
120 | with autocast(enable_FP16_training):
121 | net_outputs = model_net(batch_data)
122 | for ls_name, ls_wgh in LOSS_WEIGHTS.items():
123 | loss += net_outputs.get(ls_name, 0) * LOSS_WEIGHTS[ls_name]
124 | assert (not torch.isnan(loss).any())
125 |
126 | scaler.scale(loss).backward()
127 |
128 | with torch.no_grad():
129 | scaler.unscale_(optimizer)
130 | grad_norm = torch.nn.utils.clip_grad_norm_(model_net.parameters(), 1.0)
131 | scaler.step(optimizer)
132 | scaler.update()
133 | scaler_curr_scale = scaler.state_dict()['scale']
134 |
135 | if 'ls' not in losses_dict:
136 | losses_dict['ls'] = list()
137 | losses_dict['ls'].append(loss.item())
138 |
139 | for key_, val_ in net_outputs.items():
140 | if 'loss' in key_:
141 | if key_ not in losses_dict:
142 | losses_dict[key_] = list()
143 | ls = val_.item()
144 | if key_ in LOSS_WEIGHTS:
145 | ls *= LOSS_WEIGHTS[key_]
146 | losses_dict[key_].append(ls)
147 |
148 | if (iter_steps > TB_SKIP_STEPS) and (iter_steps % short_log_interval == 0):
149 | tb_writer.add_scalar("Other/lr", optimizer.param_groups[0]['lr'], iter_steps)
150 |
151 | for idx, (key_, val_) in enumerate(losses_dict.items()):
152 | tb_writer.add_scalar(f"Loss/{idx}_{key_}", val_[-1], iter_steps)
153 |
154 | if ((iter_steps > 5 and iter_steps < 2000 and iter_steps % short_log_interval == 0)
155 | or iter_steps % long_log_interval == 0):
156 |
157 | curr_lr = optimizer.param_groups[0]['lr']
158 | time_stamp = time.strftime('%d-%H:%M:%S', time.localtime())
159 | logging_str = "{:.1f}k".format(iter_steps/1000)
160 |
161 | for key_, val_ in losses_dict.items():
162 | dis_str = key_.split('_')[0]
163 | logging_str += ', {}:{:.4f}'.format(dis_str, np.mean(val_[-2000:]))
164 |
165 | logging_str += ', {}'.format(time_stamp)
166 | logging_str += ', {:.1f}'.format(scaler_curr_scale)
167 | logging_str += ', {:.6f}'.format(curr_lr)
168 |
169 | print(logging_str)
170 |
171 | vis_num_views = np.minimum(8, refer_view_num)
172 | fig, ax = plt.subplots(4, vis_num_views+1, figsize=(12, 5),
173 | gridspec_kw={'width_ratios': [1.5] + [1 for _ in range(vis_num_views)]}
174 | )
175 |
176 | rgb_que_image = batch_data['query_dict']['rescaled_image'][0].detach().cpu().permute(1, 2, 0).squeeze().float()
177 | gt_que_full_mask = batch_data['query_dict']['rescaled_mask'][0].detach().cpu().permute(1, 2, 0).squeeze().float()
178 | pd_que_full_mask = net_outputs['que_full_pd_mask'][0].detach().cpu().permute(1, 2, 0).squeeze().float()
179 | rgb_path = batch_data['query_dict']['rgb_path'][0].split('train_pbr/')[-1]
180 | ax[0, 0].imshow(rgb_que_image)
181 | ax[0, 0].set_title(rgb_path, fontsize=10)
182 | ax[1, 0].imshow(gt_que_full_mask)
183 | ax[2, 0].imshow(pd_que_full_mask)
184 | ax[3, 0].imshow((gt_que_full_mask - pd_que_full_mask))
185 | ax[0, 0].axis(False)
186 | ax[1, 0].axis(False)
187 | ax[2, 0].axis(False)
188 | ax[3, 0].axis(False)
189 |
190 | rgb_que_image = batch_data['query_dict']['dzi_image'][0].detach().cpu().permute(1, 2, 0).squeeze().float()
191 | gt_que_que_mask = batch_data['query_dict']['dzi_mask'][0].detach().cpu().permute(1, 2, 0).squeeze().float()
192 | pd_que_que_mask = net_outputs['que_pd_mask'][0].detach().cpu().permute(1, 2, 0).squeeze().float()
193 | ax[0, 1].imshow(rgb_que_image)
194 | ax[1, 1].imshow(gt_que_que_mask)
195 | ax[2, 1].imshow(pd_que_que_mask)
196 | ax[3, 1].imshow((gt_que_que_mask - pd_que_que_mask))
197 | ax[0, 1].axis(False)
198 | ax[1, 1].axis(False)
199 | ax[2, 1].axis(False)
200 | ax[3, 1].axis(False)
201 |
202 | for vix in range(vis_num_views-1):
203 | vjx = vix + 2
204 | rgb_ref_image = batch_data['refer_dict']['zoom_image'][vix].detach().cpu().permute(1, 2, 0).squeeze().float()
205 | gt_ref_mask = batch_data['refer_dict']['zoom_mask'][vix].detach().cpu().permute(1, 2, 0).squeeze().float()
206 | pd_ref_mask = net_outputs['ref_pd_mask'][vix].detach().cpu().permute(1, 2, 0).squeeze().float()
207 | ax[0, vjx].imshow(rgb_ref_image)
208 | ax[1, vjx].imshow(gt_ref_mask)
209 | ax[2, vjx].imshow(pd_ref_mask)
210 | ax[3, vjx].imshow((gt_ref_mask - pd_ref_mask))
211 | ax[0, vjx].axis(False)
212 | ax[1, vjx].axis(False)
213 | ax[2, vjx].axis(False)
214 | ax[3, vjx].axis(False)
215 | plt.tight_layout()
216 | tb_writer.add_figure('visulize_refer', fig, iter_steps)
217 | fig.clear()
218 |
219 | Remb_logit = net_outputs['Remb_logit'][0].detach().cpu()
220 | delta_Rdeg = net_outputs['delta_Rdeg'][0].detach().cpu().float()
221 | delta_Rdeg = torch.acos(torch.clamp(delta_Rdeg, min=-1.0, max=1.0)) / torch.pi * 180
222 | rank_Rdegs, rank_Rinds = torch.topk(delta_Rdeg, dim=0, k=delta_Rdeg.shape[0], largest=False)
223 |
224 | fig, ax = plt.subplots(1, 1)
225 | ax.plot(rank_Rdegs, Remb_logit[rank_Rinds])
226 | ax.grid()
227 | tb_writer.add_figure('Rotation probability distribution', fig, iter_steps)
228 | fig.clear()
229 |
230 | if iter_steps % checkpoint_interval == 0:
231 | if not os.path.exists(checkpoints):
232 | os.makedirs(checkpoints)
233 | time_stamp = time.strftime('%m%d_%H%M%S', time.localtime())
234 |
235 | ckpt_name = 'model_{}_{}.pth'.format(iter_steps, time_stamp)
236 | ckpt_file = os.path.join(checkpoints, ckpt_name)
237 | try:
238 | torch.save(model_net.module.state_dict(), ckpt_file)
239 | except:
240 | torch.save(model_net.state_dict(), ckpt_file)
241 |
242 | # try:
243 | # state = {
244 | # 'model_net': model_net.module.state_dict(),
245 | # 'optimizer': optimizer.state_dict(),
246 | # 'lr_scheduler': lr_scheduler.state_dict(),
247 | # 'scaler': scaler.state_dict(),
248 | # 'iter_steps': iter_steps,
249 | # }
250 | # torch.save(state, ckpt_file)
251 | # except:
252 | # state = {
253 | # 'model_net': model_net.state_dict(),
254 | # 'optimizer': optimizer.state_dict(),
255 | # 'lr_scheduler': lr_scheduler.state_dict(),
256 | # 'scaler': scaler.state_dict(),
257 | # 'iter_steps': iter_steps,
258 | # }
259 | # torch.save(state, ckpt_file)
260 |
261 | print('saving to ', ckpt_file)
262 |
263 |
264 |
--------------------------------------------------------------------------------